diff --git a/.assets/libra.png b/.assets/libra.png new file mode 100644 index 0000000000000..a159ebf19494f Binary files /dev/null and b/.assets/libra.png differ diff --git a/.circleci/config.yml b/.circleci/config.yml new file mode 100644 index 0000000000000..733a0f0d3959d --- /dev/null +++ b/.circleci/config.yml @@ -0,0 +1,39 @@ +version: 2 +jobs: + build: + docker: + - image: circleci/rust:stretch + resource_class: xlarge + steps: + - checkout + - run: + name: Version Information + command: rustc --version; cargo --version; rustup --version + - run: + name: Install Dependencies + command: | + sudo sh -c 'echo "deb http://deb.debian.org/debian stretch-backports main" > /etc/apt/sources.list.d/backports.list' + sudo apt-get update + sudo apt-get install -y protobuf-compiler/stretch-backports cmake golang curl + sudo apt-get clean + sudo rm -r /var/lib/apt/lists/* + rustup component add clippy rustfmt + - run: + name: Setup Env + command: | + echo 'export TAG=0.1.${CIRCLE_BUILD_NUM}' >> $BASH_ENV + echo 'export IMAGE_NAME=myapp' >> $BASH_ENV + - run: + name: Linting + command: | + ./scripts/clippy.sh + cargo fmt -- --check + - run: + name: Build All Targets + command: RUST_BACKTRACE=1 cargo build -j 16 --all --all-targets + - run: + name: Run All Unit Tests + command: cargo test --all --exclude testsuite + - run: + name: Run All End to End Tests + command: RUST_TEST_THREADS=2 cargo test --package testsuite diff --git a/.dockerignore b/.dockerignore new file mode 100644 index 0000000000000..a39d1062424a8 --- /dev/null +++ b/.dockerignore @@ -0,0 +1,3 @@ +.git/ +**/.terraform/ +target/ diff --git a/.gitattributes b/.gitattributes new file mode 100644 index 0000000000000..8bc545050f9a6 --- /dev/null +++ b/.gitattributes @@ -0,0 +1,12 @@ +# Ensure that text files that any contributor introduces to the repository +# have their line endings normalized to LF +* text=auto + +# All known text filetypes +*.md text +*.proto text +*.rs text +*.sh text eol=lf +*.toml text +*.txt text +*.yml text diff --git a/.github/ISSUE_TEMPLATE/bug_report.md b/.github/ISSUE_TEMPLATE/bug_report.md new file mode 100644 index 0000000000000..9cbf62661250f --- /dev/null +++ b/.github/ISSUE_TEMPLATE/bug_report.md @@ -0,0 +1,43 @@ +--- +name: "\U0001F41B Bug report" +about: Create a bug report to help improve Libra Core +title: "[Bug]" +labels: bug +assignees: '' + +--- + +# πŸ› Bug + + + +## To reproduce + +** Code snippet to reproduce ** +```rust +# Your code goes here +# Please make sure it does not require any external dependencies +``` + +** Stack trace/error message ** +``` +// Paste the output here +``` + +## Expected Behavior + + + +## System information + +**Please complete the following information:** +- +- +- + + +## Additional context + +Add any other context about the problem here. diff --git a/.github/ISSUE_TEMPLATE/feature_request.md b/.github/ISSUE_TEMPLATE/feature_request.md new file mode 100644 index 0000000000000..ae9bb986680f9 --- /dev/null +++ b/.github/ISSUE_TEMPLATE/feature_request.md @@ -0,0 +1,32 @@ +--- +name: "\U0001F680 Feature request" +about: Suggest a new feature in Libra Core +title: "[Feature Request]" +labels: enhancement +assignees: '' + +--- + +# πŸš€ Feature Request + + + +## Motivation + +**Is your feature request related to a problem? Please describe.** + + + +## Pitch + +**Describe the solution you'd like** + + +**Describe alternatives you've considered** + + +**Are you willing to open a pull request?** (See [CONTRIBUTING](../CONTRIBUTING.md)) + +## Additional context + + diff --git a/.github/ISSUE_TEMPLATE/questions.md b/.github/ISSUE_TEMPLATE/questions.md new file mode 100644 index 0000000000000..efb1f94bf73f5 --- /dev/null +++ b/.github/ISSUE_TEMPLATE/questions.md @@ -0,0 +1,10 @@ +--- +name: ❓ Questions/Help +about: If you have questions, please check Discourse +--- + +## ❓ Questions and Help + +### Please note that this issue tracker is not a help form and this issue will be closed. + +Please contact the development team on [Discourse](https://community.libra.org) diff --git a/.github/PULL_REQUEST_TEMPLATE.md b/.github/PULL_REQUEST_TEMPLATE.md new file mode 100644 index 0000000000000..b4824b51dd64d --- /dev/null +++ b/.github/PULL_REQUEST_TEMPLATE.md @@ -0,0 +1,21 @@ + + +## Motivation + +(Write your motivation for proposed changes here.) + +### Have you read the [Contributing Guidelines on pull requests](https://github.com/libra/libra/master/CONTRIBUTING.md#pull-requests)? + +(Write your answer here.) + +## Test Plan + +(Share your test plan here. If you changed code, please provide us with clear instrutions for verifying that your changes work.) + +## Related PRs + +(If this PR adds or changes functionality, please take some time to update the docs at https://github.com/libra/website, and link to your PR here.) diff --git a/.gitignore b/.gitignore new file mode 100644 index 0000000000000..c22208bb5b3e0 --- /dev/null +++ b/.gitignore @@ -0,0 +1,33 @@ +# Rust specific ignores +/target +**/*.rs.bk +# Cargo.lock is needed for deterministic testing and repeatable builds. +# +# Having it in the repo slows down development cycle. +# +# More information here https://doc.rust-lang.org/cargo/guide/cargo-toml-vs-cargo-lock.html +Cargo.lock + +# Ignore generated files in proto folders +**/proto/*.rs +!**/proto/mod.rs +!**/proto/converter.rs +**/proto/*/*.rs +!**/proto/*/mod.rs +!**/proto/*/converter.rs + +# IDE +.idea +.idea/* +*.iml +.vscode + +# Ignore wallet mnemonic files used for deterministic key derivation +*.mnemonic + +# Generated Parser File by LALRPOP +language/compiler/src/parser/syntax.rs +language/move_ir/ + +# GDB related +**/.gdb_history diff --git a/CODE_OF_CONDUCT.md b/CODE_OF_CONDUCT.md new file mode 100644 index 0000000000000..659acef601caa --- /dev/null +++ b/CODE_OF_CONDUCT.md @@ -0,0 +1,3 @@ +# Code of Conduct + +The Libra Core project has adopted a Code of Conduct that we expect project participants to adhere to. Please [read the full text](https://developers.libra.org/docs/policies/code-of-conduct) so that you can and understand what actions will and will not be tolerated. \ No newline at end of file diff --git a/CONTRIBUTING.md b/CONTRIBUTING.md new file mode 100644 index 0000000000000..fce178f4668e3 --- /dev/null +++ b/CONTRIBUTING.md @@ -0,0 +1,55 @@ +# Contributing to Libra + +Our goal is to make contributing to the Libra project easy and transparent. + +> **Note**: As the Libra Core project is currently an early-stage prototype, it is undergoing rapid development. While we welcome contributions, before making substantial contributions be sure to discuss them in the Discourse forum to ensure that they fit into the project roadmap. + +## On Contributing + + +### Libra Core + +To contribute to the Libra Core implementation, first start with the proper development copy. + +To get the development installation with all the necessary dependencies for linting, testing, and building the documentation, run the following: +```bash +git clone https://github.com/libra/libra.git +cd libra +./scripts/dev_setup.sh +cargo build +cargo test +``` + +## Our Development Process + +#### Code Style, Hints, and Testing + +Refer to our [Coding Guidelines](https://developers.libra.org/docs/coding-guidelines) for detailed guidance about how to contribute to the project. + +#### Documentation + +Libra's website is also open source (the +code can be found in this [repository](https://github.com/libra/website/)). +It is built using [Docusaurus](https://docusaurus.io/): + +If you know Markdown, you can already contribute! This lives in the the [website repo](https://github.com/libra/website). + +## Pull Requests +During the initial phase of heavy development, we plan to only audit and review pull requests. As the codebase stabilizes, we will be better able to accept pull requests from the community. + +1. Fork the repo and create your branch from `master`. +2. If you have added code that should be tested, add unit tests. +3. If you have changed APIs, update the documentation. Make sure the + documentation builds. +4. Ensure the test suite passes. +5. Make sure your code passes both linters. +6. If you haven't already, complete the Contributor License Agreement (CLA). +7. Submit your pull request. + +## Contributor License Agreement + +For pull request to be accepted by any Libra projects, a CLA must be signed. You will only need to do this once to work on any of Libra's open source projects. Individuals contributing on their own behalf can sign the [Individual CLA](https://github.com/libra/libra/blob/master/contributing/individual-cla.pdf). If you are contributing on behalf of your employer, please ask them to sign the [Corporate CLA](https://github.com/libra/libra/blob/master/contributing/corporate-cla.pdf). + +## Issues + +Libra uses [GitHub issues](https://github.com/libra/libra/issues) to track bugs. Please include necessary information and instructions to reproduce your issue. diff --git a/Cargo.toml b/Cargo.toml new file mode 100644 index 0000000000000..33133ba28c779 --- /dev/null +++ b/Cargo.toml @@ -0,0 +1,67 @@ +[workspace] + +members = [ + "admission_control/admission_control_service", + "admission_control/admission_control_proto", + "client", + "client/libra_wallet", + "common/canonical_serialization", + "common/crash_handler", + "common/debug_interface", + "common/executable_helpers", + "common/failure_ext", + "common/grpcio-client", + "common/jemalloc", + "common/logger", + "common/metrics", + "common/proptest_helpers", + "common/proto_conv", + "config", + "config/config_builder", + "config/generate_keypair", + "consensus", + "crypto/legacy_crypto", + "crypto/nextgen_crypto", + "crypto/secret_service", + "execution/execution_client", + "execution/execution_proto", + "execution/execution_service", + "execution/executor", + "language/bytecode_verifier", + "language/bytecode_verifier/invalid_mutations", + "language/functional_tests", + "language/compiler", + "language/stdlib/natives", + "language/vm", + "language/vm/vm_runtime", + "language/vm/cost_synthesis", + "language/vm/vm_runtime/vm_cache_map", + "language/vm/vm_genesis", + "language/vm/vm_runtime/vm_runtime_tests", + "libra_node", + "libra_swarm", + "network", + "network/memsocket", + "network/netcore", + "network/noise", + "mempool", + "storage/accumulator", + "storage/libradb", + "storage/schemadb", + "storage/scratchpad", + "storage/sparse_merkle", + "storage/storage_client", + "storage/state_view", + "storage/storage_proto", + "storage/storage_service", + "testsuite", + "testsuite/libra_fuzzer", + "types", + "vm_validator", +] + +[profile.release] +debug = true + +[profile.bench] +debug = true diff --git a/LICENSE b/LICENSE new file mode 100644 index 0000000000000..261eeb9e9f8b2 --- /dev/null +++ b/LICENSE @@ -0,0 +1,201 @@ + Apache License + Version 2.0, January 2004 + http://www.apache.org/licenses/ + + TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION + + 1. Definitions. + + "License" shall mean the terms and conditions for use, reproduction, + and distribution as defined by Sections 1 through 9 of this document. + + "Licensor" shall mean the copyright owner or entity authorized by + the copyright owner that is granting the License. + + "Legal Entity" shall mean the union of the acting entity and all + other entities that control, are controlled by, or are under common + control with that entity. For the purposes of this definition, + "control" means (i) the power, direct or indirect, to cause the + direction or management of such entity, whether by contract or + otherwise, or (ii) ownership of fifty percent (50%) or more of the + outstanding shares, or (iii) beneficial ownership of such entity. + + "You" (or "Your") shall mean an individual or Legal Entity + exercising permissions granted by this License. + + "Source" form shall mean the preferred form for making modifications, + including but not limited to software source code, documentation + source, and configuration files. + + "Object" form shall mean any form resulting from mechanical + transformation or translation of a Source form, including but + not limited to compiled object code, generated documentation, + and conversions to other media types. + + "Work" shall mean the work of authorship, whether in Source or + Object form, made available under the License, as indicated by a + copyright notice that is included in or attached to the work + (an example is provided in the Appendix below). + + "Derivative Works" shall mean any work, whether in Source or Object + form, that is based on (or derived from) the Work and for which the + editorial revisions, annotations, elaborations, or other modifications + represent, as a whole, an original work of authorship. For the purposes + of this License, Derivative Works shall not include works that remain + separable from, or merely link (or bind by name) to the interfaces of, + the Work and Derivative Works thereof. + + "Contribution" shall mean any work of authorship, including + the original version of the Work and any modifications or additions + to that Work or Derivative Works thereof, that is intentionally + submitted to Licensor for inclusion in the Work by the copyright owner + or by an individual or Legal Entity authorized to submit on behalf of + the copyright owner. For the purposes of this definition, "submitted" + means any form of electronic, verbal, or written communication sent + to the Licensor or its representatives, including but not limited to + communication on electronic mailing lists, source code control systems, + and issue tracking systems that are managed by, or on behalf of, the + Licensor for the purpose of discussing and improving the Work, but + excluding communication that is conspicuously marked or otherwise + designated in writing by the copyright owner as "Not a Contribution." + + "Contributor" shall mean Licensor and any individual or Legal Entity + on behalf of whom a Contribution has been received by Licensor and + subsequently incorporated within the Work. + + 2. Grant of Copyright License. Subject to the terms and conditions of + this License, each Contributor hereby grants to You a perpetual, + worldwide, non-exclusive, no-charge, royalty-free, irrevocable + copyright license to reproduce, prepare Derivative Works of, + publicly display, publicly perform, sublicense, and distribute the + Work and such Derivative Works in Source or Object form. + + 3. Grant of Patent License. Subject to the terms and conditions of + this License, each Contributor hereby grants to You a perpetual, + worldwide, non-exclusive, no-charge, royalty-free, irrevocable + (except as stated in this section) patent license to make, have made, + use, offer to sell, sell, import, and otherwise transfer the Work, + where such license applies only to those patent claims licensable + by such Contributor that are necessarily infringed by their + Contribution(s) alone or by combination of their Contribution(s) + with the Work to which such Contribution(s) was submitted. If You + institute patent litigation against any entity (including a + cross-claim or counterclaim in a lawsuit) alleging that the Work + or a Contribution incorporated within the Work constitutes direct + or contributory patent infringement, then any patent licenses + granted to You under this License for that Work shall terminate + as of the date such litigation is filed. + + 4. Redistribution. You may reproduce and distribute copies of the + Work or Derivative Works thereof in any medium, with or without + modifications, and in Source or Object form, provided that You + meet the following conditions: + + (a) You must give any other recipients of the Work or + Derivative Works a copy of this License; and + + (b) You must cause any modified files to carry prominent notices + stating that You changed the files; and + + (c) You must retain, in the Source form of any Derivative Works + that You distribute, all copyright, patent, trademark, and + attribution notices from the Source form of the Work, + excluding those notices that do not pertain to any part of + the Derivative Works; and + + (d) If the Work includes a "NOTICE" text file as part of its + distribution, then any Derivative Works that You distribute must + include a readable copy of the attribution notices contained + within such NOTICE file, excluding those notices that do not + pertain to any part of the Derivative Works, in at least one + of the following places: within a NOTICE text file distributed + as part of the Derivative Works; within the Source form or + documentation, if provided along with the Derivative Works; or, + within a display generated by the Derivative Works, if and + wherever such third-party notices normally appear. The contents + of the NOTICE file are for informational purposes only and + do not modify the License. You may add Your own attribution + notices within Derivative Works that You distribute, alongside + or as an addendum to the NOTICE text from the Work, provided + that such additional attribution notices cannot be construed + as modifying the License. + + You may add Your own copyright statement to Your modifications and + may provide additional or different license terms and conditions + for use, reproduction, or distribution of Your modifications, or + for any such Derivative Works as a whole, provided Your use, + reproduction, and distribution of the Work otherwise complies with + the conditions stated in this License. + + 5. Submission of Contributions. Unless You explicitly state otherwise, + any Contribution intentionally submitted for inclusion in the Work + by You to the Licensor shall be under the terms and conditions of + this License, without any additional terms or conditions. + Notwithstanding the above, nothing herein shall supersede or modify + the terms of any separate license agreement you may have executed + with Licensor regarding such Contributions. + + 6. Trademarks. This License does not grant permission to use the trade + names, trademarks, service marks, or product names of the Licensor, + except as required for reasonable and customary use in describing the + origin of the Work and reproducing the content of the NOTICE file. + + 7. Disclaimer of Warranty. Unless required by applicable law or + agreed to in writing, Licensor provides the Work (and each + Contributor provides its Contributions) on an "AS IS" BASIS, + WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or + implied, including, without limitation, any warranties or conditions + of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A + PARTICULAR PURPOSE. You are solely responsible for determining the + appropriateness of using or redistributing the Work and assume any + risks associated with Your exercise of permissions under this License. + + 8. Limitation of Liability. In no event and under no legal theory, + whether in tort (including negligence), contract, or otherwise, + unless required by applicable law (such as deliberate and grossly + negligent acts) or agreed to in writing, shall any Contributor be + liable to You for damages, including any direct, indirect, special, + incidental, or consequential damages of any character arising as a + result of this License or out of the use or inability to use the + Work (including but not limited to damages for loss of goodwill, + work stoppage, computer failure or malfunction, or any and all + other commercial damages or losses), even if such Contributor + has been advised of the possibility of such damages. + + 9. Accepting Warranty or Additional Liability. While redistributing + the Work or Derivative Works thereof, You may choose to offer, + and charge a fee for, acceptance of support, warranty, indemnity, + or other liability obligations and/or rights consistent with this + License. However, in accepting such obligations, You may act only + on Your own behalf and on Your sole responsibility, not on behalf + of any other Contributor, and only if You agree to indemnify, + defend, and hold each Contributor harmless for any liability + incurred by, or claims asserted against, such Contributor by reason + of your accepting any such warranty or additional liability. + + END OF TERMS AND CONDITIONS + + APPENDIX: How to apply the Apache License to your work. + + To apply the Apache License to your work, attach the following + boilerplate notice, with the fields enclosed by brackets "[]" + replaced with your own identifying information. (Don't include + the brackets!) The text should be enclosed in the appropriate + comment syntax for the file format. We also recommend that a + file or class name and description of purpose be included on the + same "printed page" as the copyright notice for easier + identification within third-party archives. + + Copyright [yyyy] [name of copyright owner] + + Licensed under the Apache License, Version 2.0 (the "License"); + you may not use this file except in compliance with the License. + You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + + Unless required by applicable law or agreed to in writing, software + distributed under the License is distributed on an "AS IS" BASIS, + WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + See the License for the specific language governing permissions and + limitations under the License. diff --git a/README.md b/README.md new file mode 100644 index 0000000000000..28e92e862356e --- /dev/null +++ b/README.md @@ -0,0 +1,63 @@ + + Libra Logo + + +
+ +[![CircleCI](https://circleci.com/gh/libra/libra.svg?style=shield)](https://circleci.com/gh/libra/libra) +[![License](https://img.shields.io/badge/license-Apache-green.svg)](LICENSE.md) + +Libra Core implements a decentralized, programmable database which provides a financial infrastructure that can empower billions of people. + +## Note to Developers +* Libra Core is a prototype. +* The APIs are constantly evolving and designed to demonstrate types of functionality. Expect substantial changes before the release. +* We’ve launched a testnet that is a live demonstration of an early prototype of the Libra Blockchain software. + +## Contributing + +Read our [Contrbuting guide](https://developers.libra.org/docs/community/contributing). Find out what’s coming on our [blog](https://developers.libra.org/blog/2019/06/18/The-Path-Forward). + +## Getting Started + +### Learn About Libra +* [Welcome](https://developers.libra.org/docs/welcome-to-libra) +* [Libra Protocol: Key Concepts](https://developers.libra.org/docs/libra-protocol) +* [Life of a Transaction](https://developers.libra.org/docs/life-of-a-transaction) + +### Try Libra Core +* [My First Transaction](https://developers.libra.org/docs/my-first-transaction) +* [Getting Started With Move](https://developers.libra.org/docs/move-overview) + +### Technical Papers +* [The Libra Blockchain](https://developers.libra.org/docs/the-libra-blockchain-paper) +* [Move: A Language With Programmable Resources](https://developers.libra.org/docs/move-paper) +* [State Machine Replication in the Libra Blockchain](https://developers.libra.org/docs/state-machine-replication-paper) + +### Blog +* [Libra: The Path Forward](https://developers.libra.org/blog/2019/06/18/the-path-forward/) + +### Libra Codebase + +* [Libra Core Overview](https://developers.libra.org/docs/libra-core-overview) +* [Admission Control](https://developers.libra.org/docs/crates/admission-control) +* [Bytecode Verifier](https://developers.libra.org/docs/crates/bytecode-verifier) +* [Consensus](https://developers.libra.org/docs/crates/consensus) +* [Crypto](https://developers.libra.org/docs/crates/crypto) +* [Execution](https://developers.libra.org/docs/crates/execution) +* [Mempool](https://developers.libra.org/docs/crates/mempool) +* [Move IR Compiler](https://developers.libra.org/docs/crates/ir-to-bytecode) +* [Move Language](https://developers.libra.org/docs/crates/move-language) +* [Network](https://developers.libra.org/docs/crates/network) +* [Storage](https://developers.libra.org/docs/crates/storage) +* [Virtual Machine](https://developers.libra.org/docs/crates/vm) + + +## Community + +* Join us on the [Libra Discourse](https://community.libra.org). +* Get the latest updates to our project by signing up for our [newsletter](https://developers.libra.org/newsletter_form). + +## License + +Libra Core is licensed as [Apache 2.0](https://github.com/libra/libra/blob/master/LICENSE). \ No newline at end of file diff --git a/SECURITY.md b/SECURITY.md new file mode 100644 index 0000000000000..98febf18056f7 --- /dev/null +++ b/SECURITY.md @@ -0,0 +1,5 @@ +# Security Policies and Procedures + +Please see Libra's +[security policies](https://developers.libra.org/docs/policies/security) and +procedures for reporting vulnerabilities. diff --git a/admission_control/README.md b/admission_control/README.md new file mode 100644 index 0000000000000..2e6e9018454f4 --- /dev/null +++ b/admission_control/README.md @@ -0,0 +1,43 @@ +--- +id: admission-control +title: Admission Control +custom_edit_url: https://github.com/libra/libra/edit/master/admission_control/README.md +--- +# Admission Control + +Admission Control (AC) is the public API endpoint for Libra and it takes public gRPC requests from clients. + +## Overview +Admission Control (AC) serves two types of requests from clients: +1. SubmitTransaction - To submit a transaction to the associated validator. +2. UpdateToLatestLedger - To query storage, e.g., account state, transaction log, proofs, etc. + +## Implementation Details +Admission Control (AC) implements two public APIs: +1. SubmitTransaction(SubmitTransactionRequest) + * Multiple validations will be performed against the request: + * The Transaction signature is checked first. If this check fails, AdmissionControlStatus::Rejected is returned to client. + * The Transaction is then validated by vm_validator. If this fails, the corresponding VMStatus is returned to the client. + * Once the transaction passes all validations, AC queries the sender's account balance and the latest sequence number from storage and sends them to Mempool along with the client request. + * If Mempool returns MempoolAddTransactionStatus::Valid, AdmissionControlStatus::Accepted is returned to the client indicating successful submission. Otherwise, corresponding AdmissionControlStatus is returned to the client. +2. UpdateToLatestLedger(UpdateToLatestLedgerRequest). No extra processing is performed in AC. +* The request is directly passed to storage for query. + +## Folder Structure +``` + . + β”œβ”€β”€ README.md + β”œβ”€β”€ admission_control_proto + β”‚Β Β  └── src + β”‚Β Β  └── proto # Protobuf definition files + └── admission_control_service + └── src # gRPC service source files + β”œβ”€β”€ admission_control_node.rs # Wrapper to run AC in a separate thread + β”œβ”€β”€ admission_control_service.rs # gRPC service and main logic + β”œβ”€β”€ main.rs # Main entry to run AC as a binary + └── unit_tests # Tests +``` + +## This module interacts with: +The Mempool component, to submit transactions from clients. +The Storage component, to query validator storage. diff --git a/admission_control/admission_control_proto/Cargo.toml b/admission_control/admission_control_proto/Cargo.toml new file mode 100644 index 0000000000000..455fa4e9248db --- /dev/null +++ b/admission_control/admission_control_proto/Cargo.toml @@ -0,0 +1,22 @@ +[package] +name = "admission_control_proto" +version = "0.1.0" +authors = ["Libra Association "] +license = "Apache-2.0" +publish = false +edition = "2018" + +[dependencies] +futures = "0.1.25" +futures03 = { version = "=0.3.0-alpha.16", package = "futures-preview" } +grpcio = "0.4.3" +protobuf = "2.6" + +failure = { package = "failure_ext", path = "../../common/failure_ext" } +logger = { path = "../../common/logger" } +mempool = { path = "../../mempool" } +proto_conv = { path = "../../common/proto_conv" } +types = { path = "../../types" } + +[build-dependencies] +build_helpers = { path = "../../common/build_helpers" } diff --git a/admission_control/admission_control_proto/build.rs b/admission_control/admission_control_proto/build.rs new file mode 100644 index 0000000000000..0003f434b5494 --- /dev/null +++ b/admission_control/admission_control_proto/build.rs @@ -0,0 +1,19 @@ +// Copyright (c) The Libra Core Contributors +// SPDX-License-Identifier: Apache-2.0 + +//! This compiles all the `.proto` files under `src/` directory. +//! +//! For example, if there is a file `src/a/b/c.proto`, it will generate `src/a/b/c.rs` and +//! `src/a/b/c_grpc.rs`. + +fn main() { + let proto_root = "src"; + let dependent_root = "../../types/src/proto"; + let mempool_dependent_root = "../../mempool/src/proto/shared"; + + build_helpers::build_helpers::compile_proto( + proto_root, + vec![dependent_root, mempool_dependent_root], + true, + ); +} diff --git a/admission_control/admission_control_proto/src/lib.rs b/admission_control/admission_control_proto/src/lib.rs new file mode 100644 index 0000000000000..6a0a8b66fe910 --- /dev/null +++ b/admission_control/admission_control_proto/src/lib.rs @@ -0,0 +1,110 @@ +// Copyright (c) The Libra Core Contributors +// SPDX-License-Identifier: Apache-2.0 + +pub mod proto; + +use crate::proto::admission_control::AdmissionControlStatus as ProtoAdmissionControlStatus; +use failure::prelude::*; +use logger::prelude::*; +use mempool::MempoolAddTransactionStatus; +use proto_conv::{FromProto, IntoProto}; +use types::vm_error::VMStatus; + +/// AC response status of submit_transaction to clients. +#[derive(Debug, PartialEq, Eq, Clone)] +pub enum AdmissionControlStatus { + /// Validator accepted the transaction. + Accepted = 0, + /// The sender is blacklisted. + Blacklisted = 1, + /// The transaction is rejected, e.g. due to incorrect signature. + Rejected = 2, +} + +impl IntoProto for AdmissionControlStatus { + type ProtoType = crate::proto::admission_control::AdmissionControlStatus; + + fn into_proto(self) -> Self::ProtoType { + match self { + AdmissionControlStatus::Accepted => ProtoAdmissionControlStatus::Accepted, + AdmissionControlStatus::Blacklisted => ProtoAdmissionControlStatus::Blacklisted, + AdmissionControlStatus::Rejected => ProtoAdmissionControlStatus::Rejected, + } + } +} + +impl FromProto for AdmissionControlStatus { + type ProtoType = crate::proto::admission_control::AdmissionControlStatus; + + fn from_proto(object: Self::ProtoType) -> Result { + let ret = match object { + ProtoAdmissionControlStatus::Accepted => AdmissionControlStatus::Accepted, + ProtoAdmissionControlStatus::Blacklisted => AdmissionControlStatus::Blacklisted, + ProtoAdmissionControlStatus::Rejected => AdmissionControlStatus::Rejected, + }; + Ok(ret) + } +} + +/// Rust structure for SubmitTransactionResponse protobuf definition. +#[derive(Debug, PartialEq, Eq, Clone)] +pub struct SubmitTransactionResponse { + /// AC status returned to client if any, it includes can be either error or accepted status. + pub ac_status: Option, + /// Mempool error status if any. + pub mempool_error: Option, + /// VM error status if any. + pub vm_error: Option, + /// The id of validator associated with this AC. + pub validator_id: Vec, +} + +impl IntoProto for SubmitTransactionResponse { + type ProtoType = crate::proto::admission_control::SubmitTransactionResponse; + + fn into_proto(self) -> Self::ProtoType { + let mut proto = Self::ProtoType::new(); + if let Some(ac_st) = self.ac_status { + proto.set_ac_status(ac_st.into_proto()); + } else if let Some(mem_err) = self.mempool_error { + proto.set_mempool_status(mem_err.into_proto()); + } else if let Some(vm_st) = self.vm_error { + proto.set_vm_status(vm_st.into_proto()); + } else { + error!("No status is available in SubmitTransactionResponse!"); + } + proto.set_validator_id(self.validator_id); + proto + } +} + +impl FromProto for SubmitTransactionResponse { + type ProtoType = crate::proto::admission_control::SubmitTransactionResponse; + + fn from_proto(mut object: Self::ProtoType) -> Result { + let ac_status = if object.has_ac_status() { + Some(AdmissionControlStatus::from_proto(object.get_ac_status())?) + } else { + None + }; + let mempool_error = if object.has_mempool_status() { + Some(MempoolAddTransactionStatus::from_proto( + object.get_mempool_status(), + )?) + } else { + None + }; + let vm_error = if object.has_vm_status() { + Some(VMStatus::from_proto(object.take_vm_status())?) + } else { + None + }; + + Ok(SubmitTransactionResponse { + ac_status, + mempool_error, + vm_error, + validator_id: object.take_validator_id(), + }) + } +} diff --git a/admission_control/admission_control_proto/src/proto/admission_control.proto b/admission_control/admission_control_proto/src/proto/admission_control.proto new file mode 100644 index 0000000000000..c1c4358273618 --- /dev/null +++ b/admission_control/admission_control_proto/src/proto/admission_control.proto @@ -0,0 +1,76 @@ +// Copyright (c) The Libra Core Contributors +// SPDX-License-Identifier: Apache-2.0 + +syntax = "proto3"; + +package admission_control; + +import "get_with_proof.proto"; +import "transaction.proto"; +import "proof.proto"; +import "ledger_info.proto"; +import "vm_errors.proto"; +import "mempool_status.proto"; + +// ----------------------------------------------------------------------------- +// ---------------- Submit transaction +// ----------------------------------------------------------------------------- +// The request for transaction submission. +message SubmitTransactionRequest { + // Transaction signed by wallet. + types.SignedTransaction signed_txn = 1; +} + +// Additional statuses that are possible from admission control in addition +// to VM statuses. +enum AdmissionControlStatus { + Accepted = 0; + Blacklisted = 1; + Rejected = 2; +} + +// The response for transaction submission. +// +// How does a client know if their transaction was included? +// A response from the transaction submission only means that the transaction +// was successfully added to mempool, but not that it is guaranteed to be +// included in the chain. Each transaction should include an expiration time in +// the signed transaction. Let's call this T0. As a client, I submit my +// transaction to a validator. I now need to poll for the transaction I +// submitted. I can use the query that takes my account and sequence number. If +// I receive back that the transaction is completed, I will verify the proofs to +// ensure that this is the transaction I expected. If I receive a response that +// my transaction is not yet completed, I must check the latest timestamp in the +// ledgerInfo that I receive back from the query. If this time is greater than +// T0, I can be certain that my transaction will never be included. If this +// time is less than T0, I need to continue polling. +message SubmitTransactionResponse { + // The status of a transaction submission can either be a VM status, or + // some other admission control/mempool specific status e.g. Blacklisted. + oneof status { + types.VMStatus vm_status = 1; + AdmissionControlStatus ac_status = 2; + mempool.MempoolAddTransactionStatus mempool_status = 3; + } + // Public key(id) of the validator that processed this transaction + bytes validator_id = 4; +} + +// ----------------------------------------------------------------------------- +// ---------------- Service definition +// ----------------------------------------------------------------------------- +service AdmissionControl { + // Public API to submit transaction to a validator. + rpc SubmitTransaction(SubmitTransactionRequest) + returns (SubmitTransactionResponse) {} + + // This API is used to update the client to the latest ledger version and + // optionally also request 1..n other pieces of data. This allows for batch + // queries. All queries return proofs that a client should check to validate + // the data. Note that if a client only wishes to update to the latest + // LedgerInfo and receive the proof of this latest version, they can simply + // omit the requested_items (or pass an empty list) + rpc UpdateToLatestLedger( + types.UpdateToLatestLedgerRequest) + returns (types.UpdateToLatestLedgerResponse) {} +} diff --git a/admission_control/admission_control_proto/src/proto/mod.rs b/admission_control/admission_control_proto/src/proto/mod.rs new file mode 100644 index 0000000000000..385a1dafa974d --- /dev/null +++ b/admission_control/admission_control_proto/src/proto/mod.rs @@ -0,0 +1,10 @@ +// Copyright (c) The Libra Core Contributors +// SPDX-License-Identifier: Apache-2.0 + +use mempool::proto::shared::mempool_status; +use types::proto::*; + +/// Auto generated proto src files +pub mod admission_control; +/// Auto generated proto src files +pub mod admission_control_grpc; diff --git a/admission_control/admission_control_service/Cargo.toml b/admission_control/admission_control_service/Cargo.toml new file mode 100644 index 0000000000000..be012b1aec287 --- /dev/null +++ b/admission_control/admission_control_service/Cargo.toml @@ -0,0 +1,35 @@ +[package] +name = "admission_control_service" +version = "0.1.0" +authors = ["Libra Association "] +license = "Apache-2.0" +publish = false +edition = "2018" + +[dependencies] +futures = "0.1.25" +futures03 = { version = "=0.3.0-alpha.16", package = "futures-preview" } +grpcio = "0.4.3" +lazy_static = "1.3.0" +protobuf = "2.6" + +admission_control_proto = { path = "../admission_control_proto" } +config = { path = "../../config" } +crypto = { path = "../../crypto/legacy_crypto" } +debug_interface = { path = "../../common/debug_interface" } +failure = { package = "failure_ext", path = "../../common/failure_ext" } +executable_helpers = { path = "../../common/executable_helpers"} +grpc_helpers = { path = "../../common/grpc_helpers"} +logger = { path = "../../common/logger" } +mempool = { path = "../../mempool" } +metrics = { path = "../../common/metrics" } +proto_conv = { path = "../../common/proto_conv" } +storage_client = { path = "../../storage/storage_client" } +types = { path = "../../types" } +vm_validator = { path = "../../vm_validator" } + +[dev-dependencies] +storage_service = { path = "../../storage/storage_service" } + +[build-dependencies] +build_helpers = { path = "../../common/build_helpers" } diff --git a/admission_control/admission_control_service/src/admission_control_node.rs b/admission_control/admission_control_service/src/admission_control_node.rs new file mode 100644 index 0000000000000..64ab520067e50 --- /dev/null +++ b/admission_control/admission_control_service/src/admission_control_node.rs @@ -0,0 +1,133 @@ +// Copyright (c) The Libra Core Contributors +// SPDX-License-Identifier: Apache-2.0 + +use crate::admission_control_service::AdmissionControlService; +use admission_control_proto::proto::admission_control_grpc; +use config::config::NodeConfig; +use debug_interface::{node_debug_service::NodeDebugService, proto::node_debug_interface_grpc}; +use failure::prelude::*; +use grpc_helpers::spawn_service_thread; +use grpcio::{ChannelBuilder, EnvBuilder, Environment}; +use logger::prelude::*; +use mempool::proto::{mempool_client::MempoolClientTrait, mempool_grpc::MempoolClient}; +use std::{sync::Arc, thread}; +use storage_client::{StorageRead, StorageReadServiceClient}; +use vm_validator::vm_validator::VMValidator; + +/// Struct to run Admission Control service in a dedicated process. It will be used to spin up +/// extra AC instances to talk to the same validator. +pub struct AdmissionControlNode { + /// Config used to setup environment for this Admission Control service instance. + node_config: NodeConfig, +} + +impl Drop for AdmissionControlNode { + fn drop(&mut self) { + info!("Drop AdmissionControl node"); + } +} + +impl AdmissionControlNode { + /// Construct a new AdmissionControlNode instance using NodeConfig. + pub fn new(node_config: NodeConfig) -> Self { + AdmissionControlNode { node_config } + } + + /// Setup environment and start a new Admission Control service. + pub fn run(&self) -> Result<()> { + logger::set_global_log_collector( + self.node_config + .log_collector + .get_log_collector_type() + .unwrap(), + self.node_config.log_collector.is_async, + self.node_config.log_collector.chan_size, + ); + info!("Starting AdmissionControl node",); + // Start receiving requests + let client_env = Arc::new(EnvBuilder::new().name_prefix("grpc-ac-mem-").build()); + let mempool_connection_str = format!( + "{}:{}", + self.node_config.mempool.address, self.node_config.mempool.mempool_service_port + ); + let mempool_channel = + ChannelBuilder::new(Arc::clone(&client_env)).connect(&mempool_connection_str); + + self.run_with_clients( + Arc::clone(&client_env), + Arc::new(MempoolClient::new(mempool_channel)), + Some(Arc::new(StorageReadServiceClient::new( + Arc::clone(&client_env), + &self.node_config.storage.address, + self.node_config.storage.port, + ))), + ) + } + + /// This method will start a node using the provided clients to external services. + /// For now, mempool is a mandatory argument, and storage is Option. If it doesn't exist, + /// it'll be generated before starting the node. + pub fn run_with_clients( + &self, + env: Arc, + mp_client: Arc, + storage_client: Option>, + ) -> Result<()> { + // create storage client if doesnt exist + let storage_client: Arc = match storage_client { + Some(c) => c, + None => Arc::new(StorageReadServiceClient::new( + env, + &self.node_config.storage.address, + self.node_config.storage.port, + )), + }; + + let vm_validator = Arc::new(VMValidator::new( + &self.node_config, + Arc::clone(&storage_client), + )); + + let handle = AdmissionControlService::new( + mp_client, + storage_client, + vm_validator, + self.node_config + .admission_control + .need_to_check_mempool_before_validation, + ); + let service = admission_control_grpc::create_admission_control(handle); + + let _ac_service_handle = spawn_service_thread( + service, + self.node_config.admission_control.address.clone(), + self.node_config + .admission_control + .admission_control_service_port, + "admission_control", + ); + + // Start Debug interface + let debug_service = + node_debug_interface_grpc::create_node_debug_interface(NodeDebugService::new()); + let _debug_handle = spawn_service_thread( + debug_service, + self.node_config.admission_control.address.clone(), + self.node_config + .debug_interface + .admission_control_node_debug_port, + "debug_service", + ); + + info!( + "Started AdmissionControl node on port {}", + self.node_config + .admission_control + .admission_control_service_port + ); + + loop { + thread::park(); + } + } +} diff --git a/admission_control/admission_control_service/src/admission_control_service.rs b/admission_control/admission_control_service/src/admission_control_service.rs new file mode 100644 index 0000000000000..ac5e9a477e709 --- /dev/null +++ b/admission_control/admission_control_service/src/admission_control_service.rs @@ -0,0 +1,230 @@ +// Copyright (c) The Libra Core Contributors +// SPDX-License-Identifier: Apache-2.0 + +//! Admission Control (AC) is a module acting as the only public end point. It receives api requests +//! from external clients (such as wallets) and performs necessary processing before sending them to +//! next step. + +use crate::OP_COUNTERS; +use admission_control_proto::proto::{ + admission_control::{ + AdmissionControlStatus, SubmitTransactionRequest, SubmitTransactionResponse, + }, + admission_control_grpc::AdmissionControl, +}; +use failure::prelude::*; +use futures::future::Future; +use futures03::executor::block_on; +use grpc_helpers::provide_grpc_response; +use logger::prelude::*; +use mempool::proto::{ + mempool::{AddTransactionWithValidationRequest, HealthCheckRequest}, + mempool_client::MempoolClientTrait, + shared::mempool_status::MempoolAddTransactionStatus::{self, MempoolIsFull}, +}; +use metrics::counters::SVC_COUNTERS; +use proto_conv::{FromProto, IntoProto}; +use std::sync::Arc; +use storage_client::StorageRead; +use types::{ + proto::get_with_proof::{UpdateToLatestLedgerRequest, UpdateToLatestLedgerResponse}, + transaction::SignedTransaction, +}; +use vm_validator::vm_validator::{get_account_state, TransactionValidation}; + +#[cfg(test)] +#[path = "unit_tests/admission_control_service_test.rs"] +mod admission_control_service_test; + +/// Struct implementing trait (service handle) AdmissionControlService. +#[derive(Clone)] +pub struct AdmissionControlService { + /// gRPC client connecting Mempool. + mempool_client: Arc, + /// gRPC client to send read requests to Storage. + storage_read_client: Arc, + /// VM validator instance to validate transactions sent from wallets. + vm_validator: Arc, + /// Flag indicating whether we need to check mempool before validation, drop txn if check + /// fails. + need_to_check_mempool_before_validation: bool, +} + +impl AdmissionControlService +where + M: MempoolClientTrait, + V: TransactionValidation, +{ + /// Constructs a new AdmissionControlService instance. + pub fn new( + mempool_client: Arc, + storage_read_client: Arc, + vm_validator: Arc, + need_to_check_mempool_before_validation: bool, + ) -> Self { + AdmissionControlService { + mempool_client, + storage_read_client, + vm_validator, + need_to_check_mempool_before_validation, + } + } + + /// Validate transaction signature, then via VM, and add it to Mempool if it passes VM check. + pub(crate) fn submit_transaction_inner( + &self, + req: SubmitTransactionRequest, + ) -> Result { + // Drop requests first if mempool is full (validator is lagging behind) so not to consume + // unnecessary resources. + if !self.can_send_txn_to_mempool()? { + debug!("Mempool is full"); + OP_COUNTERS.inc_by("submit_txn.rejected.mempool_full", 1); + let mut response = SubmitTransactionResponse::new(); + response.set_mempool_status(MempoolIsFull); + return Ok(response); + } + + let signed_txn_proto = req.get_signed_txn(); + + let signed_txn = match SignedTransaction::from_proto(signed_txn_proto.clone()) { + Ok(t) => t, + Err(e) => { + security_log(SecurityEvent::InvalidTransactionAC) + .error(&e) + .data(&signed_txn_proto) + .log(); + let mut response = SubmitTransactionResponse::new(); + response.set_ac_status(AdmissionControlStatus::Rejected); + OP_COUNTERS.inc_by("submit_txn.rejected.invalid_txn", 1); + return Ok(response); + } + }; + + let gas_cost = signed_txn.max_gas_amount(); + let validation_status = self + .vm_validator + .validate_transaction(signed_txn.clone()) + .wait() + .map_err(|e| { + security_log(SecurityEvent::InvalidTransactionAC) + .error(&e) + .data(&signed_txn) + .log(); + e + })?; + if let Some(validation_status) = validation_status { + let mut response = SubmitTransactionResponse::new(); + OP_COUNTERS.inc_by("submit_txn.vm_validation.failure", 1); + debug!( + "txn failed in vm validation, status: {:?}, txn: {:?}", + validation_status, signed_txn + ); + response.set_vm_status(validation_status.into_proto()); + return Ok(response); + } + let sender = signed_txn.sender(); + let account_state = block_on(get_account_state(self.storage_read_client.clone(), sender)); + let mut add_transaction_request = AddTransactionWithValidationRequest::new(); + add_transaction_request.signed_txn = req.signed_txn.clone(); + add_transaction_request.set_max_gas_cost(gas_cost); + + if let Ok((sequence_number, balance)) = account_state { + add_transaction_request.set_account_balance(balance); + add_transaction_request.set_latest_sequence_number(sequence_number); + } + + self.add_txn_to_mempool(add_transaction_request) + } + + fn can_send_txn_to_mempool(&self) -> Result { + if self.need_to_check_mempool_before_validation { + let req = HealthCheckRequest::new(); + let is_mempool_healthy = self.mempool_client.health_check(&req)?.get_is_healthy(); + return Ok(is_mempool_healthy); + } + Ok(true) + } + + /// Add signed transaction to mempool once it passes vm check + fn add_txn_to_mempool( + &self, + add_transaction_request: AddTransactionWithValidationRequest, + ) -> Result { + let mempool_result = self + .mempool_client + .add_transaction_with_validation(&add_transaction_request)?; + + debug!("[GRPC] Done with transaction submission request"); + let mut response = SubmitTransactionResponse::new(); + if mempool_result.get_status() == MempoolAddTransactionStatus::Valid { + OP_COUNTERS.inc_by("submit_txn.txn_accepted", 1); + response.set_ac_status(AdmissionControlStatus::Accepted); + } else { + debug!( + "txn failed in mempool, status: {:?}, txn: {:?}", + mempool_result, + add_transaction_request.get_signed_txn() + ); + OP_COUNTERS.inc_by("submit_txn.mempool.failure", 1); + response.set_mempool_status(mempool_result.get_status()); + } + Ok(response) + } + + /// Pass the UpdateToLatestLedgerRequest to Storage for read query. + fn update_to_latest_ledger_inner( + &self, + req: UpdateToLatestLedgerRequest, + ) -> Result { + let rust_req = types::get_with_proof::UpdateToLatestLedgerRequest::from_proto(req)?; + let (response_items, ledger_info_with_sigs, validator_change_events) = self + .storage_read_client + .update_to_latest_ledger(rust_req.client_known_version, rust_req.requested_items)?; + let rust_resp = types::get_with_proof::UpdateToLatestLedgerResponse::new( + response_items, + ledger_info_with_sigs, + validator_change_events, + ); + Ok(rust_resp.into_proto()) + } +} + +impl AdmissionControl for AdmissionControlService +where + M: MempoolClientTrait, + V: TransactionValidation, +{ + /// Submit a transaction to the validator this AC instance connecting to. + /// The specific transaction will be first validated by VM and then passed + /// to Mempool for further processing. + fn submit_transaction( + &mut self, + ctx: ::grpcio::RpcContext<'_>, + req: SubmitTransactionRequest, + sink: ::grpcio::UnarySink, + ) { + debug!("[GRPC] AdmissionControl::submit_transaction"); + let _timer = SVC_COUNTERS.req(&ctx); + let resp = self.submit_transaction_inner(req); + provide_grpc_response(resp, ctx, sink); + } + + /// This API is used to update the client to the latest ledger version and optionally also + /// request 1..n other pieces of data. This allows for batch queries. All queries return + /// proofs that a client should check to validate the data. + /// Note that if a client only wishes to update to the latest LedgerInfo and receive the proof + /// of this latest version, they can simply omit the requested_items (or pass an empty list). + /// AC will not directly process this request but pass it to Storage instead. + fn update_to_latest_ledger( + &mut self, + ctx: grpcio::RpcContext<'_>, + req: types::proto::get_with_proof::UpdateToLatestLedgerRequest, + sink: grpcio::UnarySink, + ) { + debug!("[GRPC] AdmissionControl::update_to_latest_ledger"); + let _timer = SVC_COUNTERS.req(&ctx); + let resp = self.update_to_latest_ledger_inner(req); + provide_grpc_response(resp, ctx, sink); + } +} diff --git a/admission_control/admission_control_service/src/lib.rs b/admission_control/admission_control_service/src/lib.rs new file mode 100644 index 0000000000000..eb7c46f72393d --- /dev/null +++ b/admission_control/admission_control_service/src/lib.rs @@ -0,0 +1,25 @@ +// Copyright (c) The Libra Core Contributors +// SPDX-License-Identifier: Apache-2.0 + +#![deny(missing_docs)] + +//! Admission Control +//! +//! Admission Control (AC) is the public API end point taking public gRPC requests from clients. +//! AC serves two types of request from clients: +//! 1. SubmitTransaction, to submit transaction to associated validator. +//! 2. UpdateToLatestLedger, to query storage, e.g. account state, transaction log, and proofs. + +/// Wrapper to run AC in a separate process. +pub mod admission_control_node; +/// AC gRPC service. +pub mod admission_control_service; +use lazy_static::lazy_static; +use metrics::OpMetrics; + +lazy_static! { + static ref OP_COUNTERS: OpMetrics = OpMetrics::new_and_registered("admission_control"); +} + +#[cfg(test)] +mod unit_tests; diff --git a/admission_control/admission_control_service/src/main.rs b/admission_control/admission_control_service/src/main.rs new file mode 100644 index 0000000000000..e7c3c478f947c --- /dev/null +++ b/admission_control/admission_control_service/src/main.rs @@ -0,0 +1,22 @@ +// Copyright (c) The Libra Core Contributors +// SPDX-License-Identifier: Apache-2.0 + +use admission_control_service::admission_control_node; +use executable_helpers::helpers::{ + setup_executable, ARG_CONFIG_PATH, ARG_DISABLE_LOGGING, ARG_PEER_ID, +}; + +/// Run a Admission Control service in its own process. +/// It will also setup global logger and initialize config. +fn main() { + let (config, _logger, _args) = setup_executable( + "Libra AdmissionControl node".to_string(), + vec![ARG_PEER_ID, ARG_CONFIG_PATH, ARG_DISABLE_LOGGING], + ); + + let admission_control_node = admission_control_node::AdmissionControlNode::new(config); + + admission_control_node + .run() + .expect("Unable to run AdmissionControl node"); +} diff --git a/admission_control/admission_control_service/src/unit_tests/admission_control_service_test.rs b/admission_control/admission_control_service/src/unit_tests/admission_control_service_test.rs new file mode 100644 index 0000000000000..037752de737de --- /dev/null +++ b/admission_control/admission_control_service/src/unit_tests/admission_control_service_test.rs @@ -0,0 +1,280 @@ +// Copyright (c) The Libra Core Contributors +// SPDX-License-Identifier: Apache-2.0 + +use crate::{ + admission_control_service::{ + AdmissionControlService, SubmitTransactionRequest, + SubmitTransactionResponse as ProtoSubmitTransactionResponse, + }, + unit_tests::LocalMockMempool, +}; +use admission_control_proto::{AdmissionControlStatus, SubmitTransactionResponse}; +use crypto::{ + hash::CryptoHash, + signing::{generate_keypair, sign_message}, +}; +use mempool::MempoolAddTransactionStatus; +use proto_conv::FromProto; +use protobuf::{Message, UnknownFields}; +use std::sync::Arc; +use storage_service::mocks::mock_storage_client::MockStorageReadClient; +use types::{ + account_address::{AccountAddress, ADDRESS_LENGTH}, + test_helpers::transaction_test_helpers::get_test_signed_txn, + transaction::RawTransactionBytes, + vm_error::{ExecutionStatus, VMStatus, VMValidationStatus}, +}; +use vm_validator::mocks::mock_vm_validator::MockVMValidator; + +fn create_ac_service_for_ut() -> AdmissionControlService { + AdmissionControlService::new( + Arc::new(LocalMockMempool::new()), + Arc::new(MockStorageReadClient), + Arc::new(MockVMValidator), + false, + ) +} + +fn assert_status(response: ProtoSubmitTransactionResponse, status: VMStatus) { + let rust_resp = SubmitTransactionResponse::from_proto(response).unwrap(); + if rust_resp.ac_status.is_some() { + assert_eq!( + rust_resp.ac_status.unwrap(), + AdmissionControlStatus::Accepted + ); + } else { + let decoded_response = rust_resp.vm_error.unwrap(); + assert_eq!(decoded_response, status) + } +} + +#[test] +fn test_submit_txn_inner_vm() { + let ac_service = create_ac_service_for_ut(); + // create request + let mut req: SubmitTransactionRequest = SubmitTransactionRequest::new(); + let sender = AccountAddress::new([0; ADDRESS_LENGTH]); + let keypair = generate_keypair(); + req.set_signed_txn(get_test_signed_txn( + sender, + 0, + keypair.0.clone(), + keypair.1, + None, + )); + let response = ac_service.submit_transaction_inner(req.clone()).unwrap(); + assert_status( + response, + VMStatus::Validation(VMValidationStatus::SendingAccountDoesNotExist( + "TEST".to_string(), + )), + ); + let sender = AccountAddress::new([1; ADDRESS_LENGTH]); + req.set_signed_txn(get_test_signed_txn( + sender, + 0, + keypair.0.clone(), + keypair.1, + None, + )); + let response = ac_service.submit_transaction_inner(req.clone()).unwrap(); + assert_status( + response, + VMStatus::Validation(VMValidationStatus::InvalidSignature), + ); + let sender = AccountAddress::new([2; ADDRESS_LENGTH]); + req.set_signed_txn(get_test_signed_txn( + sender, + 0, + keypair.0.clone(), + keypair.1, + None, + )); + let response = ac_service.submit_transaction_inner(req.clone()).unwrap(); + assert_status( + response, + VMStatus::Validation(VMValidationStatus::InsufficientBalanceForTransactionFee), + ); + let sender = AccountAddress::new([3; ADDRESS_LENGTH]); + req.set_signed_txn(get_test_signed_txn( + sender, + 0, + keypair.0.clone(), + keypair.1, + None, + )); + let response = ac_service.submit_transaction_inner(req.clone()).unwrap(); + assert_status( + response, + VMStatus::Validation(VMValidationStatus::SequenceNumberTooNew), + ); + let sender = AccountAddress::new([4; ADDRESS_LENGTH]); + req.set_signed_txn(get_test_signed_txn( + sender, + 0, + keypair.0.clone(), + keypair.1, + None, + )); + let response = ac_service.submit_transaction_inner(req.clone()).unwrap(); + assert_status( + response, + VMStatus::Validation(VMValidationStatus::SequenceNumberTooOld), + ); + let sender = AccountAddress::new([5; ADDRESS_LENGTH]); + req.set_signed_txn(get_test_signed_txn( + sender, + 0, + keypair.0.clone(), + keypair.1, + None, + )); + let response = ac_service.submit_transaction_inner(req.clone()).unwrap(); + assert_status( + response, + VMStatus::Validation(VMValidationStatus::TransactionExpired), + ); + let sender = AccountAddress::new([6; ADDRESS_LENGTH]); + req.set_signed_txn(get_test_signed_txn( + sender, + 0, + keypair.0.clone(), + keypair.1, + None, + )); + let response = ac_service.submit_transaction_inner(req.clone()).unwrap(); + assert_status( + response, + VMStatus::Validation(VMValidationStatus::InvalidAuthKey), + ); + let sender = AccountAddress::new([8; ADDRESS_LENGTH]); + req.set_signed_txn(get_test_signed_txn( + sender, + 0, + keypair.0.clone(), + keypair.1, + None, + )); + let response = ac_service.submit_transaction_inner(req.clone()).unwrap(); + assert_status(response, VMStatus::Execution(ExecutionStatus::Executed)); + let sender = AccountAddress::new([8; ADDRESS_LENGTH]); + let test_key = generate_keypair(); + req.set_signed_txn(get_test_signed_txn( + sender, + 0, + keypair.0.clone(), + test_key.1, + None, + )); + let response = SubmitTransactionResponse::from_proto( + ac_service.submit_transaction_inner(req.clone()).unwrap(), + ) + .unwrap(); + assert_eq!( + response.ac_status.unwrap(), + AdmissionControlStatus::Rejected, + ); +} + +#[test] +fn test_reject_unknown_fields() { + let ac_service = create_ac_service_for_ut(); + let mut req: SubmitTransactionRequest = SubmitTransactionRequest::new(); + let keypair = generate_keypair(); + let sender = AccountAddress::random(); + let mut signed_txn = get_test_signed_txn(sender, 0, keypair.0.clone(), keypair.1, None); + let mut raw_txn = protobuf::parse_from_bytes::<::types::proto::transaction::RawTransaction>( + signed_txn.raw_txn_bytes.as_ref(), + ) + .unwrap(); + let mut unknown_fields = UnknownFields::new(); + unknown_fields.add_fixed32(1, 2); + raw_txn.unknown_fields = unknown_fields; + + let bytes = raw_txn.write_to_bytes().unwrap(); + let hash = RawTransactionBytes(&bytes).hash(); + let signature = sign_message(hash, &keypair.0).unwrap(); + + signed_txn.set_raw_txn_bytes(bytes); + signed_txn.set_sender_signature(signature.to_compact().to_vec()); + req.set_signed_txn(signed_txn); + let response = SubmitTransactionResponse::from_proto( + ac_service.submit_transaction_inner(req.clone()).unwrap(), + ) + .unwrap(); + assert_eq!( + response.ac_status.unwrap(), + AdmissionControlStatus::Rejected + ); +} + +#[test] +fn test_submit_txn_inner_mempool() { + let ac_service = create_ac_service_for_ut(); + let mut req: SubmitTransactionRequest = SubmitTransactionRequest::new(); + let keypair = generate_keypair(); + let insufficient_balance_add = AccountAddress::new([100; ADDRESS_LENGTH]); + req.set_signed_txn(get_test_signed_txn( + insufficient_balance_add, + 0, + keypair.0.clone(), + keypair.1, + None, + )); + let response = SubmitTransactionResponse::from_proto( + ac_service.submit_transaction_inner(req.clone()).unwrap(), + ) + .unwrap(); + assert_eq!( + response.mempool_error.unwrap(), + MempoolAddTransactionStatus::InsufficientBalance, + ); + let invalid_seq_add = AccountAddress::new([101; ADDRESS_LENGTH]); + req.set_signed_txn(get_test_signed_txn( + invalid_seq_add, + 0, + keypair.0.clone(), + keypair.1, + None, + )); + let response = SubmitTransactionResponse::from_proto( + ac_service.submit_transaction_inner(req.clone()).unwrap(), + ) + .unwrap(); + assert_eq!( + response.mempool_error.unwrap(), + MempoolAddTransactionStatus::InvalidSeqNumber, + ); + let sys_error_add = AccountAddress::new([102; ADDRESS_LENGTH]); + req.set_signed_txn(get_test_signed_txn( + sys_error_add, + 0, + keypair.0.clone(), + keypair.1, + None, + )); + let response = SubmitTransactionResponse::from_proto( + ac_service.submit_transaction_inner(req.clone()).unwrap(), + ) + .unwrap(); + assert_eq!( + response.mempool_error.unwrap(), + MempoolAddTransactionStatus::InvalidUpdate, + ); + let accepted_add = AccountAddress::new([103; ADDRESS_LENGTH]); + req.set_signed_txn(get_test_signed_txn( + accepted_add, + 0, + keypair.0.clone(), + keypair.1, + None, + )); + let response = SubmitTransactionResponse::from_proto( + ac_service.submit_transaction_inner(req.clone()).unwrap(), + ) + .unwrap(); + assert_eq!( + response.ac_status.unwrap(), + AdmissionControlStatus::Accepted, + ); +} diff --git a/admission_control/admission_control_service/src/unit_tests/mod.rs b/admission_control/admission_control_service/src/unit_tests/mod.rs new file mode 100644 index 0000000000000..b6c6541b4ad00 --- /dev/null +++ b/admission_control/admission_control_service/src/unit_tests/mod.rs @@ -0,0 +1,65 @@ +// Copyright (c) The Libra Core Contributors +// SPDX-License-Identifier: Apache-2.0 + +use mempool::proto::{ + mempool::{ + AddTransactionWithValidationRequest, AddTransactionWithValidationResponse, + HealthCheckRequest, HealthCheckResponse, + }, + mempool_client::MempoolClientTrait, + shared::mempool_status::MempoolAddTransactionStatus, +}; +use proto_conv::FromProto; +use std::time::SystemTime; +use types::{account_address::ADDRESS_LENGTH, transaction::SignedTransaction}; + +// Define a local mempool to use for unit tests here, ignore methods not used by the test +#[derive(Clone)] +pub struct LocalMockMempool { + created_time: SystemTime, +} + +impl LocalMockMempool { + pub fn new() -> Self { + Self { + created_time: SystemTime::now(), + } + } +} + +impl MempoolClientTrait for LocalMockMempool { + fn add_transaction_with_validation( + &self, + req: &AddTransactionWithValidationRequest, + ) -> ::grpcio::Result { + let mut resp = AddTransactionWithValidationResponse::new(); + let insufficient_balance_add = [100_u8; ADDRESS_LENGTH]; + let invalid_seq_add = [101_u8; ADDRESS_LENGTH]; + let sys_error_add = [102_u8; ADDRESS_LENGTH]; + let accepted_add = [103_u8; ADDRESS_LENGTH]; + let mempool_full = [104_u8; ADDRESS_LENGTH]; + let signed_txn = SignedTransaction::from_proto(req.get_signed_txn().clone()).unwrap(); + let sender = signed_txn.sender(); + if sender.as_ref() == insufficient_balance_add { + resp.set_status(MempoolAddTransactionStatus::InsufficientBalance); + } else if sender.as_ref() == invalid_seq_add { + resp.set_status(MempoolAddTransactionStatus::InvalidSeqNumber); + } else if sender.as_ref() == sys_error_add { + resp.set_status(MempoolAddTransactionStatus::InvalidUpdate); + } else if sender.as_ref() == accepted_add { + resp.set_status(MempoolAddTransactionStatus::Valid); + } else if sender.as_ref() == mempool_full { + resp.set_status(MempoolAddTransactionStatus::MempoolIsFull); + } + Ok(resp) + } + fn health_check(&self, _req: &HealthCheckRequest) -> ::grpcio::Result { + let mut ret = HealthCheckResponse::new(); + let duration_ms = SystemTime::now() + .duration_since(self.created_time) + .unwrap() + .as_millis(); + ret.set_is_healthy(duration_ms > 500 || duration_ms < 300); + Ok(ret) + } +} diff --git a/client/Cargo.toml b/client/Cargo.toml new file mode 100644 index 0000000000000..ce8267c78948b --- /dev/null +++ b/client/Cargo.toml @@ -0,0 +1,41 @@ +[package] +name = "client" +version = "0.1.0" +authors = ["Libra Association "] +license = "Apache-2.0" +publish = false +edition = "2018" + +[dependencies] +bincode = "1.1.1" +chrono = "0.4.6" +futures = "0.1.23" +grpcio = "0.4.3" +hex = "0.3.2" +hyper = "0.12" +itertools = "0.8.0" +proptest = "0.9.2" +protobuf = "2.6" +rand = "0.6.5" +rustyline = "4.1.0" +tokio = "0.1.16" +rust_decimal = "1.0.1" +num-traits = "0.2" +serde = { version = "1.0.89", features = ["derive"] } +structopt = "0.2.15" + +admission_control_proto = { version = "0.1.0", path = "../admission_control/admission_control_proto" } +config = { path = "../config" } +crash_handler = { path = "../common/crash_handler" } +crypto = { path = "../crypto/legacy_crypto" } +failure = { package = "failure_ext", path = "../common/failure_ext" } +libc = "0.2.48" +libra_wallet = { path = "./libra_wallet" } +logger = { path = "../common/logger" } +metrics = { path = "../common/metrics" } +proto_conv = { path = "../common/proto_conv" } +types = { path = "../types" } +vm_genesis = { path = "../language/vm/vm_genesis" } + +[dev-dependencies] +tempfile = "3.0.6" diff --git a/client/libra_wallet/Cargo.toml b/client/libra_wallet/Cargo.toml new file mode 100644 index 0000000000000..93ba790647dac --- /dev/null +++ b/client/libra_wallet/Cargo.toml @@ -0,0 +1,43 @@ +[package] +name = "libra_wallet" +version = "0.1.0" +authors = ["Libra Association "] +license = "Apache-2.0" +publish = false +edition = "2018" + +[dependencies.ed25519-dalek] +version = "1.0.0-pre.1" + +[dependencies.types] +path = "../../types" + +[dependencies.libra_crypto] +path = "../../crypto/legacy_crypto" +package = "crypto" + +[dependencies.proto_conv] +path = "../../common/proto_conv" + +[dependencies.failure] +path = "../../common/failure_ext" +package = "failure_ext" + +[dependencies] +rust-crypto = "0.2" +log = "0.4" +simple_logger = "0.5" +rand = "0.6.5" +rand_chacha = "0.1.1" +rand_core = "0.4.0" +hex = "0.3" +byteorder = "1.2.6" +serde = "1" +serde_derive = "1" +serde_json = "1.0.31" +tiny-keccak = "1.4.2" +protobuf = "2.6" +sha3 = "0.8.2" + +[dev-dependencies] +tempfile = "3.0.6" diff --git a/client/libra_wallet/README.md b/client/libra_wallet/README.md new file mode 100644 index 0000000000000..aaabecfe1475b --- /dev/null +++ b/client/libra_wallet/README.md @@ -0,0 +1,15 @@ +# Libra Wallet + +Libra Wallet is a pure-rust implementation of hierarchical key derivation for SecretKey material in Libra. + +# Overview + +`libra_wallet` is a library providing hierarchical key derivation for SecretKey material in Libra. The following crate is largely inspired by [`rust-wallet`](https://github.com/rust-bitcoin/rust-wallet) with minor modifications to the key derivation function. Note that Libra makes use of the ed25519 Edwards Curve Digital Signature Algorithm (EdDSA) over the Edwards Cruve cruve25519. Therefore, BIP32-like PublicKey derivation is not possible without falling back to a traditional non-deterministic Schnorr signature algorithm. For this reason, we modified the key derivation function to a simpler alternative. + +The `internal_macros.rs` is taken from [`rust-bitcoin`](https://github.com/rust-bitcoin/rust-bitcoin/blob/master/src/internal_macros.rs) and `mnemonic.rs` is a slightly modified version of the file with the same name from [`rust-wallet`](https://github.com/rust-bitcoin/rust-wallet/blob/master/wallet/src/mnemonic.rs), while `error.rs`, `key_factor.rs` and `wallet_library.rs` are modified to present a minimalist wallet library for the Libra Client. Note that `mnemonic.rs` from `rust-wallet` adheres to the [`BIP39`](https://github.com/bitcoin/bips/blob/master/bip-0039.mediawiki) spec. + +# Implementation Details + +`key_factory.rs` implements the key derivation functions. The `KeyFactory` struct holds the Master Secret Material used to derive the Child Key(s). The constructor of a particular `KeyFactory` accepts a `[u8; 64]` `Seed` and computes both the `Master` Secret Material as well as the `ChainCode` from the HMAC-512 of the `Seed`. Finally, the `KeyFactory` allows to derive a child PrivateKey at a particular `ChildNumber` from the Master and ChainCode, as well as the `ChildNumber`'s u64 member. + +`wallet_library.rs` is a thin wrapper around `KeyFactory` which enables to keep track of Libra `AccountAddresses` and the information required to restore the current wallet from a `Mnemonic` backup. The `WalletLibrary` struct includes constructors that allow to generate a new `WalletLibrary` from OS randomness or generate a `WalletLibrary` from an instance of `Mnemonic`. `WalletLibrary` also allows to generate new addresses in-order or out-of-order via the `fn new_address` and `fn new_address_at_child_number`. Finally, `WalletLibrary` is capable of signing a Libra `RawTransaction` withe PrivateKey associated to the `AccountAddress` submitted. Note that in the future, Libra will support rotating authentication keys and therefore, `WalletLibrary` will need to understand more general inputs when mapping `AuthenticationKeys` to `PrivateKeys` diff --git a/client/libra_wallet/src/error.rs b/client/libra_wallet/src/error.rs new file mode 100644 index 0000000000000..e71f01f281498 --- /dev/null +++ b/client/libra_wallet/src/error.rs @@ -0,0 +1,83 @@ +// Copyright (c) The Libra Core Contributors +// SPDX-License-Identifier: Apache-2.0 + +use failure; +use libra_crypto::hkdf::HkdfError; +use std::{convert, error::Error, fmt, io}; + +/// We define our own Result type in order to not have to import the libra/common/failture_ext +pub type Result = ::std::result::Result; + +/// Libra Wallet Error is a convenience enum for generating arbitarary WalletErrors. Curently, only +/// the LibraWalletGeneric error is being used, but there are plans to add more specific errors as +/// LibraWallet matures +pub enum WalletError { + /// generic error message + LibraWalletGeneric(String), +} + +impl Error for WalletError { + fn description(&self) -> &str { + match *self { + WalletError::LibraWalletGeneric(ref s) => s, + } + } + + fn cause(&self) -> Option<&Error> { + match *self { + WalletError::LibraWalletGeneric(_) => None, + } + } +} + +impl fmt::Display for WalletError { + fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result { + match *self { + WalletError::LibraWalletGeneric(ref s) => write!(f, "LibraWalletGeneric: {}", s), + } + } +} + +impl fmt::Debug for WalletError { + fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result { + (self as &fmt::Display).fmt(f) + } +} + +impl convert::From for io::Error { + fn from(_err: WalletError) -> io::Error { + match _err { + WalletError::LibraWalletGeneric(s) => io::Error::new(io::ErrorKind::Other, s), + } + } +} + +impl convert::From for WalletError { + fn from(err: io::Error) -> WalletError { + WalletError::LibraWalletGeneric(err.description().to_string()) + } +} + +impl convert::From for WalletError { + fn from(err: failure::prelude::Error) -> WalletError { + WalletError::LibraWalletGeneric(format!("{}", err)) + } +} + +impl convert::From for WalletError { + fn from(err: protobuf::error::ProtobufError) -> WalletError { + WalletError::LibraWalletGeneric(err.description().to_string()) + } +} + +impl convert::From for WalletError { + fn from(err: ed25519_dalek::SignatureError) -> WalletError { + WalletError::LibraWalletGeneric(format!("{}", err)) + } +} + +impl convert::From for WalletError { + fn from(err: HkdfError) -> WalletError { + WalletError::LibraWalletGeneric(format!("{}", err)) + } +} diff --git a/client/libra_wallet/src/internal_macros.rs b/client/libra_wallet/src/internal_macros.rs new file mode 100644 index 0000000000000..594eb53e8f82d --- /dev/null +++ b/client/libra_wallet/src/internal_macros.rs @@ -0,0 +1,245 @@ +// Copyright (c) The Libra Core Contributors +// SPDX-License-Identifier: Apache-2.0 + +//! The following macros are slightly modified from rust-bitcoin. The original file may be found +//! here: +//! +//! https://github.com/rust-bitcoin/rust-bitcoin/blob/master/src/internal_macros.rs + +macro_rules! impl_array_newtype { + ($thing:ident, $ty:ty, $len:expr) => { + impl $thing { + #[inline] + /// Converts the object to a raw pointer + pub fn as_ptr(&self) -> *const $ty { + let &$thing(ref dat) = self; + dat.as_ptr() + } + + #[inline] + /// Converts the object to a mutable raw pointer + pub fn as_mut_ptr(&mut self) -> *mut $ty { + let &mut $thing(ref mut dat) = self; + dat.as_mut_ptr() + } + + #[inline] + /// Returns the length of the object as an array + pub fn len(&self) -> usize { + $len + } + + #[inline] + /// Returns whether the object, as an array, is empty. Always false. + pub fn is_empty(&self) -> bool { + false + } + + #[inline] + /// Returns the underlying bytes. + pub fn as_bytes(&self) -> &[$ty; $len] { + &self.0 + } + + #[inline] + /// Returns the underlying bytes. + pub fn to_bytes(&self) -> [$ty; $len] { + self.0.clone() + } + + #[inline] + /// Returns the underlying bytes. + pub fn into_bytes(self) -> [$ty; $len] { + self.0 + } + } + + impl<'a> From<&'a [$ty]> for $thing { + fn from(data: &'a [$ty]) -> $thing { + assert_eq!(data.len(), $len); + let mut ret = [0; $len]; + ret.copy_from_slice(&data[..]); + $thing(ret) + } + } + + impl ::std::ops::Index for $thing { + type Output = $ty; + + #[inline] + fn index(&self, index: usize) -> &$ty { + let &$thing(ref dat) = self; + &dat[index] + } + } + + impl_index_newtype!($thing, $ty); + + impl PartialEq for $thing { + #[inline] + fn eq(&self, other: &$thing) -> bool { + &self[..] == &other[..] + } + } + + impl Eq for $thing {} + + impl PartialOrd for $thing { + #[inline] + fn partial_cmp(&self, other: &$thing) -> Option<::std::cmp::Ordering> { + Some(self.cmp(&other)) + } + } + + impl Ord for $thing { + #[inline] + fn cmp(&self, other: &$thing) -> ::std::cmp::Ordering { + // manually implement comparison to get little-endian ordering + // (we need this for our numeric types; non-numeric ones shouldn't + // be ordered anyway except to put them in BTrees or whatever, and + // they don't care how we order as long as we're consistent). + for i in 0..$len { + if self[$len - 1 - i] < other[$len - 1 - i] { + return ::std::cmp::Ordering::Less; + } + if self[$len - 1 - i] > other[$len - 1 - i] { + return ::std::cmp::Ordering::Greater; + } + } + ::std::cmp::Ordering::Equal + } + } + + #[cfg_attr(feature = "clippy", allow(expl_impl_clone_on_copy))] // we don't define the `struct`, we have to explicitly impl + impl Clone for $thing { + #[inline] + fn clone(&self) -> $thing { + $thing::from(&self[..]) + } + } + + impl Copy for $thing {} + + impl ::std::hash::Hash for $thing { + #[inline] + fn hash(&self, state: &mut H) + where + H: ::std::hash::Hasher, + { + (&self[..]).hash(state); + } + + fn hash_slice(data: &[$thing], state: &mut H) + where + H: ::std::hash::Hasher, + { + for d in data.iter() { + (&d[..]).hash(state); + } + } + } + }; +} + +macro_rules! impl_array_newtype_encodable { + ($thing:ident, $ty:ty, $len:expr) => { + #[cfg(feature = "serde")] + impl<'de> $crate::serde::Deserialize<'de> for $thing { + fn deserialize(deserializer: D) -> Result + where + D: $crate::serde::Deserializer<'de>, + { + use $crate::std::fmt::{self, Formatter}; + + struct Visitor; + impl<'de> $crate::serde::de::Visitor<'de> for Visitor { + type Value = $thing; + + fn expecting(&self, formatter: &mut Formatter) -> fmt::Result { + formatter.write_str("a fixed size array") + } + + #[inline] + fn visit_seq(self, mut seq: A) -> Result + where + A: $crate::serde::de::SeqAccess<'de>, + { + let mut ret: [$ty; $len] = [0; $len]; + for item in ret.iter_mut() { + *item = match seq.next_element()? { + Some(c) => c, + None => { + return Err($crate::serde::de::Error::custom("end of stream")) + } + }; + } + Ok($thing(ret)) + } + } + + deserializer.deserialize_seq(Visitor) + } + } + + #[cfg(feature = "serde")] + impl $crate::serde::Serialize for $thing { + fn serialize(&self, serializer: S) -> Result + where + S: $crate::serde::Serializer, + { + let &$thing(ref dat) = self; + (&dat[..]).serialize(serializer) + } + } + }; +} + +macro_rules! impl_array_newtype_show { + ($thing:ident) => { + impl ::std::fmt::Debug for $thing { + fn fmt(&self, f: &mut ::std::fmt::Formatter) -> ::std::fmt::Result { + write!(f, concat!(stringify!($thing), "({:?})"), &self[..]) + } + } + }; +} + +macro_rules! impl_index_newtype { + ($thing:ident, $ty:ty) => { + impl ::std::ops::Index<::std::ops::Range> for $thing { + type Output = [$ty]; + + #[inline] + fn index(&self, index: ::std::ops::Range) -> &[$ty] { + &self.0[index] + } + } + + impl ::std::ops::Index<::std::ops::RangeTo> for $thing { + type Output = [$ty]; + + #[inline] + fn index(&self, index: ::std::ops::RangeTo) -> &[$ty] { + &self.0[index] + } + } + + impl ::std::ops::Index<::std::ops::RangeFrom> for $thing { + type Output = [$ty]; + + #[inline] + fn index(&self, index: ::std::ops::RangeFrom) -> &[$ty] { + &self.0[index] + } + } + + impl ::std::ops::Index<::std::ops::RangeFull> for $thing { + type Output = [$ty]; + + #[inline] + fn index(&self, _: ::std::ops::RangeFull) -> &[$ty] { + &self.0[..] + } + } + }; +} diff --git a/client/libra_wallet/src/io_utils.rs b/client/libra_wallet/src/io_utils.rs new file mode 100644 index 0000000000000..d53679cad01c6 --- /dev/null +++ b/client/libra_wallet/src/io_utils.rs @@ -0,0 +1,47 @@ +// Copyright (c) The Libra Core Contributors +// SPDX-License-Identifier: Apache-2.0 + +//! A module to generate, store and load known users accounts. +//! The concept of known users can be helpful for testing to provide reproducable results. + +use crate::*; +use failure::prelude::*; +use std::{ + fs::File, + io::{BufRead, BufReader, Write}, + path::Path, +}; + +/// Delimiter used to ser/deserialize account data. +pub const DELIMITER: &str = ";"; + +/// Recover wallet from the path specified. +pub fn recover>(path: &P) -> Result { + let input = File::open(path)?; + let mut buffered = BufReader::new(input); + + let mut line = String::new(); + let _ = buffered.read_line(&mut line)?; + let parts: Vec<&str> = line.split(DELIMITER).collect(); + ensure!(parts.len() == 2, format!("Invalid entry '{}'", line)); + + let mnemonic = Mnemonic::from(&parts[0].to_string()[..])?; + let mut wallet = WalletLibrary::new_from_mnemonic(mnemonic); + wallet.generate_addresses(parts[1].trim().to_string().parse::()?)?; + + Ok(wallet) +} + +/// Write wallet seed to file. +pub fn write_recovery>(wallet: &WalletLibrary, path: &P) -> Result<()> { + let mut output = File::create(path)?; + writeln!( + output, + "{}{}{}", + wallet.mnemonic().to_string(), + DELIMITER, + wallet.key_leaf() + )?; + + Ok(()) +} diff --git a/client/libra_wallet/src/key_factory.rs b/client/libra_wallet/src/key_factory.rs new file mode 100644 index 0000000000000..4978e9e63a5a4 --- /dev/null +++ b/client/libra_wallet/src/key_factory.rs @@ -0,0 +1,240 @@ +// Copyright (c) The Libra Core Contributors +// SPDX-License-Identifier: Apache-2.0 + +//! The following is a minimalist version of a hierarchical key derivation library for the +//! LibraWallet. +//! +//! Note that the Libra Blockchain makes use of ed25519 Edwards Digital Signature Algorithm +//! (EdDSA) and therefore, BIP32 Public Key derivation is not available without falling back to +//! a non-deterministic Schnorr signature scheme. As LibraWallet is meant to be a minimalist +//! reference implementation of a simple wallet, the following does not deviate from the +//! ed25519 spec. In a future iteration of this wallet, we will also provide an implementation +//! of a Schnorr variant over curve25519 and demonstrate our proposal for BIP32-like public key +//! derivation. +//! +//! Note further that the Key Derivation Function (KDF) chosen in the derivation of Child +//! Private Keys adheres to [HKDF RFC 5869](https://tools.ietf.org/html/rfc5869). + +use byteorder::{ByteOrder, LittleEndian}; +use crypto::{hmac::Hmac as CryptoHmac, pbkdf2::pbkdf2, sha3::Sha3}; +use ed25519_dalek; +use libra_crypto::{hash::HashValue, hkdf::Hkdf}; +use serde::{Deserialize, Serialize}; +use sha3::Sha3_256; +use std::{convert::TryFrom, ops::AddAssign}; +use tiny_keccak::Keccak; +use types::account_address::AccountAddress; + +use crate::{error::Result, mnemonic::Mnemonic}; + +/// Master is a set of raw bytes that are used for child key derivation +pub struct Master([u8; 32]); +impl_array_newtype!(Master, u8, 32); +impl_array_newtype_show!(Master); +impl_array_newtype_encodable!(Master, u8, 32); + +/// A child number for a derived key, used to derive a certain private key from the Master +#[derive(Default, Copy, Clone, PartialEq, Eq, Debug, Serialize, Deserialize)] +pub struct ChildNumber(pub(crate) u64); + +impl ChildNumber { + /// Constructor from u64 + pub fn new(child_number: u64) -> Self { + Self(child_number) + } + + /// Bump the ChildNumber + pub fn increment(&mut self) { + self.add_assign(Self(1)); + } +} + +impl std::ops::AddAssign for ChildNumber { + fn add_assign(&mut self, other: Self) { + *self = Self(self.0 + other.0) + } +} + +impl std::convert::AsRef for ChildNumber { + fn as_ref(&self) -> &u64 { + &self.0 + } +} + +impl std::convert::AsMut for ChildNumber { + fn as_mut(&mut self) -> &mut u64 { + &mut self.0 + } +} + +/// Derived private key. +pub struct ExtendedPrivKey { + /// Child number of the key used to derive from Parent. + _child_number: ChildNumber, + /// Private key. + private_key: ed25519_dalek::SecretKey, +} + +impl ExtendedPrivKey { + /// Constructor for creating an ExtendedPrivKey from a ed25519 PrivateKey. Note that the + /// ChildNumber are not used in this iteration of LibraWallet, but in order to + /// enable more general Hierarchical KeyDerivation schemes, we include it for completeness. + pub fn new(_child_number: ChildNumber, private_key: ed25519_dalek::SecretKey) -> Self { + Self { + _child_number, + private_key, + } + } + + /// Returns the PublicKey associated to a particular ExtendedPrivKey + pub fn get_public(&self) -> ed25519_dalek::PublicKey { + (&self.private_key).into() + } + + /// Computes the sha3 hash of the PublicKey and attempts to construct a Libra AccountAddress + /// from the raw bytes of the pubkey hash + pub fn get_address(&self) -> Result { + let public_key = self.get_public(); + let mut keccak = Keccak::new_sha3_256(); + let mut hash = [0u8; 32]; + keccak.update(&public_key.to_bytes()); + keccak.finalize(&mut hash); + let addr = AccountAddress::try_from(&hash[..])?; + Ok(addr) + } + + /// Libra specific sign function that is capable of signing an arbitrary HashValue + /// NOTE: In Libra, we do not sign the raw bytes of a transaction, instead we sign the raw + /// bytes of the sha3 hash of the raw bytes of a transaction. It is important to note that the + /// raw bytes of the sha3 hash will be hashed again as part of the ed25519 signature algorithm. + /// In other words: In Libra, the message used for signature and verification is the sha3 hash + /// of the transaction. This sha3 hash is then hashed again using SHA512 to arrive at the + /// deterministic nonce for the EdDSA. + pub fn sign(&self, msg: HashValue) -> ed25519_dalek::Signature { + let public_key: ed25519_dalek::PublicKey = (&self.private_key).into(); + let expanded_secret_key: ed25519_dalek::ExpandedSecretKey = + ed25519_dalek::ExpandedSecretKey::from(&self.private_key); + expanded_secret_key.sign(msg.as_ref(), &public_key) + } +} + +/// Wrapper struct from which we derive child keys +pub struct KeyFactory { + master: Master, +} + +impl KeyFactory { + const MNEMONIC_SALT_PREFIX: &'static [u8] = b"LIBRA WALLET: mnemonic salt prefix$"; + const MASTER_KEY_SALT: &'static [u8] = b"LIBRA WALLET: master key salt$"; + const INFO_PREFIX: &'static [u8] = b"LIBRA WALLET: derived key$"; + /// Instantiate a new KeyFactor from a Seed, where the [u8; 64] raw bytes of the Seed are used + /// to derive both the Master + pub fn new(seed: &Seed) -> Result { + let hkdf_extract = Hkdf::::extract(Some(KeyFactory::MASTER_KEY_SALT), &seed.0)?; + + Ok(Self { + master: Master::from(&hkdf_extract[..32]), + }) + } + + /// Getter for the Master + pub fn master(&self) -> &[u8] { + &self.master.0[..] + } + + /// Derive a particular PrivateKey at a certain ChildNumber + /// + /// Note that the function below adheres to [HKDF RFC 5869](https://tools.ietf.org/html/rfc5869). + pub fn private_child(&self, child: ChildNumber) -> Result { + // application info in the HKDF context is defined as Libra derived key$child_number. + let mut le_n = [0u8; 8]; + LittleEndian::write_u64(&mut le_n, child.0); + let mut info = KeyFactory::INFO_PREFIX.to_vec(); + info.extend_from_slice(&le_n); + + let hkdf_expand = Hkdf::::expand(&self.master(), Some(&info), 32)?; + let sk = ed25519_dalek::SecretKey::from_bytes(&hkdf_expand)?; + + Ok(ExtendedPrivKey::new(child, sk)) + } +} + +/// Seed is the ouput of a one-way function, which accepts a Mnemonic as input +pub struct Seed([u8; 32]); + +impl Seed { + /// Get the underlying Seed internal data + pub fn data(&self) -> Vec { + self.0.to_vec() + } +} + +impl Seed { + /// This constructor implements the one-way function that allows to generate a Seed from a + /// particular Mnemonic and salt. WalletLibrary implements a fixed salt, but a user could + /// choose a user-defined salt instead of the hardcoded one. + pub fn new(mnemonic: &Mnemonic, salt: &str) -> Seed { + let mut mac = CryptoHmac::new(Sha3::sha3_256(), mnemonic.to_string().as_bytes()); + let mut output = [0u8; 32]; + + let mut msalt = KeyFactory::MNEMONIC_SALT_PREFIX.to_vec(); + msalt.extend_from_slice(salt.as_bytes()); + + pbkdf2(&mut mac, &msalt, 2048, &mut output); + Seed(output) + } +} + +#[test] +fn assert_default_child_number() { + assert_eq!(ChildNumber::default(), ChildNumber(0)); +} + +#[test] +fn test_key_derivation() { + let data = hex::decode("7f7f7f7f7f7f7f7f7f7f7f7f7f7f7f7f7f7f7f7f7f7f7f7f").unwrap(); + let mnemonic = Mnemonic::from("legal winner thank year wave sausage worth useful legal winner thank year wave sausage worth useful legal will").unwrap(); + assert_eq!( + mnemonic.to_string(), + Mnemonic::mnemonic(&data).unwrap().to_string() + ); + let seed = Seed::new(&mnemonic, "LIBRA"); + + let key_factory = KeyFactory::new(&seed).unwrap(); + assert_eq!( + "16274c9618ed59177ca948529c1884ba65c57984d562ec2b4e5aa1ee3e3903be", + hex::encode(&key_factory.master()) + ); + + // Check child_0 key derivation. + let child_private_0 = key_factory.private_child(ChildNumber(0)).unwrap(); + assert_eq!( + "358a375f36d74c30b7f3299b62d712b307725938f8cc931100fbd10a434fc8b9", + hex::encode(&child_private_0.private_key.to_bytes()[..]) + ); + + // Check determinism, regenerate child_0. + let child_private_0_again = key_factory.private_child(ChildNumber(0)).unwrap(); + assert_eq!( + hex::encode(&child_private_0.private_key.to_bytes()[..]), + hex::encode(&child_private_0_again.private_key.to_bytes()[..]) + ); + + // Check child_1 key derivation. + let child_private_1 = key_factory.private_child(ChildNumber(1)).unwrap(); + assert_eq!( + "a325fe7d27b1b49f191cc03525951fec41b6ffa2d4b3007bb1d9dd353b7e56a6", + hex::encode(&child_private_1.private_key.to_bytes()[..]) + ); + + let mut child_1_again = ChildNumber(0); + child_1_again.increment(); + assert_eq!(ChildNumber(1), child_1_again); + + // Check determinism, regenerate child_1, but by incrementing ChildNumber(0). + let child_private_1_from_increment = key_factory.private_child(child_1_again).unwrap(); + assert_eq!( + "a325fe7d27b1b49f191cc03525951fec41b6ffa2d4b3007bb1d9dd353b7e56a6", + hex::encode(&child_private_1_from_increment.private_key.to_bytes()[..]) + ); +} diff --git a/client/libra_wallet/src/lib.rs b/client/libra_wallet/src/lib.rs new file mode 100644 index 0000000000000..70372e83f5990 --- /dev/null +++ b/client/libra_wallet/src/lib.rs @@ -0,0 +1,24 @@ +// Copyright (c) The Libra Core Contributors +// SPDX-License-Identifier: Apache-2.0 + +/// Error crate +pub mod error; + +/// Internal macros +#[macro_use] +pub mod internal_macros; + +/// Utils for read/write +pub mod io_utils; + +/// Utils for key derivation +pub mod key_factory; + +/// Utils for mnemonic seed +pub mod mnemonic; + +/// Utils for wallet library +pub mod wallet_library; + +/// Default imports +pub use crate::{mnemonic::Mnemonic, wallet_library::WalletLibrary}; diff --git a/client/libra_wallet/src/mnemonic.rs b/client/libra_wallet/src/mnemonic.rs new file mode 100644 index 0000000000000..e3f18d9bc1c42 --- /dev/null +++ b/client/libra_wallet/src/mnemonic.rs @@ -0,0 +1,340 @@ +// Copyright (c) The Libra Core Contributors +// SPDX-License-Identifier: Apache-2.0 + +//! The following is a slightly modified version of the file with the same name in the +//! rust-wallet crate. The original file may be found here: +//! +//! https://github.com/rust-bitcoin/rust-wallet/blob/master/wallet/src/mnemonic.rs +use crate::error::*; +use crypto::{digest::Digest, sha2::Sha256}; +#[cfg(test)] +use rand::rngs::EntropyRng; +#[cfg(test)] +use rand_core::RngCore; +use std::{ + fs::{self, File}, + io::Write, + path::Path, +}; + +#[cfg(test)] +use tempfile::NamedTempFile; + +/// Mnemonic seed for deterministic key derivation +pub struct Mnemonic(Vec<&'static str>); + +impl ToString for Mnemonic { + fn to_string(&self) -> String { + self.0.as_slice().join(" ") + } +} + +impl Mnemonic { + /// Generate mnemonic from string + pub fn from(s: &str) -> Result { + let words: Vec<_> = s.split(' ').collect(); + if words.len() < 6 || words.len() % 6 != 0 { + return Err(WalletError::LibraWalletGeneric( + "Mnemonic must have a word count divisible with 6".to_string(), + )); + } + let mut mnemonic = Vec::new(); + for word in &words { + if let Ok(idx) = WORDS.binary_search(word) { + mnemonic.push(WORDS[idx]); + } else { + return Err(WalletError::LibraWalletGeneric( + "Mneminic contains an unknown word".to_string(), + )); + } + } + Ok(Mnemonic(mnemonic)) + } + + /// Generate mnemonic from byte-array + pub fn mnemonic(data: &[u8]) -> Result { + if data.len() % 4 != 0 { + return Err(WalletError::LibraWalletGeneric( + "Data for mnemonic should have a length divisible by 4".to_string(), + )); + } + let mut check = [0u8; 32]; + + let mut sha2 = Sha256::new(); + sha2.input(data); + sha2.result(&mut check); + + let mut bits = vec![false; data.len() * 8 + data.len() / 4]; + for i in 0..data.len() { + for j in 0..8 { + bits[i * 8 + j] = (data[i] & (1 << (7 - j))) > 0; + } + } + for i in 0..data.len() / 4 { + bits[8 * data.len() + i] = (check[i / 8] & (1 << (7 - (i % 8)))) > 0; + } + let mlen = data.len() * 3 / 4; + let mut memo = Vec::new(); + for i in 0..mlen { + let mut idx = 0; + for j in 0..11 { + if bits[i * 11 + j] { + idx += 1 << (10 - j); + } + } + memo.push(WORDS[idx]); + } + Ok(Mnemonic(memo)) + } + + /// Write mnemonic to output_file_path + pub fn write(&self, output_file_path: &Path) -> Result<()> { + if output_file_path.exists() && !output_file_path.is_file() { + return Err(WalletError::LibraWalletGeneric(format!( + "Output file {:?} for mnemonic backup is reserved", + output_file_path.to_str(), + ))); + } + let mut file = File::create(output_file_path)?; + file.write_all(&self.to_string().as_bytes())?; + Ok(()) + } + + /// Read mnemonic from input_file_path + pub fn read(input_file_path: &Path) -> Result { + if input_file_path.exists() && input_file_path.is_file() { + let mnemonic_string: String = fs::read_to_string(input_file_path)?; + return Self::from(&mnemonic_string[..]); + } + Err(WalletError::LibraWalletGeneric( + "Input file for mnemonic backup does not exist".to_string(), + )) + } +} + +#[test] +fn test_roundtrip_mnemonic() { + let mut rng = EntropyRng::new(); + let mut buf = [0u8; 32]; + rng.fill_bytes(&mut buf[..]); + let file = NamedTempFile::new().unwrap(); + let path = file.into_temp_path(); + let mnemonic = Mnemonic::mnemonic(&buf[..]).unwrap(); + mnemonic.write(&path).unwrap(); + let other_mnemonic = Mnemonic::read(&path).unwrap(); + assert_eq!(mnemonic.to_string(), other_mnemonic.to_string()); +} + +const WORDS: [&str; 2048] = [ + "abandon", "ability", "able", "about", "above", "absent", "absorb", "abstract", "absurd", + "abuse", "access", "accident", "account", "accuse", "achieve", "acid", "acoustic", "acquire", + "across", "act", "action", "actor", "actress", "actual", "adapt", "add", "addict", "address", + "adjust", "admit", "adult", "advance", "advice", "aerobic", "affair", "afford", "afraid", + "again", "age", "agent", "agree", "ahead", "aim", "air", "airport", "aisle", "alarm", "album", + "alcohol", "alert", "alien", "all", "alley", "allow", "almost", "alone", "alpha", "already", + "also", "alter", "always", "amateur", "amazing", "among", "amount", "amused", "analyst", + "anchor", "ancient", "anger", "angle", "angry", "animal", "ankle", "announce", "annual", + "another", "answer", "antenna", "antique", "anxiety", "any", "apart", "apology", "appear", + "apple", "approve", "april", "arch", "arctic", "area", "arena", "argue", "arm", "armed", + "armor", "army", "around", "arrange", "arrest", "arrive", "arrow", "art", "artefact", "artist", + "artwork", "ask", "aspect", "assault", "asset", "assist", "assume", "asthma", "athlete", + "atom", "attack", "attend", "attitude", "attract", "auction", "audit", "august", "aunt", + "author", "auto", "autumn", "average", "avocado", "avoid", "awake", "aware", "away", "awesome", + "awful", "awkward", "axis", "baby", "bachelor", "bacon", "badge", "bag", "balance", "balcony", + "ball", "bamboo", "banana", "banner", "bar", "barely", "bargain", "barrel", "base", "basic", + "basket", "battle", "beach", "bean", "beauty", "because", "become", "beef", "before", "begin", + "behave", "behind", "believe", "below", "belt", "bench", "benefit", "best", "betray", "better", + "between", "beyond", "bicycle", "bid", "bike", "bind", "biology", "bird", "birth", "bitter", + "black", "blade", "blame", "blanket", "blast", "bleak", "bless", "blind", "blood", "blossom", + "blouse", "blue", "blur", "blush", "board", "boat", "body", "boil", "bomb", "bone", "bonus", + "book", "boost", "border", "boring", "borrow", "boss", "bottom", "bounce", "box", "boy", + "bracket", "brain", "brand", "brass", "brave", "bread", "breeze", "brick", "bridge", "brief", + "bright", "bring", "brisk", "broccoli", "broken", "bronze", "broom", "brother", "brown", + "brush", "bubble", "buddy", "budget", "buffalo", "build", "bulb", "bulk", "bullet", "bundle", + "bunker", "burden", "burger", "burst", "bus", "business", "busy", "butter", "buyer", "buzz", + "cabbage", "cabin", "cable", "cactus", "cage", "cake", "call", "calm", "camera", "camp", "can", + "canal", "cancel", "candy", "cannon", "canoe", "canvas", "canyon", "capable", "capital", + "captain", "car", "carbon", "card", "cargo", "carpet", "carry", "cart", "case", "cash", + "casino", "castle", "casual", "cat", "catalog", "catch", "category", "cattle", "caught", + "cause", "caution", "cave", "ceiling", "celery", "cement", "census", "century", "cereal", + "certain", "chair", "chalk", "champion", "change", "chaos", "chapter", "charge", "chase", + "chat", "cheap", "check", "cheese", "chef", "cherry", "chest", "chicken", "chief", "child", + "chimney", "choice", "choose", "chronic", "chuckle", "chunk", "churn", "cigar", "cinnamon", + "circle", "citizen", "city", "civil", "claim", "clap", "clarify", "claw", "clay", "clean", + "clerk", "clever", "click", "client", "cliff", "climb", "clinic", "clip", "clock", "clog", + "close", "cloth", "cloud", "clown", "club", "clump", "cluster", "clutch", "coach", "coast", + "coconut", "code", "coffee", "coil", "coin", "collect", "color", "column", "combine", "come", + "comfort", "comic", "common", "company", "concert", "conduct", "confirm", "congress", + "connect", "consider", "control", "convince", "cook", "cool", "copper", "copy", "coral", + "core", "corn", "correct", "cost", "cotton", "couch", "country", "couple", "course", "cousin", + "cover", "coyote", "crack", "cradle", "craft", "cram", "crane", "crash", "crater", "crawl", + "crazy", "cream", "credit", "creek", "crew", "cricket", "crime", "crisp", "critic", "crop", + "cross", "crouch", "crowd", "crucial", "cruel", "cruise", "crumble", "crunch", "crush", "cry", + "crystal", "cube", "culture", "cup", "cupboard", "curious", "current", "curtain", "curve", + "cushion", "custom", "cute", "cycle", "dad", "damage", "damp", "dance", "danger", "daring", + "dash", "daughter", "dawn", "day", "deal", "debate", "debris", "decade", "december", "decide", + "decline", "decorate", "decrease", "deer", "defense", "define", "defy", "degree", "delay", + "deliver", "demand", "demise", "denial", "dentist", "deny", "depart", "depend", "deposit", + "depth", "deputy", "derive", "describe", "desert", "design", "desk", "despair", "destroy", + "detail", "detect", "develop", "device", "devote", "diagram", "dial", "diamond", "diary", + "dice", "diesel", "diet", "differ", "digital", "dignity", "dilemma", "dinner", "dinosaur", + "direct", "dirt", "disagree", "discover", "disease", "dish", "dismiss", "disorder", "display", + "distance", "divert", "divide", "divorce", "dizzy", "doctor", "document", "dog", "doll", + "dolphin", "domain", "donate", "donkey", "donor", "door", "dose", "double", "dove", "draft", + "dragon", "drama", "drastic", "draw", "dream", "dress", "drift", "drill", "drink", "drip", + "drive", "drop", "drum", "dry", "duck", "dumb", "dune", "during", "dust", "dutch", "duty", + "dwarf", "dynamic", "eager", "eagle", "early", "earn", "earth", "easily", "east", "easy", + "echo", "ecology", "economy", "edge", "edit", "educate", "effort", "egg", "eight", "either", + "elbow", "elder", "electric", "elegant", "element", "elephant", "elevator", "elite", "else", + "embark", "embody", "embrace", "emerge", "emotion", "employ", "empower", "empty", "enable", + "enact", "end", "endless", "endorse", "enemy", "energy", "enforce", "engage", "engine", + "enhance", "enjoy", "enlist", "enough", "enrich", "enroll", "ensure", "enter", "entire", + "entry", "envelope", "episode", "equal", "equip", "era", "erase", "erode", "erosion", "error", + "erupt", "escape", "essay", "essence", "estate", "eternal", "ethics", "evidence", "evil", + "evoke", "evolve", "exact", "example", "excess", "exchange", "excite", "exclude", "excuse", + "execute", "exercise", "exhaust", "exhibit", "exile", "exist", "exit", "exotic", "expand", + "expect", "expire", "explain", "expose", "express", "extend", "extra", "eye", "eyebrow", + "fabric", "face", "faculty", "fade", "faint", "faith", "fall", "false", "fame", "family", + "famous", "fan", "fancy", "fantasy", "farm", "fashion", "fat", "fatal", "father", "fatigue", + "fault", "favorite", "feature", "february", "federal", "fee", "feed", "feel", "female", + "fence", "festival", "fetch", "fever", "few", "fiber", "fiction", "field", "figure", "file", + "film", "filter", "final", "find", "fine", "finger", "finish", "fire", "firm", "first", + "fiscal", "fish", "fit", "fitness", "fix", "flag", "flame", "flash", "flat", "flavor", "flee", + "flight", "flip", "float", "flock", "floor", "flower", "fluid", "flush", "fly", "foam", + "focus", "fog", "foil", "fold", "follow", "food", "foot", "force", "forest", "forget", "fork", + "fortune", "forum", "forward", "fossil", "foster", "found", "fox", "fragile", "frame", + "frequent", "fresh", "friend", "fringe", "frog", "front", "frost", "frown", "frozen", "fruit", + "fuel", "fun", "funny", "furnace", "fury", "future", "gadget", "gain", "galaxy", "gallery", + "game", "gap", "garage", "garbage", "garden", "garlic", "garment", "gas", "gasp", "gate", + "gather", "gauge", "gaze", "general", "genius", "genre", "gentle", "genuine", "gesture", + "ghost", "giant", "gift", "giggle", "ginger", "giraffe", "girl", "give", "glad", "glance", + "glare", "glass", "glide", "glimpse", "globe", "gloom", "glory", "glove", "glow", "glue", + "goat", "goddess", "gold", "good", "goose", "gorilla", "gospel", "gossip", "govern", "gown", + "grab", "grace", "grain", "grant", "grape", "grass", "gravity", "great", "green", "grid", + "grief", "grit", "grocery", "group", "grow", "grunt", "guard", "guess", "guide", "guilt", + "guitar", "gun", "gym", "habit", "hair", "half", "hammer", "hamster", "hand", "happy", + "harbor", "hard", "harsh", "harvest", "hat", "have", "hawk", "hazard", "head", "health", + "heart", "heavy", "hedgehog", "height", "hello", "helmet", "help", "hen", "hero", "hidden", + "high", "hill", "hint", "hip", "hire", "history", "hobby", "hockey", "hold", "hole", "holiday", + "hollow", "home", "honey", "hood", "hope", "horn", "horror", "horse", "hospital", "host", + "hotel", "hour", "hover", "hub", "huge", "human", "humble", "humor", "hundred", "hungry", + "hunt", "hurdle", "hurry", "hurt", "husband", "hybrid", "ice", "icon", "idea", "identify", + "idle", "ignore", "ill", "illegal", "illness", "image", "imitate", "immense", "immune", + "impact", "impose", "improve", "impulse", "inch", "include", "income", "increase", "index", + "indicate", "indoor", "industry", "infant", "inflict", "inform", "inhale", "inherit", + "initial", "inject", "injury", "inmate", "inner", "innocent", "input", "inquiry", "insane", + "insect", "inside", "inspire", "install", "intact", "interest", "into", "invest", "invite", + "involve", "iron", "island", "isolate", "issue", "item", "ivory", "jacket", "jaguar", "jar", + "jazz", "jealous", "jeans", "jelly", "jewel", "job", "join", "joke", "journey", "joy", "judge", + "juice", "jump", "jungle", "junior", "junk", "just", "kangaroo", "keen", "keep", "ketchup", + "key", "kick", "kid", "kidney", "kind", "kingdom", "kiss", "kit", "kitchen", "kite", "kitten", + "kiwi", "knee", "knife", "knock", "know", "lab", "label", "labor", "ladder", "lady", "lake", + "lamp", "language", "laptop", "large", "later", "latin", "laugh", "laundry", "lava", "law", + "lawn", "lawsuit", "layer", "lazy", "leader", "leaf", "learn", "leave", "lecture", "left", + "leg", "legal", "legend", "leisure", "lemon", "lend", "length", "lens", "leopard", "lesson", + "letter", "level", "liar", "liberty", "library", "license", "life", "lift", "light", "like", + "limb", "limit", "link", "lion", "liquid", "list", "little", "live", "lizard", "load", "loan", + "lobster", "local", "lock", "logic", "lonely", "long", "loop", "lottery", "loud", "lounge", + "love", "loyal", "lucky", "luggage", "lumber", "lunar", "lunch", "luxury", "lyrics", "machine", + "mad", "magic", "magnet", "maid", "mail", "main", "major", "make", "mammal", "man", "manage", + "mandate", "mango", "mansion", "manual", "maple", "marble", "march", "margin", "marine", + "market", "marriage", "mask", "mass", "master", "match", "material", "math", "matrix", + "matter", "maximum", "maze", "meadow", "mean", "measure", "meat", "mechanic", "medal", "media", + "melody", "melt", "member", "memory", "mention", "menu", "mercy", "merge", "merit", "merry", + "mesh", "message", "metal", "method", "middle", "midnight", "milk", "million", "mimic", "mind", + "minimum", "minor", "minute", "miracle", "mirror", "misery", "miss", "mistake", "mix", "mixed", + "mixture", "mobile", "model", "modify", "mom", "moment", "monitor", "monkey", "monster", + "month", "moon", "moral", "more", "morning", "mosquito", "mother", "motion", "motor", + "mountain", "mouse", "move", "movie", "much", "muffin", "mule", "multiply", "muscle", "museum", + "mushroom", "music", "must", "mutual", "myself", "mystery", "myth", "naive", "name", "napkin", + "narrow", "nasty", "nation", "nature", "near", "neck", "need", "negative", "neglect", + "neither", "nephew", "nerve", "nest", "net", "network", "neutral", "never", "news", "next", + "nice", "night", "noble", "noise", "nominee", "noodle", "normal", "north", "nose", "notable", + "note", "nothing", "notice", "novel", "now", "nuclear", "number", "nurse", "nut", "oak", + "obey", "object", "oblige", "obscure", "observe", "obtain", "obvious", "occur", "ocean", + "october", "odor", "off", "offer", "office", "often", "oil", "okay", "old", "olive", "olympic", + "omit", "once", "one", "onion", "online", "only", "open", "opera", "opinion", "oppose", + "option", "orange", "orbit", "orchard", "order", "ordinary", "organ", "orient", "original", + "orphan", "ostrich", "other", "outdoor", "outer", "output", "outside", "oval", "oven", "over", + "own", "owner", "oxygen", "oyster", "ozone", "pact", "paddle", "page", "pair", "palace", + "palm", "panda", "panel", "panic", "panther", "paper", "parade", "parent", "park", "parrot", + "party", "pass", "patch", "path", "patient", "patrol", "pattern", "pause", "pave", "payment", + "peace", "peanut", "pear", "peasant", "pelican", "pen", "penalty", "pencil", "people", + "pepper", "perfect", "permit", "person", "pet", "phone", "photo", "phrase", "physical", + "piano", "picnic", "picture", "piece", "pig", "pigeon", "pill", "pilot", "pink", "pioneer", + "pipe", "pistol", "pitch", "pizza", "place", "planet", "plastic", "plate", "play", "please", + "pledge", "pluck", "plug", "plunge", "poem", "poet", "point", "polar", "pole", "police", + "pond", "pony", "pool", "popular", "portion", "position", "possible", "post", "potato", + "pottery", "poverty", "powder", "power", "practice", "praise", "predict", "prefer", "prepare", + "present", "pretty", "prevent", "price", "pride", "primary", "print", "priority", "prison", + "private", "prize", "problem", "process", "produce", "profit", "program", "project", "promote", + "proof", "property", "prosper", "protect", "proud", "provide", "public", "pudding", "pull", + "pulp", "pulse", "pumpkin", "punch", "pupil", "puppy", "purchase", "purity", "purpose", + "purse", "push", "put", "puzzle", "pyramid", "quality", "quantum", "quarter", "question", + "quick", "quit", "quiz", "quote", "rabbit", "raccoon", "race", "rack", "radar", "radio", + "rail", "rain", "raise", "rally", "ramp", "ranch", "random", "range", "rapid", "rare", "rate", + "rather", "raven", "raw", "razor", "ready", "real", "reason", "rebel", "rebuild", "recall", + "receive", "recipe", "record", "recycle", "reduce", "reflect", "reform", "refuse", "region", + "regret", "regular", "reject", "relax", "release", "relief", "rely", "remain", "remember", + "remind", "remove", "render", "renew", "rent", "reopen", "repair", "repeat", "replace", + "report", "require", "rescue", "resemble", "resist", "resource", "response", "result", + "retire", "retreat", "return", "reunion", "reveal", "review", "reward", "rhythm", "rib", + "ribbon", "rice", "rich", "ride", "ridge", "rifle", "right", "rigid", "ring", "riot", "ripple", + "risk", "ritual", "rival", "river", "road", "roast", "robot", "robust", "rocket", "romance", + "roof", "rookie", "room", "rose", "rotate", "rough", "round", "route", "royal", "rubber", + "rude", "rug", "rule", "run", "runway", "rural", "sad", "saddle", "sadness", "safe", "sail", + "salad", "salmon", "salon", "salt", "salute", "same", "sample", "sand", "satisfy", "satoshi", + "sauce", "sausage", "save", "say", "scale", "scan", "scare", "scatter", "scene", "scheme", + "school", "science", "scissors", "scorpion", "scout", "scrap", "screen", "script", "scrub", + "sea", "search", "season", "seat", "second", "secret", "section", "security", "seed", "seek", + "segment", "select", "sell", "seminar", "senior", "sense", "sentence", "series", "service", + "session", "settle", "setup", "seven", "shadow", "shaft", "shallow", "share", "shed", "shell", + "sheriff", "shield", "shift", "shine", "ship", "shiver", "shock", "shoe", "shoot", "shop", + "short", "shoulder", "shove", "shrimp", "shrug", "shuffle", "shy", "sibling", "sick", "side", + "siege", "sight", "sign", "silent", "silk", "silly", "silver", "similar", "simple", "since", + "sing", "siren", "sister", "situate", "six", "size", "skate", "sketch", "ski", "skill", "skin", + "skirt", "skull", "slab", "slam", "sleep", "slender", "slice", "slide", "slight", "slim", + "slogan", "slot", "slow", "slush", "small", "smart", "smile", "smoke", "smooth", "snack", + "snake", "snap", "sniff", "snow", "soap", "soccer", "social", "sock", "soda", "soft", "solar", + "soldier", "solid", "solution", "solve", "someone", "song", "soon", "sorry", "sort", "soul", + "sound", "soup", "source", "south", "space", "spare", "spatial", "spawn", "speak", "special", + "speed", "spell", "spend", "sphere", "spice", "spider", "spike", "spin", "spirit", "split", + "spoil", "sponsor", "spoon", "sport", "spot", "spray", "spread", "spring", "spy", "square", + "squeeze", "squirrel", "stable", "stadium", "staff", "stage", "stairs", "stamp", "stand", + "start", "state", "stay", "steak", "steel", "stem", "step", "stereo", "stick", "still", + "sting", "stock", "stomach", "stone", "stool", "story", "stove", "strategy", "street", + "strike", "strong", "struggle", "student", "stuff", "stumble", "style", "subject", "submit", + "subway", "success", "such", "sudden", "suffer", "sugar", "suggest", "suit", "summer", "sun", + "sunny", "sunset", "super", "supply", "supreme", "sure", "surface", "surge", "surprise", + "surround", "survey", "suspect", "sustain", "swallow", "swamp", "swap", "swarm", "swear", + "sweet", "swift", "swim", "swing", "switch", "sword", "symbol", "symptom", "syrup", "system", + "table", "tackle", "tag", "tail", "talent", "talk", "tank", "tape", "target", "task", "taste", + "tattoo", "taxi", "teach", "team", "tell", "ten", "tenant", "tennis", "tent", "term", "test", + "text", "thank", "that", "theme", "then", "theory", "there", "they", "thing", "this", + "thought", "three", "thrive", "throw", "thumb", "thunder", "ticket", "tide", "tiger", "tilt", + "timber", "time", "tiny", "tip", "tired", "tissue", "title", "toast", "tobacco", "today", + "toddler", "toe", "together", "toilet", "token", "tomato", "tomorrow", "tone", "tongue", + "tonight", "tool", "tooth", "top", "topic", "topple", "torch", "tornado", "tortoise", "toss", + "total", "tourist", "toward", "tower", "town", "toy", "track", "trade", "traffic", "tragic", + "train", "transfer", "trap", "trash", "travel", "tray", "treat", "tree", "trend", "trial", + "tribe", "trick", "trigger", "trim", "trip", "trophy", "trouble", "truck", "true", "truly", + "trumpet", "trust", "truth", "try", "tube", "tuition", "tumble", "tuna", "tunnel", "turkey", + "turn", "turtle", "twelve", "twenty", "twice", "twin", "twist", "two", "type", "typical", + "ugly", "umbrella", "unable", "unaware", "uncle", "uncover", "under", "undo", "unfair", + "unfold", "unhappy", "uniform", "unique", "unit", "universe", "unknown", "unlock", "until", + "unusual", "unveil", "update", "upgrade", "uphold", "upon", "upper", "upset", "urban", "urge", + "usage", "use", "used", "useful", "useless", "usual", "utility", "vacant", "vacuum", "vague", + "valid", "valley", "valve", "van", "vanish", "vapor", "various", "vast", "vault", "vehicle", + "velvet", "vendor", "venture", "venue", "verb", "verify", "version", "very", "vessel", + "veteran", "viable", "vibrant", "vicious", "victory", "video", "view", "village", "vintage", + "violin", "virtual", "virus", "visa", "visit", "visual", "vital", "vivid", "vocal", "voice", + "void", "volcano", "volume", "vote", "voyage", "wage", "wagon", "wait", "walk", "wall", + "walnut", "want", "warfare", "warm", "warrior", "wash", "wasp", "waste", "water", "wave", + "way", "wealth", "weapon", "wear", "weasel", "weather", "web", "wedding", "weekend", "weird", + "welcome", "west", "wet", "whale", "what", "wheat", "wheel", "when", "where", "whip", + "whisper", "wide", "width", "wife", "wild", "will", "win", "window", "wine", "wing", "wink", + "winner", "winter", "wire", "wisdom", "wise", "wish", "witness", "wolf", "woman", "wonder", + "wood", "wool", "word", "work", "world", "worry", "worth", "wrap", "wreck", "wrestle", "wrist", + "write", "wrong", "yard", "year", "yellow", "you", "young", "youth", "zebra", "zero", "zone", + "zoo", +]; diff --git a/client/libra_wallet/src/wallet_library.rs b/client/libra_wallet/src/wallet_library.rs new file mode 100644 index 0000000000000..afcafd445532b --- /dev/null +++ b/client/libra_wallet/src/wallet_library.rs @@ -0,0 +1,178 @@ +// Copyright (c) The Libra Core Contributors +// SPDX-License-Identifier: Apache-2.0 + +//! The following document is a minimalist version of Libra Wallet. Note that this Wallet does +//! not promote security as the mnemonic is stored in unencrypted form. In future iterations, +//! we will be realesing more robust Wallet implementations. It is our intention to present a +//! foundation that is simple to understand and incrementally improve the LibraWallet +//! implementation and it's security guarantees throughout testnet. For a more robust wallet +//! reference, the authors suggest to audit the file of the same name in the rust-wallet crate. +//! That file can be found here: +//! +//! https://github.com/rust-bitcoin/rust-wallet/blob/master/wallet/src/walletlibrary.rs + +use crate::{ + error::*, + io_utils, + key_factory::{ChildNumber, KeyFactory, Seed}, + mnemonic::Mnemonic, +}; +use libra_crypto::hash::CryptoHash; +use proto_conv::{FromProto, IntoProto}; +use protobuf::Message; +use rand::{rngs::EntropyRng, Rng}; +use std::{collections::HashMap, path::Path}; +use types::{ + account_address::AccountAddress, + proto::transaction::SignedTransaction as ProtoSignedTransaction, + transaction::{RawTransaction, RawTransactionBytes, SignedTransaction}, +}; + +/// WalletLibrary contains all the information needed to recreate a particular wallet +pub struct WalletLibrary { + mnemonic: Mnemonic, + key_factory: KeyFactory, + addr_map: HashMap, + key_leaf: ChildNumber, +} + +impl WalletLibrary { + /// Constructor that generates a Mnemonic from OS randomness and subsequently instantiates an + /// empty WalletLibrary from that Mnemonic + #[allow(clippy::new_without_default)] + pub fn new() -> Self { + let mut rng = EntropyRng::new(); + let data: [u8; 32] = rng.gen(); + let mnemonic = Mnemonic::mnemonic(&data).unwrap(); + Self::new_from_mnemonic(mnemonic) + } + + /// Constructor that instantiates a new WalletLibrary from Mnemonic + pub fn new_from_mnemonic(mnemonic: Mnemonic) -> Self { + let seed = Seed::new(&mnemonic, "LIBRA"); + WalletLibrary { + mnemonic, + key_factory: KeyFactory::new(&seed).unwrap(), + addr_map: HashMap::new(), + key_leaf: ChildNumber(0), + } + } + + /// Function that returns the string representation of the WalletLibrary Menmonic + /// NOTE: This is not secure, and in general the mnemonic should be stored in encrypted format + pub fn mnemonic(&self) -> String { + self.mnemonic.to_string() + } + + /// Function that writes the wallet Mnemonic to file + /// NOTE: This is not secure, and in general the Mnemonic would need to be decrypted before it + /// can be written to file; otherwise the encrypted Mnemonic should be written to file + pub fn write_recovery(&self, output_file_path: &Path) -> Result<()> { + io_utils::write_recovery(&self, &output_file_path)?; + Ok(()) + } + + /// Recover wallet from input_file_path + pub fn recover(input_file_path: &Path) -> Result { + let wallet = io_utils::recover(&input_file_path)?; + Ok(wallet) + } + + /// Get the current ChildNumber in u64 format + pub fn key_leaf(&self) -> u64 { + self.key_leaf.0 + } + + /// Function that iterates from the current key_leaf until the supplied depth + pub fn generate_addresses(&mut self, depth: u64) -> Result<()> { + let current = self.key_leaf.0; + if current > depth { + return Err(WalletError::LibraWalletGeneric( + "Addresses already generated up to the supplied depth".to_string(), + )); + } + while self.key_leaf != ChildNumber(depth) { + let _ = self.new_address(); + } + Ok(()) + } + + /// Function that allows to get the address of a particular key at a certain ChildNumber + pub fn new_address_at_child_number( + &mut self, + child_number: ChildNumber, + ) -> Result { + let child = self.key_factory.private_child(child_number)?; + child.get_address() + } + + /// Function that generates a new key and adds it to the addr_map and subsequently returns the + /// AccountAddress associated to the PrivateKey, along with it's ChildNumber + pub fn new_address(&mut self) -> Result<(AccountAddress, ChildNumber)> { + let child = self.key_factory.private_child(self.key_leaf)?; + let address = child.get_address()?; + let child = self.key_leaf; + self.key_leaf.increment(); + match self.addr_map.insert(address, child) { + Some(_) => Err(WalletError::LibraWalletGeneric( + "This address is already in your wallet".to_string(), + )), + None => Ok((address, child)), + } + } + + /// Returns a list of all addresses controlled by this wallet that are currently held by the + /// addr_map + pub fn get_addresses(&self) -> Result> { + let mut ret = Vec::with_capacity(self.addr_map.len()); + let rev_map = self + .addr_map + .iter() + .map(|(&k, &v)| (v.as_ref().to_owned(), k.to_owned())) + .collect::>(); + for i in 0..self.addr_map.len() as u64 { + match rev_map.get(&i) { + Some(account_address) => { + ret.push(*account_address); + } + None => { + return Err(WalletError::LibraWalletGeneric(format!( + "Child num {} not exist while depth is {}", + i, + self.addr_map.len() + ))) + } + } + } + Ok(ret) + } + + /// Simple public function that allows to sign a Libra RawTransaction with the PrivateKey + /// associated to a particular AccountAddress. If the PrivateKey associated to an + /// AccountAddress is not contained in the addr_map, then this function will return an Error + pub fn sign_txn( + &self, + addr: &AccountAddress, + txn: RawTransaction, + ) -> Result { + if let Some(child) = self.addr_map.get(addr) { + let raw_bytes = txn.into_proto().write_to_bytes()?; + let txn_hashvalue = RawTransactionBytes(&raw_bytes).hash(); + + let child_key = self.key_factory.private_child(child.clone())?; + let signature = child_key.sign(txn_hashvalue); + let public_key = child_key.get_public(); + + let mut signed_txn = ProtoSignedTransaction::new(); + signed_txn.set_raw_txn_bytes(raw_bytes.to_vec()); + signed_txn.set_sender_public_key(public_key.to_bytes().to_vec()); + signed_txn.set_sender_signature(signature.to_bytes().to_vec()); + + Ok(SignedTransaction::from_proto(signed_txn)?) + } else { + Err(WalletError::LibraWalletGeneric( + "Well, that address is nowhere to be found... This is awkward".to_string(), + )) + } + } +} diff --git a/client/src/account_commands.rs b/client/src/account_commands.rs new file mode 100644 index 0000000000000..ef92e49654ffc --- /dev/null +++ b/client/src/account_commands.rs @@ -0,0 +1,148 @@ +// Copyright (c) The Libra Core Contributors +// SPDX-License-Identifier: Apache-2.0 + +use crate::{client_proxy::ClientProxy, commands::*}; + +/// Major command for account related operations. +pub struct AccountCommand {} + +impl Command for AccountCommand { + fn get_aliases(&self) -> Vec<&'static str> { + vec!["account", "a"] + } + fn get_description(&self) -> &'static str { + "Account operations" + } + fn execute(&self, client: &mut ClientProxy, params: &[&str]) { + let commands: Vec> = vec![ + Box::new(AccountCommandCreate {}), + Box::new(AccountCommandListAccounts {}), + Box::new(AccountCommandRecoverWallet {}), + Box::new(AccountCommandWriteRecovery {}), + Box::new(AccountCommandMint {}), + ]; + + subcommand_execute(¶ms[0], commands, client, ¶ms[1..]); + } +} + +/// Sub command to create a random account. The account will not be saved on chain. +pub struct AccountCommandCreate {} + +impl Command for AccountCommandCreate { + fn get_aliases(&self) -> Vec<&'static str> { + vec!["create", "c"] + } + fn get_description(&self) -> &'static str { + "Create an account. Returns reference ID to use in other operations" + } + fn execute(&self, client: &mut ClientProxy, params: &[&str]) { + println!(">> Creating/retrieving next account from wallet"); + match client.create_next_account(¶ms) { + Ok(account_data) => println!( + "Created/retrieved account #{} address {}", + account_data.index, + hex::encode(account_data.address) + ), + Err(e) => report_error("Error creating account", e), + } + } +} + +/// Sub command to recover wallet from the file specified. +pub struct AccountCommandRecoverWallet {} + +impl Command for AccountCommandRecoverWallet { + fn get_aliases(&self) -> Vec<&'static str> { + vec!["recover", "r"] + } + fn get_params_help(&self) -> &'static str { + "" + } + fn get_description(&self) -> &'static str { + "Recover Libra wallet from the file path" + } + fn execute(&self, client: &mut ClientProxy, params: &[&str]) { + println!(">> Recovering Wallet"); + match client.recover_wallet_accounts(¶ms) { + Ok(account_data) => { + println!( + "Wallet recovered and the first {} child accounts were derived", + account_data.len() + ); + for data in account_data { + println!("#{} address {}", data.index, hex::encode(data.address)); + } + } + Err(e) => report_error("Error recovering Libra wallet", e), + } + } +} + +/// Sub command to backup wallet to the file specified. +pub struct AccountCommandWriteRecovery {} + +impl Command for AccountCommandWriteRecovery { + fn get_aliases(&self) -> Vec<&'static str> { + vec!["write", "w"] + } + fn get_params_help(&self) -> &'static str { + "" + } + fn get_description(&self) -> &'static str { + "Save Libra wallet mnemonic recovery seed to disk" + } + fn execute(&self, client: &mut ClientProxy, params: &[&str]) { + println!(">> Saving Libra wallet mnemonic recovery seed to disk"); + match client.write_recovery(¶ms) { + Ok(_) => println!("Saved mnemonic seed to disk"), + Err(e) => report_error("Error writing mnemonic recovery seed to file", e), + } + } +} + +/// Sub command to list all accounts information. +pub struct AccountCommandListAccounts {} + +impl Command for AccountCommandListAccounts { + fn get_aliases(&self) -> Vec<&'static str> { + vec!["list", "la"] + } + fn get_description(&self) -> &'static str { + "Print all accounts that were created or loaded" + } + fn execute(&self, client: &mut ClientProxy, _params: &[&str]) { + client.print_all_accounts(); + } +} + +/// Sub command to mint account. +pub struct AccountCommandMint {} + +impl Command for AccountCommandMint { + fn get_aliases(&self) -> Vec<&'static str> { + vec!["mint", "mintb", "m", "mb"] + } + fn get_params_help(&self) -> &'static str { + "| " + } + fn get_description(&self) -> &'static str { + "Mint coins to the account. Suffix 'b' is for blocking" + } + fn execute(&self, client: &mut ClientProxy, params: &[&str]) { + println!(">> Minting coins"); + let is_blocking = blocking_cmd(params[0]); + match client.mint_coins(¶ms, is_blocking) { + Ok(_) => { + if is_blocking { + println!("Finished minting!"); + } else { + // If this value is updated, it must also be changed in + // setup_scripts/docker/mint/server.py + println!("Mint request submitted"); + } + } + Err(e) => report_error("Error minting coins", e), + } + } +} diff --git a/client/src/client_proxy.rs b/client/src/client_proxy.rs new file mode 100644 index 0000000000000..10ecc8a07331f --- /dev/null +++ b/client/src/client_proxy.rs @@ -0,0 +1,890 @@ +// Copyright (c) The Libra Core Contributors +// SPDX-License-Identifier: Apache-2.0 + +use crate::{commands::*, grpc_client::GRPCClient, AccountData}; +use admission_control_proto::proto::admission_control::SubmitTransactionRequest; +use chrono::Utc; +use config::trusted_peers::TrustedPeersConfig; +use crypto::{ + hash::CryptoHash, + signing::{sign_message, KeyPair, PrivateKey}, +}; +use failure::prelude::*; +use futures::{future::Future, stream::Stream}; +use hyper; +use libra_wallet::{io_utils, wallet_library::WalletLibrary}; +use num_traits::{ + cast::{FromPrimitive, ToPrimitive}, + identities::Zero, +}; +use proto_conv::IntoProto; +use protobuf::Message; +use rust_decimal::Decimal; +use std::{ + collections::HashMap, + convert::TryFrom, + fs, + io::{stdout, Write}, + path::Path, + str::FromStr, + sync::Arc, + thread, time, +}; +use tokio::{self, runtime::Runtime}; +use types::{ + access_path::AccessPath, + account_address::{AccountAddress, ADDRESS_LENGTH}, + account_config::{account_received_event_path, account_sent_event_path, association_address}, + account_state_blob::{AccountStateBlob, AccountStateWithProof}, + contract_event::{ContractEvent, EventWithProof}, + transaction::{Program, RawTransaction, RawTransactionBytes, SignedTransaction, Version}, + validator_verifier::ValidatorVerifier, +}; + +const CLIENT_WALLET_MNEMONIC_FILE: &str = "client.mnemonic"; +const GAS_UNIT_PRICE: u64 = 0; +const MAX_GAS_AMOUNT: u64 = 10_000; +const TX_EXPIRATION: i64 = 100; + +/// Enum used for error formatting. +#[derive(Debug)] +enum InputType { + Bool, + UnsignedInt, + Usize, +} + +/// Account data is stored in a map and referenced by an index. +#[derive(Debug)] +pub struct AddressAndIndex { + /// Address of the account. + pub address: AccountAddress, + /// The account_ref_id of this account in client. + pub index: usize, +} + +/// Used to return the sequence and sender account index submitted for a transfer +pub struct IndexAndSequence { + /// Index/key of the account in TestClient::accounts vector. + pub account_index: usize, + /// Sequence number of the account. + pub sequence_number: u64, +} + +/// Client used to test +pub struct ClientProxy { + /// client for admission control interface. + pub client: GRPCClient, + /// Created accounts. + pub accounts: Vec, + /// Address to account_ref_id map. + address_to_ref_id: HashMap, + /// We use an incrementing index to reference the accounts we create so it is easier + /// to use this from the command line. This is the index we are currently at. + cur_account_index: usize, + /// Host that operates a faucet service + faucet_server: String, + /// Account used for mint operations. + pub faucet_account: Option, + /// Wallet library managing user accounts. + pub wallet: WalletLibrary, +} + +impl ClientProxy { + /// Construct a new TestClient. + pub fn new( + host: &str, + ac_port: &str, + validator_set_file: &str, + faucet_account_file: &str, + faucet_server: Option, + mnemonic_file: Option, + ) -> Result { + let validators_config = TrustedPeersConfig::load_config(Path::new(validator_set_file)); + let validators = validators_config.get_trusted_consensus_peers(); + ensure!( + !validators.is_empty(), + "Not able to load validators from trusted peers config!" + ); + // Total 3f + 1 validators, 2f + 1 correct signatures are required. + // If < 4 validators, all validators have to agree. + let quorum_size = validators.len() * 2 / 3 + 1; + let validator_verifier = Arc::new(ValidatorVerifier::new(validators, quorum_size)); + let client = GRPCClient::new(host, ac_port, validator_verifier)?; + + let accounts = vec![]; + + // If we have a faucet account file, then load it to get the keypair + let faucet_account = if faucet_account_file.is_empty() { + None + } else { + let faucet_account_keypair: KeyPair = + ClientProxy::load_faucet_account_file(faucet_account_file); + let faucet_address = association_address(); + + let faucet_seq_number = client.get_sequence_number(faucet_address); + + let mut faucet_account_data = Self::create_account( + faucet_account_keypair.private_key().clone(), + faucet_seq_number.unwrap_or(0), + ); + faucet_account_data.address = faucet_address; + // Load the keypair from file + Some(faucet_account_data) + }; + + let faucet_server = match faucet_server { + Some(server) => server.to_string(), + None => host.replace("ac", "faucet"), + }; + + let address_to_ref_id = accounts + .iter() + .enumerate() + .map(|(ref_id, acc_data): (usize, &AccountData)| (acc_data.address, ref_id)) + .collect::>(); + let cur_account_index = accounts.len(); + + Ok(ClientProxy { + client, + accounts, + address_to_ref_id, + cur_account_index, + faucet_server, + faucet_account, + wallet: Self::get_libra_wallet(mnemonic_file)?, + }) + } + + /// Returns the account index that should be used by user to reference this account + pub fn create_next_account(&mut self, space_delim_strings: &[&str]) -> Result { + ensure!( + space_delim_strings.len() == 1, + "Invalid number of arguments to account creation" + ); + let (address, _) = self.wallet.new_address()?; + + // Sync with validator for account sequence number in case it is already created on chain. + // This assumes we have a very low probability of mnemonic word conflict. + let sequence_number = self.client.get_sequence_number(address).unwrap_or(0); + + let account_data = AccountData { + address, + key_pair: None, + sequence_number, + }; + + Ok(self.insert_account_data(account_data)) + } + + /// Print index and address of all accounts. + pub fn print_all_accounts(&self) { + if self.accounts.is_empty() { + println!("No user accounts"); + } else { + for (ref index, ref account) in self.accounts.iter().enumerate() { + println!( + "User account index: {}, address: {}, sequence number: {}", + index, + hex::encode(&account.address), + account.sequence_number, + ); + } + } + + if let Some(faucet_account) = &self.faucet_account { + println!( + "Faucet account address: {}, sequence_number: {}", + hex::encode(&faucet_account.address), + faucet_account.sequence_number, + ); + } + } + + /// Clone all accounts held in the client. + pub fn copy_all_accounts(&self) -> Vec { + self.accounts.clone() + } + + /// Set the account of this client instance. + pub fn set_accounts(&mut self, accounts: Vec) -> Vec { + self.accounts.clear(); + self.address_to_ref_id.clear(); + self.cur_account_index = 0; + let mut ret = vec![]; + for data in accounts { + ret.push(self.insert_account_data(data)); + } + ret + } + + /// Get balance from validator for the account sepcified. + pub fn get_balance(&mut self, space_delim_strings: &[&str]) -> Result { + ensure!( + space_delim_strings.len() == 2, + "Invalid number of arguments for getting balance" + ); + let account = self.get_account_address_from_parameter(space_delim_strings[1])?; + + self.client + .get_balance(account) + .map(|val| val as f64 / 1_000_000.) + } + + /// Get the latest sequence number from validator for the account specified. + pub fn get_sequence_number(&mut self, space_delim_strings: &[&str]) -> Result { + ensure!( + space_delim_strings.len() == 2 || space_delim_strings.len() == 3, + "Invalid number of arguments for getting sequence number" + ); + let account_address = self.get_account_address_from_parameter(space_delim_strings[1])?; + let sequence_number = self.client.get_sequence_number(account_address)?; + + let reset_sequence_number = if space_delim_strings.len() == 3 { + space_delim_strings[2].parse::().map_err(|error| { + format_parse_data_error( + "reset_sequence_number", + InputType::Bool, + space_delim_strings[2], + error, + ) + })? + } else { + false + }; + if reset_sequence_number { + let mut account = self.mut_account_from_parameter(space_delim_strings[1])?; + // Set sequence_number to latest one. + account.sequence_number = sequence_number; + } + Ok(sequence_number) + } + + /// Mints coins for the receiver specified. + pub fn mint_coins(&mut self, space_delim_strings: &[&str], is_blocking: bool) -> Result<()> { + ensure!( + space_delim_strings.len() == 3, + "Invalid number of arguments for mint" + ); + let receiver = self.get_account_address_from_parameter(space_delim_strings[1])?; + let num_coins = Self::convert_to_micro_libras(space_delim_strings[2])?; + + match self.faucet_account { + Some(_) => self.mint_coins_with_local_faucet_account(&receiver, num_coins, is_blocking), + None => self.mint_coins_with_faucet_service(&receiver, num_coins, is_blocking), + } + } + + /// Waits for the next transaction for a specific address and prints it + pub fn wait_for_transaction(&mut self, account: AccountAddress, sequence_number: u64) { + let mut max_iterations = 50; + print!("[waiting "); + loop { + stdout().flush().unwrap(); + max_iterations -= 1; + + match self.client.get_sequence_number(account) { + Ok(chain_seq_number) => { + if chain_seq_number == sequence_number { + println!( + "\nTransaction completed, found sequence number {}", + chain_seq_number + ); + break; + } + print!("*"); + } + Err(e) => { + if max_iterations == 0 { + panic!("wait_for_transaction timeout: {}", e); + } else { + print!("."); + } + } + } + + thread::sleep(time::Duration::from_millis(1000)); + } + } + + /// Transfer num_coins from sender account to receiver. If is_blocking = true, + /// it will keep querying validator till the sequence number is bumped up in validator. + pub fn transfer_coins_int( + &mut self, + sender_account_ref_id: usize, + receiver_address: &AccountAddress, + num_coins: u64, + gas_unit_price: Option, + max_gas_amount: Option, + is_blocking: bool, + ) -> Result { + let sender_address; + let sender_sequence; + let resp; + { + let sender = self + .accounts + .get_mut(sender_account_ref_id) + .ok_or_else(|| { + format_err!("Unable to find sender account: {}", sender_account_ref_id) + })?; + + let program = vm_genesis::encode_transfer_program(&receiver_address, num_coins); + let req = Self::create_submit_transaction_req( + program, + sender, + &self.wallet, + gas_unit_price, /* gas_unit_price */ + max_gas_amount, /* max_gas_amount */ + )?; + resp = self.client.submit_transaction(sender, &req); + sender_address = sender.address; + sender_sequence = sender.sequence_number; + } + + if is_blocking { + self.wait_for_transaction(sender_address, sender_sequence); + } + + resp.map(|_| IndexAndSequence { + account_index: sender_account_ref_id, + sequence_number: sender_sequence - 1, + }) + } + + /// Transfers coins from sender to receiver. + pub fn transfer_coins( + &mut self, + space_delim_strings: &[&str], + is_blocking: bool, + ) -> Result { + ensure!( + space_delim_strings.len() >= 4 && space_delim_strings.len() <= 6, + "Invalid number of arguments for transfer" + ); + + let sender_account_address = + self.get_account_address_from_parameter(space_delim_strings[1])?; + let receiver_address = self.get_account_address_from_parameter(space_delim_strings[2])?; + + let num_coins = Self::convert_to_micro_libras(space_delim_strings[3])?; + + let gas_unit_price = if space_delim_strings.len() > 4 { + Some(space_delim_strings[4].parse::().map_err(|error| { + format_parse_data_error( + "gas_unit_price", + InputType::UnsignedInt, + space_delim_strings[4], + error, + ) + })?) + } else { + None + }; + + let max_gas_amount = if space_delim_strings.len() > 5 { + Some(space_delim_strings[5].parse::().map_err(|error| { + format_parse_data_error( + "max_gas_amount", + InputType::UnsignedInt, + space_delim_strings[5], + error, + ) + })?) + } else { + None + }; + + let sender_account_ref_id = *self + .address_to_ref_id + .get(&sender_account_address) + .ok_or_else(|| { + format_err!( + "Unable to find existing managing account by address: {}, to see all existing \ + accounts, run: 'account list'", + sender_account_address + ) + })?; + + self.transfer_coins_int( + sender_account_ref_id, + &receiver_address, + num_coins, + gas_unit_price, + max_gas_amount, + is_blocking, + ) + } + + /// Get the latest account state from validator. + pub fn get_latest_account_state( + &mut self, + space_delim_strings: &[&str], + ) -> Result<(Option, Version)> { + ensure!( + space_delim_strings.len() == 2, + "Invalid number of arguments to get latest account state" + ); + let account = self.get_account_address_from_parameter(space_delim_strings[1])?; + self.client.get_account_blob(account) + } + + /// Get committed txn by account and sequnce number. + pub fn get_committed_txn_by_acc_seq( + &mut self, + space_delim_strings: &[&str], + ) -> Result>)>> { + ensure!( + space_delim_strings.len() == 4, + "Invalid number of arguments to get transaction by account and sequence number" + ); + let account = self.get_account_address_from_parameter(space_delim_strings[1])?; + let sequence_number = space_delim_strings[2].parse::().map_err(|error| { + format_parse_data_error( + "account_sequence_number", + InputType::UnsignedInt, + space_delim_strings[2], + error, + ) + })?; + + let fetch_events = space_delim_strings[3].parse::().map_err(|error| { + format_parse_data_error( + "fetch_events", + InputType::Bool, + space_delim_strings[3], + error, + ) + })?; + + self.client + .get_txn_by_acc_seq(account, sequence_number, fetch_events) + } + + /// Get committed txn by account and sequence number + pub fn get_committed_txn_by_range( + &mut self, + space_delim_strings: &[&str], + ) -> Result>)>> { + ensure!( + space_delim_strings.len() == 4, + "Invalid number of arguments to get transaction by range" + ); + let start_version = space_delim_strings[1].parse::().map_err(|error| { + format_parse_data_error( + "start_version", + InputType::UnsignedInt, + space_delim_strings[1], + error, + ) + })?; + let limit = space_delim_strings[2].parse::().map_err(|error| { + format_parse_data_error( + "limit", + InputType::UnsignedInt, + space_delim_strings[2], + error, + ) + })?; + let fetch_events = space_delim_strings[3].parse::().map_err(|error| { + format_parse_data_error( + "fetch_events", + InputType::Bool, + space_delim_strings[3], + error, + ) + })?; + + self.client + .get_txn_by_range(start_version, limit, fetch_events) + } + + /// Get account address from parameter. If the parameter is string of address, try to convert + /// it to address, otherwise, try to convert to u64 and looking at TestClient::accounts. + pub fn get_account_address_from_parameter(&self, para: &str) -> Result { + match is_address(para) { + true => ClientProxy::address_from_strings(para), + false => { + let account_ref_id = para.parse::().map_err(|error| { + format_parse_data_error( + "account_reference_id/account_address", + InputType::Usize, + para, + error, + ) + })?; + let account_data = self.accounts.get(account_ref_id).ok_or_else(|| { + format_err!( + "Unable to find account by account reference id: {}, to see all existing \ + accounts, run: 'account list'", + account_ref_id + ) + })?; + Ok(account_data.address) + } + } + } + + /// Get events by account and event type with start sequence number and limit. + pub fn get_events_by_account_and_type( + &mut self, + space_delim_strings: &[&str], + ) -> Result<(Vec, Option)> { + ensure!( + space_delim_strings.len() == 6, + "Invalid number of arguments to get events by access path" + ); + let account = self.get_account_address_from_parameter(space_delim_strings[1])?; + let path = match space_delim_strings[2] { + "sent" => account_sent_event_path(), + "received" => account_received_event_path(), + _ => bail!( + "Unknown event type: {:?}, only sent and received are supported", + space_delim_strings[2] + ), + }; + let access_path = AccessPath::new(account, path); + let start_seq_number = space_delim_strings[3].parse::().map_err(|error| { + format_parse_data_error( + "start_seq_number", + InputType::UnsignedInt, + space_delim_strings[3], + error, + ) + })?; + let ascending = space_delim_strings[4].parse::().map_err(|error| { + format_parse_data_error("ascending", InputType::Bool, space_delim_strings[4], error) + })?; + let limit = space_delim_strings[5].parse::().map_err(|error| { + format_parse_data_error( + "start_seq_number", + InputType::UnsignedInt, + space_delim_strings[3], + error, + ) + })?; + self.client + .get_events_by_access_path(access_path, start_seq_number, ascending, limit) + } + + /// Write mnemonic recover to the file specified. + pub fn write_recovery(&self, space_delim_strings: &[&str]) -> Result<()> { + ensure!( + space_delim_strings.len() == 2, + "Invalid number of arguments for writing recovery" + ); + + self.wallet + .write_recovery(&Path::new(space_delim_strings[1]))?; + Ok(()) + } + + /// Recover wallet accounts from file and return vec<(account_address, index)>. + pub fn recover_wallet_accounts( + &mut self, + space_delim_strings: &[&str], + ) -> Result> { + ensure!( + space_delim_strings.len() == 2, + "Invalid number of arguments for recovering wallets" + ); + + let wallet = WalletLibrary::recover(&Path::new(space_delim_strings[1]))?; + let wallet_addresses = wallet.get_addresses()?; + let mut account_data = Vec::new(); + for address in wallet_addresses { + let sequence_number = self.client.get_sequence_number(address)?; + account_data.push(AccountData { + address, + key_pair: None, + sequence_number, + }); + } + self.set_wallet(wallet); + // Clear current cached AccountData as we always swap the entire wallet completely. + Ok(self.set_accounts(account_data)) + } + + /// Insert the account data to Client::accounts and return its address and index.s + pub fn insert_account_data(&mut self, account_data: AccountData) -> AddressAndIndex { + let address = account_data.address; + + self.accounts.push(account_data); + self.address_to_ref_id + .insert(address, self.cur_account_index); + + self.cur_account_index += 1; + + AddressAndIndex { + address, + index: self.cur_account_index - 1, + } + } + + /// Test gRPC client connection with validator. + pub fn test_validator_connection(&self) -> Result<()> { + self.client.get_with_proof_sync(vec![])?; + Ok(()) + } + + fn get_libra_wallet(mnemonic_file: Option) -> Result { + let wallet_recovery_file_path = if let Some(input_mnemonic_word) = mnemonic_file { + Path::new(&input_mnemonic_word).to_path_buf() + } else { + let mut file_path = std::env::current_dir()?; + file_path.push(CLIENT_WALLET_MNEMONIC_FILE); + file_path + }; + + let wallet = if let Ok(recovered_wallet) = io_utils::recover(&wallet_recovery_file_path) { + recovered_wallet + } else { + let new_wallet = WalletLibrary::new(); + new_wallet.write_recovery(&wallet_recovery_file_path)?; + new_wallet + }; + Ok(wallet) + } + + /// Set wallet instance used by this client. + fn set_wallet(&mut self, wallet: WalletLibrary) { + self.wallet = wallet; + } + + fn load_faucet_account_file(faucet_account_file: &str) -> KeyPair { + match fs::read(faucet_account_file) { + Ok(data) => { + bincode::deserialize(&data[..]).expect("Unable to deserialize faucet account file") + } + Err(e) => { + panic!( + "Unable to read faucet account file: {}, {}", + faucet_account_file, e + ); + } + } + } + + fn address_from_strings(data: &str) -> Result { + let account_vec: Vec = hex::decode(data.parse::()?)?; + ensure!( + account_vec.len() == ADDRESS_LENGTH, + "The address {:?} is of invalid length. Addresses must be 32-bytes long" + ); + let account = match AccountAddress::try_from(&account_vec[..]) { + Ok(address) => address, + Err(error) => bail!( + "The address {:?} is invalid, error: {:?}", + &account_vec, + error, + ), + }; + Ok(account) + } + + fn mint_coins_with_local_faucet_account( + &mut self, + receiver: &AccountAddress, + num_coins: u64, + is_blocking: bool, + ) -> Result<()> { + ensure!(self.faucet_account.is_some(), "No faucet account loaded"); + let mut sender = self.faucet_account.as_mut().unwrap(); + let sender_address = sender.address; + let program = vm_genesis::encode_mint_program(&receiver, num_coins); + let req = Self::create_submit_transaction_req( + program, + sender, + &self.wallet, + None, /* gas_unit_price */ + None, /* max_gas_amount */ + )?; + let resp = self.client.submit_transaction(&mut sender, &req); + if is_blocking { + self.wait_for_transaction( + sender_address, + self.faucet_account.as_ref().unwrap().sequence_number, + ); + } + resp + } + + fn mint_coins_with_faucet_service( + &mut self, + receiver: &AccountAddress, + num_coins: u64, + is_blocking: bool, + ) -> Result<()> { + let mut runtime = Runtime::new().unwrap(); + let client = hyper::Client::new(); + + let url = format!( + "http://{}?amount={}&address={:?}", + self.faucet_server, num_coins, receiver + ) + .parse::()?; + + let response = runtime.block_on(client.get(url))?; + let status_code = response.status(); + let body = response.into_body().concat2().wait()?; + let raw_data = std::str::from_utf8(&body)?; + + if status_code != 200 { + return Err(format_err!( + "Failed to query remote faucet server[status={}]: {:?}", + status_code, + raw_data, + )); + } + let sequence_number = raw_data.parse::()?; + if is_blocking { + self.wait_for_transaction(AccountAddress::new([0; 32]), sequence_number); + } + Ok(()) + } + + fn convert_to_micro_libras(input: &str) -> Result { + ensure!(!input.is_empty(), "Empty input not allowed for libra unit"); + // This is not supposed to panic as it is used as constant here. + let max_value = Decimal::from_u64(std::u64::MAX).unwrap() / Decimal::new(1_000_000, 0); + let scale = input.find('.').unwrap_or(input.len() - 1); + ensure!( + scale <= 14, + "Input value is too big: {:?}, max: {:?}", + input, + max_value + ); + let original = Decimal::from_str(input)?; + ensure!( + original <= max_value, + "Input value is too big: {:?}, max: {:?}", + input, + max_value + ); + let value = original * Decimal::new(1_000_000, 0); + ensure!(value.fract().is_zero(), "invalid value"); + value.to_u64().ok_or_else(|| format_err!("invalid value")) + } + + fn create_submit_transaction_req( + program: Program, + sender_account: &mut AccountData, + wallet: &WalletLibrary, + gas_unit_price: Option, + max_gas_amount: Option, + ) -> Result { + let raw_txn = RawTransaction::new( + sender_account.address, + sender_account.sequence_number, + program, + max_gas_amount.unwrap_or(MAX_GAS_AMOUNT), + gas_unit_price.unwrap_or(GAS_UNIT_PRICE), + std::time::Duration::new((Utc::now().timestamp() + TX_EXPIRATION) as u64, 0), + ); + + let signed_txn = match &sender_account.key_pair { + Some(key_pair) => { + let bytes = raw_txn.clone().into_proto().write_to_bytes()?; + let hash = RawTransactionBytes(&bytes).hash(); + let signature = sign_message(hash, &key_pair.private_key())?; + + SignedTransaction::new_for_test(raw_txn, key_pair.public_key(), signature) + } + None => wallet.sign_txn(&sender_account.address, raw_txn)?, + }; + + let mut req = SubmitTransactionRequest::new(); + req.set_signed_txn(signed_txn.into_proto()); + Ok(req) + } + + fn mut_account_from_parameter(&mut self, para: &str) -> Result<&mut AccountData> { + let account_ref_id = match is_address(para) { + true => { + let account_address = ClientProxy::address_from_strings(para)?; + *self + .address_to_ref_id + .get(&account_address) + .ok_or_else(|| { + format_err!( + "Unable to find local account by address: {:?}", + account_address + ) + })? + } + false => para.parse::()?, + }; + let account_data = self + .accounts + .get_mut(account_ref_id) + .ok_or_else(|| format_err!("Unable to find account by ref id: {}", account_ref_id))?; + Ok(account_data) + } + + /// Populate a AccountData struct using private key and sequence number. + fn create_account(private_key: PrivateKey, sequence_number: u64) -> AccountData { + let keypair = KeyPair::new(private_key); + let address: AccountAddress = keypair.public_key().into(); + AccountData { + address, + key_pair: Some(keypair), + sequence_number, + } + } +} + +fn format_parse_data_error( + field: &str, + input_type: InputType, + value: &str, + error: T, +) -> Error { + format_err!( + "Unable to parse input for {} - \ + please enter an {:?}. Input was: {}, error: {:?}", + field, + input_type, + value, + error + ) +} + +#[cfg(test)] +mod tests { + use crate::client_proxy::ClientProxy; + use proptest::prelude::*; + + #[test] + fn test_micro_libra_conversion() { + assert!(ClientProxy::convert_to_micro_libras("").is_err()); + assert!(ClientProxy::convert_to_micro_libras("-11").is_err()); + assert!(ClientProxy::convert_to_micro_libras("abc").is_err()); + assert!(ClientProxy::convert_to_micro_libras("11111112312321312321321321").is_err()); + assert!(ClientProxy::convert_to_micro_libras("0").is_ok()); + assert!(ClientProxy::convert_to_micro_libras("1").is_ok()); + assert!(ClientProxy::convert_to_micro_libras("0.1").is_ok()); + assert!(ClientProxy::convert_to_micro_libras("1.1").is_ok()); + // Max of micro libra is u64::MAX (18446744073709551615). + assert!(ClientProxy::convert_to_micro_libras("18446744073709.551615").is_ok()); + assert!(ClientProxy::convert_to_micro_libras("184467440737095.51615").is_err()); + assert!(ClientProxy::convert_to_micro_libras("18446744073709.551616").is_err()); + } + + proptest! { + // Proptest is used to verify that the conversion will not panic with random input. + #[test] + fn test_micro_libra_conversion_random_string(req in any::()) { + let _res = ClientProxy::convert_to_micro_libras(&req); + } + #[test] + fn test_micro_libra_conversion_random_f64(req in any::()) { + let req_str = req.to_string(); + let _res = ClientProxy::convert_to_micro_libras(&req_str); + } + #[test] + fn test_micro_libra_conversion_random_u64(req in any::()) { + let req_str = req.to_string(); + let _res = ClientProxy::convert_to_micro_libras(&req_str); + } + } +} diff --git a/client/src/commands.rs b/client/src/commands.rs new file mode 100644 index 0000000000000..5973bb6a21740 --- /dev/null +++ b/client/src/commands.rs @@ -0,0 +1,138 @@ +// Copyright (c) The Libra Core Contributors +// SPDX-License-Identifier: Apache-2.0 + +use crate::{ + account_commands::AccountCommand, client_proxy::ClientProxy, query_commands::QueryCommand, + transfer_commands::TransferCommand, +}; + +use failure::prelude::*; +use metrics::counters::*; +use std::{collections::HashMap, sync::Arc}; +use types::account_address::ADDRESS_LENGTH; + +/// Print the error and bump up error counter. +pub fn report_error(msg: &str, e: Error) { + println!("[ERROR] {}: {}", msg, pretty_format_error(e)); + COUNTER_CLIENT_ERRORS.inc(); +} + +fn pretty_format_error(e: Error) -> String { + if let Some(grpc_error) = e.downcast_ref::() { + if let grpcio::Error::RpcFailure(grpc_rpc_failure) = grpc_error { + match grpc_rpc_failure.status { + grpcio::RpcStatusCode::Unavailable | grpcio::RpcStatusCode::DeadlineExceeded => { + return "Server unavailable, please retry and/or check \ + if host passed to the client is running" + .to_string(); + } + _ => {} + } + } + } + + return format!("{}", e); +} + +/// Check whether a command is blocking. +pub fn blocking_cmd(cmd: &str) -> bool { + cmd.ends_with('b') +} + +/// Chech whether a command is debugging command. +pub fn debug_format_cmd(cmd: &str) -> bool { + cmd.ends_with('?') +} + +/// Check whether the input string is a valid libra address. +pub fn is_address(data: &str) -> bool { + match hex::decode(data) { + Ok(vec) => vec.len() == ADDRESS_LENGTH, + Err(_) => false, + } +} + +/// Returns all the commands available, as well as the reverse index from the aliases to the +/// commands. +pub fn get_commands() -> ( + Vec>, + HashMap<&'static str, Arc>, +) { + let commands: Vec> = vec![ + Arc::new(AccountCommand {}), + Arc::new(QueryCommand {}), + Arc::new(TransferCommand {}), + ]; + let mut alias_to_cmd = HashMap::new(); + for command in &commands { + for alias in command.get_aliases() { + alias_to_cmd.insert(alias, Arc::clone(command)); + } + } + (commands, alias_to_cmd) +} + +/// Parse a cmd string, the first element in the returned vector is the command to run +pub fn parse_cmd(cmd_str: &str) -> Vec<&str> { + let input = &cmd_str[..]; + input.trim().split(' ').map(str::trim).collect() +} + +/// Print the help message for all sub commands. +pub fn print_subcommand_help(parent_command: &str, commands: &[Box]) { + println!( + "usage: {} \n\nUse the following args for this command:\n", + parent_command + ); + for cmd in commands { + println!( + "{} {}\n\t{}", + cmd.get_aliases().join(" | "), + cmd.get_params_help(), + cmd.get_description() + ); + } + println!("\n"); +} + +/// Execute sub command. +// TODO: Convert subcommands arrays to lazy statics +pub fn subcommand_execute( + parent_command_name: &str, + commands: Vec>, + client: &mut ClientProxy, + params: &[&str], +) { + let mut commands_map = HashMap::new(); + for (i, cmd) in commands.iter().enumerate() { + for alias in cmd.get_aliases() { + if commands_map.insert(alias, i) != None { + panic!("Duplicate alias {}", alias); + } + } + } + + if params.is_empty() { + print_subcommand_help(parent_command_name, &commands); + return; + } + + match commands_map.get(¶ms[0]) { + Some(&idx) => commands[idx].execute(client, ¶ms), + _ => print_subcommand_help(parent_command_name, &commands), + } +} + +/// Trait to perform client operations. +pub trait Command { + /// all commands and aliases this command support. + fn get_aliases(&self) -> Vec<&'static str>; + /// string that describes params. + fn get_params_help(&self) -> &'static str { + "" + } + /// string that describes whet command does. + fn get_description(&self) -> &'static str; + /// code to execute. + fn execute(&self, client: &mut ClientProxy, params: &[&str]); +} diff --git a/client/src/grpc_client.rs b/client/src/grpc_client.rs new file mode 100644 index 0000000000000..61457ac6622dd --- /dev/null +++ b/client/src/grpc_client.rs @@ -0,0 +1,405 @@ +// Copyright (c) The Libra Core Contributors +// SPDX-License-Identifier: Apache-2.0 + +use crate::AccountData; +use admission_control_proto::{ + proto::{ + admission_control::{ + SubmitTransactionRequest, SubmitTransactionResponse as ProtoSubmitTransactionResponse, + }, + admission_control_grpc::AdmissionControlClient, + }, + AdmissionControlStatus, SubmitTransactionResponse, +}; +use failure::prelude::*; +use futures::Future; +use grpcio::{CallOption, ChannelBuilder, EnvBuilder}; +use logger::prelude::*; +use proto_conv::{FromProto, IntoProto}; +use std::sync::Arc; +use types::{ + access_path::AccessPath, + account_address::AccountAddress, + account_config::get_account_resource_or_default, + account_state_blob::{AccountStateBlob, AccountStateWithProof}, + contract_event::{ContractEvent, EventWithProof}, + get_with_proof::{ + RequestItem, ResponseItem, UpdateToLatestLedgerRequest, UpdateToLatestLedgerResponse, + }, + transaction::{SignedTransaction, Version}, + validator_verifier::ValidatorVerifier, + vm_error::{VMStatus, VMValidationStatus}, +}; + +const MAX_GRPC_RETRY_COUNT: u64 = 1; + +/// Struct holding dependencies of client. +pub struct GRPCClient { + client: AdmissionControlClient, + validator_verifier: Arc, +} + +impl GRPCClient { + /// Construct a new Client instance. + pub fn new(host: &str, port: &str, validator_verifier: Arc) -> Result { + let conn_addr = format!("{}:{}", host, port); + + // Create a GRPC client + let env = Arc::new(EnvBuilder::new().name_prefix("grpc-client-").build()); + let ch = ChannelBuilder::new(env).connect(&conn_addr); + let client = AdmissionControlClient::new(ch); + + Ok(GRPCClient { + client, + validator_verifier, + }) + } + + /// Submits a transaction and bumps the sequence number for the sender + pub fn submit_transaction( + &self, + sender_account: &mut AccountData, + req: &SubmitTransactionRequest, + ) -> Result<()> { + let mut resp = self.submit_transaction_opt(req); + + let mut try_cnt = 0_u64; + while Self::need_to_retry(&mut try_cnt, &resp) { + resp = self.submit_transaction_opt(&req); + } + + let completed_resp = SubmitTransactionResponse::from_proto(resp?)?; + + if let Some(ac_status) = completed_resp.ac_status { + if ac_status == AdmissionControlStatus::Accepted { + // Bump up sequence_number if transaction is accepted. + sender_account.sequence_number += 1; + } else { + bail!("Transaction failed with AC status: {:?}", ac_status,); + } + } else if let Some(vm_error) = completed_resp.vm_error { + if vm_error == VMStatus::Validation(VMValidationStatus::SequenceNumberTooOld) { + sender_account.sequence_number = + self.get_sequence_number(sender_account.address)?; + bail!( + "Transaction failed with vm status: {:?}, please retry your transaction.", + vm_error + ); + } + bail!("Transaction failed with vm status: {:?}", vm_error); + } else if let Some(mempool_error) = completed_resp.mempool_error { + bail!( + "Transaction failed with mempool status: {:?}", + mempool_error, + ); + } else { + bail!( + "Malformed SubmitTransactionResponse which has no status set, {:?}", + completed_resp, + ); + } + Ok(()) + } + + /// Async version of submit_transaction + pub fn submit_transaction_async( + &self, + req: &SubmitTransactionRequest, + ) -> Result<(impl Future)> { + let resp = self + .client + .submit_transaction_async_opt(&req, Self::get_default_grpc_call_option())? + .then(|proto_resp| { + let ret = SubmitTransactionResponse::from_proto(proto_resp?)?; + Ok(ret) + }); + Ok(resp) + } + + fn submit_transaction_opt( + &self, + resp: &SubmitTransactionRequest, + ) -> Result { + Ok(self + .client + .submit_transaction_opt(resp, Self::get_default_grpc_call_option())?) + } + + fn get_with_proof_async( + &self, + requested_items: Vec, + ) -> Result> { + let req = UpdateToLatestLedgerRequest::new(0, requested_items.clone()); + debug!("get_with_proof with request: {:?}", req); + let proto_req = req.clone().into_proto(); + let arc_validator_verifier: Arc = Arc::clone(&self.validator_verifier); + let ret = self + .client + .update_to_latest_ledger_async_opt(&proto_req, Self::get_default_grpc_call_option())? + .then(move |get_with_proof_resp| { + // TODO: Cache/persist client_known_version to work with validator set change when + // the feature is available. + + let resp = UpdateToLatestLedgerResponse::from_proto(get_with_proof_resp?)?; + resp.verify(arc_validator_verifier, &req)?; + Ok(resp) + }); + Ok(ret) + } + + fn need_to_retry(try_cnt: &mut u64, ret: &Result) -> bool { + if *try_cnt <= MAX_GRPC_RETRY_COUNT { + *try_cnt += 1; + if let Err(error) = ret { + if let Some(grpc_error) = error.downcast_ref::() { + if let grpcio::Error::RpcFailure(grpc_rpc_failure) = grpc_error { + // Only retry when the connection is down to make sure we won't + // send one txn twice. + return grpc_rpc_failure.status == grpcio::RpcStatusCode::Unavailable; + } + } + } + } + false + } + /// Sync version of get_with_proof + pub fn get_with_proof_sync( + &self, + requested_items: Vec, + ) -> Result { + let mut resp: Result = + self.get_with_proof_async(requested_items.clone())?.wait(); + let mut try_cnt = 0_u64; + + while Self::need_to_retry(&mut try_cnt, &resp) { + resp = self.get_with_proof_async(requested_items.clone())?.wait(); + } + + Ok(resp?) + } + + fn get_balances_async( + &self, + addresses: &[AccountAddress], + ) -> Result, Error = failure::Error>> { + let requests = addresses + .iter() + .map(|addr| RequestItem::GetAccountState { address: *addr }) + .collect::>(); + + let num_addrs = addresses.len(); + let get_with_proof_resp = self.get_with_proof_async(requests)?; + Ok(get_with_proof_resp.then(move |get_with_proof_resp| { + let rust_resp = get_with_proof_resp?; + if rust_resp.response_items.len() != num_addrs { + bail!("Server returned wrong number of responses"); + } + + let mut balances = vec![]; + for value_with_proof in rust_resp.response_items { + debug!("get_balance response is: {:?}", value_with_proof); + match value_with_proof { + ResponseItem::GetAccountState { + account_state_with_proof, + } => { + let balance = + get_account_resource_or_default(&account_state_with_proof.blob)? + .balance(); + balances.push(balance); + } + _ => bail!( + "Incorrect type of response returned: {:?}", + value_with_proof + ), + } + } + Ok(balances) + })) + } + + pub(crate) fn get_balance(&self, address: AccountAddress) -> Result { + let mut ret = self.get_balances_async(&[address])?.wait(); + let mut try_cnt = 0_u64; + while Self::need_to_retry(&mut try_cnt, &ret) { + ret = self.get_balances_async(&[address])?.wait(); + } + + ret?.pop() + .ok_or_else(|| format_err!("Account is not available!")) + } + + /// Get the latest account sequence number for the account specified. + pub fn get_sequence_number(&self, address: AccountAddress) -> Result { + Ok(get_account_resource_or_default(&self.get_account_blob(address)?.0)?.sequence_number()) + } + + /// Get the latest account state blob from validator. + pub fn get_account_blob( + &self, + address: AccountAddress, + ) -> Result<(Option, Version)> { + let req_item = RequestItem::GetAccountState { address }; + + let mut response = self.get_with_proof_sync(vec![req_item])?; + let account_state_with_proof = response + .response_items + .remove(0) + .into_get_account_state_response()?; + + Ok(( + account_state_with_proof.blob, + response.ledger_info_with_sigs.ledger_info().version(), + )) + } + + /// Get transaction from validator by account and sequence number. + pub fn get_txn_by_acc_seq( + &self, + account: AccountAddress, + sequence_number: u64, + fetch_events: bool, + ) -> Result>)>> { + let req_item = RequestItem::GetAccountTransactionBySequenceNumber { + account, + sequence_number, + fetch_events, + }; + + let mut response = self.get_with_proof_sync(vec![req_item])?; + let (signed_txn_with_proof, _) = response + .response_items + .remove(0) + .into_get_account_txn_by_seq_num_response()?; + + Ok(signed_txn_with_proof.map(|t| (t.signed_transaction, t.events))) + } + + /// Get transactions in range (start_version..start_version + limit - 1) from validator. + pub fn get_txn_by_range( + &self, + start_version: u64, + limit: u64, + fetch_events: bool, + ) -> Result>)>> { + // Make the request. + let req_item = RequestItem::GetTransactions { + start_version, + limit, + fetch_events, + }; + let mut response = self.get_with_proof_sync(vec![req_item])?; + let txn_list_with_proof = response + .response_items + .remove(0) + .into_get_transactions_response()?; + + // Transform the response. + let num_txns = txn_list_with_proof.transaction_and_infos.len(); + let event_lists = txn_list_with_proof + .events + .map(|event_lists| event_lists.into_iter().map(Some).collect()) + .unwrap_or_else(|| vec![None; num_txns]); + + let res = itertools::zip_eq(txn_list_with_proof.transaction_and_infos, event_lists) + .map(|((signed_txn, _), events)| (signed_txn, events)) + .collect(); + Ok(res) + } + + /// Get event by access path from validator. AccountStateWithProof will be returned if + /// 1. No event is available. 2. Ascending and available event number < limit. + /// 3. Descending and start_seq_num > latest account event sequence number. + pub fn get_events_by_access_path( + &self, + access_path: AccessPath, + start_event_seq_num: u64, + ascending: bool, + limit: u64, + ) -> Result<(Vec, Option)> { + let req_item = RequestItem::GetEventsByEventAccessPath { + access_path, + start_event_seq_num, + ascending, + limit, + }; + + let mut response = self.get_with_proof_sync(vec![req_item])?; + let value_with_proof = response.response_items.remove(0); + match value_with_proof { + ResponseItem::GetEventsByEventAccessPath { + events_with_proof, + proof_of_latest_event, + } => Ok((events_with_proof, proof_of_latest_event)), + _ => bail!( + "Incorrect type of response returned: {:?}", + value_with_proof + ), + } + } + + fn get_default_grpc_call_option() -> CallOption { + CallOption::default() + .wait_for_ready(true) + .timeout(std::time::Duration::from_millis(5000)) + } +} + +#[cfg(test)] +mod tests { + use crate::client_proxy::{AddressAndIndex, ClientProxy}; + use config::trusted_peers::TrustedPeersConfigHelpers; + use libra_wallet::io_utils; + use tempfile::NamedTempFile; + + pub fn generate_accounts_from_wallet(count: usize) -> (ClientProxy, Vec) { + let mut accounts = Vec::new(); + accounts.reserve(count); + let file = NamedTempFile::new().unwrap(); + let mnemonic_path = file.into_temp_path().to_str().unwrap().to_string(); + let trust_peer_file = NamedTempFile::new().unwrap(); + let (_, trust_peer_config) = TrustedPeersConfigHelpers::get_test_config(1, None); + let trust_peer_path = trust_peer_file.into_temp_path(); + trust_peer_config.save_config(&trust_peer_path); + + let val_set_file = trust_peer_path.to_str().unwrap().to_string(); + + // We don't need to specify host/port since the client won't be used to connect, only to + // generate random accounts + let mut client_proxy = ClientProxy::new( + "", /* host */ + "", /* port */ + &val_set_file, + &"", + None, + Some(mnemonic_path), + ) + .unwrap(); + for _ in 0..count { + accounts.push(client_proxy.create_next_account(&["c"]).unwrap()); + } + + (client_proxy, accounts) + } + + #[test] + fn test_generate() { + let num = 1; + let (_, accounts) = generate_accounts_from_wallet(num); + assert_eq!(accounts.len(), num); + } + + #[test] + fn test_write_recover() { + let num = 100; + let (client, accounts) = generate_accounts_from_wallet(num); + assert_eq!(accounts.len(), num); + + let file = NamedTempFile::new().unwrap(); + let path = file.into_temp_path(); + io_utils::write_recovery(&client.wallet, &path).expect("failed to write to file"); + + let wallet = io_utils::recover(&path).expect("failed to load from file"); + + assert_eq!(client.wallet.mnemonic(), wallet.mnemonic()); + } +} diff --git a/client/src/lib.rs b/client/src/lib.rs new file mode 100644 index 0000000000000..6e5660d46f40a --- /dev/null +++ b/client/src/lib.rs @@ -0,0 +1,47 @@ +// Copyright (c) The Libra Core Contributors +// SPDX-License-Identifier: Apache-2.0 + +#![feature(duration_float)] +#![deny(missing_docs)] +//! Libra Client +//! +//! Client (binary) is the CLI tool to interact with Libra validator. +//! It supposes all public APIs. +use crypto::signing::KeyPair; +use serde::{Deserialize, Serialize}; +use types::account_address::AccountAddress; + +pub(crate) mod account_commands; +/// Main instance of client holding corresponding information, e.g. account address. +pub mod client_proxy; +/// Command struct to interact with client. +pub mod commands; +/// gRPC client wrapper to connect to validator. +pub(crate) mod grpc_client; +pub(crate) mod query_commands; +pub(crate) mod transfer_commands; + +/// Struct used to store data for each created account. We track the sequence number +/// so we can create new transactions easily +#[derive(Clone, Debug, Serialize, Deserialize, PartialEq)] +pub struct AccountData { + /// Address of the account. + pub address: AccountAddress, + /// (private_key, public_key) pair if the account is not managed by wallet + pub key_pair: Option, + /// Latest sequence number maintained by client, it can be different from validator. + pub sequence_number: u64, +} + +impl AccountData { + /// Serialize account keypair if exists. + pub fn keypair_as_string(&self) -> Option<(String, String)> { + match &self.key_pair { + Some(key_pair) => Some(( + crypto::utils::encode_to_string(&key_pair.private_key()), + crypto::utils::encode_to_string(&key_pair.public_key()), + )), + None => None, + } + } +} diff --git a/client/src/main.rs b/client/src/main.rs new file mode 100644 index 0000000000000..6ff92c985264c --- /dev/null +++ b/client/src/main.rs @@ -0,0 +1,137 @@ +// Copyright (c) The Libra Core Contributors +// SPDX-License-Identifier: Apache-2.0 + +use client::{client_proxy::ClientProxy, commands::*}; +use logger::set_default_global_logger; +use rustyline::{config::CompletionType, error::ReadlineError, Config, Editor}; +use structopt::StructOpt; + +#[derive(Debug, StructOpt)] +#[structopt( + name = "Libra Client", + author = "The Libra Association", + about = "Libra client to connect to a specific validator" +)] +struct Args { + /// Admission Control port to connect to. + #[structopt(short = "p", long = "port", default_value = "30307")] + pub port: String, + /// Host address/name to connect to. + #[structopt(short = "a", long = "host")] + pub host: String, + /// Path to the generated keypair for the faucet account. The faucet account can be used to + /// mint coins. If not passed, a new keypair will be generated for + /// you and placed in a temporary directory. + /// To manually generate a keypair, use generate_keypair: + /// `cargo run -p generate_keypair -- -o ` + #[structopt(short = "m", long = "faucet_key_file_path")] + pub faucet_account_file: Option, + /// Host that operates a faucet service + /// If not passed, will be derived from host parameter + #[structopt(short = "f", long = "faucet_server")] + pub faucet_server: Option, + /// File location from which to load mnemonic word for user account address/key generation. + /// If not passed, a new mnemonic file will be generated by libra_wallet in the current + /// directory. + #[structopt(short = "n", long = "mnemonic_file")] + pub mnemonic_file: Option, + /// File location from which to load config of trusted validators. It is used to verify + /// validator signatures in validator query response. The file should at least include public + /// key of all validators trusted by the client - which should typically be all validators on + /// the network. To connect to testnet, use 'libra/scripts/cli/trusted_peers.config.toml'. + /// Can be generated by libra-config for local testing: + /// `cargo run --bin libra-config` + /// But the preferred method is to simply use libra-swarm to run local networks + #[structopt(short = "s", long = "validator_set_file")] + pub validator_set_file: String, +} + +fn main() -> std::io::Result<()> { + let _logger = set_default_global_logger(false /* async */, None); + crash_handler::setup_panic_handler(); + + let (commands, alias_to_cmd) = get_commands(); + + let args = Args::from_args(); + let faucet_account_file = args.faucet_account_file.unwrap_or_else(|| "".to_string()); + + let mut client_proxy = ClientProxy::new( + &args.host, + &args.port, + &args.validator_set_file, + &faucet_account_file, + args.faucet_server, + args.mnemonic_file, + ) + .map_err(|e| std::io::Error::new(std::io::ErrorKind::Other, &format!("{}", e)[..]))?; + + // Test connection to validator + let test_ret = client_proxy.test_validator_connection(); + + if let Err(e) = test_ret { + println!( + "Not able to connect to validator at {}:{}, error {:?}", + args.host, args.port, e + ); + return Ok(()); + } + let cli_info = format!("Connected to validator at: {}:{}", args.host, args.port); + print_help(&cli_info, &commands); + println!("Please, input commands: \n"); + + let config = Config::builder() + .history_ignore_space(true) + .completion_type(CompletionType::List) + .auto_add_history(true) + .build(); + let mut rl = Editor::<()>::with_config(config); + loop { + let readline = rl.readline("libra% "); + match readline { + Ok(line) => { + let params = parse_cmd(&line); + match alias_to_cmd.get(params[0]) { + Some(cmd) => cmd.execute(&mut client_proxy, ¶ms), + None => match params[0] { + "quit" | "q!" => break, + "help" | "h" => print_help(&cli_info, &commands), + "" => continue, + x => println!("Unknown command: {:?}", x), + }, + } + } + Err(ReadlineError::Interrupted) => { + println!("CTRL-C"); + break; + } + Err(ReadlineError::Eof) => { + println!("CTRL-D"); + break; + } + Err(err) => { + println!("Error: {:?}", err); + break; + } + } + } + + Ok(()) +} + +/// Print the help message for the client and underlying command. +fn print_help(client_info: &str, commands: &[std::sync::Arc]) { + println!("{}", client_info); + println!("usage: \n\nUse the following commands:\n"); + for cmd in commands { + println!( + "{} {}\n\t{}", + cmd.get_aliases().join(" | "), + cmd.get_params_help(), + cmd.get_description() + ); + } + + println!("help | h \n\tPrints this help"); + println!("quit | q! \n\tExit this client"); + println!("\n"); +} diff --git a/client/src/query_commands.rs b/client/src/query_commands.rs new file mode 100644 index 0000000000000..454572d9247d7 --- /dev/null +++ b/client/src/query_commands.rs @@ -0,0 +1,231 @@ +// Copyright (c) The Libra Core Contributors +// SPDX-License-Identifier: Apache-2.0 + +use crate::{client_proxy::ClientProxy, commands::*}; +use types::account_config::get_account_resource_or_default; +use vm_genesis::get_transaction_name; + +/// Major command for query operations. +pub struct QueryCommand {} + +impl Command for QueryCommand { + fn get_aliases(&self) -> Vec<&'static str> { + vec!["query", "q"] + } + fn get_description(&self) -> &'static str { + "Query operations" + } + fn execute(&self, client: &mut ClientProxy, params: &[&str]) { + let commands: Vec> = vec![ + Box::new(QueryCommandGetBalance {}), + Box::new(QueryCommandGetSeqNum {}), + Box::new(QueryCommandGetLatestAccountState {}), + Box::new(QueryCommandGetTxnByAccountSeq {}), + Box::new(QueryCommandGetTxnByRange {}), + Box::new(QueryCommandGetEvent {}), + ]; + + subcommand_execute(¶ms[0], commands, client, ¶ms[1..]); + } +} + +/// Sub commands to query balance for the account specified. +pub struct QueryCommandGetBalance {} + +impl Command for QueryCommandGetBalance { + fn get_aliases(&self) -> Vec<&'static str> { + vec!["balance", "b"] + } + fn get_params_help(&self) -> &'static str { + "|" + } + fn get_description(&self) -> &'static str { + "Get the current balance of an account" + } + fn execute(&self, client: &mut ClientProxy, params: &[&str]) { + if params.len() != 2 { + println!("Invalid number of arguments for balance query"); + return; + } + match client.get_balance(¶ms) { + Ok(balance) => println!("Balance is: {}", balance), + Err(e) => report_error("Failed to get balance", e), + } + } +} + +/// Sub command to get the latest sequence number from validator for the account specified. +pub struct QueryCommandGetSeqNum {} + +impl Command for QueryCommandGetSeqNum { + fn get_aliases(&self) -> Vec<&'static str> { + vec!["sequence", "s"] + } + fn get_params_help(&self) -> &'static str { + "| [reset_sequence_number=true|false]" + } + fn get_description(&self) -> &'static str { + "Get the current sequence number for an account, \ + and reset current sequence number in CLI (optional, default is false)" + } + fn execute(&self, client: &mut ClientProxy, params: &[&str]) { + println!(">> Getting current sequence number"); + match client.get_sequence_number(¶ms) { + Ok(sn) => println!("Sequence number is: {}", sn), + Err(e) => report_error("Error getting sequence number", e), + } + } +} + +/// Command to query latest account state from validator. +pub struct QueryCommandGetLatestAccountState {} + +impl Command for QueryCommandGetLatestAccountState { + fn get_aliases(&self) -> Vec<&'static str> { + vec!["account_state", "as"] + } + fn get_params_help(&self) -> &'static str { + "|" + } + fn get_description(&self) -> &'static str { + "Get the latest state for an account" + } + fn execute(&self, client: &mut ClientProxy, params: &[&str]) { + println!(">> Getting latest account state"); + match client.get_latest_account_state(¶ms) { + Ok((acc, version)) => match get_account_resource_or_default(&acc) { + Ok(_) => println!( + "Latest account state is: \n \ + Account: {:#?}\n \ + State: {:#?}\n \ + Blockchain Version: {}\n", + client + .get_account_address_from_parameter(params[1]) + .expect("Unable to parse account parameter"), + acc, + version, + ), + Err(e) => report_error("Error converting account blob to account resource", e), + }, + Err(e) => report_error("Error getting latest account state", e), + } + } +} + +/// Sub command to get transaction by account and sequence number from validator. +pub struct QueryCommandGetTxnByAccountSeq {} + +impl Command for QueryCommandGetTxnByAccountSeq { + fn get_aliases(&self) -> Vec<&'static str> { + vec!["txn_acc_seq", "ts"] + } + fn get_params_help(&self) -> &'static str { + "| " + } + fn get_description(&self) -> &'static str { + "Get the committed transaction by account and sequence number. \ + Optionally also fetch events emitted by this transaction." + } + fn execute(&self, client: &mut ClientProxy, params: &[&str]) { + println!(">> Getting committed transaction by account and sequence number"); + match client.get_committed_txn_by_acc_seq(¶ms) { + Ok(txn_and_events) => { + match txn_and_events { + Some((comm_txn, events)) => { + println!( + "Committed transaction: {}", + comm_txn.format_for_client(get_transaction_name) + ); + if let Some(events_inner) = &events { + println!("Events: "); + for event in events_inner { + println!("{}", event); + } + } + } + None => println!("Transaction not available"), + }; + } + Err(e) => report_error( + "Error getting committed transaction by account and sequence number", + e, + ), + } + } +} + +/// Sub command to query transactions by range from validator. +pub struct QueryCommandGetTxnByRange {} + +impl Command for QueryCommandGetTxnByRange { + fn get_aliases(&self) -> Vec<&'static str> { + vec!["txn_range", "tr"] + } + fn get_params_help(&self) -> &'static str { + " " + } + fn get_description(&self) -> &'static str { + "Get the committed transactions by version range. \ + Optionally also fetch events emitted by these transactions." + } + fn execute(&self, client: &mut ClientProxy, params: &[&str]) { + println!(">> Getting committed transaction by range"); + match client.get_committed_txn_by_range(¶ms) { + Ok(comm_txns_and_events) => { + // Note that this should never panic because we shouldn't return items + // if the version wasn't able to be parsed in the first place + let mut cur_version = params[1].parse::().expect("Unable to parse version"); + for (txn, opt_events) in comm_txns_and_events { + println!( + "Transaction at version {}: {}", + cur_version, + txn.format_for_client(get_transaction_name) + ); + if opt_events.is_some() { + let events = opt_events.unwrap(); + if events.is_empty() { + println!("No events returned"); + } else { + for event in events { + println!("{}", event); + } + } + } + cur_version += 1; + } + } + Err(e) => report_error("Error getting committed transactions by range", e), + } + } +} + +/// Sub command to query events from validator. +pub struct QueryCommandGetEvent {} + +impl Command for QueryCommandGetEvent { + fn get_aliases(&self) -> Vec<&'static str> { + vec!["event", "ev"] + } + fn get_params_help(&self) -> &'static str { + "| " + } + fn get_description(&self) -> &'static str { + "Get events by account and event type (sent|received)." + } + fn execute(&self, client: &mut ClientProxy, params: &[&str]) { + println!(">> Getting events by account and event type."); + match client.get_events_by_account_and_type(¶ms) { + Ok((events, last_event_state)) => { + if events.is_empty() { + println!("No events returned"); + } else { + for event in events { + println!("{}", event); + } + } + println!("Last event state: {:#?}", last_event_state); + } + Err(e) => report_error("Error getting events by access path", e), + } + } +} diff --git a/client/src/transfer_commands.rs b/client/src/transfer_commands.rs new file mode 100644 index 0000000000000..21c4ece2dc407 --- /dev/null +++ b/client/src/transfer_commands.rs @@ -0,0 +1,51 @@ +// Copyright (c) The Libra Core Contributors +// SPDX-License-Identifier: Apache-2.0 + +use crate::{client_proxy::ClientProxy, commands::*}; + +/// Command to transfer coins between two accounts. +pub struct TransferCommand {} + +impl Command for TransferCommand { + fn get_aliases(&self) -> Vec<&'static str> { + vec!["transfer", "transferb", "t", "tb"] + } + fn get_params_help(&self) -> &'static str { + "\n\t| \ + | \ + [gas_unit_price (default=0)] [max_gas_amount (default 10000)] \ + Suffix 'b' is for blocking. " + } + fn get_description(&self) -> &'static str { + "Transfer coins from account to another." + } + fn execute(&self, client: &mut ClientProxy, params: &[&str]) { + if params.len() < 4 || params.len() > 6 { + println!("Invalid number of arguments for transfer"); + println!( + "{} {}", + self.get_aliases().join(" | "), + self.get_params_help() + ); + return; + } + + println!(">> Transferring"); + let is_blocking = blocking_cmd(¶ms[0]); + match client.transfer_coins(¶ms, is_blocking) { + Ok(index_and_seq) => { + if is_blocking { + println!("Finished transaction!"); + } else { + println!("Transaction submitted to validator"); + } + println!( + "To query for transaction status, run: query txn_acc_seq {} {} \ + ", + index_and_seq.account_index, index_and_seq.sequence_number + ); + } + Err(e) => report_error("Failed to perform transaction", e), + } + } +} diff --git a/clippy.toml b/clippy.toml new file mode 100644 index 0000000000000..a1e1f754f92f4 --- /dev/null +++ b/clippy.toml @@ -0,0 +1,6 @@ +# cyclomatic complexity is not always useful +cognitive-complexity-threshold = 100 +# types are used for safety encoding +type-complexity-threshold = 10000 +# manipulating complex states machines in consensus +too-many-arguments-threshold = 15 diff --git a/common/build_helpers/Cargo.toml b/common/build_helpers/Cargo.toml new file mode 100644 index 0000000000000..7731cce35a643 --- /dev/null +++ b/common/build_helpers/Cargo.toml @@ -0,0 +1,13 @@ +[package] +name = "build_helpers" +version = "0.1.0" +authors = ["Libra Association "] +license = "Apache-2.0" +publish = false +edition = "2018" + +[dependencies] +protoc-grpcio = "0.3.1" +walkdir = "2.2.0" + +grpcio-client = {path = "../grpcio-client"} diff --git a/common/build_helpers/src/build_helpers.rs b/common/build_helpers/src/build_helpers.rs new file mode 100644 index 0000000000000..1f1efa438ca80 --- /dev/null +++ b/common/build_helpers/src/build_helpers.rs @@ -0,0 +1,81 @@ +// Copyright (c) The Libra Core Contributors +// SPDX-License-Identifier: Apache-2.0 + +/// Contains helpers for build.rs files. Includes helpers for proto compilation +use std::path::{Path, PathBuf}; + +use walkdir::WalkDir; + +// Compiles all proto files under proto root and dependent roots. +// For example, if there is a file `src/a/b/c.proto`, it will generate `src/a/b/c.rs` and +// `src/a/b/c_grpc.rs`. +pub fn compile_proto(proto_root: &str, dependent_roots: Vec<&str>, generate_client_code: bool) { + let mut additional_includes = vec![]; + for dependent_root in dependent_roots { + // First compile dependent directories + compile_dir( + &dependent_root, + vec![], /* additional_includes */ + false, /* generate_client_code */ + ); + additional_includes.push(Path::new(dependent_root).to_path_buf()); + } + // Now compile this directory + compile_dir(&proto_root, additional_includes, generate_client_code); +} + +// Compile all of the proto files in proto_root directory and use the additional +// includes when compiling. +pub fn compile_dir( + proto_root: &str, + additional_includes: Vec, + generate_client_code: bool, +) { + for entry in WalkDir::new(proto_root) { + let p = entry.unwrap(); + if p.file_type().is_dir() { + continue; + } + + let path = p.path(); + if let Some(ext) = path.extension() { + if ext != "proto" { + continue; + } + println!("cargo:rerun-if-changed={}", path.display()); + compile(&path, &additional_includes, generate_client_code); + } + } +} + +fn compile(path: &Path, additional_includes: &[PathBuf], generate_client_code: bool) { + let parent = path.parent().unwrap(); + let mut src_path = parent.to_owned().to_path_buf(); + src_path.push("src"); + + let mut includes = additional_includes.to_owned(); + includes.push(parent.to_path_buf()); + + ::protoc_grpcio::compile_grpc_protos(&[path], includes.as_slice(), parent) + .unwrap_or_else(|_| panic!("Failed to compile protobuf input: {:?}", path)); + + if generate_client_code { + let file_string = path + .file_name() + .expect("unable to get filename") + .to_str() + .unwrap(); + let includes_strings = includes + .iter() + .map(|x| x.to_str().unwrap()) + .collect::>(); + + // generate client code + grpcio_client::client_stub_gen( + &[file_string], + includes_strings.as_slice(), + &parent.to_str().unwrap(), + ) + .expect("Unable to generate client stub"); + } +} diff --git a/common/build_helpers/src/lib.rs b/common/build_helpers/src/lib.rs new file mode 100644 index 0000000000000..454e5fa1739ea --- /dev/null +++ b/common/build_helpers/src/lib.rs @@ -0,0 +1,4 @@ +// Copyright (c) The Libra Core Contributors +// SPDX-License-Identifier: Apache-2.0 + +pub mod build_helpers; diff --git a/common/canonical_serialization/Cargo.toml b/common/canonical_serialization/Cargo.toml new file mode 100644 index 0000000000000..28b85c818b18a --- /dev/null +++ b/common/canonical_serialization/Cargo.toml @@ -0,0 +1,15 @@ +[package] +name = "canonical_serialization" +version = "0.1.0" +authors = ["Libra Association "] +license = "Apache-2.0" +publish = false +edition = "2018" + +[dependencies] +byteorder = "1.3.1" + +failure = { path = "../failure_ext", package = "failure_ext" } + +[dev-dependencies] +hex = "0.3" diff --git a/common/canonical_serialization/src/canonical_serialization_test.rs b/common/canonical_serialization/src/canonical_serialization_test.rs new file mode 100644 index 0000000000000..c44a830f458d0 --- /dev/null +++ b/common/canonical_serialization/src/canonical_serialization_test.rs @@ -0,0 +1,321 @@ +// Copyright (c) The Libra Core Contributors +// SPDX-License-Identifier: Apache-2.0 + +//https://rust-lang.github.io/rust-clippy/master/index.html#blacklisted_name +//disable it in test so that we can use variable names such as 'foo' and 'bar' +#![allow(clippy::blacklisted_name)] +#![allow(clippy::many_single_char_names)] + +use super::*; +use byteorder::WriteBytesExt; +use failure::Result; +use std::u32; + +// Do not change the test vectors. Please read the comment below. +const TEST_VECTOR_1: &str = "ffffffffffffffff060000006463584d4237640000000000000009000000000102\ + 03040506070805050505050505050505050505050505050505050505050505050505\ + 05050505630000000103000000010000000103000000161543030000000038150300\ + 0000160a05040000001415596903000000c9175a"; + +// Why do we need test vectors? +// +// 1. Sometimes it help to catch common bugs between serialization and +// deserialization functions that would have been missed by a simple round trip test. +// For example, if there's a bug in a shared procedure that serialize and +// deserialize both calls then roundtrip might miss it. +// +// 2. It helps to catch code changes that inadvertently introduce breaking changes +// in the serialization format that is incompatible with what generated in the +// past which would be missed by roundtrip tests, or changes that are not backward +// compatible in the sense that it may fail to deserialize bytes generated in the past. + +#[derive(Clone, Debug, Eq, PartialEq)] +pub struct Addr(pub [u8; 32]); + +impl Addr { + fn new(bytes: [u8; 32]) -> Self { + Addr(bytes) + } +} + +#[derive(Clone, Debug, Eq, PartialEq)] +struct Foo { + a: u64, + b: Vec, + c: Bar, + d: bool, + e: BTreeMap, Vec>, +} + +#[derive(Clone, Debug, Eq, PartialEq)] +struct Bar { + a: u64, + b: Vec, + c: Addr, + d: u32, +} + +impl CanonicalSerialize for Foo { + fn serialize(&self, serializer: &mut impl CanonicalSerializer) -> Result<()> { + serializer + .encode_u64(self.a)? + .encode_variable_length_bytes(&self.b)? + .encode_struct(&self.c)? + .encode_bool(self.d)? + .encode_btreemap(&self.e)?; + Ok(()) + } +} + +impl CanonicalSerialize for Bar { + fn serialize(&self, serializer: &mut impl CanonicalSerializer) -> Result<()> { + serializer + .encode_u64(self.a)? + .encode_variable_length_bytes(&self.b)? + .encode_raw_bytes(&self.c.0)? + .encode_u32(self.d)?; + Ok(()) + } +} + +impl CanonicalDeserialize for Foo { + fn deserialize(deserializer: &mut impl CanonicalDeserializer) -> Result { + let a = deserializer.decode_u64()?; + let b = deserializer.decode_variable_length_bytes()?; + let c: Bar = deserializer.decode_struct::()?; + let d: bool = deserializer.decode_bool()?; + let e: BTreeMap, Vec> = deserializer.decode_btreemap()?; + Ok(Foo { a, b, c, d, e }) + } +} + +impl CanonicalDeserialize for Bar { + fn deserialize(deserializer: &mut impl CanonicalDeserializer) -> Result { + let a = deserializer.decode_u64()?; + let b = deserializer.decode_variable_length_bytes()?; + let c = deserializer.decode_bytes_with_len(32)?; + let mut cc: [u8; 32] = [0; 32]; + cc.copy_from_slice(c.as_slice()); + + let d = deserializer.decode_u32()?; + Ok(Bar { + a, + b, + c: Addr::new(cc), + d, + }) + } +} + +#[test] +fn test_btreemap_encode() { + let mut map = BTreeMap::new(); + let value = vec![54, 20, 21, 200]; + let key1 = vec![0]; // after serialization: [1, 0] + let key2 = vec![0, 6]; // after serialization: [2, 0, 6] + let key3 = vec![1]; // after serialization: [1, 1] + let key4 = vec![2]; // after serialization: [1, 2] + map.insert(key1.clone(), value.clone()); + map.insert(key2.clone(), value.clone()); + map.insert(key3.clone(), value.clone()); + map.insert(key4.clone(), value.clone()); + + let serialized_bytes = SimpleSerializer::>::serialize(&map).unwrap(); + + let mut deserializer = SimpleDeserializer::new(&serialized_bytes); + + // ensure the order was encoded in lexicographic order + assert_eq!(deserializer.raw_bytes.read_u32::().unwrap(), 4); + assert_eq!(deserializer.decode_variable_length_bytes().unwrap(), key1); + assert_eq!(deserializer.decode_variable_length_bytes().unwrap(), value); + assert_eq!(deserializer.decode_variable_length_bytes().unwrap(), key3); + assert_eq!(deserializer.decode_variable_length_bytes().unwrap(), value); + assert_eq!(deserializer.decode_variable_length_bytes().unwrap(), key4); + assert_eq!(deserializer.decode_variable_length_bytes().unwrap(), value); + assert_eq!(deserializer.decode_variable_length_bytes().unwrap(), key2); + assert_eq!(deserializer.decode_variable_length_bytes().unwrap(), value); +} + +#[test] +fn test_serialization_roundtrip() { + let bar = Bar { + a: 50, + b: vec![10u8; 100], + c: Addr::new([3u8; 32]), + d: 12, + }; + + let mut map = BTreeMap::new(); + map.insert(vec![0, 56, 21], vec![22, 10, 5]); + map.insert(vec![1], vec![22, 21, 67]); + map.insert(vec![20, 21, 89, 105], vec![201, 23, 90]); + + let foo = Foo { + a: 1, + b: vec![32, 41, 190, 200, 2, 5, 90, 100, 123, 234, 159, 159, 101], + c: bar, + d: false, + e: map, + }; + + let mut serializer = SimpleSerializer::>::new(); + foo.serialize(&mut serializer).unwrap(); + let serialized_bytes = serializer.get_output(); + + let mut deserializer = SimpleDeserializer::new(&serialized_bytes); + let deserialized_foo = Foo::deserialize(&mut deserializer).unwrap(); + assert_eq!(foo, deserialized_foo); + assert_eq!( + deserializer.raw_bytes.position(), + deserializer.raw_bytes.get_ref().len() as u64 + ); +} + +#[test] +fn test_encode_vec() { + let bar1 = Bar { + a: 55, + b: vec![10u8; 100], + c: Addr::new([3u8; 32]), + d: 77, + }; + let bar2 = Bar { + a: 123, + b: vec![1, 5, 20], + c: Addr::new([8u8; 32]), + d: 127, + }; + + let mut vec = Vec::new(); + vec.push(bar1.clone()); + vec.push(bar2.clone()); + let mut serializer = SimpleSerializer::>::new(); + serializer.encode_vec(&vec).unwrap(); + let serialized_bytes = serializer.get_output(); + + let de_vec: Vec = SimpleDeserializer::deserialize(&serialized_bytes).unwrap(); + + assert_eq!(2, de_vec.len()); + assert_eq!(bar1, de_vec[0]); + assert_eq!(bar2, de_vec[1]); + + // test Vec implementation + let mut serializer = SimpleSerializer::>::new(); + serializer.encode_struct(&vec).unwrap(); + let serialized_bytes = serializer.get_output(); + let de_vec: Vec = SimpleDeserializer::deserialize(&serialized_bytes).unwrap(); + + assert_eq!(2, de_vec.len()); + assert_eq!(bar1, de_vec[0]); + assert_eq!(bar2, de_vec[1]); +} + +#[test] +fn test_vec_impl() { + let mut vec: Vec = Vec::new(); + vec.push(std::i32::MIN); + vec.push(std::i32::MAX); + vec.push(100); + + let mut serializer = SimpleSerializer::>::new(); + serializer.encode_struct(&vec).unwrap(); + let serialized_bytes = serializer.get_output(); + let de_vec: Vec = SimpleDeserializer::deserialize(&serialized_bytes).unwrap(); + assert_eq!(vec, de_vec); +} + +#[test] +fn test_vectors_1() { + let bar = Bar { + a: 100, + b: vec![0, 1, 2, 3, 4, 5, 6, 7, 8], + c: Addr::new([5u8; 32]), + d: 99, + }; + + let mut map = BTreeMap::new(); + map.insert(vec![0, 56, 21], vec![22, 10, 5]); + map.insert(vec![1], vec![22, 21, 67]); + map.insert(vec![20, 21, 89, 105], vec![201, 23, 90]); + + let foo = Foo { + a: u64::max_value(), + b: vec![100, 99, 88, 77, 66, 55], + c: bar, + d: true, + e: map, + }; + + let mut serializer = SimpleSerializer::>::new(); + foo.serialize(&mut serializer).unwrap(); + let serialized_bytes = serializer.get_output(); + + // make sure we serialize into exact same bytes as before + assert_eq!(TEST_VECTOR_1, hex::encode(serialized_bytes)); + + // make sure we can deserialize the test vector into expected struct + let test_vector_bytes = hex::decode(TEST_VECTOR_1).unwrap(); + let deserialized_foo: Foo = SimpleDeserializer::deserialize(&test_vector_bytes).unwrap(); + assert_eq!(foo, deserialized_foo); +} + +#[test] +fn test_serialization_failure_cases() { + // a vec longer than representable range should result in failure + let bar = Bar { + a: 100, + b: vec![0; i32::max_value() as usize + 1], + c: Addr::new([0u8; 32]), + d: 222, + }; + + let mut serializer = SimpleSerializer::>::new(); + assert!(bar.serialize(&mut serializer).is_err()); +} + +#[test] +fn test_deserialization_failure_cases() { + // invalid length prefix should fail on all decoding methods + let bytes_len_2 = vec![0; 2]; + let mut deserializer = SimpleDeserializer::new(&bytes_len_2); + assert!(deserializer.clone().decode_u64().is_err()); + assert!(deserializer.clone().decode_bytes_with_len(32).is_err()); + assert!(deserializer.clone().decode_variable_length_bytes().is_err()); + assert!(deserializer.clone().decode_struct::().is_err()); + assert!(Foo::deserialize(&mut deserializer.clone()).is_err()); + + // a length prefix longer than maximum allowed should fail + let mut long_bytes = Vec::new(); + long_bytes + .write_u32::(ARRAY_MAX_LENGTH as u32 + 1) + .unwrap(); + deserializer = SimpleDeserializer::new(&long_bytes); + assert!(deserializer.clone().decode_variable_length_bytes().is_err()); + + // vec not long enough should fail + let mut bytes_len_10 = Vec::new(); + bytes_len_10.write_u32::(32).unwrap(); + deserializer = SimpleDeserializer::new(&bytes_len_10); + assert!(deserializer.clone().decode_variable_length_bytes().is_err()); + assert!(deserializer.clone().decode_bytes_with_len(32).is_err()); + + // malformed struct should fail + let mut some_bytes = Vec::new(); + some_bytes.write_u64::(10).unwrap(); + some_bytes.write_u32::(50).unwrap(); + deserializer = SimpleDeserializer::new(&some_bytes); + assert!(deserializer.clone().decode_struct::().is_err()); + + // malformed encoded bytes with length prefix larger than real + let mut evil_bytes = Vec::new(); + evil_bytes.write_u32::(500).unwrap(); + evil_bytes.resize_with(4 + 499, Default::default); + deserializer = SimpleDeserializer::new(&evil_bytes); + assert!(deserializer.clone().decode_variable_length_bytes().is_err()); + + // malformed encoded bool with value not 0 or 1 + let mut bool_bytes = Vec::new(); + bool_bytes.write_u8(2).unwrap(); + deserializer = SimpleDeserializer::new(&bool_bytes); + assert!(deserializer.clone().decode_bool().is_err()); +} diff --git a/common/canonical_serialization/src/lib.rs b/common/canonical_serialization/src/lib.rs new file mode 100644 index 0000000000000..612e3471e802d --- /dev/null +++ b/common/canonical_serialization/src/lib.rs @@ -0,0 +1,535 @@ +// Copyright (c) The Libra Core Contributors +// SPDX-License-Identifier: Apache-2.0 + +//! This module defines traits and implementations of canonical serialization mechanism. +//! +//! A struct can implement the CanonicalSerialize trait to specify how to serialize itself, +//! and the CanonicalDeserialize trait to specify deserialization, if it needs to. One design +//! goal of this serialization format is to optimize for simplicity. It is not designed to be +//! another full-fledged network serialization as Protobuf or Thrift. It is designed +//! for doing only one thing right, which is to deterministically generate consistent bytes +//! from a data structure. +//! +//! A good example of how to use this framework is described in +//! ./canonical_serialization_test.rs +//! +//! An extremely simple implementation of CanonicalSerializer is also provided, the encoding +//! rules are: +//! (All unsigned integers are encoded in little endian representation unless specified otherwise) +//! +//! 1. The encoding of an unsigned 64-bit integer is defined as its little endian representation +//! in 8 bytes +//! +//! 2. The encoding of an item (byte array) is defined as: +//! [length in bytes, represented as 4-byte integer] || [item in bytes] +//! +//! +//! 3. The encoding of a list of items is defined as: (This is not implemented yet because +//! there is no known struct that needs it yet, but can be added later easily) +//! [No. of items in the list, represented as 4-byte integer] || encoding(item_0) || .... +//! +//! 4. The encoding of an ordered map where the keys are ordered by lexicographic order. +//! Currently we only support key and value of type Vec. The encoding is defined as: +//! [No. of key value pairs in the map, represented as 4-byte integer] || encode(key1) || +//! encode(value1) || encode(key2) || encode(value2)... +//! where the pairs are appended following the lexicographic order of the key +//! +//! What is canonical serialization? +//! +//! Canonical serialization guarantees byte consistency when serializing an in-memory +//! data structure. It is useful for situations where two parties want to efficiently compare +//! data structures they independently maintain. It happens in consensus where +//! independent validators need to agree on the state they independently compute. A cryptographic +//! hash of the serialized data structure is what ultimately gets compared. In order for +//! this to work, the serialization of the same data structures must be identical when computed +//! by independent validators potentially running different implementations +//! of the same spec in different languages. + +use byteorder::{LittleEndian, ReadBytesExt, WriteBytesExt}; +use failure::prelude::*; +use std::{ + collections::BTreeMap, + io::{Cursor, Read}, + mem::size_of, +}; + +pub mod test_helper; + +#[cfg(test)] +mod canonical_serialization_test; + +// use the signed 32-bit integer's max value as the maximum array length instead of +// unsigned 32-bit integer. This gives us the opportunity to use the additional sign bit +// to signal a length extension to support arrays longer than 2^31 in the future +const ARRAY_MAX_LENGTH: usize = i32::max_value() as usize; + +pub trait CanonicalSerialize { + fn serialize(&self, serializer: &mut impl CanonicalSerializer) -> Result<()>; +} + +pub trait CanonicalSerializer { + fn encode_struct(&mut self, structure: &impl CanonicalSerialize) -> Result<&mut Self> + where + Self: std::marker::Sized, + { + structure.serialize(self)?; + Ok(self) + } + + fn encode_u64(&mut self, v: u64) -> Result<&mut Self>; + + fn encode_u32(&mut self, v: u32) -> Result<&mut Self>; + + fn encode_u8(&mut self, v: u8) -> Result<&mut Self>; + + fn encode_bool(&mut self, b: bool) -> Result<&mut Self>; + + // Use this encoder when the length of the array is known to be fixed and always known at + // deserialization time. The raw bytes of the array without length prefix are encoded. + // For deserialization, use decode_bytes_with_len() which requires giving the length + // as input + fn encode_raw_bytes(&mut self, bytes: &[u8]) -> Result<&mut Self>; + + // Use this encoder to encode variable length byte arrays whose length may not be known at + // deserialization time. + fn encode_variable_length_bytes(&mut self, v: &[u8]) -> Result<&mut Self>; + + fn encode_btreemap( + &mut self, + v: &BTreeMap, + ) -> Result<&mut Self>; + + fn encode_vec(&mut self, v: &[T]) -> Result<&mut Self>; +} + +type Endianness = LittleEndian; + +/// An implementation of a simple canonical serialization format that implements the +/// CanonicalSerializer trait using a byte vector. +#[derive(Clone)] +pub struct SimpleSerializer { + output: W, +} + +impl Default for SimpleSerializer +where + W: Default + std::io::Write, +{ + fn default() -> Self { + SimpleSerializer::new() + } +} + +impl SimpleSerializer +where + W: Default + std::io::Write, +{ + pub fn new() -> Self { + SimpleSerializer { + output: W::default(), + } + } + + /// Create a SimpleSerializer on the fly and serialize `object` + pub fn serialize(object: &impl CanonicalSerialize) -> Result { + let mut serialzier = Self::default(); + object.serialize(&mut serialzier)?; + Ok(serialzier.get_output()) + } + + /// Consume the SimpleSerializer and return the output + pub fn get_output(self) -> W { + self.output + } +} + +impl CanonicalSerializer for SimpleSerializer +where + W: std::io::Write, +{ + fn encode_u64(&mut self, v: u64) -> Result<&mut Self> { + self.output.write_u64::(v)?; + Ok(self) + } + + fn encode_u32(&mut self, v: u32) -> Result<&mut Self> { + self.output.write_u32::(v)?; + Ok(self) + } + + fn encode_u8(&mut self, v: u8) -> Result<&mut Self> { + self.output.write_u8(v)?; + Ok(self) + } + + fn encode_bool(&mut self, b: bool) -> Result<&mut Self> { + let byte: u8 = if b { 1 } else { 0 }; + self.output.write_u8(byte)?; + Ok(self) + } + + fn encode_raw_bytes(&mut self, bytes: &[u8]) -> Result<&mut Self> { + self.output.write_all(bytes.as_ref())?; + Ok(self) + } + + fn encode_variable_length_bytes(&mut self, v: &[u8]) -> Result<&mut Self> { + ensure!( + v.len() <= ARRAY_MAX_LENGTH, + "array length exceeded the maximum length limit. \ + length: {}, Max length limit: {}", + v.len(), + ARRAY_MAX_LENGTH, + ); + + // first add the length as a 4-byte integer + self.output.write_u32::(v.len() as u32)?; + self.output.write_all(v)?; + Ok(self) + } + + fn encode_btreemap( + &mut self, + v: &BTreeMap, + ) -> Result<&mut Self> { + ensure!( + v.len() <= ARRAY_MAX_LENGTH, + "map size exceeded the maximum limit. length: {}, max length limit: {}", + v.len(), + ARRAY_MAX_LENGTH, + ); + + // add the number of pairs in the map + self.output.write_u32::(v.len() as u32)?; + + // Regardless of the order defined for K of the map, write in the order of the lexicographic + // order of the canonical serialized bytes of K + let mut map = BTreeMap::new(); + for (key, value) in v { + map.insert( + SimpleSerializer::>::serialize(key)?, + SimpleSerializer::>::serialize(value)?, + ); + } + + for (key, value) in map { + self.encode_raw_bytes(&key)?; + self.encode_raw_bytes(&value)?; + } + Ok(self) + } + + fn encode_vec(&mut self, v: &[T]) -> Result<&mut Self> { + ensure!( + v.len() <= ARRAY_MAX_LENGTH, + "map size exceeded the maximum limit. length: {}, max length limit: {}", + v.len(), + ARRAY_MAX_LENGTH, + ); + + // add the number of items in the vec + self.output.write_u32::(v.len() as u32)?; + for value in v { + self.encode_struct(value)?; + } + Ok(self) + } +} + +pub trait CanonicalDeserializer { + fn decode_struct(&mut self) -> Result + where + T: CanonicalDeserialize, + Self: Sized, + { + T::deserialize(self) + } + + fn decode_u64(&mut self) -> Result; + + fn decode_u32(&mut self) -> Result; + + fn decode_u8(&mut self) -> Result; + + fn decode_bool(&mut self) -> Result; + + // decode a byte array with the given length as input + fn decode_bytes_with_len(&mut self, len: u32) -> Result>; + + fn decode_variable_length_bytes(&mut self) -> Result>; + + fn decode_btreemap( + &mut self, + ) -> Result>; + + fn decode_vec(&mut self) -> Result>; +} + +pub trait CanonicalDeserialize { + fn deserialize(deserializer: &mut impl CanonicalDeserializer) -> Result + where + Self: Sized; +} + +#[derive(Clone)] +pub struct SimpleDeserializer<'a> { + raw_bytes: Cursor<&'a [u8]>, +} + +impl<'a> SimpleDeserializer<'a> { + pub fn new(raw_bytes: &'a T) -> Self + where + T: AsRef<[u8]> + ?Sized, + { + Self { + raw_bytes: Cursor::new(raw_bytes.as_ref()), + } + } + + pub fn deserialize(data: &'a [u8]) -> Result + where + T: CanonicalDeserialize, + { + let mut deserializer = Self::new(data); + T::deserialize(&mut deserializer) + } +} + +impl<'a> CanonicalDeserializer for SimpleDeserializer<'a> { + fn decode_u64(&mut self) -> Result { + let num = self.raw_bytes.read_u64::()?; + Ok(num) + } + + fn decode_u32(&mut self) -> Result { + let num = self.raw_bytes.read_u32::()?; + Ok(num) + } + + fn decode_u8(&mut self) -> Result { + let num = self.raw_bytes.read_u8()?; + Ok(num) + } + + fn decode_bool(&mut self) -> Result { + let b = self.raw_bytes.read_u8()?; + ensure!(b == 0 || b == 1, "bool must be 0 or 1, found {}", b,); + Ok(b != 0) + } + + fn decode_bytes_with_len(&mut self, len: u32) -> Result> { + // make sure there is enough bytes left in the buffer + let remain = self.raw_bytes.get_ref().len() as u64 - self.raw_bytes.position(); + ensure!( + remain >= len.into(), + "not enough bytes left. input size: {}, remaining: {}", + len, + remain + ); + + let mut buffer = vec![0; len as usize]; + self.raw_bytes.read_exact(&mut buffer)?; + Ok(buffer) + } + + fn decode_variable_length_bytes(&mut self) -> Result> { + let len = self.raw_bytes.read_u32::()?; + ensure!( + len as usize <= ARRAY_MAX_LENGTH, + "array length longer than max allowed length. len: {}, max: {}", + len, + ARRAY_MAX_LENGTH + ); + + // make sure there is enough bytes left in the buffer + let remain = self.raw_bytes.get_ref().len() - self.raw_bytes.position() as usize; + ensure!( + remain >= (len as usize), + "not enough bytes left. len: {}, remaining: {}", + len, + remain + ); + + let mut vec = vec![0; len as usize]; + self.raw_bytes.read_exact(&mut vec)?; + Ok(vec) + } + + fn decode_btreemap( + &mut self, + ) -> Result> { + let len = self.raw_bytes.read_u32::()?; + ensure!( + len as usize <= ARRAY_MAX_LENGTH, + "map size bigger than max allowed. size: {}, max: {}", + len, + ARRAY_MAX_LENGTH + ); + + let mut map = BTreeMap::new(); + for _i in 0..len { + let key = K::deserialize(self)?; + let value = V::deserialize(self)?; + map.insert(key, value); + } + Ok(map) + } + + fn decode_vec(&mut self) -> Result> { + let len = self.raw_bytes.read_u32::()?; + ensure!( + len as usize <= ARRAY_MAX_LENGTH, + "map size bigger than max allowed. size: {}, max: {}", + len, + ARRAY_MAX_LENGTH + ); + + let mut vec = Vec::new(); + for _i in 0..len { + let v = T::deserialize(self)?; + vec.push(v); + } + Ok(vec) + } +} + +impl CanonicalSerialize for Vec +where + T: CanonicalSerialize, +{ + fn serialize(&self, serializer: &mut impl CanonicalSerializer) -> Result<()> { + serializer.encode_vec(self.as_ref())?; + Ok(()) + } +} + +impl CanonicalDeserialize for Vec +where + T: CanonicalDeserialize, +{ + fn deserialize(deserializer: &mut impl CanonicalDeserializer) -> Result + where + Self: Sized, + { + deserializer.decode_vec() + } +} + +impl CanonicalSerialize for u32 { + fn serialize(&self, serializer: &mut impl CanonicalSerializer) -> Result<()> { + serializer.encode_u32(*self)?; + Ok(()) + } +} + +impl CanonicalDeserialize for u32 { + fn deserialize(deserializer: &mut impl CanonicalDeserializer) -> Result + where + Self: Sized, + { + deserializer.decode_u32() + } +} + +impl CanonicalSerialize for i32 { + fn serialize(&self, serializer: &mut impl CanonicalSerializer) -> Result<()> { + serializer.encode_u32(*self as u32)?; + Ok(()) + } +} + +impl CanonicalDeserialize for i32 { + fn deserialize(deserializer: &mut impl CanonicalDeserializer) -> Result + where + Self: Sized, + { + let num = deserializer.decode_u32()? as i32; + Ok(num) + } +} + +impl CanonicalSerialize for u64 { + fn serialize(&self, serializer: &mut impl CanonicalSerializer) -> Result<()> { + serializer.encode_u64(*self)?; + Ok(()) + } +} + +impl CanonicalDeserialize for u64 { + fn deserialize(deserializer: &mut impl CanonicalDeserializer) -> Result + where + Self: Sized, + { + let num = deserializer.decode_u64()?; + Ok(num) + } +} + +impl CanonicalSerialize for i64 { + fn serialize(&self, serializer: &mut impl CanonicalSerializer) -> Result<()> { + serializer.encode_u64(*self as u64)?; + Ok(()) + } +} + +impl CanonicalDeserialize for i64 { + fn deserialize(deserializer: &mut impl CanonicalDeserializer) -> Result + where + Self: Sized, + { + let num = deserializer.decode_u64()? as i64; + Ok(num) + } +} + +impl CanonicalSerialize for usize { + fn serialize(&self, serializer: &mut impl CanonicalSerializer) -> Result<()> { + assert_eq!(8, size_of::()); + serializer.encode_u64(*self as u64)?; + Ok(()) + } +} + +impl CanonicalDeserialize for usize { + fn deserialize(deserializer: &mut impl CanonicalDeserializer) -> Result + where + Self: Sized, + { + assert_eq!(8, size_of::()); + let num = deserializer.decode_u64()? as usize; + Ok(num) + } +} + +impl CanonicalSerialize for u8 { + fn serialize(&self, serializer: &mut impl CanonicalSerializer) -> Result<()> { + serializer.encode_u8(*self)?; + Ok(()) + } +} + +impl CanonicalDeserialize for u8 { + fn deserialize(deserializer: &mut impl CanonicalDeserializer) -> Result + where + Self: Sized, + { + let num = deserializer.decode_u8()?; + Ok(num) + } +} + +impl CanonicalSerialize for BTreeMap, Vec> { + fn serialize(&self, serializer: &mut impl CanonicalSerializer) -> Result<()> { + serializer.encode_btreemap(self)?; + Ok(()) + } +} + +impl CanonicalDeserialize for BTreeMap, Vec> { + fn deserialize(deserializer: &mut impl CanonicalDeserializer) -> Result + where + Self: Sized, + { + Ok(deserializer.decode_btreemap()?) + } +} diff --git a/common/canonical_serialization/src/test_helper.rs b/common/canonical_serialization/src/test_helper.rs new file mode 100644 index 0000000000000..bb0dda1479ac4 --- /dev/null +++ b/common/canonical_serialization/src/test_helper.rs @@ -0,0 +1,16 @@ +// Copyright (c) The Libra Core Contributors +// SPDX-License-Identifier: Apache-2.0 + +use crate::{CanonicalDeserialize, CanonicalSerialize, SimpleDeserializer, SimpleSerializer}; +use std::fmt::Debug; + +pub fn assert_canonical_encode_decode(object: &T) +where + T: CanonicalSerialize + CanonicalDeserialize + Debug + Eq, +{ + let serialized: Vec = + SimpleSerializer::serialize(object).expect("Serialization should work"); + let deserialized: T = + SimpleDeserializer::deserialize(&serialized).expect("Deserialization should work"); + assert_eq!(*object, deserialized); +} diff --git a/common/channel/Cargo.toml b/common/channel/Cargo.toml new file mode 100644 index 0000000000000..502b7cc31c595 --- /dev/null +++ b/common/channel/Cargo.toml @@ -0,0 +1,15 @@ +[package] +name = "channel" +version = "0.1.0" +authors = ["Libra Association "] +license = "Apache-2.0" +publish = false +edition = "2018" + +[dependencies] +futures = { version = "=0.3.0-alpha.16", package = "futures-preview" } +lazy_static = "1.3.0" +metrics = { path = "../metrics" } + +[dev-dependencies] +rusty-fork = "0.2.1" diff --git a/common/channel/src/lib.rs b/common/channel/src/lib.rs new file mode 100644 index 0000000000000..c831e2ca9528c --- /dev/null +++ b/common/channel/src/lib.rs @@ -0,0 +1,157 @@ +// Copyright (c) The Libra Core Contributors +// SPDX-License-Identifier: Apache-2.0 + +//! Provides an mpsc (multi-producer single-consumer) channel wrapped in an +//! [`IntGauge`](metrics::IntGauge) +//! +//! The original futures mpsc channels has the behavior that each cloned sender gets a guaranteed +//! slot. There are cases in our codebase that senders need to be cloned to work with combinators +//! like `buffer_unordered`. The bounded mpsc channels turn to be unbounded in this way. There are +//! great discussions in this [PR](https://github.com/rust-lang-nursery/futures-rs/pull/984). The +//! argument of the current behavior is to have only local limit on each sender, and relies on +//! global coordination for the number of cloned senders. However, this isn't really feasible in +//! some cases. One solution that came up from the discussion is to have poll_flush call poll_ready +//! (instead of a noop) to make sure the current sender task isn't parked. For the case that a new +//! cloned sender tries to send a message to a full channel, send executes poll_ready, start_send +//! and poll_flush. The first poll_ready would return Ready because maybe_parked initiated as +//! false. start_send then pushes the message to the internal message queue and parks the sender +//! task. poll_flush calls poll_ready again, and this time, it would return Pending because the +//! sender task is parked. So the send will block until the receiver removes more messages from the +//! queue and that sender's task is unparked. +//! [This PR](https://github.com/rust-lang-nursery/futures-rs/pull/1671) is supposed to fix this in +//! futures 0.3. It'll be consistent once it's merged. +//! +//! This change does have some implications though. +//! 1. When the channel size is 0, it becomes synchronous. `send` won't finish until the item is +//! taken from the receiver. +//! 2. `send` may fail if the receiver drops after receiving the item. +//! +//! let (tx, rx) = channel::new_test(1); +//! let f_tx = async move { +//! block_on(tx.send(1)).unwrap(); +//! }; +//! let f_rx = async move { +//! let item = block_on(rx.next()).unwrap(); +//! assert_eq!(item, 1); +//! }; +//! block_on(join(f_tx, f_rx)).unwrap(); +//! +//! For the example above, `tx.send` could fail. Because send has three steps - poll_ready, +//! start_send and poll_flush. After start_send, the rx can receive the item, but if rx gets +//! dropped before poll_flush, it'll trigger disconnected send error. That's why the disconnected +//! error is converted to an Ok in poll_flush. + +use futures::{ + channel::mpsc, + sink::Sink, + stream::{FusedStream, Stream}, + task::{Context, Poll}, +}; +use metrics::IntGauge; +use std::pin::Pin; + +#[cfg(test)] +mod test; + +/// Wrapper around a value with an `IntGauge` +/// It is used to gauge the number of elements in a `mpsc::channel` +#[derive(Clone)] +pub struct WithGauge { + gauge: IntGauge, + value: T, +} + +/// Similar to `mpsc::Sender`, but with an `IntGauge` +pub type Sender = WithGauge>; +/// Similar to `mpsc::Receiver`, but with an `IntGauge` +pub type Receiver = WithGauge>; + +/// `Sender` implements `Sink` in the same way as `mpsc::Sender`, but it increments the +/// associated `IntGauge` when it sends a message successfully. +impl Sink for Sender { + type SinkError = mpsc::SendError; + + fn poll_ready( + mut self: Pin<&mut Self>, + cx: &mut Context<'_>, + ) -> Poll> { + (*self).value.poll_ready(cx) + } + + fn start_send(mut self: Pin<&mut Self>, msg: T) -> Result<(), Self::SinkError> { + self.gauge.inc(); + (*self).value.start_send(msg).map_err(|e| { + self.gauge.dec(); + e + })?; + Ok(()) + } + + // `poll_flush` would block if `poll_ready` returns pending, which means the channel is at + // capacity and the sender task is parked. + fn poll_flush( + mut self: Pin<&mut Self>, + cx: &mut Context<'_>, + ) -> Poll> { + match (*self).value.poll_ready(cx) { + Poll::Ready(Err(ref e)) if e.is_disconnected() => { + // If the receiver disconnected, we consider the sink to be flushed. + Poll::Ready(Ok(())) + } + x => x, + } + } + + fn poll_close( + mut self: Pin<&mut Self>, + _: &mut Context<'_>, + ) -> Poll> { + self.value.disconnect(); + Poll::Ready(Ok(())) + } +} + +impl FusedStream for Receiver { + fn is_terminated(&self) -> bool { + self.value.is_terminated() + } +} + +/// `Receiver` implements `Stream` in the same way as `mpsc::Stream`, but it decrements the +/// associated `IntGauge` when it gets polled sucessfully. +impl Stream for Receiver { + type Item = T; + + fn poll_next(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll> { + let poll = Pin::new(&mut self.value).poll_next(cx); + if let Poll::Ready(Some(_)) = poll { + self.gauge.dec(); + } + poll + } +} + +/// Similar to `mpsc::channel`, `new` creates a pair of `Sender` and `Receiver` +pub fn new(size: usize, gauge: &IntGauge) -> (Sender, Receiver) { + gauge.set(0); + let (sender, receiver) = mpsc::channel(size); + ( + WithGauge { + gauge: gauge.clone(), + value: sender, + }, + WithGauge { + gauge: gauge.clone(), + value: receiver, + }, + ) +} + +lazy_static::lazy_static! { + pub static ref TEST_COUNTER: IntGauge = + IntGauge::new("TEST_COUNTER", "Counter of network tests").unwrap(); +} + +pub fn new_test(size: usize) -> (Sender, Receiver) { + new(size, &TEST_COUNTER) +} diff --git a/common/channel/src/test.rs b/common/channel/src/test.rs new file mode 100644 index 0000000000000..472a81b2d2896 --- /dev/null +++ b/common/channel/src/test.rs @@ -0,0 +1,67 @@ +// Copyright (c) The Libra Core Contributors +// SPDX-License-Identifier: Apache-2.0 + +use crate::{new_test, TEST_COUNTER}; +use futures::{ + executor::block_on, + task::{noop_waker, Context, Poll}, + FutureExt, SinkExt, StreamExt, +}; +use rusty_fork::{rusty_fork_id, rusty_fork_test, rusty_fork_test_name}; + +#[test] +fn test_send() { + let (mut tx, mut rx) = new_test(8); + assert_eq!(TEST_COUNTER.get(), 0); + let item = 42; + block_on(tx.send(item)).unwrap(); + assert_eq!(TEST_COUNTER.get(), 1); + let received_item = block_on(rx.next()).unwrap(); + assert_eq!(received_item, item); + assert_eq!(TEST_COUNTER.get(), 0); +} + +// Fork the unit tests into separate processes to avoid the conflict that these tests executed in +// multiple threads may manipulate TEST_COUNTER at the same time. +rusty_fork_test! { +#[test] +fn test_send_backpressure() { + let waker = noop_waker(); + let mut cx = Context::from_waker(&waker); + + let (mut tx, mut rx) = new_test(1); + assert_eq!(TEST_COUNTER.get(), 0); + block_on(tx.send(1)).unwrap(); + assert_eq!(TEST_COUNTER.get(), 1); + + let mut task = tx.send(2); + assert_eq!(task.poll_unpin(&mut cx), Poll::Pending); + let item = block_on(rx.next()).unwrap(); + assert_eq!(item, 1); + assert_eq!(TEST_COUNTER.get(), 1); + assert_eq!(task.poll_unpin(&mut cx), Poll::Ready(Ok(()))); +} +} + +// Fork the unit tests into separate processes to avoid the conflict that these tests executed in +// multiple threads may manipulate TEST_COUNTER at the same time. +rusty_fork_test! { +#[test] +fn test_send_backpressure_multi_senders() { + let waker = noop_waker(); + let mut cx = Context::from_waker(&waker); + + let (mut tx1, mut rx) = new_test(1); + assert_eq!(TEST_COUNTER.get(), 0); + block_on(tx1.send(1)).unwrap(); + assert_eq!(TEST_COUNTER.get(), 1); + + let mut tx2 = tx1.clone(); + let mut task = tx2.send(2); + assert_eq!(task.poll_unpin(&mut cx), Poll::Pending); + let item = block_on(rx.next()).unwrap(); + assert_eq!(item, 1); + assert_eq!(TEST_COUNTER.get(), 1); + assert_eq!(task.poll_unpin(&mut cx), Poll::Ready(Ok(()))); +} +} diff --git a/common/crash_handler/Cargo.toml b/common/crash_handler/Cargo.toml new file mode 100644 index 0000000000000..024111019c443 --- /dev/null +++ b/common/crash_handler/Cargo.toml @@ -0,0 +1,14 @@ +[package] +name = "crash_handler" +version = "0.1.0" +authors = ["Libra Association "] +license = "Apache-2.0" +publish = false +edition = "2018" + +[dependencies] +backtrace = "0.3.9" +toml = "0.4.7" + +logger = { path = "../logger" } +serde = { version = "1.0.89", features = ["derive"] } diff --git a/common/crash_handler/src/lib.rs b/common/crash_handler/src/lib.rs new file mode 100644 index 0000000000000..f5a6bb6d654b1 --- /dev/null +++ b/common/crash_handler/src/lib.rs @@ -0,0 +1,75 @@ +// Copyright (c) The Libra Core Contributors +// SPDX-License-Identifier: Apache-2.0 + +#![feature(panic_info_message)] + +use backtrace::Backtrace; +use logger::prelude::*; +use serde::Serialize; +use std::{ + panic::{self, PanicInfo}, + process, thread, time, +}; + +#[derive(Debug, Serialize)] +pub struct CrashInfo { + reason: String, + details: String, + backtrace: String, +} + +pub fn setup_panic_handler() { + // If RUST_BACKTRACE variable isn't present or RUST_BACKTRACE=0, we setup panic handler + let is_backtrace_set = std::env::var_os("RUST_BACKTRACE") + .map(|x| &x != "0") + .unwrap_or(false); + + if is_backtrace_set { + info!("Skip panic handler setup because RUST_BACKTRACE is set"); + } else { + panic::set_hook(Box::new(move |pi: &PanicInfo<'_>| { + handle_panic(pi); + })); + } +} + +// formats and logs panic information +fn handle_panic(panic_info: &PanicInfo<'_>) { + let reason = match panic_info.message() { + Some(m) => format!("{}", m), + None => "Unknown Reason".into(), + }; + + let mut details = String::new(); + + let payload = match panic_info.payload().downcast_ref::<&str>() { + Some(pld) => format!("Details: {}. ", pld), + None => "[no extra details]. ".into(), + }; + details.push_str(&payload); + + let location = match panic_info.location() { + Some(loc) => format!( + "Thread panicked at file '{}' at line {}", + loc.file(), + loc.line() + ), + None => "[no location details].".into(), + }; + details.push_str(&location); + + let backtrace = format!("{:#?}", Backtrace::new()); + + let info = CrashInfo { + reason, + details, + backtrace, + }; + crit!("{}", toml::to_string_pretty(&info).unwrap()); + + // allow to save on disk + thread::sleep(time::Duration::from_millis(100)); + + // kill the process + process::exit(12); +} diff --git a/common/debug_interface/Cargo.toml b/common/debug_interface/Cargo.toml new file mode 100644 index 0000000000000..d444c98b9a96d --- /dev/null +++ b/common/debug_interface/Cargo.toml @@ -0,0 +1,20 @@ +[package] +name = "debug_interface" +version = "0.1.0" +authors = ["Libra Association "] +license = "Apache-2.0" +publish = false +edition = "2018" + +[dependencies] +grpcio = "0.4.3" +futures = "0.1.23" +protobuf = "2.6" + +failure = { package = "failure_ext", path = "../failure_ext" } +jemalloc = { path = "../jemalloc" } +logger = { path = "../logger" } +metrics = { path = "../metrics" } + +[build-dependencies] +build_helpers = { path = "../build_helpers" } diff --git a/common/debug_interface/build.rs b/common/debug_interface/build.rs new file mode 100644 index 0000000000000..0095edf448c71 --- /dev/null +++ b/common/debug_interface/build.rs @@ -0,0 +1,17 @@ +// Copyright (c) The Libra Core Contributors +// SPDX-License-Identifier: Apache-2.0 + +//! This compiles all the `.proto` files under `src/` directory. +//! +//! For example, if there is a file `src/a/b/c.proto`, it will generate `src/a/b/c.rs` and +//! `src/a/b/c_grpc.rs`. + +fn main() { + let proto_root = "src/proto"; + + build_helpers::build_helpers::compile_proto( + proto_root, + vec![], /* dependent roots */ + false, /* generate_client_stub */ + ); +} diff --git a/common/debug_interface/src/lib.rs b/common/debug_interface/src/lib.rs new file mode 100644 index 0000000000000..b08195a48cef2 --- /dev/null +++ b/common/debug_interface/src/lib.rs @@ -0,0 +1,79 @@ +// Copyright (c) The Libra Core Contributors +// SPDX-License-Identifier: Apache-2.0 + +use crate::proto::{ + node_debug_interface::{DumpJemallocHeapProfileRequest, GetNodeDetailsRequest}, + node_debug_interface_grpc::NodeDebugInterfaceClient, +}; +use failure::prelude::*; +use grpcio::{ChannelBuilder, EnvBuilder}; +use std::{collections::HashMap, sync::Arc}; + +// Generated +pub mod proto; + +pub mod node_debug_helpers; +pub mod node_debug_service; + +/// Implement default utility client for NodeDebugInterface +pub struct NodeDebugClient { + client: NodeDebugInterfaceClient, + address: String, + port: u16, +} + +impl NodeDebugClient { + pub fn new>(address: A, port: u16) -> Self { + let env = Arc::new(EnvBuilder::new().name_prefix("grpc-debug-").build()); + let ch = ChannelBuilder::new(env).connect(&format!("{}:{}", address.as_ref(), port)); + let client = NodeDebugInterfaceClient::new(ch); + + Self { + client, + address: address.as_ref().to_owned(), + port, + } + } + + pub fn get_address(&self) -> &str { + &self.address + } + + pub fn get_port(&self) -> u16 { + self.port + } + + pub fn get_node_metric>(&self, metric: S) -> Result> { + let metrics = self.get_node_metrics()?; + Ok(metrics.get(metric.as_ref()).cloned()) + } + + pub fn get_node_metrics(&self) -> Result> { + let response = self + .client + .get_node_details(&GetNodeDetailsRequest::new()) + .context("Unable to query Node metrics")?; + + response + .stats + .into_iter() + .map(|(k, v)| match v.parse::() { + Ok(v) => Ok((k, v)), + Err(_) => Err(format_err!( + "Failed to parse stat value to i64 {}: {}", + &k, + &v + )), + }) + .collect() + } + + pub fn dump_heap_profile(&self) -> Result { + let response = self + .client + .dump_jemalloc_heap_profile(&DumpJemallocHeapProfileRequest::new()) + .context("Unable to request heap dump")?; + + Ok(response.status_code) + } +} diff --git a/common/debug_interface/src/node_debug_helpers.rs b/common/debug_interface/src/node_debug_helpers.rs new file mode 100644 index 0000000000000..1ecdc83d643d5 --- /dev/null +++ b/common/debug_interface/src/node_debug_helpers.rs @@ -0,0 +1,38 @@ +// Copyright (c) The Libra Core Contributors +// SPDX-License-Identifier: Apache-2.0 + +//! Helper functions for debug interface. + +use crate::proto::node_debug_interface_grpc::NodeDebugInterfaceClient; +use grpcio::{ChannelBuilder, EnvBuilder}; +use logger::prelude::*; +use std::{sync::Arc, thread, time}; + +pub fn create_debug_client(debug_port: u16) -> NodeDebugInterfaceClient { + let node_connection_str = format!("localhost:{}", debug_port); + let env = Arc::new(EnvBuilder::new().name_prefix("grpc-debug-").build()); + let ch = ChannelBuilder::new(env).connect(&node_connection_str); + NodeDebugInterfaceClient::new(ch) +} + +pub fn check_node_up(client: &NodeDebugInterfaceClient) { + let mut attempt = 200; + let get_details_req = crate::proto::node_debug_interface::GetNodeDetailsRequest::new(); + + loop { + match client.get_node_details(&get_details_req) { + Ok(_) => { + info!("Node is up"); + break; + } + Err(e) => { + if attempt > 0 { + attempt -= 1; + thread::sleep(time::Duration::from_millis(100)); + } else { + panic!("Node is not up after many attempts: {}", e); + } + } + } + } +} diff --git a/common/debug_interface/src/node_debug_service.rs b/common/debug_interface/src/node_debug_service.rs new file mode 100644 index 0000000000000..16fc234b48ae0 --- /dev/null +++ b/common/debug_interface/src/node_debug_service.rs @@ -0,0 +1,60 @@ +// Copyright (c) The Libra Core Contributors +// SPDX-License-Identifier: Apache-2.0 + +//! Debug interface to access information in a specific node. + +use crate::proto::{ + node_debug_interface::{ + DumpJemallocHeapProfileRequest, DumpJemallocHeapProfileResponse, GetNodeDetailsRequest, + GetNodeDetailsResponse, + }, + node_debug_interface_grpc::NodeDebugInterface, +}; +use futures::Future; +use logger::prelude::*; +use metrics::counters::COUNTER_ADMISSION_CONTROL_CANNOT_SEND_REPLY; + +#[derive(Clone, Default)] +pub struct NodeDebugService {} + +impl NodeDebugService { + pub fn new() -> Self { + Default::default() + } +} + +impl NodeDebugInterface for NodeDebugService { + fn get_node_details( + &mut self, + ctx: ::grpcio::RpcContext<'_>, + _req: GetNodeDetailsRequest, + sink: ::grpcio::UnarySink, + ) { + info!("[GRPC] get_node_details"); + let mut response = GetNodeDetailsResponse::new(); + response.stats = metrics::get_all_metrics(); + ctx.spawn(sink.success(response).map_err(default_reply_error_logger)) + } + + fn dump_jemalloc_heap_profile( + &mut self, + ctx: ::grpcio::RpcContext<'_>, + _request: DumpJemallocHeapProfileRequest, + sink: ::grpcio::UnarySink, + ) { + trace!("[GRPC] dump_jemalloc_heap_profile"); + let status_code = match jemalloc::dump_jemalloc_memory_profile() { + Ok(_) => 0, + Err(err_code) => err_code, + }; + let mut resp = DumpJemallocHeapProfileResponse::new(); + resp.status_code = status_code; + let f = sink.success(resp).map_err(default_reply_error_logger); + ctx.spawn(f) + } +} + +fn default_reply_error_logger(e: T) { + COUNTER_ADMISSION_CONTROL_CANNOT_SEND_REPLY.inc(); + error!("Failed to reply error due to {:?}", e) +} diff --git a/common/debug_interface/src/proto/.gitignore b/common/debug_interface/src/proto/.gitignore new file mode 100644 index 0000000000000..0b0298b4200cd --- /dev/null +++ b/common/debug_interface/src/proto/.gitignore @@ -0,0 +1,3 @@ +# Ignore all the generated files. +node_debug_interface.rs +node_debug_interface_grpc.rs diff --git a/common/debug_interface/src/proto/mod.rs b/common/debug_interface/src/proto/mod.rs new file mode 100644 index 0000000000000..f493d017aaebd --- /dev/null +++ b/common/debug_interface/src/proto/mod.rs @@ -0,0 +1,5 @@ +// Copyright (c) The Libra Core Contributors +// SPDX-License-Identifier: Apache-2.0 + +pub mod node_debug_interface; +pub mod node_debug_interface_grpc; diff --git a/common/debug_interface/src/proto/node_debug_interface.proto b/common/debug_interface/src/proto/node_debug_interface.proto new file mode 100644 index 0000000000000..47f4611394e05 --- /dev/null +++ b/common/debug_interface/src/proto/node_debug_interface.proto @@ -0,0 +1,27 @@ +// Copyright (c) The Libra Core Contributors +// SPDX-License-Identifier: Apache-2.0 + +// A Debugging interface to be used to query debug information from a Node +syntax = "proto3"; + +package debug; + +message GetNodeDetailsRequest {} + +message GetNodeDetailsResponse { map stats = 1; } + +message DumpJemallocHeapProfileRequest {} + +message DumpJemallocHeapProfileResponse { + // Status code from jemalloc mallctl call. 0 indicates success. + int32 status_code = 1; +} + +service NodeDebugInterface { + // Returns debug information about node + rpc GetNodeDetails(GetNodeDetailsRequest) returns (GetNodeDetailsResponse) {} + + // Triggers a dump of heap profile. + rpc DumpJemallocHeapProfile(DumpJemallocHeapProfileRequest) + returns (DumpJemallocHeapProfileResponse) {} +} diff --git a/common/executable_helpers/Cargo.toml b/common/executable_helpers/Cargo.toml new file mode 100644 index 0000000000000..a223522f416f3 --- /dev/null +++ b/common/executable_helpers/Cargo.toml @@ -0,0 +1,16 @@ +[package] +name = "executable_helpers" +version = "0.1.0" +authors = ["Libra Association "] +license = "Apache-2.0" +publish = false +edition = "2018" + +[dependencies] +clap = "2.32.0" +slog-scope = "4.0" + +config = { path = "../../config" } +crash_handler = { path = "../crash_handler" } +logger = { path = "../logger" } +metrics = { path = "../metrics" } diff --git a/common/executable_helpers/src/helpers.rs b/common/executable_helpers/src/helpers.rs new file mode 100644 index 0000000000000..dd0cb254d4fa1 --- /dev/null +++ b/common/executable_helpers/src/helpers.rs @@ -0,0 +1,158 @@ +// Copyright (c) The Libra Core Contributors +// SPDX-License-Identifier: Apache-2.0 + +use clap::{value_t, App, Arg, ArgMatches}; +use config::config::{NodeConfig, NodeConfigHelpers}; +use logger::prelude::*; +use slog_scope::GlobalLoggerGuard; + +// General args +pub const ARG_PEER_ID: &str = "--peer_id"; +pub const ARG_DISABLE_LOGGING: &str = "--no_logging"; +pub const ARG_CONFIG_PATH: &str = "--config_path"; + +// Used for consensus +pub const ARG_NUM_PAYLOAD: &str = "--num_payload"; +pub const ARG_PAYLOAD_SIZE: &str = "--payload_size"; + +pub fn load_configs_from_args(args: &ArgMatches<'_>) -> NodeConfig { + let node_config; + + if args.is_present(ARG_CONFIG_PATH) { + // Allow peer id over-ride via command line + let peer_id = value_t!(args, ARG_PEER_ID, String).ok(); + + let config_path = + value_t!(args, ARG_CONFIG_PATH, String).expect("Path to config file must be specified"); + info!("Loading node config from: {}", &config_path); + node_config = NodeConfig::load_config(peer_id, &config_path).expect("NodeConfig"); + + info!("Starting Full {}", node_config.base.peer_id); + } else { + // Note we will silently ignore --peer_id arg here + info!("Loading test configs"); + node_config = NodeConfigHelpers::get_single_node_test_config(false /* random ports */); + + info!("Starting Single-Mode {}", node_config.base.peer_id); + } + + // Node configuration contains important ephemeral port information and should + // not be subject to being disabled as with other logs + println!("Using node config {:?}", &node_config); + + node_config +} + +pub fn setup_metrics(peer_id: &str, node_config: &NodeConfig) { + if !node_config.metrics.dir.as_os_str().is_empty() { + metrics::dump_all_metrics_to_file_periodically( + &node_config.metrics.dir, + &format!("{}.metrics", peer_id), + node_config.metrics.collection_interval_ms, + ); + } + + // TODO: should we do this differently for different binaries? + if !node_config.metrics.push_server_addr.is_empty() { + metrics::push_all_metrics_to_pushgateway_periodically( + "libra_node", + &node_config.metrics.push_server_addr, + peer_id, + node_config.metrics.collection_interval_ms, + ); + } +} + +/// Performs common setup for the executable. Takes in args that +/// you wish to use for this executable +pub fn setup_executable( + app_name: String, + arg_names: Vec<&str>, +) -> (NodeConfig, Option, ArgMatches<'_>) { + crash_handler::setup_panic_handler(); + + let args = get_arg_matches(app_name, arg_names); + let is_logging_disabled = args.is_present(ARG_DISABLE_LOGGING); + let mut _logger = set_default_global_logger(is_logging_disabled, None); + + let config = load_configs_from_args(&args); + + // Reset the global logger using config (for chan_size currently). + // We need to drop the global logger guard first before resetting it. + _logger = None; + let logger = set_default_global_logger( + is_logging_disabled, + Some(config.base.node_async_log_chan_size), + ); + + setup_metrics(&config.base.peer_id, &config); + + (config, logger, args) +} + +fn set_default_global_logger( + is_logging_disabled: bool, + chan_size: Option, +) -> Option { + if is_logging_disabled { + return None; + } + + Some(logger::set_default_global_logger( + true, /* async */ + chan_size, /* chan_size */ + )) +} + +fn get_arg_matches(app_name: String, arg_names: Vec<&str>) -> ArgMatches<'_> { + let mut service_name = app_name.clone(); + service_name.push_str(" Service"); + + let mut app = App::new(app_name) + .version("0.1.0") + .author("Libra Association ") + .about(service_name.as_str()); + + for arg in arg_names { + let short; + let takes_value; + let help; + match arg { + ARG_PEER_ID => { + short = "-p"; + takes_value = true; + help = "Specify peer id for this node"; + } + ARG_CONFIG_PATH => { + short = "-f"; + takes_value = true; + help = "Specify the path to the config file"; + } + ARG_DISABLE_LOGGING => { + short = "-d"; + takes_value = false; + help = "Controls logging"; + } + ARG_NUM_PAYLOAD => { + short = "-n"; + takes_value = true; + help = "Specify the number of payload each node send"; + } + ARG_PAYLOAD_SIZE => { + short = "-s"; + takes_value = true; + help = "Specify the byte size of each payload"; + } + x => panic!("Invalid argument: {}", x), + } + app = app.arg( + Arg::with_name(arg) + .short(short) + .long(arg) + .takes_value(takes_value) + .help(help), + ); + } + + app.get_matches() +} diff --git a/common/executable_helpers/src/lib.rs b/common/executable_helpers/src/lib.rs new file mode 100644 index 0000000000000..db9e3052a1d29 --- /dev/null +++ b/common/executable_helpers/src/lib.rs @@ -0,0 +1,4 @@ +// Copyright (c) The Libra Core Contributors +// SPDX-License-Identifier: Apache-2.0 + +pub mod helpers; diff --git a/common/failure_ext/Cargo.toml b/common/failure_ext/Cargo.toml new file mode 100644 index 0000000000000..1b5035ccfc84b --- /dev/null +++ b/common/failure_ext/Cargo.toml @@ -0,0 +1,12 @@ +[package] +name = "failure_ext" +version = "0.1.0" +authors = ["Libra Association "] +license = "Apache-2.0" +publish = false +edition = "2018" + +[dependencies] +failure = "0.1.3" + +failure_macros = { path = "failure_macros" } diff --git a/common/failure_ext/failure_macros/Cargo.toml b/common/failure_ext/failure_macros/Cargo.toml new file mode 100644 index 0000000000000..23fab19e523af --- /dev/null +++ b/common/failure_ext/failure_macros/Cargo.toml @@ -0,0 +1,9 @@ +[package] +name = "failure_macros" +version = "0.1.0" +authors = ["Libra Association "] +license = "Apache-2.0" +publish = false +edition = "2018" + +[dependencies] diff --git a/common/failure_ext/failure_macros/src/lib.rs b/common/failure_ext/failure_macros/src/lib.rs new file mode 100644 index 0000000000000..0d417cbf80ffd --- /dev/null +++ b/common/failure_ext/failure_macros/src/lib.rs @@ -0,0 +1,15 @@ +// Copyright (c) The Libra Core Contributors +// SPDX-License-Identifier: Apache-2.0 + +//! Collection of convenience macros for error handling + +/// Exits a function early with an `Error`. +/// +/// Equivalent to the `bail!` macro, except a error type is provided instead of +/// a message. +#[macro_export] +macro_rules! bail_err { + ($e:expr) => { + return Err(From::from($e)); + }; +} diff --git a/common/failure_ext/src/lib.rs b/common/failure_ext/src/lib.rs new file mode 100644 index 0000000000000..41e179a0db4bc --- /dev/null +++ b/common/failure_ext/src/lib.rs @@ -0,0 +1,31 @@ +// Copyright (c) The Libra Core Contributors +// SPDX-License-Identifier: Apache-2.0 + +//! A common error handling library for the Libra project. +//! +//! ## Usage +//! +//! // This crate must be imported as 'failure' in order to ensure the +//! // procedural derive macro for the `Fail` trait can function properly. +//! failure = { path = "../common/failure_ext", package = "failure_ext" } +//! // Most of the types and macros you'll need can be found in the prelude. +//! use failure::prelude::*; + +pub use failure::{ + _core, bail, ensure, err_msg, format_err, AsFail, Backtrace, Causes, Compat, Context, Error, + Fail, ResultExt, SyncFailure, +}; + +// Custom error handling macros are placed in the failure_macros crate. Due to +// the way intra-crate macro exports currently work, macros can't be exported +// from anywhere but the top level when they are defined in the same crate. +pub use failure_macros::bail_err; + +pub type Result = ::std::result::Result; + +/// Prelude module containing most commonly used types/macros this crate exports. +pub mod prelude { + pub use crate::Result; + pub use failure::{bail, ensure, err_msg, format_err, Error, Fail, ResultExt}; + pub use failure_macros::bail_err; +} diff --git a/common/grpc_helpers/Cargo.toml b/common/grpc_helpers/Cargo.toml new file mode 100644 index 0000000000000..0158799354c96 --- /dev/null +++ b/common/grpc_helpers/Cargo.toml @@ -0,0 +1,21 @@ +[package] +name = "grpc_helpers" +version = "0.1.0" +authors = ["Libra Association "] +license = "Apache-2.0" +publish = false +edition = "2018" + +[dependencies] +grpcio = "0.4.3" +futures = { version = "0.3.0-alpha.13", package = "futures-preview", features = ["compat"] } +futures_01 = { version = "0.1.25", package = "futures" } + +failure = { package = "failure_ext", path = "../failure_ext" } +logger = { path = "../logger" } +metrics = { path = "../metrics" } + +[dependencies.prometheus] +version = "0.4.2" +default-features = false +features = ["nightly", "push"] diff --git a/common/grpc_helpers/src/lib.rs b/common/grpc_helpers/src/lib.rs new file mode 100644 index 0000000000000..6e8dafce57fe0 --- /dev/null +++ b/common/grpc_helpers/src/lib.rs @@ -0,0 +1,144 @@ +// Copyright (c) The Libra Core Contributors +// SPDX-License-Identifier: Apache-2.0 + +use failure::{prelude::*, Result}; +use futures::{compat::Future01CompatExt, future::Future, prelude::*}; +use futures_01::future::Future as Future01; +use grpcio::{EnvBuilder, ServerBuilder}; +use logger::prelude::*; +use metrics::counters::SVC_COUNTERS; +use std::{ + str::from_utf8, + sync::{ + mpsc::{self, Sender}, + Arc, + }, + thread, time, +}; + +pub fn default_reply_error_logger(e: T) { + error!("Failed to reply error due to {:?}", e) +} + +pub fn create_grpc_invalid_arg_status(method: &str, err: ::failure::Error) -> ::grpcio::RpcStatus { + let msg = format!("Request failed {}", err); + error!("{} failed with {}", method, &msg); + ::grpcio::RpcStatus::new(::grpcio::RpcStatusCode::InvalidArgument, Some(msg)) +} + +/// This is a helper method to return a response to the GRPC context +/// and signal that the operation is done. +/// It's also logging any errors and incrementing relevant counters. +/// The return value is `bool` to flag externally whether the result +/// is successful (true) or not (false). +pub fn provide_grpc_response( + resp: Result, + ctx: ::grpcio::RpcContext<'_>, + sink: ::grpcio::UnarySink, +) { + let mut success = true; + match resp { + Ok(resp) => ctx.spawn(sink.success(resp).map_err(default_reply_error_logger)), + Err(e) => { + success = false; + let f = sink + .fail(create_grpc_invalid_arg_status( + from_utf8(ctx.method()).expect("Unable to convert function name to string"), + e, + )) + .map_err(default_reply_error_logger); + ctx.spawn(f) + } + } + SVC_COUNTERS.resp(&ctx, success); +} + +pub fn spawn_service_thread( + service: ::grpcio::Service, + service_host_address: String, + service_public_port: u16, + service_name: impl Into, +) -> ServerHandle { + spawn_service_thread_with_drop_closure( + service, + service_host_address, + service_public_port, + service_name, + || { /* no code, to make compiler happy */ }, + ) +} + +pub fn spawn_service_thread_with_drop_closure( + service: ::grpcio::Service, + service_host_address: String, + service_public_port: u16, + service_name: impl Into, + service_drop_closure: F, +) -> ServerHandle +where + F: FnOnce() + 'static, +{ + let env = Arc::new(EnvBuilder::new().name_prefix(service_name).build()); + let server = ServerBuilder::new(env) + .register_service(service) + .bind(service_host_address, service_public_port) + .build() + .expect("Unable to create grpc server"); + ServerHandle::setup_with_drop_closure(server, Some(Box::new(service_drop_closure))) +} + +pub struct ServerHandle { + stop_sender: Sender<()>, + drop_closure: Option>, +} + +impl ServerHandle { + pub fn setup_with_drop_closure( + mut server: ::grpcio::Server, + drop_closure: Option>, + ) -> Self { + let (start_sender, start_receiver) = mpsc::channel(); + let (stop_sender, stop_receiver) = mpsc::channel(); + let handle = Self { + stop_sender, + drop_closure, + }; + thread::spawn(move || { + server.start(); + start_sender.send(()).unwrap(); + loop { + if stop_receiver.try_recv().is_ok() { + return; + } + thread::sleep(time::Duration::from_millis(100)); + } + }); + + start_receiver.recv().unwrap(); + handle + } + pub fn setup(server: ::grpcio::Server) -> Self { + Self::setup_with_drop_closure(server, None) + } +} + +impl Drop for ServerHandle { + fn drop(&mut self) { + self.stop_sender.send(()).unwrap(); + if let Some(f) = self.drop_closure.take() { + f() + } + } +} + +pub fn convert_grpc_response( + response: grpcio::Result>, +) -> impl Future> { + future::ready(response.map_err(convert_grpc_err)) + .map_ok(Future01CompatExt::compat) + .and_then(|x| x.map_err(convert_grpc_err)) +} + +fn convert_grpc_err(e: ::grpcio::Error) -> Error { + format_err!("grpc error: {}", e) +} diff --git a/common/grpcio-client/Cargo.toml b/common/grpcio-client/Cargo.toml new file mode 100644 index 0000000000000..e0ef6c07f1e66 --- /dev/null +++ b/common/grpcio-client/Cargo.toml @@ -0,0 +1,15 @@ +[package] +name = "grpcio-client" +version = "0.1.0" +authors = ["Libra Association "] +license = "Apache-2.0" +publish = false +edition = "2018" + +[dependencies] +mktemp = "0.3" +protobuf = "2.6" +protobuf-codegen = "2.*" +protoc = "2.*" +protoc-grpcio = "1.0.1" +regex = "1.1.6" diff --git a/common/grpcio-client/src/codegen.rs b/common/grpcio-client/src/codegen.rs new file mode 100644 index 0000000000000..7aa6a361f871b --- /dev/null +++ b/common/grpcio-client/src/codegen.rs @@ -0,0 +1,308 @@ +// Copyright (c) The Libra Core Contributors +// SPDX-License-Identifier: Apache-2.0 + +use std::collections::HashMap; + +use protobuf::{ + compiler_plugin, + descriptor::{FileDescriptorProto, MethodDescriptorProto, ServiceDescriptorProto}, + descriptorx::{RootScope, WithScope}, +}; + +use super::util; +use protobuf_codegen::code_writer::CodeWriter; + +/* +This is mostly copied from grpcio-compiler. +It's copied and not reimplemented or re-used in some way because: +Most methods there are private, and I have to use the same names/structs +for the generated output. +*/ + +struct MethodGen<'a> { + proto: &'a MethodDescriptorProto, + root_scope: &'a RootScope<'a>, +} + +impl<'a> MethodGen<'a> { + fn new(proto: &'a MethodDescriptorProto, root_scope: &'a RootScope<'a>) -> MethodGen<'a> { + MethodGen { proto, root_scope } + } + + fn input(&self) -> String { + format!( + "super::{}", + self.root_scope + .find_message(self.proto.get_input_type()) + .rust_fq_name() + ) + } + + fn output(&self) -> String { + format!( + "super::{}", + self.root_scope + .find_message(self.proto.get_output_type()) + .rust_fq_name() + ) + } + + fn method_type(&self) -> (util::MethodType, String) { + match ( + self.proto.get_client_streaming(), + self.proto.get_server_streaming(), + ) { + (false, false) => ( + util::MethodType::Unary, + util::fq_grpc("util::MethodType::Unary"), + ), + (true, false) => ( + util::MethodType::ClientStreaming, + util::fq_grpc("util::MethodType::ClientStreaming"), + ), + (false, true) => ( + util::MethodType::ServerStreaming, + util::fq_grpc("util::MethodType::ServerStreaming"), + ), + (true, true) => ( + util::MethodType::Duplex, + util::fq_grpc("util::MethodType::Duplex"), + ), + } + } + + fn name(&self) -> String { + util::to_snake_case(self.proto.get_name()) + } + + // Method signatures + fn unary(&self, method_name: &str) -> String { + format!( + "{}(&self, req: &{}) -> {}<{}>", + method_name, + self.input(), + util::fq_grpc("Result"), + self.output() + ) + } + + fn unary_async(&self, method_name: &str) -> String { + format!( + "{}_async(&self, req: &{}) -> {} + Send>>", + method_name, + self.input(), + util::fq_grpc("Result"), + self.output(), + util::fq_grpc("Error") + ) + } + + fn client_streaming(&self, method_name: &str) -> String { + format!( + "{}(&self) -> {}<({}<{}>, {}<{}>)>", + method_name, + util::fq_grpc("Result"), + util::fq_grpc("ClientCStreamSender"), + self.input(), + util::fq_grpc("ClientCStreamReceiver"), + self.output() + ) + } + + fn server_streaming(&self, method_name: &str) -> String { + format!( + "{}(&self, req: &{}) -> {}<{}<{}>>", + method_name, + self.input(), + util::fq_grpc("Result"), + util::fq_grpc("ClientSStreamReceiver"), + self.output() + ) + } + + fn duplex_streaming(&self, method_name: &str) -> String { + format!( + "{}(&self) -> {}<({}<{}>, {}<{}>)>", + method_name, + util::fq_grpc("Result"), + util::fq_grpc("ClientDuplexSender"), + self.input(), + util::fq_grpc("ClientDuplexReceiver"), + self.output() + ) + } + + fn write_method(&self, has_impl: bool, w: &mut CodeWriter<'_>) { + let method_name = self.name(); + let method_name = method_name.as_str(); + + // some parts are not implemented yet, account for both + let (sig, implemented) = match self.method_type().0 { + util::MethodType::Unary => (self.unary(method_name), true), + util::MethodType::ClientStreaming => (self.client_streaming(method_name), false), + util::MethodType::ServerStreaming => (self.server_streaming(method_name), false), + util::MethodType::Duplex => (self.duplex_streaming(method_name), false), + }; + + if has_impl { + w.def_fn(sig.as_str(), |w| { + if implemented { + w.write_line(format!("self.{}(req)", method_name)); + } else { + w.unimplemented(); + } + }); + } else { + w.def_fn(sig.as_str(), |w| { + w.unimplemented(); + }); + } + + // async variant: only implemented for unary methods for now + if let util::MethodType::Unary = self.method_type().0 { + w.def_fn(self.unary_async(method_name).as_str(), |w| { + if has_impl { + w.match_expr(format!("self.{}_async(req)", method_name), |w| { + w.case_expr("Ok(f)", "Ok(Box::new(f))"); + w.case_expr("Err(e)", "Err(e)"); + }); + } else { + w.unimplemented(); + } + }); + } + } + + fn write_trait(&self, w: &mut CodeWriter<'_>) { + self.write_method(false, w); + } + + fn write_impl(&self, w: &mut CodeWriter<'_>) { + self.write_method(true, w); + } +} + +struct ClientTraitGen<'a> { + proto: &'a ServiceDescriptorProto, + methods: Vec>, + base_name: String, +} + +impl<'a> ClientTraitGen<'a> { + fn new( + proto: &'a ServiceDescriptorProto, + file: &FileDescriptorProto, + root_scope: &'a RootScope<'_>, + ) -> ClientTraitGen<'a> { + let methods = proto + .get_method() + .iter() + .map(|m| MethodGen::new(m, root_scope)) + .collect(); + + let base = protobuf::descriptorx::proto_path_to_rust_mod(file.get_name()); + + ClientTraitGen { + proto, + methods, + base_name: base, + } + } + + fn service_name(&self) -> String { + util::to_camel_case(self.proto.get_name()) + } + + fn trait_name(&self) -> String { + format!("{}Trait", self.client_name()) + } + + fn client_name(&self) -> String { + format!("{}Client", self.service_name()) + } + + fn write_trait(&self, w: &mut CodeWriter<'_>) { + w.pub_trait_extend(self.trait_name().as_str(), "Clone + Send + Sync", |w| { + // methods that go inside the trait + + for method in &self.methods { + w.write_line(""); + method.write_trait(w); + } + }); + } + + fn write_impl(&self, w: &mut CodeWriter<'_>) { + let type_name = format!("super::{}_grpc::{}", self.base_name, self.client_name()); + w.impl_for_block(self.trait_name(), type_name, |w| { + for method in &self.methods { + w.write_line(""); + method.write_impl(w); + } + }); + } + + fn write(&self, w: &mut CodeWriter<'_>) { + // Client trait definition + self.write_trait(w); + w.write_line(""); + + // Impl block for client trait + self.write_impl(w); + } +} + +fn gen_file( + file: &FileDescriptorProto, + root_scope: &RootScope<'_>, +) -> Option { + if file.get_service().is_empty() { + return None; + } + + let base = protobuf::descriptorx::proto_path_to_rust_mod(file.get_name()); + + let mut v = Vec::new(); + { + let mut w = CodeWriter::new(&mut v); + w.write_generated(); + w.write_line("#![allow(unused_variables)]"); + w.write_line("use futures::Future;"); + w.write_line(""); + for service in file.get_service() { + w.write_line(""); + ClientTraitGen::new(service, file, root_scope).write(&mut w); + } + } + + Some(compiler_plugin::GenResult { + name: base + "_client.rs", + content: v, + }) +} + +pub fn gen( + file_descriptors: &[FileDescriptorProto], + files_to_generate: &[&str], +) -> Vec { + let files_map: HashMap<&str, &FileDescriptorProto> = + file_descriptors.iter().map(|f| (f.get_name(), f)).collect(); + + let root_scope = RootScope { file_descriptors }; + + let mut results = Vec::new(); + + for file_name in files_to_generate { + // not all files need client stubs, some are simple protobufs, no service + let temp1 = file_name.to_string(); + let file = files_map[&temp1[..]]; + + if file.get_service().is_empty() { + continue; + } + + results.extend(gen_file(file, &root_scope).into_iter()); + } + + results +} diff --git a/common/grpcio-client/src/lib.rs b/common/grpcio-client/src/lib.rs new file mode 100644 index 0000000000000..1968286656eaa --- /dev/null +++ b/common/grpcio-client/src/lib.rs @@ -0,0 +1,71 @@ +// Copyright (c) The Libra Core Contributors +// SPDX-License-Identifier: Apache-2.0 + +use protoc_grpcio::CompileResult; +use std::path::Path; + +/// This crate provides a library for generating a Client trait for GRPC clients +/// generated by grpc-rs (protoc-grpcio) +/// +/// This crate compliments the functionality provided by `protoc-grpcio` by defining a Trait for +/// the GRPC client service that can be used instead of the client directly for polymorphism and +/// testing. +/// +/// +/// ## Usage Example +/// +/// To generate client trait as part of `build.rs` script, add: +/// +/// ```ignore,no_run +/// grpcio_client::client_stub_gen( +/// &["calculator.proto"], /* service files to generate traits for */ +/// &["src/proto", "../deps/src/proto"], /* proto paths & includes */ +/// "src/proto", /* target dir */ +/// ); +/// ``` +/// +/// This will create the file `calculator_client.rs` under `src/proto` folder. +/// +/// The generated file will include 2 structures: +/// ```rust +/// // assuming the service name is `Calculator` +/// pub trait CalculatorClientTrait { +/// // methods +/// } +/// ``` +/// and +/// +/// ```rust +/// # struct CalculatorClient; +/// # pub trait CalculatorClientTrait { +/// # // methods +/// # } +/// +/// impl CalculatorClientTrait for CalculatorClient { +/// // method impl -- calling method from client +/// } +/// ``` +mod codegen; +mod util; + +/// Generate client trait for the GRPC Client +/// * `from` - the files with the services to generate client traits for +/// * `includes` - a vector of the parent folder of the files from `from` and all their includes. +/// * `to` - a path to a folder to store the generated files. +/// +/// Generates client trait for the GRPC service defined in the first argument. +/// `from` argument includes +/// +/// ## Example use: +/// client_stub_gen(&["src/proto/myservice.proto"], vec![], "src/proto"); +pub fn client_stub_gen>( + from: &[&str], + includes: &[&str], + to: P, +) -> CompileResult<()> { + let descriptor_set = util::protoc_descriptor_set(from, includes)?; + util::write_out_generated_files(codegen::gen(descriptor_set.get_file(), from), &to) + .expect("failed to write generated grpc definitions"); + + Ok(()) +} diff --git a/common/grpcio-client/src/util.rs b/common/grpcio-client/src/util.rs new file mode 100644 index 0000000000000..3e96747b8cb0b --- /dev/null +++ b/common/grpcio-client/src/util.rs @@ -0,0 +1,97 @@ +// Copyright (c) The Libra Core Contributors +// SPDX-License-Identifier: Apache-2.0 + +use mktemp::Temp; +use protobuf::{compiler_plugin, descriptor::FileDescriptorSet, error::ProtobufError}; +use protoc::{DescriptorSetOutArgs, Protoc}; +use protoc_grpcio::CompileResult; +use regex::Regex; +use std::{ + fs::File, + io::{Read, Write}, + path::Path, +}; + +/// copied from protoc-grpcio +/// it's not public there +pub(crate) fn write_out_generated_files

( + generation_results: Vec, + output_dir: P, +) -> CompileResult<()> +where + P: AsRef, +{ + for result in generation_results { + let file = output_dir.as_ref().join(result.name); + File::create(&file) + .expect("failed to create file") + .write_all(&result.content) + .expect("failed to write file"); + } + + Ok(()) +} + +/// Generate snake case names. This is useful +/// helloWorldFoo => hello_world_foo +/// ID => :( not working for this case +/// +/// This needs to be the same as in grpcio-compiler, but I +/// didn't copy it. +pub fn to_snake_case(name: &str) -> String { + let re = Regex::new("((:?^|(:?[A-Z]))[a-z0-9_]+)").unwrap(); + let mut words = vec![]; + for cap in re.captures_iter(name) { + words.push(cap.get(1).unwrap().as_str().to_lowercase()); + } + words.join("_") // my best line of code +} + +// TODO: frumious: make camel case +pub fn to_camel_case(name: &str) -> String { + // do nothing for now + name.to_string() +} + +pub fn fq_grpc(item: &str) -> String { + format!("::grpcio::{}", item) +} + +pub enum MethodType { + Unary, + ClientStreaming, + ServerStreaming, + Duplex, +} + +pub fn protoc_descriptor_set( + from: &[&str], + includes: &[&str], +) -> Result { + let protoc = Protoc::from_env_path(); + protoc + .check() + .expect("failed to find `protoc`, `protoc` must be availabe in `PATH`"); + + let descriptor_set = Temp::new_file()?; + + protoc + .write_descriptor_set(DescriptorSetOutArgs { + out: match descriptor_set.as_ref().to_str() { + Some(s) => s, + None => unreachable!("failed to convert path to string"), + }, + input: from, + includes, + include_imports: true, + }) + .expect("failed to write descriptor set"); + + let mut serialized_descriptor_set = Vec::new(); + File::open(&descriptor_set) + .expect("failed to open descriptor set") + .read_to_end(&mut serialized_descriptor_set) + .expect("failed to read descriptor set"); + + protobuf::parse_from_bytes::(&serialized_descriptor_set) +} diff --git a/common/grpcio-extras/Cargo.toml b/common/grpcio-extras/Cargo.toml new file mode 100644 index 0000000000000..f7b16f44ca398 --- /dev/null +++ b/common/grpcio-extras/Cargo.toml @@ -0,0 +1,11 @@ +[package] +name = "grpcio-extras" +version = "0.1.0" +authors = ["Libra Association "] +license = "Apache-2.0" +publish = false +edition = "2018" + +[dependencies] +futures = "0.1.25" +grpcio = "0.4.3" diff --git a/common/grpcio-extras/src/lib.rs b/common/grpcio-extras/src/lib.rs new file mode 100644 index 0000000000000..da9a8b3b635b7 --- /dev/null +++ b/common/grpcio-extras/src/lib.rs @@ -0,0 +1,14 @@ +// Copyright (c) The Libra Core Contributors +// SPDX-License-Identifier: Apache-2.0 + +use futures::Future; + +pub trait Cancelable { + fn cancel(&mut self); +} + +impl Cancelable for T { + fn cancel(&mut self) { + unimplemented!(); + } +} diff --git a/common/jemalloc/Cargo.toml b/common/jemalloc/Cargo.toml new file mode 100644 index 0000000000000..b6b7c59274a96 --- /dev/null +++ b/common/jemalloc/Cargo.toml @@ -0,0 +1,11 @@ +[package] +name = "jemalloc" +version = "0.1.0" +authors = ["Libra Association "] +license = "Apache-2.0" +publish = false +edition = "2018" + +[dependencies] +jemalloc-sys= { version = "0.1.8", features = ["profiling", "unprefixed_malloc_on_supported_platforms"] } +jemallocator = { version = "0.1.8", features = ["alloc_trait", "profiling"] } diff --git a/common/jemalloc/src/lib.rs b/common/jemalloc/src/lib.rs new file mode 100644 index 0000000000000..a29bd35d91693 --- /dev/null +++ b/common/jemalloc/src/lib.rs @@ -0,0 +1,37 @@ +// Copyright (c) The Libra Core Contributors +// SPDX-License-Identifier: Apache-2.0 + +use std::{ffi::CString, ptr}; + +#[global_allocator] +static ALLOC: jemallocator::Jemalloc = jemallocator::Jemalloc; + +/// Tell jemalloc to dump the heap memory profile. This only works if +/// the binary is started with environment variable +/// +/// MALLOC_CONF="prof:true,prof_prefix:jeprof.out" +/// +/// Calling this function will cause jemalloc to write the memory +/// profile to the current working directory. Then one can process the +/// heap profile to various format with the 'jeprof' utility. ie +/// +/// jeprof --pdf target/debug/libra_node jeprof.out.2141437.2.m2.heap > out.pdf +/// +/// Returns the error code coming out of jemalloc if heap dump fails. +pub fn dump_jemalloc_memory_profile() -> Result<(), i32> { + let opt_name = CString::new("prof.dump").expect("CString::new failed."); + unsafe { + let err_code = jemalloc_sys::mallctl( + opt_name.as_ptr(), + ptr::null_mut(), + ptr::null_mut(), + ptr::null_mut(), + 0, + ); + if err_code == 0 { + Ok(()) + } else { + Err(err_code) + } + } +} diff --git a/common/logger/Cargo.toml b/common/logger/Cargo.toml new file mode 100644 index 0000000000000..3a89d20a47906 --- /dev/null +++ b/common/logger/Cargo.toml @@ -0,0 +1,36 @@ +[package] +name = "logger" +version = "0.1.0" +authors = ["Libra Association "] +license = "Apache-2.0" +publish = false +edition = "2018" + +[dependencies] +backtrace = { version = "0.3", features = ["serialize-serde"] } +chrono = "0.4" +crossbeam = "^0.4.1" +futures = "0.1.24" +hyper = "0.12" +itertools = "0.8.0" +lazy_static = "1.3.0" +mime = "0.3.2" +rand = "0.6.5" +serde = { version = "1.0.91", features = ["derive"] } +serde_json = "1.0.39" +# use this line to change verbosity +slog = { version = "2.4", features = ["max_level_trace", "release_max_level_debug"] } +slog-async = "2.3" +slog-envlogger = "2.1.0" +slog-scope = "4.0" +slog-term = "2.4" +thread-id = "3.3.0" +tokio = "0.1.16" + +# Do NOT add any other inter-project dependencies. +# This is to avoid ever having a circular dependency with the logger crate. +failure = { package = "failure_ext", path = "../failure_ext" } + +[dev-dependencies] +rand = "0.6.5" +regex = "1.1.6" diff --git a/common/logger/src/collector_serializer.rs b/common/logger/src/collector_serializer.rs new file mode 100644 index 0000000000000..2c7d88817f685 --- /dev/null +++ b/common/logger/src/collector_serializer.rs @@ -0,0 +1,306 @@ +// Copyright (c) The Libra Core Contributors +// SPDX-License-Identifier: Apache-2.0 + +//! Provides CollectorSerializer. See it's documentation for more help + +use std::{ + collections::HashMap, + fmt::{format, Arguments}, +}; + +use slog::{Key, Result, Serializer}; + +use crate::kv_categorizer::KVCategorizer; + +/// This serializer collects all KV pairs into a Vec, converting the values to `String`. +/// It filters out the one that are of `KVCategory::Ignore` +pub struct CollectorSerializer<'a, C: KVCategorizer>(Vec<(Key, String)>, &'a C); + +impl<'a, C: KVCategorizer> CollectorSerializer<'a, C> { + /// Create a collector serializer that will use the given categorizer to collect desired values + pub fn new(categorizer: &'a C) -> Self { + CollectorSerializer(Vec::new(), categorizer) + } + + /// Once done collecting KV pairs call this to retrieve collected values + pub fn into_inner(self) -> Vec<(Key, String)> { + self.0 + } +} + +/// Define a macro to implement serializer emit functions. +macro_rules! impl_emit_body( + ($s:expr, $k:expr, $v:expr) => { + if $s.1.ignore($k) { + return Ok(()) + } + $s.0.push(($k, format!("{}", $v))); + }; +); + +/// Define a macro to implement serializer emit functions for standard types. +macro_rules! impl_emit( + ($name:ident, $t:ty) => { + /// Emit $t + fn $name(&mut self, key: Key, val: $t) -> Result { + impl_emit_body!(self, key, val); + Ok(()) + } + }; +); + +impl<'a, C: KVCategorizer> Serializer for CollectorSerializer<'a, C> { + /// Emit None + fn emit_none(&mut self, key: Key) -> Result { + impl_emit_body!(self, key, "None"); + Ok(()) + } + + /// Emit () + fn emit_unit(&mut self, key: Key) -> Result { + impl_emit_body!(self, key, "()"); + Ok(()) + } + + impl_emit!(emit_usize, usize); + impl_emit!(emit_isize, isize); + impl_emit!(emit_bool, bool); + impl_emit!(emit_char, char); + impl_emit!(emit_u8, u8); + impl_emit!(emit_i8, i8); + impl_emit!(emit_u16, u16); + impl_emit!(emit_i16, i16); + impl_emit!(emit_u32, u32); + impl_emit!(emit_i32, i32); + impl_emit!(emit_f32, f32); + impl_emit!(emit_u64, u64); + impl_emit!(emit_i64, i64); + impl_emit!(emit_f64, f64); + impl_emit!(emit_str, &str); + impl_emit!(emit_arguments, &Arguments<'_>); +} + +/// This serializer collects all KV pairs into a map, converting the values to `String`. +#[derive(Default)] +pub struct PlainKVSerializer(HashMap<&'static str, String>); + +impl PlainKVSerializer { + pub fn new() -> Self { + Default::default() + } + /// Once done collecting KV pairs call this to retrieve collected values + pub fn into_inner(self) -> HashMap<&'static str, String> { + self.0 + } +} + +impl Serializer for PlainKVSerializer { + fn emit_arguments(&mut self, key: &'static str, value: &Arguments) -> slog::Result { + self.0.insert(key, format(*value)); + Ok(()) + } +} + +#[cfg(test)] +mod tests { + use super::*; + + use itertools::assert_equal; + use rand::{rngs::StdRng, Rng, SeedableRng}; + use slog::{b, record, Level, Record, Result as SlogResult, KV}; + + use crate::kv_categorizer::{InlineCategorizer, KVCategory}; + + #[derive(Clone)] + struct TestKv { + key: Key, + vusize: usize, + visize: isize, + vbool: bool, + vchar: char, + vu8: u8, + vi8: i8, + vu16: u16, + vi16: i16, + vu32: u32, + vi32: i32, + vf32: f32, + vu64: u64, + vi64: i64, + vf64: f64, + vstr: String, + } + + impl TestKv { + fn new(key: Key, rng: &mut R) -> Self { + TestKv { + key, + vusize: rng.gen(), + visize: rng.gen(), + vbool: rng.gen(), + vchar: rng.gen(), + vu8: rng.gen(), + vi8: rng.gen(), + vu16: rng.gen(), + vi16: rng.gen(), + vu32: rng.gen(), + vi32: rng.gen(), + vf32: rng.gen(), + vu64: rng.gen(), + vi64: rng.gen(), + vf64: rng.gen(), + vstr: format!("value{}", rng.gen::()), + } + } + + fn to_vec(&self) -> Vec<(Key, String)> { + vec![ + (self.key, "None".to_owned()), + (self.key, "()".to_owned()), + (self.key, format!("{}", self.vusize)), + (self.key, format!("{}", self.visize)), + (self.key, format!("{}", self.vbool)), + (self.key, format!("{}", self.vchar)), + (self.key, format!("{}", self.vu8)), + (self.key, format!("{}", self.vi8)), + (self.key, format!("{}", self.vu16)), + (self.key, format!("{}", self.vi16)), + (self.key, format!("{}", self.vu32)), + (self.key, format!("{}", self.vi32)), + (self.key, format!("{}", self.vf32)), + (self.key, format!("{}", self.vu64)), + (self.key, format!("{}", self.vi64)), + (self.key, format!("{}", self.vf64)), + (self.key, self.vstr.clone()), + ] + } + } + + impl KV for TestKv { + fn serialize(&self, _record: &Record<'_>, serializer: &mut dyn Serializer) -> SlogResult { + serializer + .emit_none(self.key) + .expect("failure emitting none"); + serializer + .emit_unit(self.key) + .expect("failure emitting unit"); + serializer + .emit_usize(self.key, self.vusize) + .expect("failure emitting usize"); + serializer + .emit_isize(self.key, self.visize) + .expect("failure emitting isize"); + serializer + .emit_bool(self.key, self.vbool) + .expect("failure emitting bool"); + serializer + .emit_char(self.key, self.vchar) + .expect("failure emitting char"); + serializer + .emit_u8(self.key, self.vu8) + .expect("failure emitting u8"); + serializer + .emit_i8(self.key, self.vi8) + .expect("failure emitting i8"); + serializer + .emit_u16(self.key, self.vu16) + .expect("failure emitting u16"); + serializer + .emit_i16(self.key, self.vi16) + .expect("failure emitting i16"); + serializer + .emit_u32(self.key, self.vu32) + .expect("failure emitting u32"); + serializer + .emit_i32(self.key, self.vi32) + .expect("failure emitting i32"); + serializer + .emit_f32(self.key, self.vf32) + .expect("failure emitting f32"); + serializer + .emit_u64(self.key, self.vu64) + .expect("failure emitting u64"); + serializer + .emit_i64(self.key, self.vi64) + .expect("failure emitting i64"); + serializer + .emit_f64(self.key, self.vf64) + .expect("failure emitting f64"); + serializer + .emit_str(self.key, &self.vstr) + .expect("failure emitting str"); + Ok(()) + } + } + + fn do_test(categorizer: &C, kv_values: V, kv_expected: E) + where + C: KVCategorizer, + V: IntoIterator, + E: IntoIterator, + { + let mut serializer = CollectorSerializer::new(categorizer); + + for value in kv_values { + value + .serialize( + &record!(Level::Info, "test", &format_args!(""), b!()), + &mut serializer, + ) + .expect("serialize failed!"); + } + + assert_equal( + serializer.into_inner(), + kv_expected.into_iter().flat_map(|x| x.to_vec()), + ); + } + + #[test] + fn test_inline_all() { + let mut rng: StdRng = SeedableRng::from_seed([1; 32]); + let input = vec![ + TestKv::new("test1", &mut rng), + TestKv::new("test2", &mut rng), + ]; + do_test(&InlineCategorizer, vec![], vec![]); + do_test(&InlineCategorizer, input.clone(), input); + } + + struct TestCategorizer; + impl KVCategorizer for TestCategorizer { + fn categorize(&self, _key: Key) -> KVCategory { + unimplemented!(); // It's not used by serializer + } + + fn name(&self, _key: Key) -> &'static str { + unimplemented!(); // It's not used by serializer + } + + fn ignore(&self, key: Key) -> bool { + key.starts_with("ignore") + } + } + + #[test] + fn test_ignoring() { + let mut rng: StdRng = SeedableRng::from_seed([2; 32]); + let normal = vec![ + TestKv::new("test1", &mut rng), + TestKv::new("test2", &mut rng), + ]; + let ignore = vec![ + TestKv::new("ignore1", &mut rng), + TestKv::new("ignore2", &mut rng), + ]; + let n = || normal.iter().cloned(); + let i = || ignore.iter().cloned(); + let e = || vec![]; + + do_test(&TestCategorizer, e(), e()); + do_test(&TestCategorizer, n(), n()); + do_test(&TestCategorizer, i(), e()); + do_test(&TestCategorizer, n().chain(i()), n()); + do_test(&TestCategorizer, i().chain(n()), n()); + } +} diff --git a/common/logger/src/glog_format.rs b/common/logger/src/glog_format.rs new file mode 100644 index 0000000000000..fd7a7c4a45c39 --- /dev/null +++ b/common/logger/src/glog_format.rs @@ -0,0 +1,258 @@ +// Copyright (c) The Libra Core Contributors +// SPDX-License-Identifier: Apache-2.0 + +//! Provides common slog with glog formatting. + +use crate::{ + collector_serializer::CollectorSerializer, + kv_categorizer::{KVCategorizer, KVCategory}, +}; +use chrono; +use itertools::{Either, Itertools}; +use slog::{Drain, Key, Level, OwnedKVList, Record, KV}; +use slog_term::{Decorator, RecordDecorator}; +use std::{io, str}; +use thread_id; + +/// A slog `Drain` for glog-formatted logs. +pub struct GlogFormat { + decorator: D, + categorizer: C, +} + +impl GlogFormat { + /// Create a glog-formatted `Drain` using the provided `Decorator`, and `Categorizer` + pub fn new(decorator: D, categorizer: C) -> GlogFormat { + GlogFormat { + decorator, + categorizer, + } + } +} + +fn write_logline( + decorator: &mut dyn RecordDecorator, + level: Level, + metadata: &OnelineMetadata, +) -> io::Result<()> { + // Convert log level to a single character representation. + let level = match level { + Level::Critical => 'C', + Level::Error => 'E', + Level::Warning => 'W', + Level::Info => 'I', + Level::Debug => 'D', + Level::Trace => 'T', + }; + + decorator.start_level()?; + write!(decorator, "{}", level)?; + + decorator.start_timestamp()?; + write!(decorator, "{}", metadata.now.format("%m%d %H:%M:%S%.6f"))?; + + decorator.start_whitespace()?; + write!(decorator, " ")?; + + // Write the message. + decorator.start_msg()?; + write!( + decorator, + "{tid:>5} {file}:{line}] ", + tid = metadata.tid, + file = metadata.file, + line = metadata.line, + ) +} + +fn print_inline_kv( + decorator: &mut dyn RecordDecorator, + categorizer: &C, + kv: Vec<(Key, String)>, +) -> io::Result<()> { + for (k, v) in kv { + decorator.start_comma()?; + write!(decorator, ", ")?; + decorator.start_key()?; + write!(decorator, "{}", categorizer.name(k))?; + decorator.start_separator()?; + write!(decorator, ": ")?; + decorator.start_value()?; + write!(decorator, "{}", v)?; + } + Ok(()) +} + +fn finish_logline(decorator: &mut dyn RecordDecorator) -> io::Result<()> { + decorator.start_whitespace()?; + writeln!(decorator)?; + decorator.flush() +} + +impl Drain for GlogFormat { + type Ok = (); + type Err = io::Error; + + fn log(&self, record: &Record<'_>, values: &OwnedKVList) -> io::Result { + self.decorator.with_record(record, values, |decorator| { + let (inline_kv, level_kv): (Vec<_>, Vec<_>) = { + let mut serializer = CollectorSerializer::new(&self.categorizer); + values.serialize(record, &mut serializer)?; + record.kv().serialize(record, &mut serializer)?; + + serializer + .into_inner() + .into_iter() + .filter_map(|(k, v)| match self.categorizer.categorize(k) { + KVCategory::Ignore => None, + KVCategory::Inline => Some((None, k, v)), + KVCategory::LevelLog(level) => Some((Some(level), k, v)), + }) + .partition_map(|(l, k, v)| match l { + None => Either::Left((k, v)), + Some(level) => Either::Right((level, k, v)), + }) + }; + + let metadata = OnelineMetadata::new(record); + + write_logline(decorator, record.level(), &metadata)?; + write!(decorator, "{}", record.msg())?; + print_inline_kv(decorator, &self.categorizer, inline_kv)?; + finish_logline(decorator)?; + + for (level, k, v) in level_kv { + write_logline(decorator, level, &metadata)?; + write!(decorator, "{}: {}", self.categorizer.name(k), v)?; + finish_logline(decorator)?; + } + Ok(()) + }) + } +} + +struct OnelineMetadata { + now: chrono::DateTime, + tid: usize, + file: &'static str, + line: u32, +} + +impl OnelineMetadata { + fn new(record: &Record<'_>) -> Self { + OnelineMetadata { + now: chrono::Local::now(), + tid: thread_id::get(), + file: record.file(), + line: record.line(), + } + } +} + +#[cfg(test)] +mod tests { + use super::GlogFormat; + + use std::{ + io, + sync::{Arc, Mutex}, + }; + + use crate::kv_categorizer::InlineCategorizer; + use lazy_static::lazy_static; + use regex::{Captures, Regex}; + use slog::{info, o, Drain, Logger}; + use slog_term::PlainSyncDecorator; + use thread_id; + + lazy_static! { + // Create a regex that matches log lines. + static ref LOG_REGEX: Regex = Regex::new( + r"^(.)(\d{4} \d\d:\d\d:\d\d\.\d{6}) +(\d+) ([^:]+):(\d+)\] (.*)$", + ).unwrap(); + } + + /// Wrap a buffer so that it can be used by slog as a log output. + #[derive(Clone)] + pub struct TestBuffer { + buffer: Arc>>, + } + + impl TestBuffer { + pub fn new() -> TestBuffer { + TestBuffer { + buffer: Arc::new(Mutex::new(Vec::new())), + } + } + + pub fn get_string(&self) -> String { + let buffer = self.buffer.lock().unwrap(); + String::from_utf8(buffer.clone()).unwrap() + } + } + + impl io::Write for TestBuffer { + fn write(&mut self, buf: &[u8]) -> io::Result { + self.buffer.lock().unwrap().write(buf) + } + + fn flush(&mut self) -> io::Result<()> { + self.buffer.lock().unwrap().flush() + } + } + + #[derive(Debug, PartialEq, Eq)] + struct TestLine { + level: String, + tid: String, + file: String, + line: String, + msg: String, + } + + impl<'a> TestLine { + fn new(level: &'static str, line: u32, msg: &'static str) -> Self { + TestLine { + level: level.to_owned(), + tid: thread_id::get().to_string(), + file: file!().to_owned(), + line: line.to_string(), + msg: msg.to_owned(), + } + } + + fn with_captures(captures: Captures<'a>) -> Self { + TestLine { + level: captures.get(1).unwrap().as_str().to_owned(), + tid: captures.get(3).unwrap().as_str().to_owned(), + file: captures.get(4).unwrap().as_str().to_owned(), + line: captures.get(5).unwrap().as_str().to_owned(), + msg: captures.get(6).unwrap().as_str().to_owned(), + } + } + } + + #[test] + fn test_inline() { + // Create a logger that logs to a buffer instead of stderr. + let test_buffer = TestBuffer::new(); + let decorator = PlainSyncDecorator::new(test_buffer.clone()); + let drain = GlogFormat::new(decorator, InlineCategorizer).fuse(); + let log = Logger::root(drain, o!("mode" => "test")); + + // Send a log to the buffer. Remember the line the log was on. + let line = line!() + 1; + info!(log, "Test log {}", 1; "tau" => 6.28); + + // Get the log string back out of the buffer. + let log_string = test_buffer.get_string(); + + // Check the log line's fields to make sure they match expected values. + // For the timestamp, it's sufficient to just check it has the right form. + let captures = LOG_REGEX.captures(log_string.as_str().trim_end()).unwrap(); + assert_eq!( + TestLine::with_captures(captures), + TestLine::new("I", line, "Test log 1, mode: test, tau: 6.28",) + ); + } +} diff --git a/common/logger/src/http_local_slog_drain.rs b/common/logger/src/http_local_slog_drain.rs new file mode 100644 index 0000000000000..f5fb303e4304c --- /dev/null +++ b/common/logger/src/http_local_slog_drain.rs @@ -0,0 +1,69 @@ +// Copyright (c) The Libra Core Contributors +// SPDX-License-Identifier: Apache-2.0 + +//! Slog::Drain implementation to send log to specified http end point using POST. +//! A PlainKVSerializer is used to collect KVs from both Record and OwnedKVList. +//! The log sent will be plain json. +//! +//! ## Usage +//! +//! use slog::{o, Drain, Logger, *}; +//! let client = crate::logger::http_log_client::HttpLogClient{ +//! use_https: false, +//! destination: "http://localhost:1234".to_string(), +//! }; +//! let drain = crate::logger::http_local_slog_drain::HttpLocalSlogDrain { client }; +//! let logger = Logger::root(drain.fuse(), o!("component" => "admission_control")); +//! slog_info!(logger, "test info log"; "log-key" => true); +use crate::{collector_serializer::PlainKVSerializer, http_log_client::HttpLogClient}; +use serde_json::json; +use slog::{slog_error, Drain, OwnedKVList, Record, KV}; +pub use slog_scope::error; +use std::{ + error::Error, + time::{SystemTime, UNIX_EPOCH}, +}; + +#[derive(Debug)] +pub struct HttpLocalSlogDrain { + client: HttpLogClient, +} +impl Drain for HttpLocalSlogDrain { + type Ok = (); + type Err = slog::Never; + fn log(&self, record: &Record, values: &OwnedKVList) -> Result { + let ret = self.log_impl(record, values); + if let Some(e) = ret.err() { + // The error from logging should not be cascading, but we have to log it + // somewhere for troubleshooting. + error!("Error sending log using http client: {}", e); + } + Ok(()) + } +} + +impl HttpLocalSlogDrain { + pub fn new(client: HttpLogClient) -> Self { + HttpLocalSlogDrain { client } + } + fn log_impl(&self, record: &Record, values: &OwnedKVList) -> Result<(), Box> { + let mut serializer = PlainKVSerializer::new(); + values.serialize(record, &mut serializer)?; + record.kv().serialize(record, &mut serializer)?; + let mut kvs = serializer.into_inner(); + kvs.insert("msg", format!("{}", record.msg())); + kvs.insert("level", record.level().as_str().to_string()); + kvs.insert("current_mod", module_path!().to_string()); + kvs.insert( + "time", + SystemTime::now() + .duration_since(UNIX_EPOCH) + .expect("Current timestamp is earlier than UNIX epoch") + .as_secs() + .to_string(), + ); + + self.client.send_log(json!(kvs).to_string())?; + Ok(()) + } +} diff --git a/common/logger/src/http_log_client.rs b/common/logger/src/http_log_client.rs new file mode 100644 index 0000000000000..37e75b9daeec3 --- /dev/null +++ b/common/logger/src/http_log_client.rs @@ -0,0 +1,70 @@ +// Copyright (c) The Libra Core Contributors +// SPDX-License-Identifier: Apache-2.0 + +//! Send log to a http end point. Error will be propagate to call side. +//! +//! ## Usage +//! +//! let client = crate::logger::http_log_client::HttpLogClient{ +//! use_https: false, +//! destination: "http://localhost:1234".to_string(), +//! }; +use failure::Result; +use futures::{future::Future, stream::Stream}; +use hyper::{header, Body, Client, Request, Uri}; +use slog::{slog_error, slog_trace}; +pub use slog_scope::{debug, error, trace}; +use std::thread; +use tokio::runtime::current_thread; + +#[derive(Debug)] +pub struct HttpLogClient { + // Destination string for this client. + // Https is not supported and will result an error. + destination: String, +} + +impl HttpLogClient { + pub fn new(destination: String) -> Result { + let uri = destination.parse::()?; + if let Some(schema) = uri.scheme_part() { + if schema.as_str() == "https" { + error!("Https is not supported, uri {}", destination) + } + } + + Ok(HttpLogClient { destination }) + } + + pub fn send_log(&self, logs: String) -> Result<()> { + let client = Client::builder().build_http(); + let body = format!("json={}", logs); + let req = Request::post(self.destination.clone()) + .header( + header::CONTENT_TYPE, + mime::APPLICATION_WWW_FORM_URLENCODED.as_ref(), + ) + .header(header::CONTENT_LENGTH, body.len() as u64) + .body(Body::from(body))?; + + let fut = client + .request(req) + .and_then(|res| { + let status = res.status(); + res.into_body().concat2().and_then(move |body| { + trace!( + "Log sent, status: {}, response: {}", + status, + String::from_utf8_lossy(&body).into_owned() + ); + Ok(()) + }) + }) + .map_err(|e| error!("error sending log: {}", e)); + + // TODO Verify whether we need to spawn thread for each call or slog_async (queue + 1 + // worker) is sufficient. + thread::spawn(move || current_thread::Runtime::new().unwrap().block_on(fut)); + Ok(()) + } +} diff --git a/common/logger/src/kv_categorizer.rs b/common/logger/src/kv_categorizer.rs new file mode 100644 index 0000000000000..34ba69a8d1e7c --- /dev/null +++ b/common/logger/src/kv_categorizer.rs @@ -0,0 +1,83 @@ +// Copyright (c) The Libra Core Contributors +// SPDX-License-Identifier: Apache-2.0 + +//! Provides ways to control how the KV values passed to slog macros are printed + +use slog::{Key, Level}; + +/// The KV value is being processed based on the category it is bucketed in +#[derive(Debug, PartialEq, Eq)] +pub enum KVCategory { + /// KV value is not printed at all + Ignore, + /// KV value is inlined with the main message passed to slog macro + Inline, + /// KV value is printed as a separate line with the provided log level + LevelLog(Level), +} + +/// Structures implementing this trait are being used to categorize the KV values into one of the +/// `KVCategory`. +pub trait KVCategorizer { + /// For a given key from KV decide which category it belongs to + fn categorize(&self, key: Key) -> KVCategory; + /// For a given key from KV return a name that should be printed for it + fn name(&self, key: Key) -> &'static str; + /// True if category of a given key is KVCategory::Ignore + fn ignore(&self, key: Key) -> bool { + self.categorize(key) == KVCategory::Ignore + } +} + +/// Placeholder categorizer that inlines all KV values with names equal to key +pub struct InlineCategorizer; +impl KVCategorizer for InlineCategorizer { + fn categorize(&self, _key: Key) -> KVCategory { + KVCategory::Inline + } + + fn name(&self, key: Key) -> &'static str { + key + } +} + +/// Used to properly print `error_chain` `Error`s. It displays the error and it's causes in +/// separate log lines as well as backtrace if provided. +/// The `error_chain` `Error` must implement `KV` trait. It is recommended to use `impl_kv_error` +/// macro to generate the implementation. +pub struct ErrorCategorizer; +impl KVCategorizer for ErrorCategorizer { + fn categorize(&self, key: Key) -> KVCategory { + match key { + "error" => KVCategory::LevelLog(Level::Error), + "cause" => KVCategory::LevelLog(Level::Debug), + "backtrace" => KVCategory::LevelLog(Level::Trace), + _ => InlineCategorizer.categorize(key), + } + } + + fn name(&self, key: Key) -> &'static str { + match key { + "error" => "Error", + "cause" => "Caused by", + "backtrace" => "Originated in", + "root_cause" => "Root cause", + _ => InlineCategorizer.name(key), + } + } +} + +#[cfg(test)] +mod tests { + use super::*; + + #[test] + fn test_inline() { + let categorizer = InlineCategorizer; + let values = vec!["test", "test2"]; + for v in values { + assert_eq!(categorizer.categorize(v), KVCategory::Inline); + assert_eq!(categorizer.name(v), v); + } + } +} diff --git a/common/logger/src/lib.rs b/common/logger/src/lib.rs new file mode 100644 index 0000000000000..bde7d59853439 --- /dev/null +++ b/common/logger/src/lib.rs @@ -0,0 +1,225 @@ +// Copyright (c) The Libra Core Contributors +// SPDX-License-Identifier: Apache-2.0 + +//! A default logger for Libra project. +//! +//! ## Usage +//! +//! ```rust, no_run +//! use logger::prelude::*; +//! +//! pub fn main() { +//! let _g = logger::set_default_global_logger(false /* async */, Some(256)); +//! info!("Starting..."); +//! } +//! ``` + +mod collector_serializer; +mod glog_format; +mod http_local_slog_drain; +mod http_log_client; +mod kv_categorizer; +mod security; +mod simple_logger; + +use crate::{ + http_local_slog_drain::HttpLocalSlogDrain, http_log_client::HttpLogClient, + kv_categorizer::ErrorCategorizer, +}; +use crossbeam::atomic::ArcCell; +use failure::prelude::*; +use glog_format::GlogFormat; +use lazy_static::lazy_static; +use slog::{o, Discard, Drain, FilterLevel, Logger, Never}; +pub use slog::{slog_crit, slog_debug, slog_error, slog_info, slog_trace, slog_warn}; +use slog_async::Async; +use slog_envlogger::{EnvLogger, LogBuilder}; +pub use slog_scope::{crit, debug, error, info, trace, warn}; +use slog_scope::{set_global_logger, GlobalLoggerGuard}; +use slog_term::{PlainDecorator, TermDecorator}; +use std::sync::{Arc, Mutex}; + +/// Logger prelude which includes all logging macros. +pub mod prelude { + pub use crate::{ + log_collector_crit, log_collector_debug, log_collector_error, log_collector_info, + log_collector_trace, log_collector_warn, + security::{security_log, SecurityEvent}, + }; + pub use slog::{slog_crit, slog_debug, slog_error, slog_info, slog_trace, slog_warn}; + pub use slog_scope::{crit, debug, error, info, trace, warn}; +} + +pub use simple_logger::{set_simple_logger, set_simple_logger_prefix}; + +/// Creates and sets default global logger. +/// Caller must keep the returned guard alive. +pub fn set_default_global_logger(async_drain: bool, chan_size: Option) -> GlobalLoggerGuard { + let logger = create_default_root_logger(async_drain, chan_size); + set_global_logger(logger) +} + +/// Creates a root logger with default settings. +fn create_default_root_logger(async_drain: bool, chan_size: Option) -> Logger { + let drain = GlogFormat::new(PlainDecorator::new(::std::io::stderr()), ErrorCategorizer).fuse(); + let logger = create_env_logger(drain); + get_logger(async_drain, chan_size, logger) +} + +/// Creates a logger that respects RUST_LOG environment variable +fn create_env_logger_with_level(drain: D, level: FilterLevel) -> EnvLogger +where + D: Drain + Send + 'static, +{ + let mut builder = LogBuilder::new(drain); + // Have the default logging level be 'Info' + builder = builder.filter(None, level); + + // Apply directives from the "RUST_LOG" env var + if let Ok(s) = ::std::env::var("RUST_LOG") { + builder = builder.parse(&s); + } + builder.build() +} + +/// Creates a logger that respects RUST_LOG environment variable +fn create_env_logger(drain: D) -> EnvLogger +where + D: Drain + Send + 'static, +{ + // Have the default logging level be 'Info' + create_env_logger_with_level(drain, FilterLevel::Info) +} + +/// Creates a root logger with test settings: does not do output if test passes. +/// Caveat: cargo test does not capture output for non main thread. So this logger is not +/// very useful for multithreading scenarios. +fn create_test_root_logger() -> Logger { + let drain = GlogFormat::new(TermDecorator::new().build(), ErrorCategorizer).fuse(); + let envlogger = create_env_logger_with_level(drain, FilterLevel::Debug); + Logger::root(Mutex::new(envlogger).fuse(), o!()) +} + +// TODO: redo this +lazy_static! { + static ref TESTING_ENVLOGGER_GUARD: GlobalLoggerGuard = { + let logger = { + if ::std::env::var("RUST_LOG").is_ok() { + create_default_root_logger(false /* async */, None /* chan_size */) + } else { + Logger::root(Discard, o!()) + } + }; + set_global_logger(logger) + }; + + static ref END_TO_END_TESTING_ENVLOGGER_GUARD: GlobalLoggerGuard = { + let logger = create_test_root_logger(); + set_global_logger(logger) + }; +} + +lazy_static! { + static ref GLOBAL_LOG_COLLECTOR: ArcCell = + ArcCell::new(Arc::new(Logger::root(Discard, o!()))); +} + +#[derive(Clone, Debug)] +pub enum LoggerType { + // Logger sending data to http destination. Data includes endpoint. + Http(String), + // Logger sending data to stdio and stderr, which will be used along with + // kv_categorizer + StdOutput, +} + +pub fn set_global_log_collector(collector: LoggerType, is_async: bool, chan_size: Option) { + // Log collector should be available at this time. + let log_collector = get_log_collector(collector, is_async, chan_size).unwrap(); + GLOBAL_LOG_COLLECTOR.set(Arc::new(log_collector)); +} + +/// Create and setup default global logger following the env-logger conventions, +/// i.e. configured by environment variable RUST_LOG. +/// This is useful to make logging optional in unit tests. +pub fn try_init_for_testing() { + ::lazy_static::initialize(&TESTING_ENVLOGGER_GUARD); +} + +/// Create and setup default global logger for use in end to end testing. +pub fn init_for_e2e_testing() { + ::lazy_static::initialize(&END_TO_END_TESTING_ENVLOGGER_GUARD); +} + +// Get external logger according to config. +fn get_log_collector( + collector: LoggerType, + is_async: bool, + chan_size: Option, +) -> Result { + match collector { + LoggerType::Http(http_endpoint) => { + let client = HttpLogClient::new(http_endpoint); + let drain = HttpLocalSlogDrain::new(client?); + Ok(get_logger(is_async, chan_size, drain)) + } + LoggerType::StdOutput => Ok(create_default_root_logger(is_async, chan_size)), + } +} + +fn get_logger(is_async: bool, chan_size: Option, drain: D) -> Logger +where + D: Drain + Send + 'static, +{ + if is_async { + let async_builder = match chan_size { + Some(chan_size_inner) => Async::new(drain).chan_size(chan_size_inner), + None => Async::new(drain), + }; + Logger::root(async_builder.build().fuse(), o!()) + } else { + Logger::root(Mutex::new(drain).fuse(), o!()) + } +} + +/// Access the `Global Logger Collector` for the current logging scope +/// +/// This function doesn't have to clone the Logger +/// so it might be a bit faster. +pub fn with_logger(f: F) -> R +where + F: FnOnce(&Logger) -> R, +{ + f(&(*GLOBAL_LOG_COLLECTOR.get())) +} + +/// Log a critical level message using current log collector +#[macro_export] +macro_rules! log_collector_crit( ($($args:tt)+) => { + $crate::with_logger(|logger| slog_crit![logger, $($args)+]) +};); +/// Log a error level message using current log collector +#[macro_export] +macro_rules! log_collector_error( ($($args:tt)+) => { + $crate::with_logger(|logger| slog_error![logger, $($args)+]) +};); +/// Log a warning level message using current log collector +#[macro_export] +macro_rules! log_collector_warn( ($($args:tt)+) => { + $crate::with_logger(|logger| slog_warn![logger, $($args)+]) +};); +/// Log a info level message using current log collector +#[macro_export] +macro_rules! log_collector_info( ($($args:tt)+) => { + $crate::with_logger(|logger| slog_info![logger, $($args)+]) +};); +/// Log a debug level message using current log collector +#[macro_export] +macro_rules! log_collector_debug( ($($args:tt)+) => { + $crate::with_logger(|logger| slog_debug![logger, $($args)+]) +};); +/// Log a trace level message using current log collector +#[macro_export] +macro_rules! log_collector_trace( ($($args:tt)+) => { + $crate::with_logger(|logger| slog_trace![logger, $($args)+]) +};); diff --git a/common/logger/src/security.rs b/common/logger/src/security.rs new file mode 100644 index 0000000000000..6f9a5b3ff11ca --- /dev/null +++ b/common/logger/src/security.rs @@ -0,0 +1,193 @@ +// Copyright (c) The Libra Core Contributors +// SPDX-License-Identifier: Apache-2.0 + +use super::*; +use backtrace::Backtrace; +use rand::{rngs::SmallRng, FromEntropy, Rng}; +use serde::Serialize; +use std::fmt::Debug; + +#[derive(Serialize)] +pub enum SecurityEvent { + /// Admission Control received a transaction with an invalid signature + InvalidTransactionAC, + + /// Mempool received a transaction with an invalid signature + InvalidTransactionMP, + + /// Consensus received a transaction with an invalid signature + InvalidTransactionConsensus, + + /// Executor received an invalid transactions chunk + InvalidChunkExecutor, + + /// Mempool received an invalid network event + InvalidNetworkEventMP, + + /// Consensus received an invalid vote + DuplicateConsensusVote, + + /// Consensus received an invalid proposal + InvalidConsensusProposal, + + /// Consensus received an invalid vote + InvalidConsensusVote, + + /// Consensus received an invalid new round message + InvalidConsensusRound, + + /// A block being committed or executed is invalid + InvalidBlock, + + /// Network identified an invalid peer + InvalidNetworkPeer, + + /// Error for testing + #[cfg(test)] + TestError, +} + +/// The `SecurityLog` struct is used to log security-sensitive operations, for instance when an +/// invalid signature is detected or when an unexpected event happens. +/// +/// The `security_log()` function should be used to instantiate this struct. It can be decorated +/// with different type of metadata: +/// - `event` contains a pre-defined element in the `SecurityEvent` enum +/// - `error` can contain an error type provided by the application +/// - `data` can contain associated metadata related to the event +/// - `backtrace` can contain a backtrace of the current call stack +/// +/// All these information can be defined by using the appropriate function with the same name. +/// +/// The method `log()` needs to be called to ensure that the event is actually printed. +/// +/// # Example: +/// ```rust +/// use logger::prelude::*; +/// use std::fmt::Debug; +/// +/// #[derive(Debug)] +/// struct SampleData { +/// i: u8, +/// s: Vec, +/// } +/// +/// #[derive(Debug)] +/// enum TestError { +/// Error, +/// } +/// +/// pub fn main() { +/// security_log(SecurityEvent::InvalidTransactionAC) +/// .error(&TestError::Error) +/// .data(&SampleData { +/// i: 0xff, +/// s: vec![0x90, 0xcd, 0x80], +/// }) +/// .data("additional payload") +/// .backtrace(100) +/// .log(); +/// } +/// ``` +/// In this example, `security_log()` logs an event of type `SecurityEvent::InvalidTransactionAC`, +/// having `TestError::Error` as application error, a `SimpleData` struct and a `String` as +/// additional metadata, and a backtrace that samples 100% of the times. + +#[must_use = "must use `log()`"] +#[derive(Serialize)] +pub struct SecurityLog { + event: SecurityEvent, + error: Option, + data: Vec, + backtrace: Option, +} + +/// Creates a `SecurityLog` struct that can be decorated with additional data. +pub fn security_log(event: SecurityEvent) -> SecurityLog { + SecurityLog::new(event) +} + +impl SecurityLog { + pub(crate) fn new(event: SecurityEvent) -> Self { + SecurityLog { + event, + error: None, + data: Vec::new(), + backtrace: None, + } + } + + /// Adds additional metadata to the `SecurityLog` struct. The argument needs to implement the + /// `std::fmt::Debug` trait. + pub fn data(mut self, data: T) -> Self { + let data = format!("{:?}", data); + if usize::checked_add(self.data.len(), 1).is_some() { + self.data.push(data); + } + self + } + + /// Adds an application error to the `SecurityLog` struct. The argument needs to implement the + /// `std::fmt::Debug` trait. + pub fn error(mut self, error: T) -> Self { + self.error = Some(format!("{:?}", error)); + self + } + + /// Adds a backtrace to the `SecurityLog` struct. + pub fn backtrace(mut self, sampling_rate: u8) -> Self { + let sampling_rate = std::cmp::min(sampling_rate, 100); + self.backtrace = { + let mut rng = SmallRng::from_entropy(); + match rng.gen_range(0, 100) { + x if x < sampling_rate => Some(Backtrace::new()), + _ => None, + } + }; + self + } + + pub(crate) fn to_string(&self) -> String { + match serde_json::to_string(&self) { + Ok(s) => s, + Err(e) => e.to_string(), + } + } + + /// Prints the `SecurityEvent` struct. + pub fn log(self) { + error!("[security] {}", self.to_string()); + } +} + +#[cfg(test)] +mod tests { + use super::*; + + #[derive(Debug)] + struct SampleData { + i: u8, + s: Vec, + } + + #[derive(Debug)] + enum TestError { + Error, + } + + #[test] + fn test_log() { + let s = security_log(SecurityEvent::TestError) + .error(&TestError::Error) + .data(&SampleData { + i: 0xff, + s: vec![0x90, 0xcd, 0x80], + }) + .data("second_payload"); + assert_eq!( + s.to_string(), + r#"{"event":"TestError","error":"Error","data":["SampleData { i: 255, s: [144, 205, 128] }","\"second_payload\""],"backtrace":null}"#, + ); + } + +} diff --git a/common/logger/src/simple_logger.rs b/common/logger/src/simple_logger.rs new file mode 100644 index 0000000000000..3d51a8ccff34a --- /dev/null +++ b/common/logger/src/simple_logger.rs @@ -0,0 +1,91 @@ +// Copyright (c) The Libra Core Contributors +// SPDX-License-Identifier: Apache-2.0 + +use slog::{o, Drain, FilterLevel, Level, Logger, OwnedKVList, Record}; +use slog_envlogger::LogBuilder; +use slog_scope::set_global_logger; +use std::{cell::RefCell, io, mem, sync::Mutex}; + +/// Simple logger mostly intended for use in test code +/// It contains bunch of boilerplate hacks to make output of test look not too verbose(but still +/// have some logs) +/// +/// Simple logger output logs into stdout in simple format: +///

{ + /// Encode a Rust struct to Protobuf bytes. + fn into_proto_bytes(self) -> Result>; +} + +/// blanket implementation for `protobuf::Message`. +impl IntoProtoBytes

for T +where + P: Message, + T: IntoProto, +{ + fn into_proto_bytes(self) -> Result> { + Ok(self.into_proto().write_to_bytes()?) + } +} + +pub trait FromProtoBytes

: Sized { + /// Decode a Rust struct from encoded Protobuf bytes. + fn from_proto_bytes(bytes: &[u8]) -> Result; +} + +/// blanket implementation for `protobuf::Message`. +impl FromProtoBytes

for T +where + P: Message, + T: FromProto, +{ + /// Decode a Rust struct from encoded Protobuf bytes. + fn from_proto_bytes(bytes: &[u8]) -> Result { + Self::from_proto(protobuf::parse_from_bytes(bytes)?) + } +} + +#[cfg(feature = "derive")] +pub use proto_conv_derive::{FromProto, IntoProto}; + +// For a few types like integers, the Rust type and Protobuf type are identical. +macro_rules! impl_direct_conversion { + ($type_name: ty) => { + impl FromProto for $type_name { + type ProtoType = $type_name; + + fn from_proto(object: Self::ProtoType) -> Result { + Ok(object) + } + } + + impl IntoProto for $type_name { + type ProtoType = $type_name; + + fn into_proto(self) -> Self::ProtoType { + self + } + } + }; +} + +impl_direct_conversion!(u32); +impl_direct_conversion!(u64); +impl_direct_conversion!(i32); +impl_direct_conversion!(i64); +impl_direct_conversion!(bool); +impl_direct_conversion!(String); +impl_direct_conversion!(Vec); + +// Note: repeated primitive type fields like Vec are not supported right now, because their +// corresponding protobuf type is Vec, not protobuf::RepeatedField. + +impl FromProto for Vec +where + P: protobuf::Message, + T: FromProto, +{ + type ProtoType = protobuf::RepeatedField

; + + fn from_proto(object: Self::ProtoType) -> Result { + object + .into_iter() + .map(T::from_proto) + .collect::>>() + } +} + +impl IntoProto for Vec +where + P: protobuf::Message, + T: IntoProto, +{ + type ProtoType = protobuf::RepeatedField

; + + fn into_proto(self) -> Self::ProtoType { + self.into_iter().map(T::into_proto).collect() + } +} diff --git a/common/proto_conv/src/test_helper.rs b/common/proto_conv/src/test_helper.rs new file mode 100644 index 0000000000000..aae9ba47ebd49 --- /dev/null +++ b/common/proto_conv/src/test_helper.rs @@ -0,0 +1,61 @@ +// Copyright (c) The Libra Core Contributors +// SPDX-License-Identifier: Apache-2.0 + +use crate::{FromProto, FromProtoBytes, IntoProto, IntoProtoBytes}; +use std::fmt::Debug; + +pub fn assert_protobuf_encode_decode(object: &T) +where + T: FromProto + IntoProto + Clone + Debug + Eq, +{ + object.assert_protobuf_encode_decode() +} + +trait ProtobufEncodeDecodeTest

: + FromProto + IntoProto + Clone + Debug + Eq +{ + /// The default implementation tests conversion roundtrip via `{Into/From}Proto`. If the type + /// implements `protobuf::Message`, also tests conversion roundtrip via + /// `{Into/From}ProtoBytes`. + fn assert_protobuf_encode_decode(&self); +} + +impl ProtobufEncodeDecodeTest

for T +where + T: FromProto + IntoProto + Clone + Debug + Eq, +{ + default fn assert_protobuf_encode_decode(&self) { + test_into_from_proto(self); + } +} + +impl ProtobufEncodeDecodeTest

for T +where + P: protobuf::Message, + T: FromProto + IntoProto + Clone + Debug + Eq, +{ + fn assert_protobuf_encode_decode(&self) { + test_into_from_proto(self); + test_into_from_proto_bytes(self); + } +} + +/// Tests conversion roundtrip via `{Into/From}Proto`. +fn test_into_from_proto(object: &T) +where + T: FromProto + IntoProto + Clone + Debug + Eq, +{ + let proto = object.clone().into_proto(); + let from_proto = T::from_proto(proto).expect("Should convert."); + assert_eq!(*object, from_proto); +} + +/// Tests conversion roundtrip via `{Into/From}ProtoBytes`. +fn test_into_from_proto_bytes(object: &T) +where + T: FromProtoBytes

+ IntoProtoBytes

+ Clone + Debug + Eq, +{ + let proto_bytes = object.clone().into_proto_bytes().expect("Should convert."); + let from_proto_bytes = T::from_proto_bytes(&proto_bytes).expect("Should convert."); + assert_eq!(*object, from_proto_bytes); +} diff --git a/common/proto_conv/tests/derive.rs b/common/proto_conv/tests/derive.rs new file mode 100644 index 0000000000000..5290c7e0c6cf2 --- /dev/null +++ b/common/proto_conv/tests/derive.rs @@ -0,0 +1,80 @@ +// Copyright (c) The Libra Core Contributors +// SPDX-License-Identifier: Apache-2.0 + +#![allow(clippy::unit_arg)] + +mod proto; + +use proptest::prelude::*; +use proptest_derive::Arbitrary; +use proto_conv::{test_helper::assert_protobuf_encode_decode, FromProto, IntoProto}; + +macro_rules! test_conversion { + ($struct_name: ident, $test_name: ident, $field_type: ty) => { + #[derive(Arbitrary, Clone, Debug, Eq, PartialEq, FromProto, IntoProto)] + #[ProtoType(crate::proto::test::$struct_name)] + struct $struct_name { + value: $field_type, + } + + proptest! { + #[test] + fn $test_name(rust_object in any::<$struct_name>()) { + let proto_object = rust_object.clone().into_proto(); + prop_assert_eq!(rust_object.clone().value, proto_object.get_value()); + + let rust_object2 = $struct_name::from_proto(proto_object) + .expect("Converting Protobuf object to Rust object should work."); + prop_assert_eq!(rust_object, rust_object2); + } + } + }; +} + +test_conversion!(Int32, test_convert_int32, i32); +test_conversion!(Int64, test_convert_int64, i64); +test_conversion!(UInt32, test_convert_uint32, u32); +test_conversion!(UInt64, test_convert_uint64, u64); +test_conversion!(SInt32, test_convert_sint32, i32); +test_conversion!(SInt64, test_convert_sint64, i64); +test_conversion!(Fixed32, test_convert_fixed32, u32); +test_conversion!(Fixed64, test_convert_fixed64, u64); +test_conversion!(Boolean, test_convert_boolean, bool); +test_conversion!(Strings, test_convert_strings, String); +test_conversion!(Bytes, test_convert_bytes, Vec); + +#[derive(Arbitrary, Clone, Debug, Eq, PartialEq, FromProto, IntoProto)] +#[ProtoType(crate::proto::test::Structs)] +struct Structs { + a: Int32, + b: UInt32, + c: SInt32, + d: Fixed32, + e: Boolean, + f: Strings, + g: Bytes, + h: Vec, +} + +proptest! { + #[test] + fn test_convert_vecs(rust_object in any::>()) { + assert_protobuf_encode_decode(&rust_object); + } + + #[test] + fn test_convert_structs(rust_object in any::()) { + let proto_object = rust_object.clone().into_proto(); + prop_assert_eq!(rust_object.clone().a.value, proto_object.get_a().get_value()); + prop_assert_eq!(rust_object.clone().b.value, proto_object.get_b().get_value()); + prop_assert_eq!(rust_object.clone().c.value, proto_object.get_c().get_value()); + prop_assert_eq!(rust_object.clone().d.value, proto_object.get_d().get_value()); + prop_assert_eq!(rust_object.clone().e.value, proto_object.get_e().get_value()); + prop_assert_eq!(rust_object.clone().f.value, proto_object.get_f().get_value()); + prop_assert_eq!(rust_object.clone().g.value, proto_object.get_g().get_value()); + + let rust_object2 = Structs::from_proto(proto_object) + .expect("Converting Protobuf object to Rust object should work."); + prop_assert_eq!(rust_object, rust_object2); + } +} diff --git a/common/proto_conv/tests/proto/mod.rs b/common/proto_conv/tests/proto/mod.rs new file mode 100644 index 0000000000000..c22abb3305cf2 --- /dev/null +++ b/common/proto_conv/tests/proto/mod.rs @@ -0,0 +1,4 @@ +// Copyright (c) The Libra Core Contributors +// SPDX-License-Identifier: Apache-2.0 + +pub mod test; diff --git a/common/proto_conv/tests/proto/test.proto b/common/proto_conv/tests/proto/test.proto new file mode 100644 index 0000000000000..27aafcff5aa84 --- /dev/null +++ b/common/proto_conv/tests/proto/test.proto @@ -0,0 +1,41 @@ +// Copyright (c) The Libra Core Contributors +// SPDX-License-Identifier: Apache-2.0 + +syntax = "proto3"; + +package test; + +message Int32 { int32 value = 1; } + +message Int64 { int64 value = 1; } + +message UInt32 { uint32 value = 1; } + +message UInt64 { uint64 value = 1; } + +message SInt32 { sint32 value = 1; } + +message SInt64 { sint64 value = 1; } + +message Fixed32 { fixed32 value = 1; } + +message Fixed64 { fixed64 value = 1; } + +message Boolean { bool value = 1; } + +message Strings { string value = 1; } + +message Bytes { bytes value = 1; } + +message Repeated { repeated Int32 value = 1; } + +message Structs { + Int32 a = 1; + UInt32 b = 2; + SInt32 c = 3; + Fixed32 d = 4; + Boolean e = 5; + Strings f = 6; + Bytes g = 7; + repeated Int32 h = 8; +} diff --git a/common/tools/Cargo.toml b/common/tools/Cargo.toml new file mode 100644 index 0000000000000..7054e90b52f98 --- /dev/null +++ b/common/tools/Cargo.toml @@ -0,0 +1,10 @@ +[package] +name = "tools" +version = "0.1.0" +authors = ["Libra Association "] +license = "Apache-2.0" +publish = false +edition = "2018" + +[dependencies] +futures = "0.1.24" diff --git a/common/tools/src/lib.rs b/common/tools/src/lib.rs new file mode 100644 index 0000000000000..c6c4107a74708 --- /dev/null +++ b/common/tools/src/lib.rs @@ -0,0 +1,6 @@ +// Copyright (c) The Libra Core Contributors +// SPDX-License-Identifier: Apache-2.0 + +#![feature(set_stdio)] + +pub mod output_capture; diff --git a/common/tools/src/output_capture.rs b/common/tools/src/output_capture.rs new file mode 100644 index 0000000000000..c4575e914aee4 --- /dev/null +++ b/common/tools/src/output_capture.rs @@ -0,0 +1,82 @@ +// Copyright (c) The Libra Core Contributors +// SPDX-License-Identifier: Apache-2.0 + +use std::{ + io, + sync::{Arc, Mutex}, +}; + +/// Rust test runner normally captures output of tests +/// However, this does not work if test spawns new threads, only main thread output is captured: +/// https://github.com/rust-lang/rust/issues/42474 +/// This struct solves this problem by grabbing output capture in main thread and allowing to +/// apply it to threads created by test code with it. +/// +/// It is not magical though, you need to grab the capturing writer in main thread and manually +/// call apply() on every thread that you create in test in order for this to work +/// +/// For tokio runtime, runtime::Builder::after_start can be used to setup capture in tokio threads +/// +/// See more details in description of 15fb112518f01d729ca49abe0c900c5c550783ab +#[derive(Clone)] +pub struct OutputCapture { + writer: Option, +} + +impl OutputCapture { + /// Grabs override on current thread. + /// Call this method in main thread of test to grab current stdout override + /// If no override is set, this function will return no-op OutputCapture + pub fn grab() -> OutputCapture { + OutputCapture { + writer: AggregateWriter::grab(), + } + } + + /// Apply output capture to current thread + /// If no capture was grabbed in grab(), this method will be no-op + pub fn apply(&self) { + if let Some(ref writer) = self.writer { + io::set_print(Some(Box::new(writer.clone()))); + io::set_panic(Some(Box::new(writer.clone()))); + } + } +} + +// This is cloneable writer +// It aggregates output from all cloned instances into single inner writer +#[derive(Clone)] +struct AggregateWriter { + inner: Arc>, +} + +impl AggregateWriter { + fn grab() -> Option { + // Because dyn Writer is not cloneable, the only way to take current writer + // is to push it out by setting print to some new value (None in this case) + let previous = io::set_print(None); + if let Some(previous) = previous { + let writer = AggregateWriter::new(previous); + io::set_print(Some(Box::new(writer.clone()))); + Some(writer) + } else { + None + } + } + + fn new(inner: Box) -> AggregateWriter { + AggregateWriter { + inner: Arc::new(Mutex::new(inner)), + } + } +} + +impl io::Write for AggregateWriter { + fn write(&mut self, buf: &[u8]) -> io::Result { + self.inner.lock().unwrap().write(buf) + } + + fn flush(&mut self) -> io::Result<()> { + self.inner.lock().unwrap().flush() + } +} diff --git a/config/Cargo.toml b/config/Cargo.toml new file mode 100644 index 0000000000000..0492a3c0f0736 --- /dev/null +++ b/config/Cargo.toml @@ -0,0 +1,23 @@ +[package] +name = "config" +version = "0.1.0" +authors = ["Libra Association "] +license = "Apache-2.0" +publish = false +edition = "2018" + +[dependencies] +clap = "2.32" +get_if_addrs = "0.5.3" +hex = "0.3.2" +parity-multiaddr = "0.4.0" +rand = "0.6.5" +serde = { version = "1.0.89", features = ["derive"] } +tempfile = "3.0.6" +toml = "0.4" + +crypto = { path = "../crypto/legacy_crypto" } +proto_conv = { path = "../common/proto_conv" } +logger = { path = "../common/logger" } +failure = { path = "../common/failure_ext", package = "failure_ext" } +types = { path = "../types" } diff --git a/config/config_builder/Cargo.toml b/config/config_builder/Cargo.toml new file mode 100644 index 0000000000000..ec94505bc3018 --- /dev/null +++ b/config/config_builder/Cargo.toml @@ -0,0 +1,23 @@ +[package] +name = "config_builder" +version = "0.1.0" +authors = ["Libra Association "] +license = "Apache-2.0" +publish = false +edition = "2018" + +[dependencies] +bincode = "1.1.1" +clap = "2.32" +hex = "0.3.2" +serde = { version = "1.0.91", features = ["derive"] } +tempfile = "3.0.6" +toml = "0.4" + +config = { path = ".." } +crypto = { path = "../../crypto/legacy_crypto" } +failure = { path = "../../common/failure_ext", package = "failure_ext" } +generate_keypair = { path = "../generate_keypair" } +proto_conv = { path = "../../common/proto_conv", features = ["derive"] } +types = { path = "../../types" } +vm_genesis = { path = "../../language/vm/vm_genesis" } diff --git a/config/config_builder/src/bin/libra-config.rs b/config/config_builder/src/bin/libra-config.rs new file mode 100644 index 0000000000000..f526a60a5b40a --- /dev/null +++ b/config/config_builder/src/bin/libra-config.rs @@ -0,0 +1,113 @@ +// Copyright (c) The Libra Core Contributors +// SPDX-License-Identifier: Apache-2.0 + +use clap::{value_t, App, Arg}; +use config_builder::swarm_config::SwarmConfigBuilder; +use std::convert::TryInto; + +const BASE_ARG: &str = "base"; +const NODES_ARG: &str = "nodes"; +const OUTPUT_DIR_ARG: &str = "output-dir"; +const DISCOVERY_ARG: &str = "discovery"; +const KEY_SEED_ARG: &str = "key-seed"; +const FAUCET_ACCOUNT_FILE_ARG: &str = "faucet_account_file"; + +fn main() { + let args = App::new("Libra Config Tool") + .version("0.1.0") + .author("Libra Association ") + .about("Tool to manage and create Libra Configs") + .arg( + Arg::with_name(BASE_ARG) + .short("b") + .long(BASE_ARG) + .takes_value(true) + .required(true) + .help("Base config to use"), + ) + .arg( + Arg::with_name(NODES_ARG) + .short("n") + .long(NODES_ARG) + .takes_value(true) + .default_value("1") + .help("Specify the number of nodes to configure"), + ) + .arg( + Arg::with_name(OUTPUT_DIR_ARG) + .short("o") + .long(OUTPUT_DIR_ARG) + .takes_value(true) + .help("The output directory"), + ) + .arg( + Arg::with_name(DISCOVERY_ARG) + .short("d") + .long(DISCOVERY_ARG) + .help("Generate peer config with one peer only (to force discovery)"), + ) + .arg( + Arg::with_name(KEY_SEED_ARG) + .short("s") + .long(KEY_SEED_ARG) + .takes_value(true) + .help("Use the provided seed for generating keys for each of the validators"), + ) + .arg( + Arg::with_name(FAUCET_ACCOUNT_FILE_ARG) + .short("m") + .long(FAUCET_ACCOUNT_FILE_ARG) + .help("File location from which to load faucet account generated via generate_keypair tool") + .takes_value(true), + ) + .get_matches(); + let base_path = value_t!(args, BASE_ARG, String).expect("Path to base config"); + let nodes_count = value_t!(args, NODES_ARG, usize).unwrap(); + let output_dir = if args.is_present(OUTPUT_DIR_ARG) { + let dir = value_t!(args, OUTPUT_DIR_ARG, String).unwrap(); + dir.into() + } else { + ::std::env::current_dir().unwrap() + }; + let faucet_account_file_path = value_t!(args, FAUCET_ACCOUNT_FILE_ARG, String) + .expect("Must provide faucet account file path"); + let (faucet_account_keypair, _faucet_key_file_path, _temp_dir) = + generate_keypair::load_faucet_key_or_create_default(Some(faucet_account_file_path)); + + let mut config_builder = SwarmConfigBuilder::new(); + config_builder + .with_nodes(nodes_count) + .with_base(base_path) + .with_output_dir(output_dir) + .with_faucet_keypair(faucet_account_keypair); + if args.is_present(DISCOVERY_ARG) { + config_builder.force_discovery(); + } + if args.is_present(KEY_SEED_ARG) { + let seed_hex = value_t!(args, KEY_SEED_ARG, String).expect("Missing Seed"); + let seed = hex::decode(seed_hex).unwrap(); + config_builder.with_key_seed(seed[..32].try_into().unwrap()); + } + let generated_configs = config_builder.build().expect("Unable to generate configs"); + + println!( + "Trusted Peers Config: {:?}", + generated_configs.get_trusted_peers_config().0 + ); + + println!( + "Seed Peers Config: {:?}", + generated_configs.get_seed_peers_config().0 + ); + + for (path, node_config) in generated_configs.get_configs() { + println!( + "Node Config for PeerId({}): {:?}", + node_config.base.peer_id, path + ); + println!( + "Node Keys for PeerId({}): {:?}", + node_config.base.peer_id, node_config.base.peer_keypairs_file + ); + } +} diff --git a/config/config_builder/src/lib.rs b/config/config_builder/src/lib.rs new file mode 100644 index 0000000000000..d46854282498b --- /dev/null +++ b/config/config_builder/src/lib.rs @@ -0,0 +1,5 @@ +// Copyright (c) The Libra Core Contributors +// SPDX-License-Identifier: Apache-2.0 + +pub mod swarm_config; +pub mod util; diff --git a/config/config_builder/src/swarm_config.rs b/config/config_builder/src/swarm_config.rs new file mode 100644 index 0000000000000..3b57f4ff9401b --- /dev/null +++ b/config/config_builder/src/swarm_config.rs @@ -0,0 +1,276 @@ +// Copyright (c) The Libra Core Contributors +// SPDX-License-Identifier: Apache-2.0 + +//! Convenience structs and functions for generating configuration for a swarm of libra nodes +use crate::util::gen_genesis_transaction; +use config::{ + config::{KeyPairs, NodeConfig, NodeConfigHelpers}, + seed_peers::{SeedPeersConfig, SeedPeersConfigHelpers}, + trusted_peers::{TrustedPeersConfig, TrustedPeersConfigHelpers}, +}; +use crypto::signing::KeyPair; +use failure::prelude::*; +use std::path::{Path, PathBuf}; +use tempfile; +use vm_genesis::default_config; + +pub struct SwarmConfig { + configs: Vec<(PathBuf, NodeConfig)>, + seed_peers: (PathBuf, SeedPeersConfig), + trusted_peers: (PathBuf, TrustedPeersConfig), +} + +impl SwarmConfig { + //TODO convert this to use the Builder paradigm + pub fn new( + mut template: NodeConfig, + num_nodes: usize, + faucet_key: KeyPair, + prune_seed_peers_for_discovery: bool, + is_ipv4: bool, + key_seed: Option<[u8; 32]>, + output_dir: &Path, + static_ports: bool, + storage_dir: Option, + ) -> Result { + // Generate trusted peer configs + their private keys. + let (peers_private_keys, trusted_peers_config) = + TrustedPeersConfigHelpers::get_test_config(num_nodes, key_seed); + trusted_peers_config.save_config(&output_dir.join(&template.base.trusted_peers_file)); + let mut seed_peers_config = SeedPeersConfigHelpers::get_test_config_with_ipver( + &trusted_peers_config, + None, + is_ipv4, + ); + + template.storage.dir = storage_dir.unwrap_or_else(|| { + let dir = tempfile::tempdir().expect("error creating tempdir"); + dir.path().to_path_buf() + }); + + gen_genesis_transaction( + &output_dir.join(&template.execution.genesis_file_location), + &faucet_key, + &trusted_peers_config, + )?; + + let mut configs = Vec::new(); + // Generate configs for all nodes. + for (node_id, addrs) in &seed_peers_config.seed_peers { + let mut config = template.clone(); + config.base.peer_id = node_id.clone(); + // serialize keypairs on independent {node}.node.keys.toml file + // this is because the peer_keypairs field is skipped during (de)serialization + let private_keys = peers_private_keys.get(node_id.as_str()).unwrap(); + let peer_keypairs = KeyPairs::load(private_keys); + let key_file_name = format!("{}.node.keys.toml", config.base.peer_id); + + config.base.peer_keypairs_file = key_file_name.into(); + peer_keypairs.save_config(&output_dir.join(&config.base.peer_keypairs_file)); + if !static_ports { + NodeConfigHelpers::randomize_config_ports(&mut config); + } + + // create subdirectory for storage: /db, unless provided directly + config.storage.dir = config.storage.dir.join(node_id).join("db"); + + // If listen address is different from advertised address, we need to set it + // appropriately below. + config.network.listen_address = addrs[0].clone(); + config.network.advertised_address = addrs[0].clone(); + + config.vm_config = default_config(); + configs.push(config); + } + if prune_seed_peers_for_discovery { + seed_peers_config.seed_peers = seed_peers_config + .seed_peers + .clone() + .into_iter() + .take(1) + .collect(); + } + seed_peers_config.save_config(&output_dir.join(&template.network.seed_peers_file)); + let configs = configs + .into_iter() + .map(|config| { + let file_name = format!("{}.node.config.toml", config.base.peer_id); + let config_file = output_dir.join(file_name); + (config_file, config) + }) + .collect::>(); + + for (path, node_config) in &configs { + node_config.save_config(&path); + } + + Ok(Self { + configs, + seed_peers: ( + output_dir.join(template.network.seed_peers_file), + seed_peers_config, + ), + trusted_peers: ( + output_dir.join(template.base.trusted_peers_file), + trusted_peers_config, + ), + }) + } + + pub fn get_configs(&self) -> &[(PathBuf, NodeConfig)] { + &self.configs + } + + pub fn get_seed_peers_config(&self) -> &(PathBuf, SeedPeersConfig) { + &self.seed_peers + } + + pub fn get_trusted_peers_config(&self) -> &(PathBuf, TrustedPeersConfig) { + &self.trusted_peers + } +} + +pub struct SwarmConfigBuilder { + node_count: usize, + template_path: PathBuf, + static_ports: bool, + output_dir: PathBuf, + force_discovery: bool, + is_ipv4: bool, + key_seed: Option<[u8; 32]>, + faucet_account_keypair_filepath: Option, + faucet_account_keypair: Option, + storage_dir: Option, +} +impl Default for SwarmConfigBuilder { + fn default() -> Self { + SwarmConfigBuilder { + node_count: 1, + template_path: "config/data/configs/node.config.toml".into(), + static_ports: false, + output_dir: "configs".into(), + force_discovery: false, + is_ipv4: false, + key_seed: None, + faucet_account_keypair_filepath: None, + faucet_account_keypair: None, + storage_dir: None, + } + } +} + +impl SwarmConfigBuilder { + pub fn new() -> SwarmConfigBuilder { + SwarmConfigBuilder::default() + } + + pub fn randomize_ports(&mut self) -> &mut Self { + self.static_ports = false; + self + } + + pub fn static_ports(&mut self) -> &mut Self { + self.static_ports = true; + self + } + + pub fn with_base>(&mut self, base_template_path: P) -> &mut Self { + self.template_path = base_template_path.as_ref().to_path_buf(); + self + } + + pub fn with_output_dir>(&mut self, output_dir: P) -> &mut Self { + self.output_dir = output_dir.as_ref().to_path_buf(); + self + } + + pub fn with_faucet_keypair_filepath>(&mut self, keypair_file: P) -> &mut Self { + self.faucet_account_keypair_filepath = Some(keypair_file.as_ref().to_path_buf()); + self + } + + pub fn with_faucet_keypair(&mut self, keypair: KeyPair) -> &mut Self { + self.faucet_account_keypair = Some(keypair); + self + } + + pub fn with_nodes(&mut self, n: usize) -> &mut Self { + self.node_count = n; + self + } + + pub fn force_discovery(&mut self) -> &mut Self { + self.force_discovery = true; + self + } + + pub fn with_ipv4(&mut self) -> &mut Self { + self.is_ipv4 = true; + self + } + + pub fn with_ipv6(&mut self) -> &mut Self { + self.is_ipv4 = false; + self + } + + pub fn with_key_seed(&mut self, seed: [u8; 32]) -> &mut Self { + self.key_seed = Some(seed); + self + } + + pub fn build(&self) -> Result { + // verify required fields + let faucet_key_path = self.faucet_account_keypair_filepath.clone(); + let faucet_key_option = self.faucet_account_keypair.clone(); + let faucet_key = faucet_key_option.unwrap_or_else(|| { + generate_keypair::load_key_from_file( + faucet_key_path.expect("Must provide faucet key file"), + ) + .expect("Faucet account key is required to generate config") + }); + + // generate all things needed for generation + if !self.output_dir.is_dir() { + if !self.output_dir.exists() { + // generate if doesnt exist + std::fs::create_dir(&self.output_dir).expect("Failed to create output dir"); + } + assert!( + !self.output_dir.is_file(), + "Output-dir is a file, expecting a directory" + ); + } + + // read template + let mut template = NodeConfig::load_template(&self.template_path)?; + // update everything in the template and then generate swarm config + let listen_address = if self.is_ipv4 { "0.0.0.0" } else { "::1" }; + let listen_address = listen_address.to_string(); + template.admission_control.address = listen_address.clone(); + template.debug_interface.address = listen_address; + + template.execution.genesis_file_location = "genesis.blob".to_string(); + + // Set and generate trusted peers config file + if template.base.trusted_peers_file.is_empty() { + template.base.trusted_peers_file = "trusted_peers.config.toml".to_string(); + }; + // Set seed peers file and config. Config is populated in the loop below + if template.network.seed_peers_file.is_empty() { + template.network.seed_peers_file = "seed_peers.config.toml".to_string(); + }; + + SwarmConfig::new( + template, + self.node_count, + faucet_key, + self.force_discovery, + self.is_ipv4, + self.key_seed, + &self.output_dir, + self.static_ports, + self.storage_dir.clone(), + ) + } +} diff --git a/config/config_builder/src/util.rs b/config/config_builder/src/util.rs new file mode 100644 index 0000000000000..228d8412aa2ca --- /dev/null +++ b/config/config_builder/src/util.rs @@ -0,0 +1,56 @@ +// Copyright (c) The Libra Core Contributors +// SPDX-License-Identifier: Apache-2.0 + +use config::{ + config::{NodeConfig, NodeConfigHelpers}, + trusted_peers::{TrustedPeersConfig, TrustedPeersConfigHelpers}, +}; +use crypto::signing::KeyPair; +use failure::prelude::*; +use proto_conv::IntoProtoBytes; +use std::{convert::TryFrom, fs::File, io::prelude::*, path::Path}; +use types::{account_address::AccountAddress, validator_public_keys::ValidatorPublicKeys}; +use vm_genesis::encode_genesis_transaction_with_validator; + +pub fn gen_genesis_transaction>( + path: P, + faucet_account_keypair: &KeyPair, + trusted_peer_config: &TrustedPeersConfig, +) -> Result<()> { + let validator_set = trusted_peer_config + .peers + .iter() + .map(|(peer_id, peer)| { + ValidatorPublicKeys::new( + AccountAddress::try_from(peer_id.clone()).expect("[config] invalid peer_id"), + peer.get_consensus_public(), + peer.get_network_signing_public(), + peer.get_network_identity_public(), + ) + }) + .collect(); + let transaction = encode_genesis_transaction_with_validator( + faucet_account_keypair.private_key(), + faucet_account_keypair.public_key(), + validator_set, + ); + let mut file = File::create(path)?; + file.write_all(&transaction.into_proto_bytes()?)?; + Ok(()) +} + +/// Returns the config as well as the genesis keyapir +pub fn get_test_config() -> (NodeConfig, KeyPair) { + // TODO: test config should be moved here instead of config crate + let config = NodeConfigHelpers::get_single_node_test_config(true); + let (private_key, _) = ::crypto::signing::generate_keypair(); + let keypair = KeyPair::new(private_key); + + gen_genesis_transaction( + &config.execution.genesis_file_location, + &keypair, + &TrustedPeersConfigHelpers::get_test_config(1, None).1, + ) + .expect("[config] failed to create genesis transaction"); + (config, keypair) +} diff --git a/config/data/configs/node.config.toml b/config/data/configs/node.config.toml new file mode 100644 index 0000000000000..516af045478e5 --- /dev/null +++ b/config/data/configs/node.config.toml @@ -0,0 +1,81 @@ +[base] +peer_id = '' +peer_keypairs_file = '' +data_dir_path = '' +trusted_peers_file = '' +node_sync_batch_size = 1000 +node_sync_retries = 3 +node_sync_channel_buffer_size = 10 +node_async_log_chan_size = 256 + +[metrics] +dir = 'metrics' +collection_interval_ms = 1000 +push_server_addr = '' + +[mempool] +broadcast_transactions = true +shared_mempool_tick_interval_ms = 50 +shared_mempool_batch_size = 100 +shared_mempool_max_concurrent_inbound_syncs = 100 +capacity = 10000000 +capacity_per_user = 100 +sequence_cache_capacity = 1000 +system_transaction_timeout_secs = 86400 +address = 'localhost' +mempool_service_port = 55555 +system_transaction_gc_interval_ms = 180000 + +[execution] +address = 'localhost' +port = 55558 +testnet_genesis = false +genesis_file_location = '' + +[storage] +address = 'localhost' +port = 30305 +dir = 'libradb' + +[admission_control] +address = 'localhost' +admission_control_service_port = 30307 +need_to_check_mempool_before_validation = false + +[secret_service] +address = 'localhost' +secret_service_port = 30333 + +[consensus] +max_block_size = 100 +proposer_type = 'rotating_proposer' +contiguous_rounds = 2 + +[network] +seed_peers_file = '' +listen_address = '/ip4/0.0.0.0/tcp/30303' +advertised_address = '/ip4/127.0.0.1/tcp/30303' +discovery_interval_ms = 1000 +connectivity_check_interval_ms = 5000 +enable_encryption_and_authentication = true + +[debug_interface] +admission_control_node_debug_port = 50313 +storage_node_debug_port = 50315 +secret_service_node_debug_port = 50316 +metrics_server_port = 14297 +address = 'localhost' + +[log_collector] +is_async = true +use_std_output = true + +[vm_config] + [vm_config.publishing_options] + type = "Locked" + whitelist = [ + "88c0c64595f6cec7d0c0bfe29e1be1886c736ec3d26888d049e30909f7a72836", + "d3493756a00b7a9e4d9ca8482e80fd055411ce53882bdcb08fec97d42eef0bde", + "ee31d65b559ad5a300e6a508ff3edb2d23f1589ef68d0ead124d8f0374073d84", + "2bb3828f55bc640a85b17d9c6e120e84f8c068c9fd850e1a1d61d2f91ed295fd" + ] diff --git a/config/data/configs/overrides/persistent_data.node.config.override.toml b/config/data/configs/overrides/persistent_data.node.config.override.toml new file mode 100644 index 0000000000000..8ded8c1f1ee58 --- /dev/null +++ b/config/data/configs/overrides/persistent_data.node.config.override.toml @@ -0,0 +1,2 @@ +[base] +data_dir_path = '/tmp/libra/test_persistent' diff --git a/config/data/configs/overrides/testnet.node.config.override.toml b/config/data/configs/overrides/testnet.node.config.override.toml new file mode 100644 index 0000000000000..780d35844e823 --- /dev/null +++ b/config/data/configs/overrides/testnet.node.config.override.toml @@ -0,0 +1,12 @@ +[base] +data_dir_path = "/opt/libra/data" +trusted_peers_file = "/opt/libra/etc/trusted_peers.config.toml" + +[admission_control] +address = "0.0.0.0" +admission_control_service_port = 30307 + +[network] +seed_peers_file = "/opt/libra/etc/seed_peers.config.toml" +listen_address = "/ip4/0.0.0.0/tcp/30303" +advertised_address = "/ip4/SELF_IP/tcp/30303" diff --git a/config/data/metrics/prometheus.yml b/config/data/metrics/prometheus.yml new file mode 100644 index 0000000000000..804555f9cffc2 --- /dev/null +++ b/config/data/metrics/prometheus.yml @@ -0,0 +1,30 @@ +global: + scrape_interval: 15s # Set the scrape interval to every 15 seconds. Default is every 1 minute. + evaluation_interval: 15s # Evaluate rules every 15 seconds. The default is every 1 minute. + # scrape_timeout is set to the global default (10s). + +# Alertmanager configuration +alerting: + alertmanagers: + - static_configs: + - targets: + # - alertmanager:9093 + +# Load rules once and periodically evaluate them according to the global 'evaluation_interval'. +rule_files: + # - "first_rules.yml" + # - "second_rules.yml" + +# A scrape configuration containing exactly one endpoint to scrape: +# Here it's Prometheus itself. +scrape_configs: + # The job name is added as a label `job=` to any timeseries scraped from this config. + - job_name: 'libra_node' + honor_labels: true + + # metrics_path defaults to '/metrics' + # scheme defaults to 'http'. + + static_configs: + - targets: ['127.0.0.1:9091'] + diff --git a/config/generate_keypair/Cargo.toml b/config/generate_keypair/Cargo.toml new file mode 100644 index 0000000000000..b45e1152e0279 --- /dev/null +++ b/config/generate_keypair/Cargo.toml @@ -0,0 +1,17 @@ +[package] +name = "generate_keypair" +version = "0.1.0" +authors = ["Libra Association "] +license = "Apache-2.0" +publish = false +edition = "2018" + +[dependencies] +bincode = "1.1.1" +clap = {version = "2.32"} +serde = { version = "1.0.89", features = ["derive"] } +serde_json = "1.0.38" +tempdir = "0.3.7" + +crypto = { path = "../../crypto/legacy_crypto" } +failure = { path = "../../common/failure_ext", package = "failure_ext" } \ No newline at end of file diff --git a/config/generate_keypair/src/lib.rs b/config/generate_keypair/src/lib.rs new file mode 100644 index 0000000000000..45f42aa4414b9 --- /dev/null +++ b/config/generate_keypair/src/lib.rs @@ -0,0 +1,72 @@ +// Copyright (c) The Libra Core Contributors +// SPDX-License-Identifier: Apache-2.0 + +use bincode::serialize; +use crypto::signing::KeyPair; +use failure::prelude::*; +use std::{ + fs::{self, File}, + io::Write, + path::Path, +}; +use tempdir::TempDir; + +pub fn create_faucet_key_file(output_file: &str) -> KeyPair { + let output_file_path = Path::new(&output_file); + + if output_file_path.exists() && !output_file_path.is_file() { + panic!("Specified output file path is a directory"); + } + + let (private_key, _) = ::crypto::signing::generate_keypair(); + let keypair = KeyPair::new(private_key); + + // Write to disk + let encoded: Vec = serialize(&keypair).expect("Unable to serialize keys"); + let mut file = + File::create(output_file_path).expect("Unable to create/truncate file at specified path"); + file.write_all(&encoded) + .expect("Unable to write keys to file at specified path"); + keypair +} + +/// Tries to load a keypair from the path given as argument +pub fn load_key_from_file>(path: P) -> Result { + bincode::deserialize(&fs::read(path)?[..]).map_err(|b| b.into()) +} + +/// Returns the generated or loaded keypair, the path to the file where this keypair is saved, +/// and a reference to the temp directory that was possibly created (a handle so that +/// it doesn't go out of scope) +pub fn load_faucet_key_or_create_default( + file_path: Option, +) -> (KeyPair, String, Option) { + // If there is already a faucet key file, then open it and parse the keypair. If there + // isn't one, then create a temp directory and generate the keypair + if let Some(faucet_account_file) = file_path { + match load_key_from_file(faucet_account_file.clone()) { + Ok(keypair) => (keypair, faucet_account_file.to_string(), None), + Err(e) => { + panic!( + "Unable to read faucet account file: {}, {}", + faucet_account_file, e + ); + } + } + } else { + // Generate keypair in temp directory + let tmp_dir = + TempDir::new("keypair").expect("Unable to create temp dir for faucet keypair"); + let faucet_key_file_path = tmp_dir + .path() + .join("temp_faucet_keys") + .to_str() + .unwrap() + .to_string(); + ( + crate::create_faucet_key_file(&faucet_key_file_path), + faucet_key_file_path, + Some(tmp_dir), + ) + } +} diff --git a/config/generate_keypair/src/main.rs b/config/generate_keypair/src/main.rs new file mode 100644 index 0000000000000..db78dc59bc732 --- /dev/null +++ b/config/generate_keypair/src/main.rs @@ -0,0 +1,26 @@ +// Copyright (c) The Libra Core Contributors +// SPDX-License-Identifier: Apache-2.0 + +use clap::{value_t, App, Arg}; +use generate_keypair::create_faucet_key_file; + +const OUTPUT_ARG: &str = "output"; + +fn main() { + let args = App::new("Libra Key Generation Tool") + .version("0.1.0") + .author("Libra Association ") + .about("Tool to generate public/private keypairs") + .arg( + Arg::with_name(OUTPUT_ARG) + .short("o") + .long(OUTPUT_ARG) + .takes_value(true) + .help("Output file path. Keypair is written to this file"), + ) + .get_matches(); + + let output_file = + value_t!(args, OUTPUT_ARG, String).expect("Missing output file path argument"); + create_faucet_key_file(&output_file); +} diff --git a/config/src/config.rs b/config/src/config.rs new file mode 100644 index 0000000000000..6614f6e2b920f --- /dev/null +++ b/config/src/config.rs @@ -0,0 +1,658 @@ +// Copyright (c) The Libra Core Contributors +// SPDX-License-Identifier: Apache-2.0 + +use crate::utils::{deserialize_whitelist, get_local_ip, serialize_whitelist}; +use parity_multiaddr::{Multiaddr, Protocol}; +use std::{ + collections::HashSet, + fs::File, + io::{Read, Write}, + path::{Path, PathBuf}, + string::ToString, +}; + +use crypto::{ + signing, + x25519::{self, X25519PrivateKey, X25519PublicKey}, +}; +use logger::LoggerType; +use serde::{Deserialize, Serialize}; +use tempfile::TempDir; +use toml; + +use failure::prelude::*; +use proto_conv::FromProtoBytes; +use types::transaction::{SignedTransaction, SCRIPT_HASH_LENGTH}; + +use crate::{ + config::ConsensusProposerType::{FixedProposer, RotatingProposer}, + seed_peers::{SeedPeersConfig, SeedPeersConfigHelpers}, + trusted_peers::{ + deserialize_key, serialize_key, TrustedPeerPrivateKeys, TrustedPeersConfig, + TrustedPeersConfigHelpers, + }, + utils::get_available_port, +}; + +#[cfg(test)] +#[path = "unit_tests/config_test.rs"] +mod config_test; + +pub const DISPOSABLE_DIR_MARKER: &str = ""; + +// path is relative to this file location +static CONFIG_TEMPLATE: &[u8] = include_bytes!("../data/configs/node.config.toml"); + +/// Config pulls in configuration information from the config file. +/// This is used to set up the nodes and configure various parameters. +/// The config file is broken up into sections for each module +/// so that only that module can be passed around +#[derive(Clone, Debug, Deserialize, Serialize)] +pub struct NodeConfig { + //TODO Add configuration for multiple chain's in a future diff + pub base: BaseConfig, + pub metrics: MetricsConfig, + pub execution: ExecutionConfig, + pub admission_control: AdmissionControlConfig, + pub debug_interface: DebugInterfaceConfig, + + pub storage: StorageConfig, + pub network: NetworkConfig, + pub consensus: ConsensusConfig, + pub mempool: MempoolConfig, + pub log_collector: LoggerConfig, + pub vm_config: VMConfig, + + pub secret_service: SecretServiceConfig, +} + +#[derive(Debug, Deserialize, Serialize)] +pub struct BaseConfig { + pub peer_id: String, + // peer_keypairs contains all the node's private keys, + // it is filled later on from a different file + #[serde(skip)] + pub peer_keypairs: KeyPairs, + // peer_keypairs_file contains the configuration file containing all the node's private keys. + pub peer_keypairs_file: PathBuf, + pub data_dir_path: PathBuf, + #[serde(skip)] + temp_data_dir: Option, + //TODO move set of trusted peers into genesis file + pub trusted_peers_file: String, + #[serde(skip)] + pub trusted_peers: TrustedPeersConfig, + + // Size of chunks to request when performing restart sync to catchup + pub node_sync_batch_size: u64, + + // Number of retries per chunk download + pub node_sync_retries: usize, + + // Buffer size for sync_channel used for node syncing (number of elements that it can + // hold before it blocks on sends) + pub node_sync_channel_buffer_size: u64, + + // chan_size of slog async drain for node logging. + pub node_async_log_chan_size: usize, +} + +// KeyPairs is used to store all of a node's private keys. +// It is filled via a config file at the moment. +#[derive(Clone, Debug, Serialize, Deserialize)] +pub struct KeyPairs { + #[serde(serialize_with = "serialize_key")] + #[serde(deserialize_with = "deserialize_key")] + network_signing_private_key: signing::PrivateKey, + #[serde(serialize_with = "serialize_key")] + #[serde(deserialize_with = "deserialize_key")] + network_signing_public_key: signing::PublicKey, + + #[serde(serialize_with = "serialize_key")] + #[serde(deserialize_with = "deserialize_key")] + network_identity_private_key: X25519PrivateKey, + #[serde(serialize_with = "serialize_key")] + #[serde(deserialize_with = "deserialize_key")] + network_identity_public_key: X25519PublicKey, + + #[serde(serialize_with = "serialize_key")] + #[serde(deserialize_with = "deserialize_key")] + consensus_private_key: signing::PrivateKey, + #[serde(serialize_with = "serialize_key")] + #[serde(deserialize_with = "deserialize_key")] + consensus_public_key: signing::PublicKey, +} + +// required for serialization +impl Default for KeyPairs { + fn default() -> Self { + let (private_sig, public_sig) = signing::generate_keypair(); + let (private_kex, public_kex) = x25519::generate_keypair(); + Self { + network_signing_private_key: private_sig.clone(), + network_signing_public_key: public_sig, + network_identity_private_key: private_kex.clone(), + network_identity_public_key: public_kex, + consensus_private_key: private_sig.clone(), + consensus_public_key: public_sig, + } + } +} + +impl KeyPairs { + // used to deserialize keypairs from a configuration file + pub fn load_config>(path: P) -> Self { + let path = path.as_ref(); + let mut file = File::open(path) + .unwrap_or_else(|_| panic!("Cannot open KeyPair Config file {:?}", path)); + let mut contents = String::new(); + file.read_to_string(&mut contents) + .unwrap_or_else(|_| panic!("Error reading KeyPair Config file {:?}", path)); + + Self::parse(&contents) + } + fn parse(config_string: &str) -> Self { + toml::from_str(config_string).expect("Unable to parse Config") + } + // used to serialize keypairs to a configuration file + pub fn save_config>(&self, output_file: P) { + let contents = toml::to_vec(&self).expect("Error serializing"); + + let mut file = File::create(output_file).expect("Error opening file"); + + file.write_all(&contents).expect("Error writing file"); + } + // used in testing to fill the structure with test keypairs + pub fn load(private_keys: &TrustedPeerPrivateKeys) -> Self { + let network_signing_private_key = private_keys.get_network_signing_private(); + let network_signing_public_key = (&network_signing_private_key).into(); + let network_identity_private_key = private_keys.get_network_identity_private(); + let network_identity_public_key = (&network_identity_private_key).into(); + let consensus_private_key = private_keys.get_consensus_private(); + let consensus_public_key = (&consensus_private_key).into(); + Self { + network_signing_private_key, + network_signing_public_key, + network_identity_private_key, + network_identity_public_key, + consensus_private_key, + consensus_public_key, + } + } + // getters for private keys + pub fn get_network_signing_private(&self) -> signing::PrivateKey { + self.network_signing_private_key.clone() + } + pub fn get_network_identity_private(&self) -> X25519PrivateKey { + self.network_identity_private_key.clone() + } + pub fn get_consensus_private(&self) -> signing::PrivateKey { + self.consensus_private_key.clone() + } + // getters for public keys + pub fn get_network_signing_public(&self) -> signing::PublicKey { + self.network_signing_public_key + } + pub fn get_network_identity_public(&self) -> X25519PublicKey { + self.network_identity_public_key + } + pub fn get_consensus_public(&self) -> signing::PublicKey { + self.consensus_public_key + } + // getters for keypairs + pub fn get_network_signing_keypair(&self) -> (signing::PrivateKey, signing::PublicKey) { + ( + self.get_network_signing_private(), + self.get_network_signing_public(), + ) + } + pub fn get_network_identity_keypair(&self) -> (X25519PrivateKey, X25519PublicKey) { + ( + self.get_network_identity_private(), + self.get_network_identity_public(), + ) + } + pub fn get_consensus_keypair(&self) -> (signing::PrivateKey, signing::PublicKey) { + (self.get_consensus_private(), self.get_consensus_public()) + } +} + +impl Clone for BaseConfig { + fn clone(&self) -> Self { + Self { + peer_id: self.peer_id.clone(), + peer_keypairs: self.peer_keypairs.clone(), + peer_keypairs_file: self.peer_keypairs_file.clone(), + data_dir_path: self.data_dir_path.clone(), + temp_data_dir: None, + trusted_peers_file: self.trusted_peers_file.clone(), + trusted_peers: self.trusted_peers.clone(), + node_sync_batch_size: self.node_sync_batch_size, + node_sync_retries: self.node_sync_retries, + node_sync_channel_buffer_size: self.node_sync_channel_buffer_size, + node_async_log_chan_size: self.node_async_log_chan_size, + } + } +} + +#[derive(Clone, Debug, Deserialize, Serialize)] +pub struct MetricsConfig { + pub dir: PathBuf, + pub collection_interval_ms: u64, + pub push_server_addr: String, +} + +#[derive(Clone, Debug, Default, Deserialize, Serialize)] +pub struct ExecutionConfig { + pub address: String, + pub port: u16, + // directive to load the testnet genesis block or the default genesis block. + // There are semantic differences between the 2 genesis related to minting and + // account creation + pub testnet_genesis: bool, + pub genesis_file_location: String, +} + +impl ExecutionConfig { + pub fn get_genesis_transaction(&self) -> Result { + let mut file = File::open(self.genesis_file_location.clone())?; + let mut buffer = vec![]; + file.read_to_end(&mut buffer)?; + SignedTransaction::from_proto_bytes(&buffer) + } +} + +#[derive(Clone, Debug, Deserialize, Serialize)] +pub struct LoggerConfig { + pub http_endpoint: Option, + pub is_async: bool, + pub chan_size: Option, + pub use_std_output: bool, +} + +impl LoggerConfig { + pub fn get_log_collector_type(&self) -> Option { + // There is priority between different logger. If multiple ones are specified, only + // the higher one will be returned. + if self.http_endpoint.is_some() { + return Some(LoggerType::Http( + self.http_endpoint + .clone() + .expect("Http endpoint not available for logger"), + )); + } else if self.use_std_output { + return Some(LoggerType::StdOutput); + } + None + } +} + +#[derive(Clone, Debug, Deserialize, Serialize)] +pub struct SecretServiceConfig { + pub address: String, + pub secret_service_port: u16, +} + +#[derive(Clone, Debug, Deserialize, Serialize)] +pub struct AdmissionControlConfig { + pub address: String, + pub admission_control_service_port: u16, + pub need_to_check_mempool_before_validation: bool, +} + +#[derive(Clone, Debug, Deserialize, Serialize)] +pub struct DebugInterfaceConfig { + pub admission_control_node_debug_port: u16, + pub secret_service_node_debug_port: u16, + pub storage_node_debug_port: u16, + // This has similar use to the core-node-debug-server itself + pub metrics_server_port: u16, + pub address: String, +} + +#[derive(Clone, Debug, Deserialize, Serialize)] +pub struct StorageConfig { + pub address: String, + pub port: u16, + pub dir: PathBuf, +} + +impl StorageConfig { + pub fn get_dir(&self) -> &Path { + &self.dir + } +} + +#[derive(Clone, Debug, Deserialize, Serialize)] +pub struct NetworkConfig { + pub seed_peers_file: String, + #[serde(skip)] + pub seed_peers: SeedPeersConfig, + // TODO: Add support for multiple listen/advertised addresses in config. + // The address that this node is listening on for new connections. + pub listen_address: Multiaddr, + // The address that this node advertises to other nodes for the discovery protocol. + pub advertised_address: Multiaddr, + pub discovery_interval_ms: u64, + pub connectivity_check_interval_ms: u64, + pub enable_encryption_and_authentication: bool, +} + +#[derive(Clone, Debug, Deserialize, Serialize)] +pub struct ConsensusConfig { + max_block_size: u64, + proposer_type: String, + contiguous_rounds: u32, + max_pruned_blocks_in_mem: Option, + pacemaker_initial_timeout_ms: Option, +} + +#[derive(Debug, Eq, PartialEq)] +pub enum ConsensusProposerType { + // Choose the smallest PeerId as the proposer + FixedProposer, + // Round robin rotation of proposers + RotatingProposer, +} + +impl ConsensusConfig { + pub fn get_proposer_type(&self) -> ConsensusProposerType { + match self.proposer_type.as_str() { + "fixed_proposer" => FixedProposer, + "rotating_proposer" => RotatingProposer, + &_ => unimplemented!("Invalid proposer type: {}", self.proposer_type), + } + } + + pub fn contiguous_rounds(&self) -> u32 { + self.contiguous_rounds + } + + pub fn max_block_size(&self) -> u64 { + self.max_block_size + } + + pub fn max_pruned_blocks_in_mem(&self) -> &Option { + &self.max_pruned_blocks_in_mem + } + + pub fn pacemaker_initial_timeout_ms(&self) -> &Option { + &self.pacemaker_initial_timeout_ms + } +} + +#[derive(Clone, Debug, Deserialize, Serialize)] +pub struct MempoolConfig { + pub broadcast_transactions: bool, + pub shared_mempool_tick_interval_ms: u64, + pub shared_mempool_batch_size: usize, + pub shared_mempool_max_concurrent_inbound_syncs: usize, + pub capacity: usize, + // max number of transactions per user in Mempool + pub capacity_per_user: usize, + pub sequence_cache_capacity: usize, + pub system_transaction_timeout_secs: u64, + pub system_transaction_gc_interval_ms: u64, + pub mempool_service_port: u16, + pub address: String, +} + +impl NodeConfig { + /// Reads the config file and returns the configuration object + pub fn load_template>(path: P) -> Result { + let path = path.as_ref(); + let mut file = + File::open(path).with_context(|_| format!("Cannot open NodeConfig file {:?}", path))?; + let mut config_string = String::new(); + file.read_to_string(&mut config_string) + .with_context(|_| format!("Cannot read NodeConfig file {:?}", path))?; + + let config = Self::parse(&config_string) + .with_context(|_| format!("Cannot parse NodeConfig file {:?}", path))?; + + Ok(config) + } + + /// Reads the config file and returns the configuration object in addition to doing some + /// post-processing of the config + /// Paths used in the config are either absolute or relative to the config location + pub fn load_config>(peer_id: Option, path: P) -> Result { + let mut config = Self::load_template(&path)?; + // Allow peer_id override if set + if let Some(peer_id) = peer_id { + config.base.peer_id = peer_id; + } + if !config.base.trusted_peers_file.is_empty() { + config.base.trusted_peers = TrustedPeersConfig::load_config( + path.as_ref() + .with_file_name(&config.base.trusted_peers_file), + ); + } + if !config.base.peer_keypairs_file.as_os_str().is_empty() { + config.base.peer_keypairs = KeyPairs::load_config( + path.as_ref() + .with_file_name(&config.base.peer_keypairs_file), + ); + } + if !config.network.seed_peers_file.is_empty() { + config.network.seed_peers = SeedPeersConfig::load_config( + path.as_ref() + .with_file_name(&config.network.seed_peers_file), + ); + } + if config.network.advertised_address.to_string().is_empty() { + config.network.advertised_address = + get_local_ip().ok_or_else(|| ::failure::err_msg("No local IP"))?; + } + if config.network.listen_address.to_string().is_empty() { + config.network.listen_address = + get_local_ip().ok_or_else(|| ::failure::err_msg("No local IP"))?; + } + NodeConfigHelpers::update_data_dir_path_if_needed(&mut config, &path)?; + Ok(config) + } + + pub fn save_config>(&self, output_file: P) { + let contents = toml::to_vec(&self).expect("Error serializing"); + let mut file = File::create(output_file).expect("Error opening file"); + + file.write_all(&contents).expect("Error writing file"); + } + + /// Parses the config file into a Config object + pub fn parse(config_string: &str) -> Result { + assert!(!config_string.is_empty()); + Ok(toml::from_str(config_string)?) + } + + /// Returns the peer info for this node + pub fn own_addrs(&self) -> (String, Vec) { + let own_peer_id = self.base.peer_id.clone(); + let own_addrs = vec![self.network.advertised_address.clone()]; + (own_peer_id, own_addrs) + } +} + +// Given a multiaddr, randomizes its Tcp port if present. +fn randomize_tcp_port(addr: &Multiaddr) -> Multiaddr { + let mut new_addr = Multiaddr::empty(); + for p in addr.iter() { + if let Protocol::Tcp(_) = p { + new_addr.push(Protocol::Tcp(get_available_port())); + } else { + new_addr.push(p); + } + } + new_addr +} + +fn get_tcp_port(addr: &Multiaddr) -> Option { + for p in addr.iter() { + if let Protocol::Tcp(port) = p { + return Some(port); + } + } + None +} + +pub struct NodeConfigHelpers {} + +impl NodeConfigHelpers { + /// Returns a simple test config for single node. It does not have correct trusted_peers_file, + /// peer_keypairs_file, and seed_peers_file set and expected that callee will provide these + pub fn get_single_node_test_config(random_ports: bool) -> NodeConfig { + Self::get_single_node_test_config_publish_options(random_ports, None) + } + + /// Returns a simple test config for single node. It does not have correct trusted_peers_file, + /// peer_keypairs_file, and seed_peers_file set and expected that callee will provide these + /// `publishing_options` is either one of either `Open` or `CustomScripts` only. + pub fn get_single_node_test_config_publish_options( + random_ports: bool, + publishing_options: Option, + ) -> NodeConfig { + let config_string = String::from_utf8_lossy(CONFIG_TEMPLATE); + let mut config = + NodeConfig::parse(&config_string).expect("Error parsing single node test config"); + if random_ports { + NodeConfigHelpers::randomize_config_ports(&mut config); + } + + if let Some(vm_publishing_option) = publishing_options { + config.vm_config.publishing_options = vm_publishing_option; + } + + let (peers_private_keys, trusted_peers_test) = + TrustedPeersConfigHelpers::get_test_config(1, None); + let peer_id = trusted_peers_test.peers.keys().collect::>()[0]; + config.base.peer_id = peer_id.clone(); + // load node's keypairs + let private_keys = peers_private_keys.get(peer_id.as_str()).unwrap(); + config.base.peer_keypairs = KeyPairs::load(private_keys); + config.base.trusted_peers = trusted_peers_test; + config.network.seed_peers = SeedPeersConfigHelpers::get_test_config( + &config.base.trusted_peers, + get_tcp_port(&config.network.advertised_address), + ); + NodeConfigHelpers::update_data_dir_path_if_needed(&mut config, ".") + .expect("creating tempdir"); + config + } + + /// Replaces temp marker with the actual path and returns holder to the temp dir. + fn update_data_dir_path_if_needed>( + config: &mut NodeConfig, + base_path: P, + ) -> Result<()> { + if config.base.data_dir_path == Path::new(DISPOSABLE_DIR_MARKER) { + let dir = tempfile::tempdir().context("error creating tempdir")?; + config.base.data_dir_path = dir.path().to_owned(); + config.base.temp_data_dir = Some(dir); + } + config.metrics.dir = config.base.data_dir_path.join(&config.metrics.dir); + config.storage.dir = config.base.data_dir_path.join(config.storage.get_dir()); + if config.execution.genesis_file_location == DISPOSABLE_DIR_MARKER { + config.execution.genesis_file_location = config + .base + .data_dir_path + .join("genesis.blob") + .to_str() + .unwrap() + .to_string(); + } + config.execution.genesis_file_location = base_path + .as_ref() + .with_file_name(&config.execution.genesis_file_location) + .to_str() + .unwrap() + .to_string(); + + Ok(()) + } + + pub fn randomize_config_ports(config: &mut NodeConfig) { + config.admission_control.admission_control_service_port = get_available_port(); + config.debug_interface.admission_control_node_debug_port = get_available_port(); + config.debug_interface.metrics_server_port = get_available_port(); + config.debug_interface.secret_service_node_debug_port = get_available_port(); + config.debug_interface.storage_node_debug_port = get_available_port(); + config.execution.port = get_available_port(); + config.mempool.mempool_service_port = get_available_port(); + config.network.advertised_address = randomize_tcp_port(&config.network.advertised_address); + config.network.listen_address = randomize_tcp_port(&config.network.listen_address); + config.secret_service.secret_service_port = get_available_port(); + config.storage.port = get_available_port(); + } +} + +/// Holds the VM configuration, currently this is only the publishing options for scripts and +/// modules, but in the future this may need to be expanded to hold more information. +#[derive(Clone, Debug, Deserialize, Serialize)] +pub struct VMConfig { + pub publishing_options: VMPublishingOption, +} + +/// Defines and holds the publishing policies for the VM. There are three possible configurations: +/// 1. No module publishing, only whitelisted scripts are allowed. +/// 2. No module publishing, custom scripts are allowed. +/// 3. Both module publishing and custom scripts are allowed. +/// We represent these as an enum instead of a struct since whitelisting and module/script +/// publishing are mutually exclusive options. +#[derive(Clone, Debug, Deserialize, Serialize)] +#[serde(tag = "type", content = "whitelist")] +pub enum VMPublishingOption { + /// Only allow scripts on a whitelist to be run + #[serde(deserialize_with = "deserialize_whitelist")] + #[serde(serialize_with = "serialize_whitelist")] + Locked(HashSet<[u8; SCRIPT_HASH_LENGTH]>), + /// Allow custom scripts, but _not_ custom module publishing + CustomScripts, + /// Allow both custom scripts and custom module publishing + Open, +} + +impl VMPublishingOption { + pub fn custom_scripts_only(&self) -> bool { + !self.is_open() && !self.is_locked() + } + + pub fn is_open(&self) -> bool { + match self { + VMPublishingOption::Open => true, + _ => false, + } + } + + pub fn is_locked(&self) -> bool { + match self { + VMPublishingOption::Locked { .. } => true, + _ => false, + } + } + + pub fn get_whitelist_set(&self) -> Option<&HashSet<[u8; SCRIPT_HASH_LENGTH]>> { + match self { + VMPublishingOption::Locked(whitelist) => Some(&whitelist), + _ => None, + } + } +} + +impl VMConfig { + /// Creates a new `VMConfig` where the whitelist is empty. This should only be used for testing. + #[allow(non_snake_case)] + #[doc(hidden)] + pub fn empty_whitelist_FOR_TESTING() -> Self { + VMConfig { + publishing_options: VMPublishingOption::Locked(HashSet::new()), + } + } + + pub fn save_config>(&self, output_file: P) { + let contents = toml::to_vec(&self).expect("Error serializing"); + + let mut file = File::create(output_file).expect("Error opening file"); + + file.write_all(&contents).expect("Error writing file"); + } +} diff --git a/config/src/lib.rs b/config/src/lib.rs new file mode 100644 index 0000000000000..6ad1df2689c27 --- /dev/null +++ b/config/src/lib.rs @@ -0,0 +1,7 @@ +// Copyright (c) The Libra Core Contributors +// SPDX-License-Identifier: Apache-2.0 + +pub mod config; +pub mod seed_peers; +pub mod trusted_peers; +pub mod utils; diff --git a/config/src/seed_peers.rs b/config/src/seed_peers.rs new file mode 100644 index 0000000000000..3f217360a7e04 --- /dev/null +++ b/config/src/seed_peers.rs @@ -0,0 +1,99 @@ +// Copyright (c) The Libra Core Contributors +// SPDX-License-Identifier: Apache-2.0 + +use crate::{trusted_peers::TrustedPeersConfig, utils::get_available_port}; +use parity_multiaddr::{Multiaddr, Protocol}; +use serde::{Deserialize, Serialize}; +use std::{ + collections::HashMap, + fs::File, + io::{Read, Write}, + path::Path, +}; + +#[cfg(test)] +#[path = "unit_tests/seed_peers_test.rs"] +mod seed_peers_test; + +#[derive(Debug, Clone, Serialize, Deserialize)] +pub struct SeedPeersConfig { + // All peers config. Key:a unique peer id, will be PK in future, Value: peer discovery info + pub seed_peers: HashMap>, +} + +impl SeedPeersConfig { + pub fn load_config>(path: P) -> Self { + let path = path.as_ref(); + let mut file = + File::open(path).unwrap_or_else(|_| panic!("Cannot open Seed Peers file {:?}", path)); + let mut contents = String::new(); + file.read_to_string(&mut contents) + .unwrap_or_else(|_| panic!("Error reading Seed Peers file {:?}", path)); + Self::parse(&contents) + } + + pub fn save_config>(&self, output_file: P) { + let contents = toml::to_vec(&self).expect("Error serializing"); + let mut file = File::create(output_file).expect("Error opening file"); + file.write_all(&contents).expect("Error writing file"); + } + + fn parse(config_string: &str) -> Self { + toml::from_str(config_string).expect("Unable to parse Config") + } +} + +impl Default for SeedPeersConfig { + fn default() -> SeedPeersConfig { + Self { + seed_peers: HashMap::new(), + } + } +} + +pub struct SeedPeersConfigHelpers {} + +impl SeedPeersConfigHelpers { + /// Creates a new SeedPeersConfig based on provided TrustedPeersConfig. + /// Each node gets a random port, unless we have only 1 peer and the port is supplied + pub fn get_test_config( + trusted_peers: &TrustedPeersConfig, + port: Option, + ) -> SeedPeersConfig { + Self::get_test_config_with_ipver(trusted_peers, port, true) + } + + /// Creates a new SeedPeersConfig based on provided TrustedPeersConfig. + /// Each node gets a random port, unless we have only 1 peer and the port is supplied + pub fn get_test_config_with_ipver( + trusted_peers: &TrustedPeersConfig, + port: Option, + ipv4: bool, + ) -> SeedPeersConfig { + let mut seed_peers = HashMap::new(); + // sort to have same repeatable order + let mut peers: Vec = trusted_peers + .peers + .clone() + .into_iter() + .map(|(peer_id, _)| peer_id) + .collect(); + peers.sort_unstable_by_key(std::clone::Clone::clone); + // If a port is supplied, we should have only 1 peer. + if port.is_some() { + assert_eq!(1, peers.len()); + } + for peer_id in peers { + // Create a new PeerInfo and increment the ports + let mut addr = Multiaddr::empty(); + if ipv4 { + addr.push(Protocol::Ip4("0.0.0.0".parse().unwrap())); + } else { + addr.push(Protocol::Ip6("::1".parse().unwrap())); + } + addr.push(Protocol::Tcp(port.unwrap_or_else(get_available_port))); + seed_peers.insert(peer_id.clone(), vec![addr]); + } + SeedPeersConfig { seed_peers } + } +} diff --git a/config/src/trusted_peers.rs b/config/src/trusted_peers.rs new file mode 100644 index 0000000000000..60127748f07a7 --- /dev/null +++ b/config/src/trusted_peers.rs @@ -0,0 +1,213 @@ +// Copyright (c) The Libra Core Contributors +// SPDX-License-Identifier: Apache-2.0 + +use crypto::{ + signing, + utils::{encode_to_string, from_encoded_string}, + x25519::{self, X25519PrivateKey, X25519PublicKey}, +}; +use rand::{rngs::StdRng, SeedableRng}; +use serde::{de::DeserializeOwned, Deserialize, Deserializer, Serialize, Serializer}; +use std::{ + collections::HashMap, + convert::TryFrom, + fs::File, + io::{Read, Write}, + path::Path, +}; +use types::account_address::AccountAddress; + +#[cfg(test)] +#[path = "unit_tests/trusted_peers_test.rs"] +mod trusted_peers_test; + +#[derive(Clone, Debug, Serialize, Deserialize)] +pub struct TrustedPeer { + #[serde(serialize_with = "serialize_key")] + #[serde(deserialize_with = "deserialize_key")] + network_signing_pubkey: signing::PublicKey, + #[serde(serialize_with = "serialize_key")] + #[serde(deserialize_with = "deserialize_key")] + network_identity_pubkey: X25519PublicKey, + #[serde(serialize_with = "serialize_key")] + #[serde(deserialize_with = "deserialize_key")] + consensus_pubkey: signing::PublicKey, +} + +pub struct TrustedPeerPrivateKeys { + network_signing_private_key: signing::PrivateKey, + network_identity_private_key: X25519PrivateKey, + consensus_private_key: signing::PrivateKey, +} + +impl TrustedPeerPrivateKeys { + pub fn get_network_signing_private(&self) -> signing::PrivateKey { + self.network_signing_private_key.clone() + } + pub fn get_network_identity_private(&self) -> X25519PrivateKey { + self.network_identity_private_key.clone() + } + pub fn get_consensus_private(&self) -> signing::PrivateKey { + self.consensus_private_key.clone() + } +} + +impl TrustedPeer { + pub fn get_network_signing_public(&self) -> signing::PublicKey { + self.network_signing_pubkey + } + pub fn get_network_identity_public(&self) -> X25519PublicKey { + self.network_identity_pubkey + } + pub fn get_consensus_public(&self) -> signing::PublicKey { + self.consensus_pubkey + } +} + +pub fn serialize_key(key: &K, serializer: S) -> Result +where + S: Serializer, + K: Serialize, +{ + serializer.serialize_str(&encode_to_string(key)) +} + +pub fn deserialize_key<'de, D, K>(deserializer: D) -> Result +where + D: Deserializer<'de>, + K: DeserializeOwned + 'static, +{ + let encoded_key: String = Deserialize::deserialize(deserializer)?; + + Ok(from_encoded_string(encoded_key)) +} + +#[derive(Clone, Debug, Serialize, Deserialize)] +pub struct TrustedPeersConfig { + pub peers: HashMap, +} + +impl TrustedPeersConfig { + pub fn load_config>(path: P) -> Self { + let path = path.as_ref(); + let mut file = File::open(path) + .unwrap_or_else(|_| panic!("Cannot open Trusted Peers Config file {:?}", path)); + let mut contents = String::new(); + file.read_to_string(&mut contents) + .unwrap_or_else(|_| panic!("Error reading Trusted Peers Config file {:?}", path)); + + Self::parse(&contents) + } + + pub fn save_config>(&self, output_file: P) { + let contents = toml::to_vec(&self).expect("Error serializing"); + + let mut file = File::create(output_file).expect("Error opening file"); + + file.write_all(&contents).expect("Error writing file"); + } + + pub fn get_public_keys(&self, peer_id: &str) -> TrustedPeer { + self.peers + .get(peer_id) + .unwrap_or_else(|| panic!("Missing keys for {}", peer_id)) + .clone() + } + + pub fn get_consensus_keys(&self, peer_id: &str) -> signing::PublicKey { + self.get_public_keys(peer_id).consensus_pubkey + } + + pub fn get_network_signing_keys(&self, peer_id: &str) -> signing::PublicKey { + self.get_public_keys(peer_id).network_signing_pubkey + } + + pub fn get_network_identity_keys(&self, peer_id: &str) -> X25519PublicKey { + self.get_public_keys(peer_id).network_identity_pubkey + } + + /// Returns a map of AccountAddress to its PublicKey for consensus. + pub fn get_trusted_consensus_peers(&self) -> HashMap { + let mut res = HashMap::new(); + for (account, keys) in &self.peers { + res.insert( + AccountAddress::try_from(account.clone()).expect("Failed to parse account addr"), + keys.consensus_pubkey, + ); + } + res + } + + /// Returns a map of AccountAddress to a pair of PublicKeys for network peering. The first + /// PublicKey is the one used for signing, whereas the second is to determine eligible members + /// of the network. + pub fn get_trusted_network_peers( + &self, + ) -> HashMap { + self.peers + .iter() + .map(|(account, keys)| { + ( + AccountAddress::try_from(account.clone()) + .expect("Failed to parse account addr"), + (keys.network_signing_pubkey, keys.network_identity_pubkey), + ) + }) + .collect() + } + + fn parse(config_string: &str) -> Self { + toml::from_str(config_string).expect("Unable to parse Config") + } +} + +impl Default for TrustedPeersConfig { + fn default() -> TrustedPeersConfig { + Self { + peers: HashMap::new(), + } + } +} + +pub struct TrustedPeersConfigHelpers {} + +impl TrustedPeersConfigHelpers { + /// Creates a new TrustedPeersConfig with the given number of peers, + /// as well as a hashmap of all the test validator nodes' private keys. + pub fn get_test_config( + number_of_peers: usize, + seed: Option<[u8; 32]>, + ) -> (HashMap, TrustedPeersConfig) { + let mut peers = HashMap::new(); + let mut peers_private_keys = HashMap::new(); + // deterministically derive keypairs from a seeded-rng + let seed = if let Some(seed) = seed { + seed + } else { + [0u8; 32] + }; + + let mut fast_rng = StdRng::from_seed(seed); + for _ in 0..number_of_peers { + let (private0, public0) = signing::generate_keypair_for_testing(&mut fast_rng); + let (private1, public1) = x25519::generate_keypair_for_testing(&mut fast_rng); + let (private2, public2) = signing::generate_keypair_for_testing(&mut fast_rng); + // save the public_key in peers hashmap + let peer = TrustedPeer { + network_signing_pubkey: public0, + network_identity_pubkey: public1, + consensus_pubkey: public2, + }; + let peer_id = AccountAddress::from(peer.consensus_pubkey); + peers.insert(peer_id.to_string(), peer); + // save the private keys in a different hashmap + let private_keys = TrustedPeerPrivateKeys { + network_signing_private_key: private0, + network_identity_private_key: private1, + consensus_private_key: private2, + }; + peers_private_keys.insert(peer_id.to_string(), private_keys); + } + (peers_private_keys, TrustedPeersConfig { peers }) + } +} diff --git a/config/src/unit_tests/config_test.rs b/config/src/unit_tests/config_test.rs new file mode 100644 index 0000000000000..b10a80e9a0692 --- /dev/null +++ b/config/src/unit_tests/config_test.rs @@ -0,0 +1,28 @@ +// Copyright (c) The Libra Core Contributors +// SPDX-License-Identifier: Apache-2.0 + +use super::*; +use std::fs; + +#[test] +fn verify_test_config() { + // This test verifies that the default config in config.toml is valid + let _ = NodeConfigHelpers::get_single_node_test_config(false); +} + +#[test] +fn verify_all_configs() { + // This test verifies that all configs in data/config are valid + let paths = fs::read_dir("data/configs").expect("cannot read config dir"); + + for path in paths { + let config_path = path.unwrap().path(); + let config_path_str = config_path.to_str().unwrap(); + if config_path_str.ends_with(".toml") { + println!("Loading {}", config_path_str); + let _ = NodeConfig::load_config(None, config_path_str).expect("NodeConfig"); + } else { + println!("Invalid file {} for verifying", config_path_str); + } + } +} diff --git a/config/src/unit_tests/seed_peers_test.rs b/config/src/unit_tests/seed_peers_test.rs new file mode 100644 index 0000000000000..09e9007b6ccfd --- /dev/null +++ b/config/src/unit_tests/seed_peers_test.rs @@ -0,0 +1,11 @@ +// Copyright (c) The Libra Core Contributors +// SPDX-License-Identifier: Apache-2.0 + +use super::SeedPeersConfigHelpers; +use crate::trusted_peers::TrustedPeersConfigHelpers; + +#[test] +fn generate_test_config() { + let (_, trusted_peers) = TrustedPeersConfigHelpers::get_test_config(10, None); + let _ = SeedPeersConfigHelpers::get_test_config(&trusted_peers, None); +} diff --git a/config/src/unit_tests/trusted_peers_test.rs b/config/src/unit_tests/trusted_peers_test.rs new file mode 100644 index 0000000000000..5a64165175a0b --- /dev/null +++ b/config/src/unit_tests/trusted_peers_test.rs @@ -0,0 +1,9 @@ +// Copyright (c) The Libra Core Contributors +// SPDX-License-Identifier: Apache-2.0 + +use super::TrustedPeersConfigHelpers; + +#[test] +fn generate_test_config() { + let (_, _) = TrustedPeersConfigHelpers::get_test_config(10, None); +} diff --git a/config/src/utils.rs b/config/src/utils.rs new file mode 100644 index 0000000000000..b8562651205d8 --- /dev/null +++ b/config/src/utils.rs @@ -0,0 +1,96 @@ +// Copyright (c) The Libra Core Contributors +// SPDX-License-Identifier: Apache-2.0 + +use get_if_addrs::get_if_addrs; +use parity_multiaddr::{Multiaddr, Protocol}; +use serde::{Deserialize, Deserializer, Serialize, Serializer}; +use std::{ + collections::HashSet, + hash::BuildHasher, + net::{IpAddr, TcpListener, TcpStream}, +}; +use types::transaction::SCRIPT_HASH_LENGTH; + +/// Return an ephemeral, available port. On unix systems, the port returned will be in the +/// TIME_WAIT state ensuring that the OS won't hand out this port for some grace period. +/// Callers should be able to bind to this port given they use SO_REUSEADDR. +pub fn get_available_port() -> u16 { + const MAX_PORT_RETRIES: u32 = 1000; + + for _ in 0..MAX_PORT_RETRIES { + if let Ok(port) = get_ephemeral_port() { + return port; + } + } + + panic!("Error: could not find an available port"); +} + +fn get_ephemeral_port() -> ::std::io::Result { + // Request a random available port from the OS + let listener = TcpListener::bind(("localhost", 0))?; + let addr = listener.local_addr()?; + + // Create and accept a connection (which we'll promptly drop) in order to force the port + // into the TIME_WAIT state, ensuring that the port will be reserved from some limited + // amount of time (roughly 60s on some Linux systems) + let _sender = TcpStream::connect(addr)?; + let _incoming = listener.accept()?; + + Ok(addr.port()) +} + +/// Extracts one local non-loopback IP address, if one exists. Otherwise returns None. +pub fn get_local_ip() -> Option { + get_if_addrs().ok().and_then(|if_addrs| { + if_addrs + .into_iter() + .filter(|if_addr| !if_addr.is_loopback()) + .nth(0) + .and_then(|if_addr| { + let mut addr = Multiaddr::empty(); + match if_addr.ip() { + IpAddr::V4(a) => { + addr.push(Protocol::Ip4(a)); + } + IpAddr::V6(a) => { + addr.push(Protocol::Ip6(a)); + } + } + Some(addr) + }) + }) +} + +pub fn deserialize_whitelist<'de, D>( + deserializer: D, +) -> ::std::result::Result, D::Error> +where + D: Deserializer<'de>, +{ + let whitelisted_scripts: Vec = Deserialize::deserialize(deserializer)?; + let whitelist = whitelisted_scripts + .iter() + .map(|s| { + let mut hash = [0u8; SCRIPT_HASH_LENGTH]; + let decoded_hash = + hex::decode(s).expect("Unable to decode script hash from configuration file."); + assert!(decoded_hash.len() == SCRIPT_HASH_LENGTH); + hash.copy_from_slice(decoded_hash.as_slice()); + hash + }) + .collect(); + Ok(whitelist) +} + +pub fn serialize_whitelist( + whitelist: &HashSet<[u8; SCRIPT_HASH_LENGTH], H>, + serializer: S, +) -> Result +where + S: Serializer, + H: BuildHasher, +{ + let encoded_whitelist: Vec = whitelist.iter().map(hex::encode).collect(); + encoded_whitelist.serialize(serializer) +} diff --git a/consensus/Cargo.toml b/consensus/Cargo.toml new file mode 100644 index 0000000000000..b1342ba28a670 --- /dev/null +++ b/consensus/Cargo.toml @@ -0,0 +1,63 @@ +[package] +name = "consensus" +version = "0.1.0" +authors = ["Libra Association "] +license = "Apache-2.0" +publish = false +edition = "2018" + +[dependencies] +byteorder = "1.3.1" +bytes = "0.4.12" +grpcio = "0.4.3" +futures = { version = "=0.3.0-alpha.16", package = "futures-preview", features = ["io-compat", "compat"] } +futures_locks = { version = "=0.3.0", package = "futures-locks", features=["tokio"]} +mirai-annotations = "0.1.0" +num-traits = "0.2" +num-derive = "0.2" +proptest = "0.9" +protobuf = "2.6" +rand = "0.6.5" +serde = { version = "1.0.87", features = ["derive"] } +tokio = "0.1.11" +termion = "1.5.1" +lazy_static = "1.3.0" +rmp-serde = "0.13.7" + +canonical_serialization = { path = "../common/canonical_serialization" } +channel = { path = "../common/channel" } +config = { path = "../config" } +crypto = { path = "../crypto/legacy_crypto" } +execution_proto = { path = "../execution/execution_proto" } +failure = { path = "../common/failure_ext", package = "failure_ext" } +grpc_helpers = { path = "../common/grpc_helpers" } +logger = { path = "../common/logger" } +mempool = { path = "../mempool" } +metrics = { path = "../common/metrics" } +network = { path = "../network" } +proto_conv = { path = "../common/proto_conv" } +schemadb = { path = "../storage/schemadb" } +storage_client = { path = "../storage/storage_client" } +storage_proto = { path = "../storage/storage_proto" } +tools = { path = "../common/tools" } +types = { path = "../types" } + +[dependencies.prometheus] +version = "0.4.2" +default-features = false +features = ["nightly", "push"] + +[build-dependencies] +build_helpers = { path = "../common/build_helpers" } + +[dev-dependencies] +cached = "0.8.0" +tempfile = "3.0.6" +parity-multiaddr = "0.4.0" +rusty-fork = "0.2.1" + +config_builder = { path = "../config/config_builder" } +execution_service = { path = "../execution/execution_service" } +storage_service = { path = "../storage/storage_service" } +vm_genesis = { path = "../language/vm/vm_genesis" } +vm_validator = { path = "../vm_validator" } diff --git a/consensus/README.md b/consensus/README.md new file mode 100644 index 0000000000000..9bc2c0244f7dc --- /dev/null +++ b/consensus/README.md @@ -0,0 +1,63 @@ +--- +id: consensus +title: Consensus +custom_edit_url: https://github.com/libra/libra/edit/master/consensus/README.md +--- + +# Consensus + +The consensus component supports state machine replication using the LibraBFT consensus protocol. + +## Overview + +A consensus protocol allows a set of validators to create the logical appearance of a single database. The consensus protocol replicates submitted transactions among the validators, executes potential transactions against the current database, and then agrees on a binding commitment to the ordering of transactions and resulting execution. As a result, all validators can maintain an identical database for a given version number following the [state machine replication paradigm](https://dl.acm.org/citation.cfm?id=98167). The Libra Blockchain uses a variant of the [HotStuff consensus protocool](https://arxiv.org/pdf/1803.05069.pdf), a recent Byzantine fault-tolerant ([BFT](https://en.wikipedia.org/wiki/Byzantine_fault)) consensus protocol, called LibraBFT. It provides safety (all honest validators agree on commits and execution) and liveness (commits are continually produced) in the partial synchrony model defined in the paper "Consensus in the Presence of Partial Synchrony" by Dwork, Lynch, and Stockmeyer ([DLS](https://groups.csail.mit.edu/tds/papers/Lynch/jacm88.pdf)) and used in [PBFT](http://pmg.csail.mit.edu/papers/osdi99.pdf) as well as newer protocols such as [Tendermint](https://arxiv.org/abs/1807.04938). In this document, we present a high-level description of the LibraBFT protocol and discuss how the code is organized. See [here](https://developers.libra.org/docs/the-libra-blockchain-paper) to learn how LibraBFT fits into the Libra Blockchain. For details on the specifications and proofs of LibraBFT, read the full [technical report](https://developers.libra.org/docs/state-machine-replication-paper). + +Agreement on the database state must be reached between validators, even if +there are Byzantine faults. The Byzantine failures model allows some validators +to arbitrarily deviate from the protocol without constraint, with the exception +of being computationally bound (and thus not able to break cryptographic assumptions). Byzantine faults are worst-case errors where validators collude and behave maliciously to try to sabotage system behavior. A consensus protocol that tolerates Byzantine faults caused by malicious or hacked validators can also mitigate arbitrary hardware and software failures. + +LibraBFT assumes that a set of 3f + 1 votes is distributed among a set of validators that may be honest or Byzantine. LibraBFT remains safe, preventing attacks such as double spends and forks when at most f votes are controlled by Byzantine validators — also implying that at least 2f+1 votes are honest. LibraBFT remains live, committing transactions from clients, as long as there exists a global stabilization time (GST), after which all messages between honest validators are delivered to other honest validators within a maximal network delay $\Delta$ (this is the partial synchrony model introduced in [DLS](https://groups.csail.mit.edu/tds/papers/Lynch/jacm88.pdf)). In addition to traditional guarantees, LibraBFT maintains safety when validators crash and restart β€” even if all validators restart at the same time. + +### LibraBFT Overview + +In LibraBFT, validators receive transactions from clients and share them with each other through a shared mempool protocol. The LibraBFT protocol then proceeds in a sequence of rounds. In each round, a validator takes the role of leader and proposes a block of transactions to extend a certified sequence of blocks (see quorum certificates below) that contain the full previous transaction history. A validator receives the proposed block and checks their voting rules to determine if it should vote for certifying this block. These simple rules ensure the safety of LibraBFT β€” and their implementation can be cleanly separated and audited. If the validator intends to vote for this block, it executes the block’s transactions speculatively and without external effect. This results in the computation of an authenticator for the database that results from the execution of the block. The validator then sends a signed vote for the block and the database authenticator to the leader. The leader gathers these votes to form a quorum certificate that provides evidence of $\ge$ 2f + 1 votes for this block and broadcasts the quorum certificate to all validators. + +A block is committed when a contiguous 3-chain commit rule is met. A block at round k is committed if it has a quorum certificate and is confirmed by two more blocks and quorum certificates at rounds k + 1 and k + 2. The commit rule eventually allows honest validators to commit a block. LibraBFT guarantees that all honest validators will eventually commit the block (and proceeding sequence of blocks linked from it). Once a sequence of blocks has committed, the state resulting from executing their transactions can be persisted and forms a replicated database. + +### Advantages of the HotStuff Paradigm + +We evaluated several BFT-based protocols against the dimensions of performance, reliability, security, ease of robust implementation, and operational overhead for validators. Our goal was to choose a protocol that would initially support at least 100 validators and would be able to evolve over time to support 500–1,000 validators. We had three reasons for selecting the HotStuff protocol as the basis for LibraBFT: (i) simplicity and modularity; (ii) ability to easily integrate consensus with execution; and (iii) promising performance in early experiments. + +The HotStuff protocol decomposes into modules for safety (voting and commit rules) and liveness (pacemaker). This decoupling provides the ability to develop and experiment independently and on different modules in parallel. Due to the simple voting and commit rules, protocol safety is easy to implement and verify. It is straightforward to integrate execution as a part of consensus to avoid forking issues that arise from non-deterministic execution in a leader-based protocol. Finally, our early prototypes confirmed high throughput and low transaction latency as independently measured in [HotStuff]((https://arxiv.org/pdf/1803.05069.pdf)). We did not consider proof-of-work based protocols, such as [Bitcoin](https://bitcoin.org/bitcoin.pdf), due to their poor performance +and high energy (and environmental) costs. + +### HotStuff Extensions and Modifications + +In LibraBFT, to better support the goals of the Libra ecosystem, we extend and adapt the core HotStuff protocol and implementation in several ways. Importantly, we reformulate the safety conditions and provide extended proofs of safety, liveness, and optimistic responsiveness. We also implement a number of additional features. First, we make the protocol more resistant to non-determinism bugs, by having validators collectively sign the resulting state of a block rather than just the sequence of transactions. This also allows clients to use quorum certificates to authenticate reads from the database. Second, we design a pacemaker that emits explicit timeouts, and validators rely on a quorum of those to move to the next round β€” without requiring synchronized clocks. Third, we intend to design an unpredictable leader election mechanism in which the leader of a round is determined by the proposer of the latest committed block using a verifiable random function [VRF](https://people.csail.mit.edu/silvio/Selected%20Scientific%20Papers/Pseudo%20Randomness/Verifiable_Random_Functions.pdf). This mechanism limits the window of time in which an adversary can launch an effective denial-of-service attack against a leader. Fourth, we use aggregate signatures that preserve the identity of validators who sign quorum certificates. This allows us to provide incentives to validators that contribute to quorum certificates. Aggregate signatures also do not require a complex [threshold key setup](https://www.cypherpunks.ca/~iang/pubs/DKG.pdf). + +## Implementation Details + +The consensus component is mostly implemented in the [Actor](https://en.wikipedia.org/wiki/Actor_model) programming model — i.e., it uses message-passing to communicate between different subcomponents with the [tokio](https://tokio.rs/) framework used as the task runtime. The primary exception to the actor model (as it is accessed in parallel by several subcomponents) is the consensus data structure *BlockStore* which manages the blocks, execution, quorum certificates, and other shared data structures. The major subcomponents in the consensus component are: + +* **TxnManager** is the interface to the mempool component and supports the pulling of transactions as well as removing committed transactions. A proposer uses on-demand pull transactions from mempool to form a proposal block. +* **StateComputer** is the interface for accessing the execution component. It can execute blocks, commit blocks, and can synchronize state. +* **BlockStore** maintains the tree of proposal blocks, block execution, votes, quorum certificates, and persistent storage. It is responsible for maintaining the consistency of the combination of these data structures and can be concurrently accessed by other subcomponents. +* **EventProcessor** is responsible for processing the individual events (e.g., process_new_round, process_proposal, process_vote). It exposes the async processing functions for each event type and drives the protocol. +* **Pacemaker** is responsible for the liveness of the consensus protocol. It changes rounds due to timeout certificates or quorum certificates and proposes blocks when it is the proposer for the current round. +* **SafetyRules** is responsible for the safety of the consensus protocol. It processes quorum certificates and LedgerInfo to learn about new commits and guarantees that the two voting rules are followed — even in the case of restart (since all safety data is persisted to local storage). + +All consensus messages are signed by their creators and verified by their receivers. Message verification occurs closest to the network layer to avoid invalid or unnecessary data from entering the consensus protocol. + +## How is this module organized? + + consensus + β”œβ”€β”€ src + β”‚Β Β  └── chained_bft # Implementation of the LibraBFT protoocol + β”‚Β Β  β”œβ”€β”€ block_storage # In-memory storage of blocks and related data structures + β”‚Β Β  β”œβ”€β”€ consensus_types # Consensus data types (i.e. quorum certificates) + β”‚Β Β  β”œβ”€β”€ consensusdb # Database interaction to persist consensus data for safety and liveness + β”‚Β Β  β”œβ”€β”€ liveness # Pacemaker, proposer, and other liveness related code + β”‚Β Β  β”œβ”€β”€ safety # Safety (voting) rules + β”‚Β Β  └── test_utils # Mock implementations that are used for testing only + └── state_synchronizer # Synchronization between validators to catch up on committed state diff --git a/consensus/build.rs b/consensus/build.rs new file mode 100644 index 0000000000000..59003e0d0d26e --- /dev/null +++ b/consensus/build.rs @@ -0,0 +1,18 @@ +// Copyright (c) The Libra Core Contributors +// SPDX-License-Identifier: Apache-2.0 + +//! This compiles all the `.proto` files under `src/` directory. +//! +//! For example, if there is a file `src/a/b/c.proto`, it will generate `src/a/b/c.rs` and +//! `src/a/b/c_grpc.rs`. + +fn main() { + let proto_root = "src"; + let dependent_root = "../types/src/proto"; + + build_helpers::build_helpers::compile_proto( + proto_root, + vec![dependent_root], + false, /* generate_client_stub */ + ); +} diff --git a/consensus/src/.gitignore b/consensus/src/.gitignore new file mode 100644 index 0000000000000..2e7de426c25fc --- /dev/null +++ b/consensus/src/.gitignore @@ -0,0 +1 @@ +!lib.rs diff --git a/consensus/src/chained_bft/block_storage/block_inserter.rs b/consensus/src/chained_bft/block_storage/block_inserter.rs new file mode 100644 index 0000000000000..99dcf42b37679 --- /dev/null +++ b/consensus/src/chained_bft/block_storage/block_inserter.rs @@ -0,0 +1,187 @@ +// Copyright (c) The Libra Core Contributors +// SPDX-License-Identifier: Apache-2.0 + +use crate::{ + chained_bft::{ + block_storage::{block_tree::BlockTree, InsertError}, + common::Payload, + consensus_types::block::Block, + persistent_storage::PersistentStorage, + }, + state_replication::{ExecutedState, StateComputer}, +}; +use crypto::{hash::CryptoHash, HashValue}; +use futures::compat::Future01CompatExt; +use logger::prelude::*; +use std::sync::{Arc, RwLock}; + +/// The max number of the parallel executions a BlockInserterGuard allows. +const MAX_PARALLEL_EXECUTIONS: u8 = 32; + +/// BlockInserterGuard is using futures_locks asynchronous locks in order to synchronize +/// async insertions from different concurrent tasks. +/// It keeps multiple instances of rwlocks in order to reduce unnecessary serialization of +/// independent blocks: the guard is determined by a given block id. +pub struct BlockInserterGuard { + // The guards is a vector of async locks to the very same instance of inserter. + // The guard is chosen as a function of a block id: same blocks operations must be serialized. + guards: Vec>>>, +} + +impl BlockInserterGuard { + pub fn new( + inner: Arc>>, + state_computer: Arc>, + enforce_increasing_timestamps: bool, + storage: Arc>, + ) -> Self { + let inserter_ref = Arc::new(BlockInserter::new( + inner, + state_computer, + enforce_increasing_timestamps, + storage, + )); + let mut guards = vec![]; + for _ in 0..MAX_PARALLEL_EXECUTIONS { + guards.push(futures_locks::RwLock::new(Arc::clone(&inserter_ref))); + } + + Self { guards } + } + + /// Execute and insert a block if it passes all validation tests. + /// Returns the Arc to the block kept in the block tree. + /// + /// This function assumes that the ancestors are present (returns MissingParent otherwise). + /// + /// Duplicate inserts will return the previously inserted block ( + /// note that it is considered a valid non-error case, for example, it can happen if a validator + /// receives a certificate for a block that is currently being added). + pub async fn execute_and_insert_block( + &self, + block: Block, + ) -> Result>, InsertError> { + // Choose a guard deterministically as a function of a block id: different requests for the + // same block must be serialized. + let guard_idx = (*block.id().to_vec().last().unwrap() % MAX_PARALLEL_EXECUTIONS) as usize; + let inserter = self + .guards + .get(guard_idx) + .unwrap() + .write() + .compat() + .await + .unwrap(); + inserter.execute_and_insert_block(block).await + } +} + +struct BlockInserter { + inner: Arc>>, + state_computer: Arc>, + enforce_increasing_timestamps: bool, + /// The persistent storage backing up the in-memory data structure, every write should go + /// through this before in-memory tree. + storage: Arc>, +} + +impl BlockInserter { + fn new( + inner: Arc>>, + state_computer: Arc>, + enforce_increasing_timestamps: bool, + storage: Arc>, + ) -> Self { + Self { + inner, + state_computer, + enforce_increasing_timestamps, + storage, + } + } + + async fn execute_and_insert_block( + &self, + block: Block, + ) -> Result>, InsertError> { + if let Some(existing_block) = self.inner.read().unwrap().get_block(block.id()) { + return Ok(existing_block); + } + let (parent_id, parent_exec_version) = match self.verify_and_get_parent_info(&block) { + Ok(t) => t, + Err(e) => { + security_log(SecurityEvent::InvalidBlock) + .error(&e) + .data(&block) + .log(); + return Err(e); + } + }; + let compute_res = self + .state_computer + .compute(parent_id, block.id(), block.get_payload()) + .await + .map_err(|e| { + error!("Execution failure for block {}: {:?}", block, e); + InsertError::StateComputerError + })?; + + let version = parent_exec_version + compute_res.num_successful_txns; + + let state = ExecutedState { + state_id: compute_res.new_state_id, + version, + }; + self.storage + .save_tree(vec![block.clone()], vec![]) + .map_err(|_| InsertError::StorageFailure)?; + self.inner + .write() + .unwrap() + .insert_block(block, state, compute_res) + .map_err(|e| e.into()) + } + + /// All the verifications of a block that is going to be added to the tree. + /// We assume that all the ancestors are present, returns MissingParent error otherwise. + /// Returns parent id and version in case of success. + fn verify_and_get_parent_info( + &self, + block: &Block, + ) -> Result<(HashValue, u64), InsertError> { + if block.round() <= self.inner.read().unwrap().root().round() { + return Err(InsertError::OldBlock); + } + + let block_hash = block.hash(); + if block.id() != block_hash { + return Err(InsertError::InvalidBlockHash); + } + + if block.quorum_cert().certified_block_id() != block.parent_id() { + return Err(InsertError::ParentNotCertified); + } + + let parent = match self.inner.read().unwrap().get_block(block.parent_id()) { + None => { + return Err(InsertError::MissingParentBlock(block.parent_id())); + } + Some(parent) => parent, + }; + if parent.height() + 1 != block.height() { + return Err(InsertError::InvalidBlockHeight); + } + if parent.round() >= block.round() { + return Err(InsertError::InvalidBlockRound); + } + if self.enforce_increasing_timestamps && parent.timestamp_usecs() >= block.timestamp_usecs() + { + return Err(InsertError::NonIncreasingTimestamp); + } + let parent_id = parent.id(); + match self.inner.read().unwrap().get_state_for_block(parent_id) { + Some(ExecutedState { version, .. }) => Ok((parent.id(), version)), + None => Err(InsertError::ParentVersionNotFound), + } + } +} diff --git a/consensus/src/chained_bft/block_storage/block_store.rs b/consensus/src/chained_bft/block_storage/block_store.rs new file mode 100644 index 0000000000000..29cf7a83fab49 --- /dev/null +++ b/consensus/src/chained_bft/block_storage/block_store.rs @@ -0,0 +1,472 @@ +// Copyright (c) The Libra Core Contributors +// SPDX-License-Identifier: Apache-2.0 + +use crate::{ + chained_bft::{ + block_storage::{ + block_tree::BlockTree, BlockReader, BlockTreeError, InsertError, VoteReceptionResult, + }, + common::{Payload, Round}, + consensus_types::{block::Block, quorum_cert::QuorumCert}, + persistent_storage::PersistentStorage, + safety::vote_msg::VoteMsg, + }, + state_replication::{ExecutedState, StateComputer}, +}; +use crypto::HashValue; +use logger::prelude::*; + +use crate::{ + chained_bft::{ + block_storage::block_inserter::BlockInserterGuard, persistent_storage::RecoveryData, + }, + state_replication::StateComputeResult, +}; +use mirai_annotations::checked_precondition; +use std::{ + collections::{vec_deque::VecDeque, HashMap}, + sync::{Arc, RwLock}, +}; +use types::{ledger_info::LedgerInfo, validator_signer::ValidatorSigner}; + +#[cfg(test)] +#[path = "block_store_test.rs"] +mod block_store_test; + +#[derive(Debug, PartialEq)] +/// Whether we need to do block retrieval if we want to insert a Quorum Cert. +pub enum NeedFetchResult { + QCRoundBeforeRoot, + QCAlreadyExist, + QCBlockExist, + NeedFetch, +} + +/// Responsible for maintaining all the blocks of payload and the dependencies of those blocks +/// (parent and previous QC links). It is expected to be accessed concurrently by multiple threads +/// and is thread-safe. +/// +/// Example tree block structure based on parent links. +/// | -> A3 +/// Genesis -> B0 -> B1 -> B2 -> B3 +/// | -> C1 -> C2 +/// | -> D3 +/// +/// Example corresponding tree block structure for the QC links (must follow QC constraints). +/// | -> A3 +/// Genesis -> B0 -> B1 -> B2 -> B3 +/// | -> C1 +/// | -------> C2 +/// | -------------> D3 +pub struct BlockStore { + inner: Arc>>, + block_inserter: BlockInserterGuard, + validator_signer: ValidatorSigner, + state_computer: Arc>, + enforce_increasing_timestamps: bool, + /// The persistent storage backing up the in-memory data structure, every write should go + /// through this before in-memory tree. + storage: Arc>, +} + +impl BlockStore { + pub async fn new( + storage: Arc>, + initial_data: RecoveryData, + validator_signer: ValidatorSigner, + state_computer: Arc>, + enforce_increasing_timestamps: bool, + max_pruned_blocks_in_mem: usize, + ) -> Self { + let (root, blocks, quorum_certs) = initial_data.take(); + let inner = Arc::new(RwLock::new( + Self::build_block_tree( + root, + blocks, + quorum_certs, + Arc::clone(&state_computer), + max_pruned_blocks_in_mem, + ) + .await, + )); + let block_inserter = BlockInserterGuard::new( + Arc::clone(&inner), + Arc::clone(&state_computer), + enforce_increasing_timestamps, + Arc::clone(&storage), + ); + BlockStore { + inner, + block_inserter, + validator_signer, + state_computer, + enforce_increasing_timestamps, + storage, + } + } + + async fn build_block_tree( + root: (Block, QuorumCert, QuorumCert), + blocks: Vec>, + quorum_certs: Vec, + state_computer: Arc>, + max_pruned_blocks_in_mem: usize, + ) -> BlockTree { + let mut tree = BlockTree::new(root.0, root.1, root.2, max_pruned_blocks_in_mem); + let quorum_certs = quorum_certs + .into_iter() + .map(|qc| (qc.certified_block_id(), qc)) + .collect::>(); + for block in blocks { + let compute_res = state_computer + .compute(block.parent_id(), block.id(), block.get_payload()) + .await + .expect("fail to rebuild scratchpad"); + let version = tree + .get_state_for_block(block.parent_id()) + .expect("parent state does not exist") + .version + + compute_res.num_successful_txns; + let executed_state = ExecutedState { + state_id: compute_res.new_state_id, + version, + }; + // if this block is certified, ensure we agree with the certified state. + if let Some(qc) = quorum_certs.get(&block.id()) { + assert_eq!( + qc.certified_state(), + executed_state, + "We have inconsistent executed state with Quorum Cert for block {}", + block.id() + ); + } + tree.insert_block(block, executed_state, compute_res) + .expect("Block insertion failed while build the tree"); + } + quorum_certs.into_iter().for_each(|(_, qc)| { + tree.insert_quorum_cert(qc) + .expect("QuorumCert insertion failed while build the tree") + }); + tree + } + + pub async fn rebuild( + &self, + root: (Block, QuorumCert, QuorumCert), + blocks: Vec>, + quorum_certs: Vec, + ) { + let tree = Self::build_block_tree( + root, + blocks, + quorum_certs, + Arc::clone(&self.state_computer), + self.inner.read().unwrap().max_pruned_blocks_in_mem(), + ) + .await; + let to_remove = self.inner.read().unwrap().get_all_block_id(); + if let Err(e) = self.storage.prune_tree(to_remove) { + // it's fine to fail here, the next restart will try to clean up dangling blocks again. + error!("fail to delete block: {:?}", e); + } + *self.inner.write().unwrap() = tree; + } + + pub fn signer(&self) -> &ValidatorSigner { + &self.validator_signer + } + + /// Execute and insert a block if it passes all validation tests. + /// Returns the Arc to the block kept in the block store. + /// + /// This function assumes that the ancestors are present (returns MissingParent otherwise). + /// + /// Duplicate inserts will return the previously inserted block ( + /// note that it is considered a valid non-error case, for example, it can happen if a validator + /// receives a certificate for a block that is currently being added). + pub async fn execute_and_insert_block( + &self, + block: Block, + ) -> Result>, InsertError> { + self.block_inserter.execute_and_insert_block(block).await + } + + /// Check if we're far away from this ledger info and need to sync. + /// Returns false if we have this block in the tree or the root's round is higher than the + /// block. + pub fn need_sync_for_quorum_cert( + &self, + committed_block_id: HashValue, + qc: &QuorumCert, + ) -> bool { + // LedgerInfo doesn't carry the information about the round of the committed block. However, + // the 3-chain safety rules specify that the round of the committed block must be + // certified_block_round() - 2. In case root().round() is greater than that the committed + // block carried by LI is older than my current commit. + !(self.block_exists(committed_block_id) + || self.root().round() + 2 >= qc.certified_block_round()) + } + + /// Checks if quorum certificate can be inserted in block store without RPC + /// Returns the enum to indicate the detailed status. + pub fn need_fetch_for_quorum_cert(&self, qc: &QuorumCert) -> NeedFetchResult { + if qc.certified_block_round() < self.root().round() { + return NeedFetchResult::QCRoundBeforeRoot; + } + if self + .get_quorum_cert_for_block(qc.certified_block_id()) + .is_some() + { + return NeedFetchResult::QCAlreadyExist; + } + if self.block_exists(qc.certified_block_id()) { + return NeedFetchResult::QCBlockExist; + } + NeedFetchResult::NeedFetch + } + + /// Validates quorum certificates and inserts it into block tree assuming dependencies exist. + pub async fn insert_single_quorum_cert(&self, qc: QuorumCert) -> Result<(), InsertError> { + // Ensure executed state is consistent with Quorum Cert, otherwise persist the quorum's + // state and hopefully we restart and agree with it. + let executed_state = self + .get_state_for_block(qc.certified_block_id()) + .ok_or_else(|| InsertError::MissingParentBlock(qc.certified_block_id()))?; + assert_eq!( + executed_state, + qc.certified_state(), + "We have inconsistent executed state with the executed state from the quorum \ + certificate for block {}, will kill this validator and rely on state synchronization \ + to try to achieve consistent state with the quorum certificate.", + qc.certified_block_id(), + ); + self.storage + .save_tree(vec![], vec![qc.clone()]) + .map_err(|_| InsertError::StorageFailure)?; + self.inner + .write() + .unwrap() + .insert_quorum_cert(qc) + .map_err(|e| e.into()) + } + + /// Adds a vote for the block. + /// The returned value either contains the vote result (with new / old QC etc.) or a + /// verification error. + /// A block store does not verify that the block, which is voted for, is present locally. + /// It returns QC, if it is formed, but does not insert it into block store, because it might + /// not have required dependencies yet + /// Different execution ids are treated as different blocks (e.g., if some proposal is + /// executed in a non-deterministic fashion due to a bug, then the votes for execution result + /// A and the votes for execution result B are aggregated separately). + pub async fn insert_vote( + &self, + vote_msg: VoteMsg, + min_votes_for_qc: usize, + ) -> VoteReceptionResult { + self.inner + .write() + .unwrap() + .insert_vote(&vote_msg, min_votes_for_qc) + } + + /// Prune the tree up to next_root_id (keep next_root_id's block). Any branches not part of + /// the next_root_id's tree should be removed as well. + /// + /// For example, root = B_0 + /// B_0 -> B_1 -> B_2 + /// | -> B_3 -> B4 + /// + /// prune_tree(B_3) should be left with + /// B_3 -> B_4, root = B_3 + /// + /// Returns the block ids of the blocks removed. + pub async fn prune_tree(&self, next_root_id: HashValue) -> VecDeque { + let id_to_remove = self + .inner + .read() + .unwrap() + .find_blocks_to_prune(next_root_id); + if let Err(e) = self + .storage + .prune_tree(id_to_remove.clone().into_iter().collect()) + { + // it's fine to fail here, as long as the commit succeeds, the next restart will clean + // up dangling blocks, and we need to prune the tree to keep the root consistent with + // executor. + error!("fail to delete block: {:?}", e); + } + self.inner + .write() + .unwrap() + .process_pruned_blocks(next_root_id, id_to_remove.clone()); + id_to_remove + } + + /// If block id information is found, returns the ledger info placeholder, otherwise, return + /// a placeholder with info of the genesis block. + pub fn ledger_info_placeholder(&self, id: Option) -> LedgerInfo { + let block_id = match id { + None => return Self::zero_ledger_info_placeholder(), + Some(id) => id, + }; + let block = match self.get_block(block_id) { + Some(b) => b, + None => { + return Self::zero_ledger_info_placeholder(); + } + }; + let (state_id, version) = match self.get_state_for_block(block_id) { + Some(state) => (state.state_id, state.version), + None => { + return Self::zero_ledger_info_placeholder(); + } + }; + LedgerInfo::new( + version, + state_id, + HashValue::zero(), + block_id, + 0, // TODO [Reconfiguration] use the real epoch number. + block.timestamp_usecs(), + ) + } + + /// Used in case we're using a ledger info just as a placeholder for signing the votes / QCs + /// and there is no real block committed. + /// It's all pretty much zeroes. + fn zero_ledger_info_placeholder() -> LedgerInfo { + LedgerInfo::new( + 0, + HashValue::zero(), + HashValue::zero(), + HashValue::zero(), + 0, + 0, + ) + } +} + +impl BlockReader for BlockStore { + type Payload = T; + + fn block_exists(&self, block_id: HashValue) -> bool { + self.inner.read().unwrap().block_exists(block_id) + } + + fn get_block(&self, block_id: HashValue) -> Option>> { + self.inner.read().unwrap().get_block(block_id) + } + + fn get_state_for_block(&self, block_id: HashValue) -> Option { + self.inner.read().unwrap().get_state_for_block(block_id) + } + + fn get_compute_result(&self, block_id: HashValue) -> Option> { + self.inner.read().unwrap().get_compute_result(block_id) + } + + fn root(&self) -> Arc> { + self.inner.read().unwrap().root() + } + + fn get_quorum_cert_for_block(&self, block_id: HashValue) -> Option> { + self.inner + .read() + .unwrap() + .get_quorum_cert_for_block(block_id) + } + + fn is_ancestor( + &self, + ancestor: &Block, + block: &Block, + ) -> Result { + self.inner.read().unwrap().is_ancestor(ancestor, block) + } + + fn path_from_root(&self, block: Arc>) -> Option>>> { + self.inner.read().unwrap().path_from_root(block) + } + + fn create_block( + &self, + parent: Arc>, + payload: Self::Payload, + round: Round, + timestamp_usecs: u64, + ) -> Block { + if self.enforce_increasing_timestamps { + checked_precondition!(parent.timestamp_usecs() < timestamp_usecs); + } + let quorum_cert = self + .get_quorum_cert_for_block(parent.id()) + .expect("Parent for the newly created block is not certified!") + .as_ref() + .clone(); + Block::make_block( + parent.as_ref(), + payload, + round, + timestamp_usecs, + quorum_cert, + &self.validator_signer, + ) + } + + fn highest_certified_block(&self) -> Arc> { + self.inner.read().unwrap().highest_certified_block() + } + + fn highest_quorum_cert(&self) -> Arc { + self.inner.read().unwrap().highest_quorum_cert() + } + + fn highest_ledger_info(&self) -> Arc { + self.inner.read().unwrap().highest_ledger_info() + } +} + +#[cfg(test)] +impl BlockStore { + /// Returns the number of blocks in the tree + fn len(&self) -> usize { + self.inner.read().unwrap().len() + } + + /// Returns the number of child links in the tree + fn child_links(&self) -> usize { + self.inner.read().unwrap().child_links() + } + + /// The number of pruned blocks that are still available in memory + pub(super) fn pruned_blocks_in_mem(&self) -> usize { + self.inner.read().unwrap().pruned_blocks_in_mem() + } + + /// Helper to insert vote and qc + /// Can't be used in production, because production insertion potentially requires state sync + pub async fn insert_vote_and_qc( + &self, + vote_msg: VoteMsg, + qc_size: usize, + ) -> VoteReceptionResult { + let r = self.insert_vote(vote_msg, qc_size).await; + if let VoteReceptionResult::NewQuorumCertificate(ref qc) = r { + self.insert_single_quorum_cert(qc.as_ref().clone()) + .await + .unwrap(); + } + r + } + + /// Helper function to insert the block with the qc together + pub async fn insert_block_with_qc( + &self, + block: Block, + ) -> Result>, InsertError> { + self.insert_single_quorum_cert(block.quorum_cert().clone()) + .await?; + Ok(self.execute_and_insert_block(block).await?) + } +} diff --git a/consensus/src/chained_bft/block_storage/block_store_test.rs b/consensus/src/chained_bft/block_storage/block_store_test.rs new file mode 100644 index 0000000000000..6a9214736bf3b --- /dev/null +++ b/consensus/src/chained_bft/block_storage/block_store_test.rs @@ -0,0 +1,488 @@ +// Copyright (c) The Libra Core Contributors +// SPDX-License-Identifier: Apache-2.0 + +use crate::chained_bft::{ + block_storage::{BlockReader, BlockStore, InsertError, NeedFetchResult, VoteReceptionResult}, + common::Author, + consensus_types::{ + block::{block_test, Block}, + quorum_cert::QuorumCert, + }, + safety::vote_msg::VoteMsg, + test_utils::{ + build_empty_tree, build_empty_tree_with_custom_signing, placeholder_certificate_for_block, + placeholder_ledger_info, TreeInserter, + }, +}; +use crypto::HashValue; +use futures::executor::block_on; +use proptest::prelude::*; +use std::{cmp::min, collections::HashSet, sync::Arc}; +use types::{account_address::AccountAddress, validator_signer::ValidatorSigner}; + +fn build_simple_tree() -> (Vec>>>, Arc>>) { + let block_store = build_empty_tree(); + let genesis = block_store.root(); + let genesis_block_id = genesis.id(); + let genesis_block = block_store + .get_block(genesis_block_id) + .expect("genesis block must exist"); + assert_eq!(block_store.len(), 1); + assert_eq!(block_store.child_links(), block_store.len() - 1); + assert_eq!(block_store.block_exists(genesis_block.id()), true); + + // | -> A1 -> A2 -> A3 + // Genesis -> B1 -> B2 + // | -> C1 + let mut inserter = TreeInserter::new(block_store.clone()); + let a1 = inserter.insert_block(genesis_block.as_ref(), 1); + let a2 = inserter.insert_block(a1.as_ref(), 2); + let a3 = inserter.insert_block(a2.as_ref(), 3); + let b1 = inserter.insert_block(genesis_block.as_ref(), 4); + let b2 = inserter.insert_block(b1.as_ref(), 5); + let c1 = inserter.insert_block(b1.as_ref(), 6); + + assert_eq!(block_store.len(), 7); + assert_eq!(block_store.child_links(), block_store.len() - 1); + + (vec![genesis_block, a1, a2, a3, b1, b2, c1], block_store) +} + +#[test] +fn test_block_store_create_block() { + let block_store = build_empty_tree(); + let genesis = block_store.root(); + let a1 = block_store.create_block(Arc::clone(&genesis), vec![1], 1, 1); + assert_eq!(a1.parent_id(), genesis.id()); + assert_eq!(a1.round(), 1); + assert_eq!(a1.height(), 1); + assert_eq!(a1.quorum_cert().certified_block_id(), genesis.id()); + + let a1_ref = block_on(block_store.execute_and_insert_block(a1)).unwrap(); + + // certify a1 + let vote_msg = VoteMsg::new( + a1_ref.id(), + block_store.get_state_for_block(a1_ref.id()).unwrap(), + a1_ref.round(), + block_store.signer().author(), + placeholder_ledger_info(), + block_store.signer(), + ); + block_on(block_store.insert_vote_and_qc(vote_msg, 1)); + + let b1 = block_store.create_block(Arc::clone(&a1_ref), vec![2], 2, 2); + assert_eq!(b1.parent_id(), a1_ref.id()); + assert_eq!(b1.round(), 2); + assert_eq!(b1.height(), 2); + assert_eq!(b1.quorum_cert().certified_block_id(), a1_ref.id()); +} + +#[test] +fn test_highest_block_and_quorum_cert() { + let block_store = build_empty_tree(); + assert_eq!( + block_store.highest_certified_block().as_ref(), + &Block::make_genesis_block() + ); + assert_eq!( + block_store.highest_quorum_cert().as_ref(), + &QuorumCert::certificate_for_genesis() + ); + + let genesis = block_store.root(); + let mut inserter = TreeInserter::new(block_store.clone()); + + // Genesis block and quorum certificate is still the highest + let block_round_1 = inserter.insert_block(genesis.as_ref(), 1); + assert_eq!( + block_store.highest_certified_block().as_ref(), + &Block::make_genesis_block() + ); + assert_eq!( + block_store.highest_quorum_cert().as_ref(), + &QuorumCert::certificate_for_genesis() + ); + + // block_round_1 block and quorum certificate is now the highest + let block_round_3 = inserter.insert_block(block_round_1.as_ref(), 3); + assert_eq!( + block_store.highest_certified_block().as_ref(), + block_round_1.as_ref() + ); + assert_eq!( + block_store.highest_quorum_cert().as_ref(), + block_store + .get_block(block_round_3.id()) + .expect("block_round_1 should exist") + .quorum_cert() + ); + + // block_round_1 block and quorum certificate is still the highest, since block_round_4 + // also builds on block_round_1 + let block_round_4 = inserter.insert_block(block_round_1.as_ref(), 4); + assert_eq!( + block_store.highest_certified_block().as_ref(), + block_round_1.as_ref() + ); + assert_eq!( + block_store.highest_quorum_cert().as_ref(), + block_store + .get_block(block_round_4.id()) + .expect("block_round_1 should exist") + .quorum_cert() + ); +} + +#[test] +fn test_qc_ancestry() { + let block_store = build_empty_tree(); + let genesis = block_store.root(); + let mut inserter = TreeInserter::new(block_store.clone()); + let block_a_1 = inserter.insert_block(genesis.as_ref(), 1); + let block_a_2 = inserter.insert_block(block_a_1.as_ref(), 2); + + assert_eq!( + block_store.get_block(genesis.quorum_cert().certified_block_id()), + None + ); + assert_eq!( + block_store.get_block(block_a_1.quorum_cert().certified_block_id()), + Some(genesis) + ); + assert_eq!( + block_store.get_block(block_a_2.quorum_cert().certified_block_id()), + Some(block_a_1) + ); +} + +// This test should be continuously extended to eventually become the +// single-page spec for the logic of our block storage. +proptest! { + + #[test] + fn test_block_store_insert( + (keypairs, blocks) in block_test::block_forest_and_its_keys( + // quorum size + 10, + // recursion depth + 50) + ){ + let authors: HashSet = keypairs.iter().map(|(_, public_key)| AccountAddress::from(*public_key)).collect(); + let (priv_key, pub_key) = keypairs.first().expect("several keypairs generated"); + let signer = ValidatorSigner::new(AccountAddress::from(*pub_key), *pub_key, priv_key.clone()); + let block_store = build_empty_tree_with_custom_signing(signer.clone()); + for block in blocks { + if block.round() > 0 && authors.contains(&block.author()) { + let known_parent = block_store.block_exists(block.parent_id()); + let certified_parent = block.quorum_cert().certified_block_id() == block.parent_id(); + let res = block_on(block_store.execute_and_insert_block(block.clone())); + if !certified_parent { + prop_assert_eq!(res.err(), Some(InsertError::ParentNotCertified)); + } else if !known_parent { + // We cannot really bring blocks in this test because the block retrieval + // functionality invokes event processing, which is not setup here. + assert!(res.is_err()); + } + else { + // The parent must be present if we get to this line. + let parent = block_store.get_block(block.parent_id()).unwrap(); + if block.height() != parent.height() + 1 { + prop_assert_eq!(res.err(), Some(InsertError::InvalidBlockHeight)); + } else if block.round() <= parent.round() { + prop_assert_eq!(res.err(), Some(InsertError::InvalidBlockRound)); + } else { + prop_assert_eq!(res.clone().ok(), + Some(Arc::new(block.clone())), + "expected ok on block: {:#?}, got {:#?}", block, res); + } + } + } + } + } +} + +#[test] +fn test_block_store_prune() { + let (blocks, block_store) = build_simple_tree(); + // Attempt to prune genesis block (should be no-op) + assert_eq!(block_on(block_store.prune_tree(blocks[0].id())).len(), 0); + assert_eq!(block_store.len(), 7); + assert_eq!(block_store.child_links(), block_store.len() - 1); + assert_eq!(block_store.pruned_blocks_in_mem(), 0); + + let (blocks, block_store) = build_simple_tree(); + // Prune up to block A1 + assert_eq!(block_on(block_store.prune_tree(blocks[1].id())).len(), 4); + assert_eq!(block_store.len(), 3); + assert_eq!(block_store.child_links(), block_store.len() - 1); + assert_eq!(block_store.pruned_blocks_in_mem(), 4); + + let (blocks, block_store) = build_simple_tree(); + // Prune up to block A2 + assert_eq!(block_on(block_store.prune_tree(blocks[2].id())).len(), 5); + assert_eq!(block_store.len(), 2); + assert_eq!(block_store.child_links(), block_store.len() - 1); + assert_eq!(block_store.pruned_blocks_in_mem(), 5); + + let (blocks, block_store) = build_simple_tree(); + // Prune up to block A3 + assert_eq!(block_on(block_store.prune_tree(blocks[3].id())).len(), 6); + assert_eq!(block_store.len(), 1); + assert_eq!(block_store.child_links(), block_store.len() - 1); + + let (blocks, block_store) = build_simple_tree(); + // Prune up to block B1 + assert_eq!(block_on(block_store.prune_tree(blocks[4].id())).len(), 4); + assert_eq!(block_store.len(), 3); + assert_eq!(block_store.child_links(), block_store.len() - 1); + + let (blocks, block_store) = build_simple_tree(); + // Prune up to block B2 + assert_eq!(block_on(block_store.prune_tree(blocks[5].id())).len(), 6); + assert_eq!(block_store.len(), 1); + assert_eq!(block_store.child_links(), block_store.len() - 1); + + let (blocks, block_store) = build_simple_tree(); + // Prune up to block C1 + assert_eq!(block_on(block_store.prune_tree(blocks[6].id())).len(), 6); + assert_eq!(block_store.len(), 1); + assert_eq!(block_store.child_links(), block_store.len() - 1); + + // Prune the chain of Genesis -> B1 -> B2 + let (blocks, block_store) = build_simple_tree(); + // Prune up to block B1 + assert_eq!(block_on(block_store.prune_tree(blocks[4].id())).len(), 4); + assert_eq!(block_store.len(), 3); + assert_eq!(block_store.child_links(), block_store.len() - 1); + // Prune up to block B2 + assert_eq!(block_on(block_store.prune_tree(blocks[5].id())).len(), 2); + assert_eq!(block_store.len(), 1); + assert_eq!(block_store.child_links(), block_store.len() - 1); +} + +#[test] +fn test_block_tree_gc() { + // build a tree with 100 nodes, max_pruned_nodes_in_mem = 10 + let block_store = build_empty_tree(); + let genesis = block_store.root(); + let mut cur_node = block_store.get_block(genesis.id()).unwrap(); + let mut added_blocks = vec![]; + + let mut inserter = TreeInserter::new(block_store.clone()); + for round in 1..100 { + cur_node = inserter.insert_block(cur_node.as_ref(), round); + added_blocks.push(cur_node.clone()); + } + + for (i, block) in added_blocks.iter().enumerate() { + assert_eq!(block_store.len(), 100 - i); + assert_eq!(block_store.pruned_blocks_in_mem(), min(i, 10)); + block_on(block_store.prune_tree(block.id())); + } +} + +#[test] +fn test_path_from_root() { + let block_store = build_empty_tree(); + let genesis = block_store.get_block(block_store.root().id()).unwrap(); + let mut inserter = TreeInserter::new(block_store.clone()); + let b1 = inserter.insert_block(genesis.as_ref(), 1); + let b2 = inserter.insert_block(b1.as_ref(), 2); + let b3 = inserter.insert_block(b2.as_ref(), 3); + + assert_eq!( + block_store.path_from_root(b3.clone()), + Some(vec![b3.clone(), b2.clone(), b1.clone()]) + ); + assert_eq!(block_store.path_from_root(genesis.clone()), Some(vec![])); + + block_on(block_store.prune_tree(b2.id())); + + assert_eq!( + block_store.path_from_root(b3.clone()), + Some(vec![b3.clone()]) + ); + assert_eq!(block_store.path_from_root(genesis.clone()), None); +} + +#[test] +fn test_insert_vote() { + // Set up enough different authors to support different votes for the same block. + let qc_size = 10; + let mut signers = vec![]; + let mut author_public_keys = vec![]; + for _ in 0..qc_size { + let signer = ValidatorSigner::random(); + author_public_keys.push(( + AccountAddress::from(signer.public_key()), + signer.public_key(), + )); + signers.push(signer); + } + let my_signer = ValidatorSigner::random(); + author_public_keys.push(( + AccountAddress::from(my_signer.public_key()), + my_signer.public_key(), + )); + let block_store = build_empty_tree_with_custom_signing(my_signer); + let genesis = block_store.root(); + let mut inserter = TreeInserter::new(block_store.clone()); + let block = inserter.insert_block(genesis.as_ref(), 1); + + assert!(block_store.get_quorum_cert_for_block(block.id()).is_none()); + let qc_size = 10; + for (i, voter) in signers.iter().enumerate().take(10).skip(1) { + let vote_msg = VoteMsg::new( + block.id(), + block_store.get_state_for_block(block.id()).unwrap(), + block.round(), + voter.author(), + placeholder_ledger_info(), + voter, + ); + let vote_res = block_on(block_store.insert_vote_and_qc(vote_msg.clone(), qc_size)); + + // first vote of an author is accepted + assert_eq!(vote_res, VoteReceptionResult::VoteAdded(i)); + // filter out duplicates + assert_eq!( + block_on(block_store.insert_vote_and_qc(vote_msg, qc_size)), + VoteReceptionResult::DuplicateVote, + ); + // qc is still not there + assert!(block_store.get_quorum_cert_for_block(block.id()).is_none()); + } + + // Add the final vote to form a QC + let final_voter = &signers[0]; + let vote_msg = VoteMsg::new( + block.id(), + block_store.get_state_for_block(block.id()).unwrap(), + block.round(), + final_voter.author(), + placeholder_ledger_info(), + final_voter, + ); + match block_on(block_store.insert_vote_and_qc(vote_msg, qc_size)) { + VoteReceptionResult::NewQuorumCertificate(qc) => { + assert_eq!(qc.certified_block_id(), block.id()); + } + _ => { + panic!("QC not formed!"); + } + } + + let block_qc = block_store.get_quorum_cert_for_block(block.id()).unwrap(); + assert_eq!(block_qc.certified_block_id(), block.id()); +} + +#[test] +fn test_illegal_timestamp() { + let block_store = build_empty_tree(); + let genesis = block_store.root(); + let block_with_illegal_timestamp = Block::>::new_internal( + vec![], + genesis.id(), + 1, + 1, + // This timestamp is illegal, it is the same as genesis + genesis.timestamp_usecs(), + QuorumCert::certificate_for_genesis(), + block_store.signer(), + ); + let result = block_on(block_store.execute_and_insert_block(block_with_illegal_timestamp)); + assert!(result.is_err()); + assert_eq!(result.err().unwrap(), InsertError::NonIncreasingTimestamp); +} + +#[test] +fn test_highest_qc() { + let block_tree = build_empty_tree(); + let mut inserter = TreeInserter::new(block_tree.clone()); + + // build a tree of the following form + // genesis <- a1 <- a2 <- a3 + let genesis = block_tree.root(); + let a1 = inserter.insert_block(genesis.as_ref(), 1); + assert_eq!(block_tree.highest_certified_block(), genesis.clone()); + let a2 = inserter.insert_block(a1.as_ref(), 2); + assert_eq!(block_tree.highest_certified_block(), a1.clone()); + let _a3 = inserter.insert_block(a2.as_ref(), 3); + assert_eq!(block_tree.highest_certified_block(), a2.clone()); +} + +#[test] +fn test_need_fetch_for_qc() { + let block_tree = build_empty_tree(); + let mut inserter = TreeInserter::new(block_tree.clone()); + + // build a tree of the following form + // genesis <- a1 <- a2 <- a3 + let genesis = block_tree.root(); + let a1 = inserter.insert_block(genesis.as_ref(), 1); + let a2 = inserter.insert_block(a1.as_ref(), 2); + let a3 = inserter.insert_block(a2.as_ref(), 3); + block_on(block_tree.prune_tree(a2.id())); + let need_fetch_qc = placeholder_certificate_for_block( + vec![block_tree.signer().clone()], + HashValue::zero(), + a3.round() + 1, + ); + let too_old_qc = QuorumCert::certificate_for_genesis(); + let can_insert_qc = + placeholder_certificate_for_block(vec![block_tree.signer().clone()], a3.id(), a3.round()); + let duplicate_qc = block_tree.get_quorum_cert_for_block(a2.id()).unwrap(); + assert_eq!( + block_tree.need_fetch_for_quorum_cert(&need_fetch_qc), + NeedFetchResult::NeedFetch + ); + assert_eq!( + block_tree.need_fetch_for_quorum_cert(&too_old_qc), + NeedFetchResult::QCRoundBeforeRoot, + ); + assert_eq!( + block_tree.need_fetch_for_quorum_cert(&can_insert_qc), + NeedFetchResult::QCBlockExist, + ); + assert_eq!( + block_tree.need_fetch_for_quorum_cert(duplicate_qc.as_ref()), + NeedFetchResult::QCAlreadyExist, + ); +} + +#[test] +fn test_need_sync_for_qc() { + let block_tree = build_empty_tree(); + let mut inserter = TreeInserter::new(block_tree.clone()); + + // build a tree of the following form + // genesis <- a1 <- a2 <- a3 + let genesis = block_tree.root(); + let a1 = inserter.insert_block(genesis.as_ref(), 1); + let a2 = inserter.insert_block(a1.as_ref(), 2); + let a3 = inserter.insert_block(a2.as_ref(), 3); + block_on(block_tree.prune_tree(a3.id())); + let qc = placeholder_certificate_for_block( + vec![block_tree.signer().clone()], + HashValue::zero(), + a3.round() + 3, + ); + assert_eq!( + block_tree.need_sync_for_quorum_cert(HashValue::zero(), &qc), + true + ); + let qc = placeholder_certificate_for_block( + vec![block_tree.signer().clone()], + HashValue::zero(), + a3.round() + 2, + ); + assert_eq!( + block_tree.need_sync_for_quorum_cert(HashValue::zero(), &qc), + false, + ); + assert_eq!( + block_tree.need_sync_for_quorum_cert(genesis.id(), &QuorumCert::certificate_for_genesis()), + false + ); +} diff --git a/consensus/src/chained_bft/block_storage/block_tree.rs b/consensus/src/chained_bft/block_storage/block_tree.rs new file mode 100644 index 0000000000000..7ce26cbc24e6e --- /dev/null +++ b/consensus/src/chained_bft/block_storage/block_tree.rs @@ -0,0 +1,464 @@ +// Copyright (c) The Libra Core Contributors +// SPDX-License-Identifier: Apache-2.0 + +use crate::{ + chained_bft::{ + block_storage::{BlockTreeError, VoteReceptionResult}, + consensus_types::{block::Block, quorum_cert::QuorumCert}, + safety::vote_msg::VoteMsg, + }, + counters, + state_replication::{ExecutedState, StateComputeResult}, + time_service::duration_since_epoch, +}; +use canonical_serialization::CanonicalSerialize; +use crypto::HashValue; +use logger::prelude::*; +use mirai_annotations::checked_verify_eq; +use serde::Serialize; +use std::{ + collections::{ + hash_map::Entry::{Occupied, Vacant}, + vec_deque::VecDeque, + HashMap, + }, + fmt::Debug, + sync::Arc, + time::Duration, +}; +use types::ledger_info::LedgerInfoWithSignatures; + +/// This structure maintains a consistent block tree of parent and children links. Blocks contain +/// parent links and are immutable. For all parent links, a child link exists. This structure +/// should only be used internally in BlockStore. +pub struct BlockTree { + /// All the blocks known to this replica (with parent links) + id_to_block: HashMap>>, + /// All child links (i.e. reverse parent links) for easy cleaning. Note that a block may + /// have multiple child links. There should be id_to_blocks.len() - 1 total + /// id_to_child entries. + id_to_child: HashMap>>>, + /// Mapping between proposals(Block) to execution results. + id_to_state: HashMap, + /// Keeps the state compute results of the executed blocks. + /// The state compute results is calculated for all the pending blocks prior to insertion to + /// the tree (the initial root node might not have it, because it's been already + /// committed). The execution results are not persisted: they're recalculated again for the + /// pending blocks upon restart. + id_to_compute_result: HashMap>, + /// Root of the tree. + root: Arc>, + /// A certified block with highest round + highest_certified_block: Arc>, + /// The quorum certificate of highest_certified_block + highest_quorum_cert: Arc, + /// The quorum certificate that carries a highest ledger info + highest_ledger_info: Arc, + + /// `id_to_votes` might keep multiple LedgerInfos per proposed block in order + /// to tolerate non-determinism in execution: given a proposal, a QuorumCertificate is going + /// to be collected only for all the votes that have identical state id. + /// The vote digest is a hash that covers both the proposal id and the state id. + /// Thus, the structure of `id_to_votes` is as follows: + /// HashMap> + id_to_votes: HashMap>, + /// Map of block id to its completed quorum certificate (2f + 1 votes) + id_to_quorum_cert: HashMap>, + /// To keep the IDs of the elements that have been pruned from the tree but not cleaned up yet. + pruned_block_ids: VecDeque, + /// Num pruned blocks to keep in memory. + max_pruned_blocks_in_mem: usize, +} + +impl BlockTree +where + T: Serialize + Default + Debug + CanonicalSerialize, +{ + pub(super) fn new( + root: Block, + root_quorum_cert: QuorumCert, + root_ledger_info: QuorumCert, + max_pruned_blocks_in_mem: usize, + ) -> Self { + assert_eq!( + root.id(), + root_ledger_info + .ledger_info() + .ledger_info() + .consensus_block_id(), + "inconsistent root and ledger info" + ); + let root = Arc::new(root); + let mut id_to_block = HashMap::new(); + id_to_block.insert(root.id(), root.clone()); + counters::NUM_BLOCKS_IN_TREE.set(1); + + let root_quorum_cert = Arc::new(root_quorum_cert); + let mut id_to_quorum_cert = HashMap::new(); + id_to_quorum_cert.insert( + root_quorum_cert.certified_block_id(), + Arc::clone(&root_quorum_cert), + ); + + let mut id_to_state = HashMap::new(); + id_to_state.insert(root.id(), root_quorum_cert.certified_state()); + + let pruned_block_ids = VecDeque::with_capacity(max_pruned_blocks_in_mem); + + BlockTree { + id_to_block, + id_to_child: HashMap::new(), + id_to_state, + id_to_compute_result: HashMap::new(), + root: Arc::clone(&root), + highest_certified_block: Arc::clone(&root), + highest_quorum_cert: Arc::clone(&root_quorum_cert), + highest_ledger_info: Arc::new(root_ledger_info), + id_to_votes: HashMap::new(), + id_to_quorum_cert, + pruned_block_ids, + max_pruned_blocks_in_mem, + } + } + + fn remove_block(&mut self, block_id: HashValue) { + // Delete my child links + self.id_to_child.remove(&block_id); + // Remove the block from the store + self.id_to_block.remove(&block_id); + self.id_to_state.remove(&block_id); + self.id_to_compute_result.remove(&block_id); + self.id_to_votes.remove(&block_id); + self.id_to_quorum_cert.remove(&block_id); + } + + pub(super) fn block_exists(&self, block_id: HashValue) -> bool { + self.id_to_block.contains_key(&block_id) + } + + pub(super) fn get_block(&self, block_id: HashValue) -> Option>> { + self.id_to_block.get(&block_id).cloned() + } + + pub(super) fn get_state_for_block(&self, block_id: HashValue) -> Option { + self.id_to_state.get(&block_id).cloned() + } + + pub(super) fn get_compute_result( + &self, + block_id: HashValue, + ) -> Option> { + self.id_to_compute_result.get(&block_id).cloned() + } + + pub(super) fn root(&self) -> Arc> { + self.root.clone() + } + + pub(super) fn highest_certified_block(&self) -> Arc> { + Arc::clone(&self.highest_certified_block) + } + + pub(super) fn highest_quorum_cert(&self) -> Arc { + Arc::clone(&self.highest_quorum_cert) + } + + pub(super) fn highest_ledger_info(&self) -> Arc { + Arc::clone(&self.highest_ledger_info) + } + + pub(super) fn get_quorum_cert_for_block(&self, block_id: HashValue) -> Option> { + self.id_to_quorum_cert.get(&block_id).cloned() + } + + pub(super) fn is_ancestor( + &self, + ancestor: &Block, + block: &Block, + ) -> Result { + let mut current_block = block; + while current_block.round() >= ancestor.round() { + let parent_id = current_block.parent_id(); + current_block = self + .id_to_block + .get(&parent_id) + .ok_or(BlockTreeError::BlockNotFound { id: parent_id })?; + if current_block.id() == ancestor.id() { + return Ok(true); + } + } + Ok(false) + } + + pub(super) fn insert_block( + &mut self, + block: Block, + state: ExecutedState, + compute_result: StateComputeResult, + ) -> Result>, BlockTreeError> { + if !self.block_exists(block.parent_id()) { + return Err(BlockTreeError::BlockNotFound { + id: block.parent_id(), + }); + } + let block = Arc::new(block); + + match self.id_to_block.get(&block.id()) { + Some(previous_block) => { + debug!("Already had block {:?} for id {:?} when trying to add another block {:?} for the same id", + previous_block, + block.id(), + block); + checked_verify_eq!(*self.id_to_state.get(&block.id()).unwrap(), state); + Ok(previous_block.clone()) + } + _ => { + let children = match self.id_to_child.entry(block.parent_id()) { + Vacant(entry) => entry.insert(Vec::new()), + Occupied(entry) => entry.into_mut(), + }; + children.push(block.clone()); + counters::NUM_BLOCKS_IN_TREE.inc(); + self.id_to_block.insert(block.id(), block.clone()); + self.id_to_state.insert(block.id(), state); + self.id_to_compute_result + .insert(block.id(), Arc::new(compute_result)); + Ok(block) + } + } + } + + pub(super) fn insert_quorum_cert(&mut self, qc: QuorumCert) -> Result<(), BlockTreeError> { + let block_id = qc.certified_block_id(); + let qc = Arc::new(qc); + match self.id_to_block.get(&block_id) { + Some(block) => { + if block.round() > self.highest_certified_block.round() { + self.highest_certified_block = block.clone(); + self.highest_quorum_cert = Arc::clone(&qc); + } + } + None => return Err(BlockTreeError::BlockNotFound { id: block_id }), + } + + self.id_to_quorum_cert + .entry(block_id) + .or_insert_with(|| Arc::clone(&qc)); + + let committed_block_id = qc.ledger_info().ledger_info().consensus_block_id(); + if let Some(block) = self.id_to_block.get(&committed_block_id) { + if block.round() + > self + .id_to_block + .get( + &self + .highest_ledger_info + .ledger_info() + .ledger_info() + .consensus_block_id(), + ) + .expect("Highest ledger info's block should exist") + .round() + { + self.highest_ledger_info = qc; + } + } + Ok(()) + } + + pub(super) fn insert_vote( + &mut self, + vote_msg: &VoteMsg, + min_votes_for_qc: usize, + ) -> VoteReceptionResult { + let block_id = vote_msg.proposed_block_id(); + if let Some(old_qc) = self.id_to_quorum_cert.get(&block_id) { + return VoteReceptionResult::OldQuorumCertificate(Arc::clone(old_qc)); + } + + // All the votes collected for all the execution results of a given proposal. + let block_votes = self + .id_to_votes + .entry(block_id) + .or_insert_with(HashMap::new); + + // Note that the digest covers not just the proposal id, but also the resulting + // state id as well as the round number. In other words, if two different voters have the + // same digest then they reached the same state following the same proposals. + let digest = vote_msg.vote_hash(); + let li_with_sig = block_votes.entry(digest).or_insert_with(|| { + LedgerInfoWithSignatures::new(vote_msg.ledger_info().clone(), HashMap::new()) + }); + let author = vote_msg.author(); + if li_with_sig.signatures().contains_key(&author) { + return VoteReceptionResult::DuplicateVote; + } + li_with_sig.add_signature(author, vote_msg.signature().clone()); + + let num_votes = li_with_sig.signatures().len(); + if num_votes >= min_votes_for_qc { + let quorum_cert = QuorumCert::new( + block_id, + vote_msg.executed_state(), + vote_msg.round(), + li_with_sig.clone(), + ); + // Note that the block might not be present locally, in which case we cannot calculate + // time between block creation and qc + if let Some(block) = self.get_block(block_id) { + if let Some(time_to_qc) = duration_since_epoch() + .checked_sub(Duration::from_micros(block.timestamp_usecs())) + { + counters::CREATION_TO_QC_MS.observe(time_to_qc.as_millis() as f64); + } + } + return VoteReceptionResult::NewQuorumCertificate(Arc::new(quorum_cert)); + } + VoteReceptionResult::VoteAdded(num_votes) + } + + /// Find the blocks to prune up to next_root_id (keep next_root_id's block). Any branches not + /// part of the next_root_id's tree should be removed as well. + /// + /// For example, root = B_0 + /// B_0 -> B_1 -> B_2 + /// | -> B_3 -> B4 + /// + /// prune_tree(B_3) should be left with + /// B_3 -> B_4, root = B_3 + /// + /// Note this function is read-only, use with process_pruned_blocks to do the actual prune. + pub(super) fn find_blocks_to_prune(&self, next_root_id: HashValue) -> VecDeque { + // Nothing to do if this is the root + if next_root_id == self.root.id() { + return VecDeque::new(); + } + + let mut blocks_pruned = VecDeque::new(); + let mut blocks_to_be_pruned = Vec::new(); + blocks_to_be_pruned.push(self.root.clone()); + while let Some(block_to_remove) = blocks_to_be_pruned.pop() { + // Add the children to the blocks to be pruned (if any), but stop when it reaches the + // new root + if let Some(children) = self.id_to_child.get(&block_to_remove.id()) { + for child in children { + if next_root_id == child.id() { + continue; + } + + blocks_to_be_pruned.push(child.clone()); + } + } + // Track all the block ids removed + blocks_pruned.push_back(block_to_remove.id()); + } + blocks_pruned + } + + /// Process the data returned by the prune_tree, they're separated because caller might + /// be interested in doing extra work e.g. delete from persistent storage. + /// Note that we do not necessarily remove the pruned blocks: they're kept in a separate buffer + /// for some time in order to enable other peers to retrieve the blocks even after they've + /// been committed. + pub(super) fn process_pruned_blocks( + &mut self, + root_id: HashValue, + mut newly_pruned_blocks: VecDeque, + ) { + // Update the next root + self.root = self + .id_to_block + .get(&root_id) + .expect("next_root_id must exist") + .clone(); + + counters::NUM_BLOCKS_IN_TREE.sub(newly_pruned_blocks.len() as i64); + // The newly pruned blocks are pushed back to the deque pruned_block_ids. + // In case the overall number of the elements is greater than the predefined threshold, + // the oldest elements (in the front of the deque) are removed from the tree. + self.pruned_block_ids.append(&mut newly_pruned_blocks); + if self.pruned_block_ids.len() > self.max_pruned_blocks_in_mem { + let num_blocks_to_remove = self.pruned_block_ids.len() - self.max_pruned_blocks_in_mem; + for _ in 0..num_blocks_to_remove { + if let Some(id) = self.pruned_block_ids.pop_front() { + self.remove_block(id); + } + } + } + } + + /// Returns all the blocks between the root and the given block, including the given block + /// but excluding the root. + /// In case a given block is not the successor of the root, return None. + /// While generally the provided blocks should always belong to the active tree, there might be + /// a race, in which the root of the tree is propagated forward between retrieving the block + /// and getting its path from root (e.g., at proposal generator). Hence, we don't want to panic + /// and prefer to return None instead. + pub(super) fn path_from_root(&self, block: Arc>) -> Option>>> { + let mut res = vec![]; + let mut cur_block = block; + while cur_block.round() > self.root.round() { + res.push(Arc::clone(&cur_block)); + cur_block = match self.get_block(cur_block.parent_id()) { + None => { + return None; + } + Some(b) => b, + }; + } + // At this point cur_block.round() <= self.root.round() + if cur_block.id() != self.root.id() { + return None; + } + Some(res) + } + + pub(super) fn max_pruned_blocks_in_mem(&self) -> usize { + self.max_pruned_blocks_in_mem + } + + pub(super) fn get_all_block_id(&self) -> Vec { + self.id_to_block.keys().cloned().collect() + } + + #[allow(dead_code)] + fn print_all_blocks(&self) { + println!("Printing all {} blocks", self.id_to_block.len()); + for block in self.id_to_block.values() { + println!("{:?}", block); + } + } +} + +#[cfg(test)] +impl BlockTree +where + T: Serialize + Default + Debug + CanonicalSerialize, +{ + /// Returns the number of blocks in the tree + pub(super) fn len(&self) -> usize { + // BFS over the tree to find the number of blocks in the tree. + let mut res = 0; + let mut to_visit = Vec::new(); + to_visit.push(Arc::clone(&self.root)); + while let Some(block) = to_visit.pop() { + res += 1; + if let Some(children) = self.id_to_child.get(&block.id()) { + for child in children { + to_visit.push(Arc::clone(&child)); + } + } + } + res + } + + /// Returns the number of child links in the tree + pub(super) fn child_links(&self) -> usize { + self.len() - 1 + } + + /// The number of pruned blocks that are still available in memory + pub(super) fn pruned_blocks_in_mem(&self) -> usize { + self.pruned_block_ids.len() + } +} diff --git a/consensus/src/chained_bft/block_storage/mod.rs b/consensus/src/chained_bft/block_storage/mod.rs new file mode 100644 index 0000000000000..c6775273a05a1 --- /dev/null +++ b/consensus/src/chained_bft/block_storage/mod.rs @@ -0,0 +1,217 @@ +// Copyright (c) The Libra Core Contributors +// SPDX-License-Identifier: Apache-2.0 + +use crate::chained_bft::{ + common::{Author, Round}, + consensus_types::{block::Block, quorum_cert::QuorumCert}, +}; +use crypto::HashValue; +use std::sync::Arc; + +mod block_inserter; +mod block_store; +mod block_tree; + +use crate::{ + chained_bft::safety::vote_msg::VoteMsgVerificationError, + state_replication::{ExecutedState, StateComputeResult}, +}; +pub use block_store::{BlockStore, NeedFetchResult}; +use network::protocols::rpc::error::RpcError; +use types::validator_verifier::VerifyError; + +#[derive(Debug, PartialEq, Fail)] +/// The possible reasons for failing to retrieve a block by id from a given peer. +#[allow(dead_code)] +#[derive(Clone)] +pub enum BlockRetrievalFailure { + /// Could not find a given author + #[fail(display = "Unknown author: {:?}", author)] + UnknownAuthor { author: Author }, + + /// Any sort of a network failure (should probably have an enum for network failures). + #[fail(display = "Network failure: {:?}", msg)] + NetworkFailure { msg: String }, + + /// The remote peer did not recognize the given block id. + #[fail(display = "Block id {:?} not recognized by the peer", block_id)] + UnknownBlockId { block_id: HashValue }, + + /// Cannot retrieve a block from itself + #[fail(display = "Attempt of a self block retrieval.")] + SelfRetrieval, + + /// The event is not correctly signed. + #[fail(display = "InvalidSignature")] + InvalidSignature, + + /// The response is not valid: status doesn't match blocks, blocks unable to deserialize etc. + #[fail(display = "InvalidResponse")] + InvalidResponse, +} + +#[allow(dead_code)] +#[derive(Clone, Debug, PartialEq, Fail)] +/// Status after trying to insert a block into the BlockStore +pub enum InsertError { + /// The parent block does not exist, hence not inserting this block + #[fail(display = "MissingParentBlock")] + MissingParentBlock(HashValue), + /// The block hash is invalid + #[fail(display = "InvalidBlockHash")] + InvalidBlockHash, + /// The block height is not parent's height + 1. + #[fail(display = "InvalidBlockHeight")] + InvalidBlockHeight, + /// The block round is not greater than that of the parent. + #[fail(display = "InvalidBlockRound")] + InvalidBlockRound, + /// The block's timestamp is not greater than that of the parent. + #[fail(display = "InvalidTiemstamp")] + NonIncreasingTimestamp, + /// The block is not newer than the root of the tree. + #[fail(display = "OldBlock")] + OldBlock, + /// The event is from unknown an unknown author. + #[fail(display = "UnknownAuthor")] + UnknownAuthor, + /// The event is not correctly signed. + #[fail(display = "InvalidSignature")] + InvalidSignature, + /// The external state computer failure. + #[fail(display = "StateComputeError")] + StateComputerError, + /// Block's parent is not certified with the QC carried by the block. + #[fail(display = "ParentNotCertified")] + ParentNotCertified, + /// State version corresponding to block's parent not found. + #[fail(display = "ParentVersionNotFound")] + ParentVersionNotFound, + /// Some of the block's ancestors could not be retrieved. + #[fail(display = "AncestorRetrievalError")] + AncestorRetrievalError, + #[fail(display = "StorageFailure")] + StorageFailure, +} + +impl From for BlockRetrievalFailure { + fn from(source: RpcError) -> Self { + BlockRetrievalFailure::NetworkFailure { + msg: source.to_string(), + } + } +} + +impl From for InsertError { + fn from(error: VerifyError) -> Self { + match error { + VerifyError::UnknownAuthor => InsertError::UnknownAuthor, + VerifyError::InvalidSignature => InsertError::InvalidSignature, + VerifyError::TooFewSignatures { .. } => InsertError::InvalidSignature, + VerifyError::TooManySignatures { .. } => InsertError::InvalidSignature, + } + } +} + +impl From for InsertError { + fn from(_error: VoteMsgVerificationError) -> Self { + InsertError::InvalidSignature + } +} + +/// Result of the vote processing. The failure case (Verification error) is returned +/// as the Error part of the result. +#[derive(Debug, PartialEq)] +pub enum VoteReceptionResult { + /// The vote has been added but QC has not been formed yet. Return the number of votes for + /// the given (proposal, execution) pair. + VoteAdded(usize), + /// The very same vote message has been processed in past. + DuplicateVote, + /// This block has been already certified. + OldQuorumCertificate(Arc), + /// This block has just been certified after adding the vote. + NewQuorumCertificate(Arc), +} + +#[derive(Debug, Fail)] +/// Tree query error types. +pub enum BlockTreeError { + #[fail(display = "Block not found: {:?}", id)] + BlockNotFound { id: HashValue }, +} + +impl From for InsertError { + fn from(error: BlockTreeError) -> InsertError { + match error { + BlockTreeError::BlockNotFound { id } => InsertError::MissingParentBlock(id), + } + } +} + +pub trait BlockReader: Send + Sync { + type Payload; + + /// Check if a block with the block_id exist in the BlockTree. + fn block_exists(&self, block_id: HashValue) -> bool; + + /// Try to get a block with the block_id, return an Arc of it if found. + fn get_block(&self, block_id: HashValue) -> Option>>; + + /// Try to get a state id (HashValue) of the system corresponding to block execution. + fn get_state_for_block(&self, block_id: HashValue) -> Option; + + /// Try to get an execution result given the specified block id. + fn get_compute_result(&self, block_id: HashValue) -> Option>; + + /// Get the current root block of the BlockTree. + fn root(&self) -> Arc>; + + fn get_quorum_cert_for_block(&self, block_id: HashValue) -> Option>; + + /// Returns true if a given "ancestor" block is an ancestor of a given "block". + /// Returns a failure if not all the blocks are present between the block's height and the + /// parent's height. + fn is_ancestor( + &self, + ancestor: &Block, + block: &Block, + ) -> Result; + + /// Returns all the blocks between the root and the given block, including the given block + /// but excluding the root. + /// In case a given block is not the successor of the root, return None. + /// For example if a tree is b0 <- b1 <- b2 <- b3, then + /// path_from_root(b2) -> Some([b2, b1]) + /// path_from_root(b0) -> Some([]) + /// path_from_root(a) -> None + fn path_from_root( + &self, + block: Arc>, + ) -> Option>>>; + + /// Generates and returns a block with the given parent and payload. + /// Note that it does not add the block to the tree, just generates it. + /// The main reason we want this function in the BlockStore is the fact that the signer required + /// for signing the newly created block is held by the block store. + /// The function panics in the following cases: + /// * If the parent or its quorum certificate are not present in the tree, + /// * If the given round (which is typically calculated by Pacemaker) is not greater than that + /// of a parent. + fn create_block( + &self, + parent: Arc>, + payload: Self::Payload, + round: Round, + timestamp_usecs: u64, + ) -> Block; + + /// Return the certified block with the highest round. + fn highest_certified_block(&self) -> Arc>; + + /// Return the quorum certificate with the highest round + fn highest_quorum_cert(&self) -> Arc; + + /// Return the quorum certificate that carries ledger info with the highest round + fn highest_ledger_info(&self) -> Arc; +} diff --git a/consensus/src/chained_bft/chained_bft_consensus_provider.rs b/consensus/src/chained_bft/chained_bft_consensus_provider.rs new file mode 100644 index 0000000000000..6f6bedd7cda73 --- /dev/null +++ b/consensus/src/chained_bft/chained_bft_consensus_provider.rs @@ -0,0 +1,168 @@ +// Copyright (c) The Libra Core Contributors +// SPDX-License-Identifier: Apache-2.0 + +use crate::{ + chained_bft::{ + chained_bft_smr::ChainedBftSMR, network::ConsensusNetworkImpl, + persistent_storage::PersistentStorage, + }, + consensus_provider::ConsensusProvider, + state_computer::ExecutionProxy, + state_replication::StateMachineReplication, + txn_manager::MempoolProxy, +}; +use network::validator_network::{ConsensusNetworkEvents, ConsensusNetworkSender}; + +use crate::{ + chained_bft::{ + chained_bft_smr::ChainedBftSMRConfig, common::Author, persistent_storage::StorageWriteProxy, + }, + state_synchronizer::{setup_state_synchronizer, StateSynchronizer}, +}; +use config::config::{ConsensusProposerType::FixedProposer, NodeConfig}; +use execution_proto::proto::execution_grpc::ExecutionClient; +use failure::prelude::*; +use logger::prelude::*; +use mempool::proto::mempool_grpc::MempoolClient; +use std::{convert::TryFrom, sync::Arc}; +use tokio::runtime; +use types::{ + account_address::AccountAddress, transaction::SignedTransaction, + validator_signer::ValidatorSigner, validator_verifier::ValidatorVerifier, +}; + +struct InitialSetup { + author: Author, + signer: ValidatorSigner, + quorum_size: usize, + peers: Arc>, + validator: Arc, +} + +/// Supports the implementation of ConsensusProvider using LibraBFT. +pub struct ChainedBftProvider { + smr: ChainedBftSMR, Author>, + mempool_client: Arc, + execution_client: Arc, + synchronizer_client: Arc, +} + +impl ChainedBftProvider { + pub fn new( + node_config: &NodeConfig, + network_sender: ConsensusNetworkSender, + network_events: ConsensusNetworkEvents, + mempool_client: Arc, + execution_client: Arc, + ) -> Self { + let runtime = runtime::Builder::new() + .name_prefix("consensus-") + .build() + .expect("Failed to create Tokio runtime!"); + + let initial_setup = Self::initialize_setup(node_config); + let network = ConsensusNetworkImpl::new( + initial_setup.author, + network_sender.clone(), + network_events, + Arc::clone(&initial_setup.peers), + Arc::clone(&initial_setup.validator), + ); + let synchronizer = + setup_state_synchronizer(network_sender, runtime.executor(), node_config); + let proposer = { + if node_config.consensus.get_proposer_type() == FixedProposer { + vec![Self::choose_leader(&initial_setup)] + } else { + initial_setup.validator.get_ordered_account_addresses() + } + }; + debug!("[Consensus] My peer: {:?}", initial_setup.author); + debug!("[Consensus] Chosen proposer: {:?}", proposer); + let config = ChainedBftSMRConfig::from_node_config(&node_config.consensus); + let (storage, initial_data) = StorageWriteProxy::start(node_config); + info!( + "Starting up the consensus state machine with recovery data - {:?}, {:?}", + initial_data.state(), + initial_data.highest_timeout_certificates() + ); + let smr = ChainedBftSMR::new( + initial_setup.author, + initial_setup.quorum_size, + initial_setup.signer, + proposer, + network, + runtime, + config, + storage, + initial_data, + ); + Self { + smr, + mempool_client, + execution_client, + synchronizer_client: Arc::new(synchronizer), + } + } + + /// Retrieve the initial "state" for consensus. This function is synchronous and returns after + /// reading the local persistent store and retrieving the initial state from the executor. + fn initialize_setup(node_config: &NodeConfig) -> InitialSetup { + // Keeping the initial set of validators in a node config is embarrassing and we should + // all feel bad about it. + let peer_id_str = node_config.base.peer_id.clone(); + let author = + AccountAddress::try_from(peer_id_str).expect("Failed to parse peer id of a validator"); + let private_key = node_config.base.peer_keypairs.get_consensus_private(); + let public_key = node_config.base.peer_keypairs.get_consensus_public(); + let signer = ValidatorSigner::new(author, public_key, private_key); + let peers_with_public_keys = node_config.base.trusted_peers.get_trusted_consensus_peers(); + let peers = Arc::new( + peers_with_public_keys + .keys() + .map(AccountAddress::clone) + .collect(), + ); + let quorum_size = peers_with_public_keys.len() * 2 / 3 + 1; + let validator = Arc::new(ValidatorVerifier::new( + peers_with_public_keys.clone(), + quorum_size, + )); + debug!("[Consensus]: quorum_size = {:?}", quorum_size); + InitialSetup { + author, + signer, + quorum_size, + peers, + validator, + } + } + + /// Choose a proposer that is going to be the single leader (relevant for a mock fixed proposer + /// election only). + fn choose_leader(initial_setup: &InitialSetup) -> Author { + // As it is just a tmp hack function, pick the smallest PeerId to be a proposer. + *initial_setup + .peers + .iter() + .max() + .expect("No trusted peers found!") + } +} + +impl ConsensusProvider for ChainedBftProvider { + fn start(&mut self) -> Result<()> { + let txn_manager = Arc::new(MempoolProxy::new(self.mempool_client.clone())); + let state_computer = Arc::new(ExecutionProxy::new( + self.execution_client.clone(), + self.synchronizer_client.clone(), + )); + debug!("Starting consensus provider."); + self.smr.start(txn_manager, state_computer) + } + + fn stop(&mut self) { + self.smr.stop(); + debug!("Consensus provider stopped."); + } +} diff --git a/consensus/src/chained_bft/chained_bft_smr.rs b/consensus/src/chained_bft/chained_bft_smr.rs new file mode 100644 index 0000000000000..6c3e4a4929cfb --- /dev/null +++ b/consensus/src/chained_bft/chained_bft_smr.rs @@ -0,0 +1,527 @@ +// Copyright (c) The Libra Core Contributors +// SPDX-License-Identifier: Apache-2.0 + +use crate::{ + chained_bft::{ + block_storage::{BlockReader, BlockStore}, + common::{Payload, Round}, + event_processor::{EventProcessor, ProcessProposalResult}, + liveness::{ + local_pacemaker::{ExponentialTimeInterval, LocalPacemaker}, + new_round_msg::NewRoundMsg, + pacemaker::{NewRoundEvent, Pacemaker, PacemakerEvent}, + pacemaker_timeout_manager::HighestTimeoutCertificates, + proposal_generator::ProposalGenerator, + proposer_election::{ProposalInfo, ProposerElection, ProposerInfo}, + rotating_proposer_election::RotatingProposer, + }, + network::{ + BlockRetrievalRequest, ChunkRetrievalRequest, ConsensusNetworkImpl, NetworkReceivers, + }, + persistent_storage::{PersistentLivenessStorage, PersistentStorage, RecoveryData}, + safety::{safety_rules::SafetyRules, vote_msg::VoteMsg}, + }, + counters, + state_replication::{StateComputer, StateMachineReplication, TxnManager}, + state_synchronizer::SyncStatus, + stream_utils::start_event_processing_loop, + time_service::{ClockTimeService, TimeService}, +}; +use failure::prelude::*; +use futures::{ + channel::mpsc, + compat::Future01CompatExt, + executor::block_on, + future::{FutureExt, TryFutureExt}, + stream::StreamExt, +}; +use types::validator_signer::ValidatorSigner; + +use config::config::ConsensusConfig; +use futures::SinkExt; +use logger::prelude::*; +use std::{ + sync::{Arc, RwLock}, + thread, + time::{Duration, Instant}, +}; +use tokio::runtime::{Runtime, TaskExecutor}; + +type ConcurrentEventProcessor = Arc>>; + +/// Consensus configuration derived from ConsensusConfig +pub struct ChainedBftSMRConfig { + /// Keep up to this number of committed blocks before cleaning them up from the block store. + pub max_pruned_blocks_in_mem: usize, + /// Initial timeout for pacemaker + pub pacemaker_initial_timeout: Duration, + /// Contiguous rounds for proposer + pub contiguous_rounds: u32, + /// Max block size (number of transactions) that consensus pulls from mempool + pub max_block_size: u64, +} + +impl ChainedBftSMRConfig { + pub fn from_node_config(cfg: &ConsensusConfig) -> ChainedBftSMRConfig { + let pacemaker_initial_timeout_ms = cfg.pacemaker_initial_timeout_ms().unwrap_or(1000); + ChainedBftSMRConfig { + max_pruned_blocks_in_mem: cfg.max_pruned_blocks_in_mem().unwrap_or(10000) as usize, + pacemaker_initial_timeout: Duration::from_millis(pacemaker_initial_timeout_ms), + contiguous_rounds: cfg.contiguous_rounds(), + max_block_size: cfg.max_block_size(), + } + } +} + +/// ChainedBFTSmr is the one to generate the components (BLockStore, Proposer, etc.) and start the +/// driver. ChainedBftSMR implements the StateMachineReplication, it is going to be used by +/// ConsensusProvider for the e2e flow. +pub struct ChainedBftSMR { + author: P, + // TODO [Reconfiguration] quorum size is just a function of current validator set. + quorum_size: usize, + signer: ValidatorSigner, + proposers: Vec

, + runtime: Option, + block_store: Option>>, + network: ConsensusNetworkImpl, + config: ChainedBftSMRConfig, + storage: Arc>, + initial_data: Option>, +} + +#[allow(dead_code)] +impl ChainedBftSMR { + pub fn new( + author: P, + quorum_size: usize, + signer: ValidatorSigner, + proposers: Vec

, + network: ConsensusNetworkImpl, + runtime: Runtime, + config: ChainedBftSMRConfig, + storage: Arc>, + initial_data: RecoveryData, + ) -> Self { + Self { + author, + quorum_size, + signer, + proposers, + runtime: Some(runtime), + block_store: None, + network, + config, + storage, + initial_data: Some(initial_data), + } + } + + pub fn block_store(&self) -> Option>> { + self.block_store.clone() + } + + fn create_pacemaker( + &self, + persistent_liveness_storage: Box, + highest_committed_round: Round, + highest_certified_round: Round, + highest_timeout_certificates: HighestTimeoutCertificates, + time_service: Arc, + pacemaker_timeout_sender: channel::Sender, + ) -> Arc { + // 1.5^6 ~= 11 + // Timeout goes from initial_timeout to initial_timeout*11 in 6 steps + let time_interval = Box::new(ExponentialTimeInterval::new( + self.config.pacemaker_initial_timeout, + 1.5, + 6, + )); + Arc::new(LocalPacemaker::new( + persistent_liveness_storage, + time_interval, + highest_committed_round, + highest_certified_round, + time_service, + pacemaker_timeout_sender, + self.quorum_size, + highest_timeout_certificates, + )) + } + + /// Create a proposer election handler based on proposers + fn create_proposer_election(&self) -> Arc + Send + Sync> { + assert!(!self.proposers.is_empty()); + Arc::new(RotatingProposer::new( + self.proposers.clone(), + self.config.contiguous_rounds, + )) + } + + async fn process_new_round_events( + mut receiver: mpsc::Receiver, + event_processor: ConcurrentEventProcessor, + ) { + while let Some(new_round_event) = receiver.next().await { + let guard = event_processor.read().compat().await.unwrap(); + guard.process_new_round_event(new_round_event).await; + } + } + + async fn process_proposals( + executor: TaskExecutor, + mut receiver: channel::Receiver>, + event_processor: ConcurrentEventProcessor, + ) { + while let Some(proposal_info) = receiver.next().await { + let guard = event_processor.read().compat().await.unwrap(); + match guard.process_proposal(proposal_info).await { + ProcessProposalResult::Done => (), + // Spawn a new task that would start retrieving the missing + // blocks in the background. + ProcessProposalResult::NeedFetch(deadline, proposal) => executor.spawn( + Self::fetch_and_process_proposal( + Arc::clone(&event_processor), + deadline, + proposal, + ) + .boxed() + .unit_error() + .compat(), + ), + // Spawn a new task that would start state synchronization + // in the background. + ProcessProposalResult::NeedSync(deadline, proposal) => executor.spawn( + Self::sync_and_process_proposal( + Arc::clone(&event_processor), + deadline, + proposal, + ) + .boxed() + .unit_error() + .compat(), + ), + } + } + } + + async fn fetch_and_process_proposal( + event_processor: ConcurrentEventProcessor, + deadline: Instant, + proposal: ProposalInfo, + ) { + let guard = event_processor.read().compat().await.unwrap(); + guard.fetch_and_process_proposal(deadline, proposal).await + } + + async fn sync_and_process_proposal( + event_processor: ConcurrentEventProcessor, + deadline: Instant, + proposal: ProposalInfo, + ) { + let mut guard = event_processor.write().compat().await.unwrap(); + guard.sync_and_process_proposal(deadline, proposal).await + } + + async fn process_winning_proposals( + mut receiver: mpsc::Receiver>, + event_processor: ConcurrentEventProcessor, + ) { + while let Some(proposal_info) = receiver.next().await { + let guard = event_processor.read().compat().await.unwrap(); + guard.process_winning_proposal(proposal_info).await; + } + } + + async fn process_votes( + mut receiver: channel::Receiver, + event_processor: ConcurrentEventProcessor, + quorum_size: usize, + ) { + while let Some(vote) = receiver.next().await { + let guard = event_processor.read().compat().await.unwrap(); + guard.process_vote(vote, quorum_size).await; + } + } + + async fn process_new_round_msg( + mut receiver: channel::Receiver, + event_processor: ConcurrentEventProcessor, + mut sender: mpsc::Sender, + ) { + while let Some(new_round_msg) = receiver.next().await { + let pacemaker_timeout = new_round_msg.pacemaker_timeout().clone(); + let mut guard = event_processor.write().compat().await.unwrap(); + guard.process_new_round_msg(new_round_msg).await; + if let Err(e) = sender + .send(PacemakerEvent::RemoteTimeout { pacemaker_timeout }) + .await + { + error!("Failed to send event to pacemaker {:?}", e); + return; + } + } + } + + async fn process_outgoing_pacemaker_timeouts( + mut receiver: channel::Receiver, + event_processor: ConcurrentEventProcessor, + mut network: ConsensusNetworkImpl, + ) { + while let Some(round) = receiver.next().await { + // Update the last voted round and generate the timeout message + let guard = event_processor.read().compat().await.unwrap(); + let timeout_msg = guard.process_outgoing_pacemaker_timeout(round).await; + match timeout_msg { + Some(timeout_msg) => { + network.broadcast_new_round(timeout_msg).await; + } + None => { + info!("Broadcast not sent as the processing of the timeout failed. Will retry again on the next timeout."); + } + } + } + } + + async fn process_block_retrievals( + mut receiver: channel::Receiver>, + event_processor: ConcurrentEventProcessor, + ) { + while let Some(request) = receiver.next().await { + let guard = event_processor.read().compat().await.unwrap(); + guard.process_block_retrieval(request).await; + } + } + + async fn process_chunk_retrievals( + mut receiver: channel::Receiver, + event_processor: ConcurrentEventProcessor, + ) { + while let Some(request) = receiver.next().await { + let guard = event_processor.read().compat().await.unwrap(); + guard.process_chunk_retrieval(request).await; + } + } + + fn start_event_processing( + &self, + event_processor: ConcurrentEventProcessor, + executor: TaskExecutor, + new_round_events_receiver: mpsc::Receiver, + proposal_winners_receiver: mpsc::Receiver>, + network_receivers: NetworkReceivers, + pm_events_sender: mpsc::Sender, + pacemaker_timeout_sender_rx: channel::Receiver, + ) { + executor.spawn( + Self::process_new_round_events(new_round_events_receiver, event_processor.clone()) + .boxed() + .unit_error() + .compat(), + ); + + executor.spawn( + Self::process_proposals( + executor.clone(), + network_receivers.proposals, + event_processor.clone(), + ) + .boxed() + .unit_error() + .compat(), + ); + + executor.spawn( + Self::process_winning_proposals(proposal_winners_receiver, event_processor.clone()) + .boxed() + .unit_error() + .compat(), + ); + + executor.spawn( + Self::process_block_retrievals( + network_receivers.block_retrieval, + event_processor.clone(), + ) + .boxed() + .unit_error() + .compat(), + ); + + executor.spawn( + Self::process_chunk_retrievals( + network_receivers.chunk_retrieval, + event_processor.clone(), + ) + .boxed() + .unit_error() + .compat(), + ); + + executor.spawn( + Self::process_votes( + network_receivers.votes, + event_processor.clone(), + self.quorum_size, + ) + .boxed() + .unit_error() + .compat(), + ); + + executor.spawn( + Self::process_new_round_msg( + network_receivers.new_rounds, + event_processor.clone(), + pm_events_sender, + ) + .boxed() + .unit_error() + .compat(), + ); + + executor.spawn( + Self::process_outgoing_pacemaker_timeouts( + pacemaker_timeout_sender_rx, + event_processor.clone(), + self.network.clone(), + ) + .boxed() + .unit_error() + .compat(), + ); + } +} + +impl StateMachineReplication for ChainedBftSMR { + type Payload = T; + + fn start( + &mut self, + txn_manager: Arc>, + state_computer: Arc>, + ) -> Result<()> { + let executor = self + .runtime + .as_mut() + .expect("Consensus start: No valid runtime found!") + .executor(); + let time_service = Arc::new(ClockTimeService::new(executor.clone())); + + // We first start the network and retrieve the network receivers (this function needs a + // mutable reference). + // Must do it here before giving the clones of network to other components. + let network_receivers = self.network.start(&executor); + let initial_data = self + .initial_data + .take() + .expect("already started, initial data is None"); + let consensus_state = initial_data.state(); + let highest_timeout_certificates = initial_data.highest_timeout_certificates().clone(); + if initial_data.need_sync() { + loop { + // make sure we sync to the root state in case we're not + let status = block_on(state_computer.sync_to(initial_data.root_ledger_info())); + match status { + Ok(SyncStatus::Finished) => break, + Ok(SyncStatus::DownloadFailed) => { + warn!("DownloadFailed, we may not establish connection with peers yet, sleep and retry"); + // we can remove this when we start to handle NewPeer/LostPeer events. + thread::sleep(Duration::from_secs(2)); + } + Ok(e) => panic!( + "state synchronizer failure: {:?}, this validator will be killed as it can not \ + recover from this error. After the validator is restarted, synchronization will \ + be retried.", + e + ), + Err(e) => panic!( + "state synchronizer failure: {:?}, this validator will be killed as it can not \ + recover from this error. After the validator is restarted, synchronization will \ + be retried.", + e + ), + } + } + } + + let block_store = Arc::new(block_on(BlockStore::new( + Arc::clone(&self.storage), + initial_data, + self.signer.clone(), + Arc::clone(&state_computer), + true, + self.config.max_pruned_blocks_in_mem, + ))); + self.block_store = Some(Arc::clone(&block_store)); + + // txn manager is required both by proposal generator (to pull the proposers) + // and by event processor (to update their status). + let proposal_generator = ProposalGenerator::new( + block_store.clone(), + Arc::clone(&txn_manager), + time_service.clone(), + self.config.max_block_size, + true, + ); + + let safety_rules = Arc::new(RwLock::new(SafetyRules::new( + block_store.clone(), + consensus_state, + ))); + + let (pacemaker_timeout_sender_tx, pacemaker_timeout_sender_rx) = + channel::new(1_024, &counters::PENDING_PACEMAKER_TIMEOUTS); + let mut pacemaker = self.create_pacemaker( + self.storage.persistent_liveness_storage(), + safety_rules.read().unwrap().last_committed_round(), + block_store.highest_certified_block().round(), + highest_timeout_certificates, + time_service.clone(), + pacemaker_timeout_sender_tx, + ); + let (pm_events_sender, new_round_events_receiver) = + start_event_processing_loop(&mut pacemaker, executor.clone()); + + let mut proposer_election = self.create_proposer_election(); + let (proposal_candidates_sender, proposal_winners_receiver) = + start_event_processing_loop(&mut proposer_election, executor.clone()); + let event_processor = Arc::new(futures_locks::RwLock::new(EventProcessor::new( + self.author, + Arc::clone(&block_store), + Arc::clone(&pacemaker), + Arc::clone(&proposer_election), + pm_events_sender.clone(), + proposal_candidates_sender, + proposal_generator, + safety_rules, + state_computer, + txn_manager, + self.network.clone(), + Arc::clone(&self.storage), + time_service.clone(), + true, + ))); + + self.start_event_processing( + event_processor, + executor.clone(), + new_round_events_receiver, + proposal_winners_receiver, + network_receivers, + pm_events_sender.clone(), + pacemaker_timeout_sender_rx, + ); + + debug!("Chained BFT SMR started."); + Ok(()) + } + + /// Stop is synchronous: waits for all the worker threads to terminate. + fn stop(&mut self) { + if let Some(rt) = self.runtime.take() { + block_on(rt.shutdown_now().compat()).unwrap(); + debug!("Chained BFT SMR stopped.") + } + } +} diff --git a/consensus/src/chained_bft/chained_bft_smr_test.rs b/consensus/src/chained_bft/chained_bft_smr_test.rs new file mode 100644 index 0000000000000..ba17b1c9b0e13 --- /dev/null +++ b/consensus/src/chained_bft/chained_bft_smr_test.rs @@ -0,0 +1,630 @@ +// Copyright (c) The Libra Core Contributors +// SPDX-License-Identifier: Apache-2.0 + +use crate::{ + chained_bft::{ + block_storage::BlockReader, + chained_bft_smr::{ChainedBftSMR, ChainedBftSMRConfig}, + common::Author, + liveness::proposer_election::ProposalInfo, + network::ConsensusNetworkImpl, + network_tests::NetworkPlayground, + safety::vote_msg::VoteMsg, + test_utils::{MockStateComputer, MockStorage, MockTransactionManager, TestPayload}, + }, + state_replication::StateMachineReplication, +}; +use channel; +use crypto::hash::CryptoHash; +use futures::{channel::mpsc, executor::block_on, prelude::*}; +use network::validator_network::{ConsensusNetworkEvents, ConsensusNetworkSender}; +use proto_conv::FromProto; +use std::sync::Arc; +use types::{validator_signer::ValidatorSigner, validator_verifier::ValidatorVerifier}; + +use crate::chained_bft::{ + persistent_storage::RecoveryData, + test_utils::{consensus_runtime, with_smr_id}, +}; +use config::config::ConsensusProposerType::{self, FixedProposer, RotatingProposer}; +use std::{collections::HashMap, time::Duration}; +use tokio::runtime; +use types::ledger_info::LedgerInfoWithSignatures; + +/// Auxiliary struct that is preparing SMR for the test +struct SMRNode { + author: Author, + signer: ValidatorSigner, + validator: Arc, + peers: Arc>, + proposer: Vec, + smr_id: usize, + smr: ChainedBftSMR, + commit_cb_receiver: mpsc::UnboundedReceiver, + mempool: Arc, + mempool_notif_receiver: mpsc::Receiver, + storage: Arc>, +} + +impl SMRNode { + fn start( + quorum_size: usize, + playground: &mut NetworkPlayground, + signer: ValidatorSigner, + validator: Arc, + peers: Arc>, + proposer: Vec, + smr_id: usize, + storage: Arc>, + initial_data: RecoveryData, + ) -> Self { + let author = signer.author(); + + let (network_reqs_tx, network_reqs_rx) = channel::new_test(8); + let (consensus_tx, consensus_rx) = channel::new_test(8); + let network_sender = ConsensusNetworkSender::new(network_reqs_tx); + let network_events = ConsensusNetworkEvents::new(consensus_rx); + + playground.add_node(author, consensus_tx, network_reqs_rx); + let runtime = runtime::Builder::new() + .after_start(with_smr_id(signer.author().short_str())) + .build() + .expect("Failed to create Tokio runtime!"); + let network = ConsensusNetworkImpl::new( + author, + network_sender, + network_events, + Arc::clone(&peers), + Arc::clone(&validator), + ); + + let config = ChainedBftSMRConfig { + max_pruned_blocks_in_mem: 10000, + pacemaker_initial_timeout: Duration::from_secs(1), + contiguous_rounds: 2, + max_block_size: 50, + }; + let mut smr = ChainedBftSMR::new( + author, + quorum_size, + signer.clone(), + proposer.clone(), + network, + runtime, + config, + storage.clone(), + initial_data, + ); + let (commit_cb_sender, commit_cb_receiver) = mpsc::unbounded::(); + let mut mp = MockTransactionManager::new(); + let commit_receiver = mp.take_commit_receiver(); + let mempool = Arc::new(mp); + smr.start( + mempool.clone(), + Arc::new(MockStateComputer::new(commit_cb_sender.clone())), + ) + .expect("Failed to start SMR!"); + Self { + author, + signer, + validator, + peers, + proposer, + smr_id, + smr, + commit_cb_receiver, + mempool, + mempool_notif_receiver: commit_receiver, + storage, + } + } + + fn restart(mut self, quorum_size: usize, playground: &mut NetworkPlayground) -> Self { + self.smr.stop(); + let recover_data = self + .storage + .get_recovery_data() + .unwrap_or_else(|e| panic!("fail to restart due to: {}", e)); + Self::start( + quorum_size, + playground, + self.signer, + self.validator, + self.peers, + self.proposer, + self.smr_id + 10, + self.storage, + recover_data, + ) + } + + fn start_num_nodes( + num_nodes: usize, + quorum_size: usize, + playground: &mut NetworkPlayground, + proposer_type: ConsensusProposerType, + ) -> Vec { + let mut signers = vec![]; + let mut author_to_public_keys = HashMap::new(); + for smr_id in 0..num_nodes { + // 0 -> [0000], 1 -> [1000] in the logs + let random_validator_signer = ValidatorSigner::from_int(smr_id as u8); + author_to_public_keys.insert( + random_validator_signer.author(), + random_validator_signer.public_key(), + ); + signers.push(random_validator_signer); + } + let validator_verifier = + Arc::new(ValidatorVerifier::new(author_to_public_keys, quorum_size)); + let peers: Arc> = Arc::new( + signers + .clone() + .into_iter() + .map(|signer| signer.author()) + .collect(), + ); + let proposer = { + match proposer_type { + FixedProposer => vec![peers[0]], + RotatingProposer => validator_verifier.get_ordered_account_addresses(), + } + }; + let mut nodes = vec![]; + for smr_id in 0..num_nodes { + let (storage, initial_data) = MockStorage::start_for_testing(); + nodes.push(Self::start( + quorum_size, + playground, + signers.remove(0), + Arc::clone(&validator_verifier), + Arc::clone(&peers), + proposer.clone(), + smr_id, + storage, + initial_data, + )); + } + nodes + } +} + +fn verify_finality_proof(node: &SMRNode, ledger_info_with_sig: &LedgerInfoWithSignatures) { + let ledger_info_hash = ledger_info_with_sig.ledger_info().hash(); + for (author, signature) in ledger_info_with_sig.signatures() { + assert_eq!( + Ok(()), + node.validator + .verify_signature(*author, ledger_info_hash, signature) + ); + } +} + +#[test] +/// Should receive a new proposal upon start +fn basic_start_test() { + let runtime = consensus_runtime(); + let mut playground = NetworkPlayground::new(runtime.executor()); + let nodes = SMRNode::start_num_nodes(2, 2, &mut playground, RotatingProposer); + let genesis = nodes[0] + .smr + .block_store() + .expect("No valid block store!") + .root(); + block_on(async move { + let mut msg = playground + .wait_for_messages(1, NetworkPlayground::proposals_only) + .await; + let first_proposal = + ProposalInfo::, Author>::from_proto(msg[0].1.take_proposal()).unwrap(); + assert_eq!(first_proposal.proposal.height(), 1); + assert_eq!(first_proposal.proposal.parent_id(), genesis.id()); + assert_eq!( + first_proposal.proposal.quorum_cert().certified_block_id(), + genesis.id() + ); + }); +} + +#[test] +/// Upon startup, the first proposal is sent, delivered and voted by all the participants. +fn start_with_proposal_test() { + let runtime = consensus_runtime(); + let mut playground = NetworkPlayground::new(runtime.executor()); + let nodes = SMRNode::start_num_nodes(2, 2, &mut playground, RotatingProposer); + + block_on(async move { + let _proposals = playground + .wait_for_messages(1, NetworkPlayground::proposals_only) + .await; + // Need to wait for 2 votes for the 2 replicas + let votes: Vec = playground + .wait_for_messages(2, NetworkPlayground::votes_only) + .await + .into_iter() + .map(|(_, mut msg)| VoteMsg::from_proto(msg.take_vote()).unwrap()) + .collect(); + let proposed_block_id = votes[0].proposed_block_id(); + + // Verify that the proposed block id is indeed present in the block store. + assert!(nodes[0] + .smr + .block_store() + .unwrap() + .get_block(proposed_block_id) + .is_some()); + assert!(nodes[1] + .smr + .block_store() + .unwrap() + .get_block(proposed_block_id) + .is_some()); + }); +} + +#[test] +/// Upon startup, the first proposal is sent, voted by all the participants, QC is formed and +/// then the next proposal is sent. +fn basic_full_round() { + let runtime = consensus_runtime(); + let mut playground = NetworkPlayground::new(runtime.executor()); + let _nodes = SMRNode::start_num_nodes(2, 2, &mut playground, FixedProposer); + + block_on(async move { + let _broadcast_proposals_1 = playground + .wait_for_messages(1, NetworkPlayground::proposals_only) + .await; + let _votes_1 = playground + .wait_for_messages(1, NetworkPlayground::votes_only) + .await; + let mut broadcast_proposals_2 = playground + .wait_for_messages(1, NetworkPlayground::proposals_only) + .await; + let next_proposal = ProposalInfo::, Author>::from_proto( + broadcast_proposals_2[0].1.take_proposal(), + ) + .unwrap(); + assert_eq!(next_proposal.proposal.round(), 2); + assert_eq!(next_proposal.proposal.height(), 2); + }); +} + +/// Verify the basic e2e flow: blocks are committed, txn manager is notified, block tree is +/// pruned, restart the node and we can still continue. +#[test] +fn basic_commit_and_restart() { + let runtime = consensus_runtime(); + let mut playground = NetworkPlayground::new(runtime.executor()); + let mut nodes = SMRNode::start_num_nodes(2, 2, &mut playground, RotatingProposer); + let mut block_ids = vec![]; + + block_on(async { + let num_rounds = 10; + + for round in 0..num_rounds { + let _proposals = playground + .wait_for_messages(1, NetworkPlayground::exclude_new_round) + .await; + + // A proposal is carrying a QC that commits a block of round - 3. + if round >= 3 { + let block_id_to_commit = block_ids[round - 3]; + let commit_v1 = nodes[0].commit_cb_receiver.next().await.unwrap(); + let commit_v2 = nodes[1].commit_cb_receiver.next().await.unwrap(); + assert_eq!( + commit_v1.ledger_info().consensus_block_id(), + block_id_to_commit + ); + verify_finality_proof(&nodes[0], &commit_v1); + assert_eq!( + commit_v2.ledger_info().consensus_block_id(), + block_id_to_commit + ); + verify_finality_proof(&nodes[1], &commit_v2); + } + + // v1 and v2 send votes + let mut votes = playground + .wait_for_messages(1, NetworkPlayground::votes_only) + .await; + let vote_msg = VoteMsg::from_proto(votes[0].1.take_vote()).unwrap(); + block_ids.push(vote_msg.proposed_block_id()); + } + assert!( + nodes[0].smr.block_store().unwrap().root().height() >= 6, + "height of node 0 is {}", + nodes[0].smr.block_store().unwrap().root().height() + ); + assert!( + nodes[1].smr.block_store().unwrap().root().height() >= 6, + "height of node 1 is {}", + nodes[1].smr.block_store().unwrap().root().height() + ); + // This message is for proposal with round 11 to delivery the QC, but not gather the QC + // so after restart, proposer will propose round 11 again. + playground + .wait_for_messages(1, NetworkPlayground::exclude_new_round) + .await; + }); + // create a new playground to avoid polling potential vote messages in previous one. + playground = NetworkPlayground::new(runtime.executor()); + nodes = nodes + .into_iter() + .map(|node| node.restart(2, &mut playground)) + .collect(); + + block_on(async { + let mut round = 0; + + while round < 10 { + // The loop is to ensure that we collect a network vote(enough for QC with 2 nodes) then + // move the round forward because there's a race that node1 may or may not + // reject round 11 depends on whether it voted for before restart. + loop { + let msg = playground + .wait_for_messages(1, NetworkPlayground::exclude_new_round) + .await; + if msg[0].1.has_vote() { + round += 1; + break; + } + } + } + // Because of the race, we can't assert the commit reliably, instead we assert + // both nodes commit to at least round 17. + // We cannot reliable wait for the event of "commit & prune": the only thing that we know is + // that after receiving the vote for round 20, the root should be at least height 16. + assert!( + nodes[0].smr.block_store().unwrap().root().height() >= 16, + "height of node 0 is {}", + nodes[0].smr.block_store().unwrap().root().height() + ); + assert!( + nodes[1].smr.block_store().unwrap().root().height() >= 16, + "height of node 1 is {}", + nodes[1].smr.block_store().unwrap().root().height() + ); + }); +} + +#[test] +fn basic_block_retrieval() { + let runtime = consensus_runtime(); + let mut playground = NetworkPlayground::new(runtime.executor()); + // This test depends on the fixed proposer on nodes[0] + let mut nodes = SMRNode::start_num_nodes(3, 2, &mut playground, FixedProposer); + block_on(async move { + let mut first_proposals = vec![]; + // First three proposals are delivered just to nodes[0[ and nodes[1]. + playground.drop_message_for(&nodes[0].author, nodes[2].author); + for _ in 0..2 { + playground + .wait_for_messages(1, NetworkPlayground::proposals_only) + .await; + let mut votes = playground + .wait_for_messages(1, NetworkPlayground::votes_only) + .await; + let vote_msg = VoteMsg::from_proto(votes[0].1.take_vote()).unwrap(); + let proposal_id = vote_msg.proposed_block_id(); + first_proposals.push(proposal_id); + } + // The next proposal is delivered to all: as a result nodes[2] should retrieve the missing + // blocks from nodes[0] and vote for the 3th proposal. + playground.stop_drop_message_for(&nodes[0].author, &nodes[2].author); + + playground + .wait_for_messages(2, NetworkPlayground::proposals_only) + .await; + // Wait until nodes[2] sent out a vote, drop the vote from nodes[1] so that nodes[0] + // won't move too far and prune the requested block + playground.drop_message_for(&nodes[1].author, nodes[0].author); + playground + .wait_for_messages(1, NetworkPlayground::votes_only) + .await; + playground.stop_drop_message_for(&nodes[1].author, &nodes[0].author); + // the first two proposals should be present at nodes[2] + for block_id in &first_proposals { + assert!(nodes[2] + .smr + .block_store() + .unwrap() + .get_block(*block_id) + .is_some()); + } + + // Both nodes[1] and nodes[2] are going to vote for 4th proposal and commit the 1th one. + + // Verify that nodes[2] commits the first proposal. + playground + .wait_for_messages(2, NetworkPlayground::votes_only) + .await; + if let Some(commit_v3) = nodes[2].commit_cb_receiver.next().await { + assert_eq!( + commit_v3.ledger_info().consensus_block_id(), + first_proposals[0], + ); + } + }); +} + +#[test] +fn block_retrieval_with_timeout() { + let runtime = consensus_runtime(); + let mut playground = NetworkPlayground::new(runtime.executor()); + let nodes = SMRNode::start_num_nodes(3, 2, &mut playground, FixedProposer); + block_on(async move { + let mut first_proposals = vec![]; + // First three proposals are delivered just to nodes[0] and nodes[1]. + playground.drop_message_for(&nodes[0].author, nodes[2].author); + for _ in 0..2 { + playground + .wait_for_messages(1, NetworkPlayground::proposals_only) + .await; + let mut votes = playground + .wait_for_messages(1, NetworkPlayground::votes_only) + .await; + let vote_msg = VoteMsg::from_proto(votes[0].1.take_vote()).unwrap(); + let proposal_id = vote_msg.proposed_block_id(); + first_proposals.push(proposal_id); + } + // The next proposal is delivered to all: as a result nodes[2] should retrieve the missing + // blocks from v1 and vote for the 4th proposal. + playground.stop_drop_message_for(&nodes[0].author, &nodes[2].author); + + playground + .wait_for_messages(2, NetworkPlayground::proposals_only) + .await; + playground.drop_message_for(&nodes[1].author, nodes[0].author); + // Block RPC and wait until timeout for current round + playground.drop_message_for(&nodes[2].author, nodes[0].author); + playground + .wait_for_messages(1, NetworkPlayground::new_round_only) + .await; + // Unblock RPC + playground.stop_drop_message_for(&nodes[2].author, &nodes[0].author); + // Wait until v3 sent out a vote, drop the vote from v2 so that v1 won't move too far + // and prune the requested block + playground + .wait_for_messages(1, NetworkPlayground::votes_only) + .await; + playground.stop_drop_message_for(&nodes[1].author, &nodes[0].author); + // the first two proposals should be present at v3 + for block_id in &first_proposals { + assert!(nodes[2] + .smr + .block_store() + .unwrap() + .get_block(*block_id) + .is_some()); + } + }); +} + +#[test] +/// Verify that a node that is lagging behind can catch up by state sync some blocks +/// have been pruned by the others. +fn basic_state_sync() { + let runtime = consensus_runtime(); + let mut playground = NetworkPlayground::new(runtime.executor()); + // This test depends on the fixed proposer on nodes[0] + let mut nodes = SMRNode::start_num_nodes(3, 2, &mut playground, FixedProposer); + block_on(async move { + let mut proposals = vec![]; + // The first ten proposals are delivered just to nodes[0] and nodes[1], which should commit + // the first seven blocks. + playground.drop_message_for(&nodes[0].author, nodes[2].author); + for _ in 0..10 { + playground + .wait_for_messages(1, NetworkPlayground::proposals_only) + .await; + let mut votes = playground + .wait_for_messages(1, NetworkPlayground::votes_only) + .await; + let vote_msg = VoteMsg::from_proto(votes[0].1.take_vote()).unwrap(); + let proposal_id = vote_msg.proposed_block_id(); + proposals.push(proposal_id); + } + + let mut node0_commits = vec![]; + for i in 0..7 { + node0_commits.push( + nodes[0] + .commit_cb_receiver + .next() + .await + .unwrap() + .ledger_info() + .consensus_block_id(), + ); + assert_eq!(node0_commits[i], proposals[i]); + } + + // Next proposal is delivered to all: as a result nodes[2] should be able to retrieve the + // missing blocks from nodes[0] and commit the first eight proposals as well. + playground.stop_drop_message_for(&nodes[0].author, &nodes[2].author); + playground + .wait_for_messages(2, NetworkPlayground::proposals_only) + .await; + let mut node2_commits = vec![]; + // The only notification we will receive is for the last (8th) proposal. + node2_commits.push( + nodes[2] + .commit_cb_receiver + .next() + .await + .unwrap() + .ledger_info() + .consensus_block_id(), + ); + assert_eq!(node2_commits[0], proposals[7]); + + // wait for the vote from node2 + playground + .wait_for_messages(1, NetworkPlayground::votes_only) + .await; + for (_, proposal) in playground + .wait_for_messages(2, NetworkPlayground::proposals_only) + .await + { + assert_eq!(proposal.has_proposal(), true); + } + // Verify that node 2 has notified its mempool about the committed txn of next block. + nodes[2] + .mempool_notif_receiver + .next() + .await + .expect("Fail to be notified by a mempool committed txns"); + assert_eq!(nodes[2].mempool.get_committed_txns().len(), 50); + }); +} + +#[test] +/// Verify that a node syncs up when receiving a timeout message with a relevant ledger info +fn state_sync_on_timeout() { + let runtime = consensus_runtime(); + let mut playground = NetworkPlayground::new(runtime.executor()); + // This test depends on the fixed proposer on nodes[0] + let mut nodes = SMRNode::start_num_nodes(3, 2, &mut playground, FixedProposer); + block_on(async move { + let mut proposals = vec![]; + // The first ten proposals are delivered just to nodes[0] and nodes[1], which should commit + // the first seven blocks. + playground.drop_message_for(&nodes[0].author, nodes[2].author); + for _ in 0..10 { + playground + .wait_for_messages(1, NetworkPlayground::proposals_only) + .await; + let mut votes = playground + .wait_for_messages(1, NetworkPlayground::votes_only) + .await; + let vote_msg = VoteMsg::from_proto(votes[0].1.take_vote()).unwrap(); + let proposal_id = vote_msg.proposed_block_id(); + proposals.push(proposal_id); + } + + // Start dropping messages from 0 to 1 as well: node 0 is now disconnected and we can + // expect timeouts from both 0 and 1. + playground.drop_message_for(&nodes[0].author, nodes[1].author); + + // Wait for a timeout message from 2 to {0, 1} and from 1 to {0, 2} + // (node 0 cannot send to anyone). Note that there are 6 messages waited on + // since 2 can timeout 2x while waiting for 1 to timeout. + playground + .wait_for_messages(6, NetworkPlayground::new_round_only) + .await; + + let mut node2_commits = vec![]; + // The only notification we will receive is for the last commit known to nodes[1]: 7th + // proposal. + node2_commits.push( + nodes[2] + .commit_cb_receiver + .next() + .await + .unwrap() + .ledger_info() + .consensus_block_id(), + ); + assert_eq!(node2_commits[0], proposals[6]); + }); +} diff --git a/consensus/src/chained_bft/common.rs b/consensus/src/chained_bft/common.rs new file mode 100644 index 0000000000000..085762ed0d971 --- /dev/null +++ b/consensus/src/chained_bft/common.rs @@ -0,0 +1,53 @@ +// Copyright (c) The Libra Core Contributors +// SPDX-License-Identifier: Apache-2.0 + +#![allow(clippy::unit_arg)] + +use canonical_serialization::{CanonicalDeserialize, CanonicalSerialize}; +use serde::{de::DeserializeOwned, Serialize}; +use std::fmt::Debug; +use types::account_address::AccountAddress; + +/// The round of a block is a consensus-internal counter, which starts with 0 and increases +/// monotonically. It is used for the protocol safety and liveness (please see the detailed +/// protocol description). +pub type Round = u64; +/// Height refers to the chain depth of a consensus block in a tree with respect to parent links. +/// The genesis block starts at height 0. The round of a block is always >= height. Height is +/// only used for debugging and testing as it is not required for implementing LibraBFT. +pub type Height = u64; +/// Author refers to the author's account address +pub type Author = AccountAddress; + +/// Trait alias for the Block Payload. +pub trait Payload: + Clone + + Send + + Sync + + CanonicalSerialize + + CanonicalDeserialize + + DeserializeOwned + + Serialize + + Default + + Debug + + PartialEq + + Eq + + 'static +{ +} + +impl Payload for T where + T: Clone + + Send + + Sync + + CanonicalSerialize + + CanonicalDeserialize + + DeserializeOwned + + Serialize + + Default + + Debug + + PartialEq + + Eq + + 'static +{ +} diff --git a/consensus/src/chained_bft/consensus_types/block.rs b/consensus/src/chained_bft/consensus_types/block.rs new file mode 100644 index 0000000000000..a6c337d0be04c --- /dev/null +++ b/consensus/src/chained_bft/consensus_types/block.rs @@ -0,0 +1,391 @@ +// Copyright (c) The Libra Core Contributors +// SPDX-License-Identifier: Apache-2.0 + +use crate::{ + chained_bft::{ + common::{Author, Height, Round}, + consensus_types::quorum_cert::QuorumCert, + safety::vote_msg::VoteMsgVerificationError, + }, + state_replication::ExecutedState, +}; +use canonical_serialization::{ + CanonicalDeserialize, CanonicalSerialize, CanonicalSerializer, SimpleSerializer, +}; +use crypto::{ + hash::{BlockHasher, CryptoHash, CryptoHasher, GENESIS_BLOCK_ID}, + HashValue, Signature, +}; +use failure::Result; +use mirai_annotations::{checked_precondition, checked_precondition_eq}; +use network::proto::Block as ProtoBlock; +use proto_conv::{FromProto, IntoProto}; +use rmp_serde::{from_slice, to_vec_named}; +use serde::{de::DeserializeOwned, Deserialize, Serialize}; +use std::{ + collections::HashMap, + convert::TryFrom, + fmt::{Display, Formatter}, +}; +use types::{ + ledger_info::{LedgerInfo, LedgerInfoWithSignatures}, + validator_signer::ValidatorSigner, + validator_verifier::ValidatorVerifier, +}; + +#[cfg(test)] +#[path = "block_test.rs"] +pub mod block_test; + +#[derive(Debug)] +pub enum BlockVerificationError { + /// The verification of quorum cert of this block failed. + QCVerificationError(VoteMsgVerificationError), + /// The signature verification of this block failed. + SigVerifyError, +} + +/// Blocks are managed in a speculative tree, the committed blocks form a chain. +/// Each block must know the id of its parent and keep the QuorurmCertificate to that parent. +#[derive(Deserialize, Serialize, Clone, Debug, PartialEq, Eq)] +pub struct Block { + /// This block's id as a hash value + id: HashValue, + /// Parent block id of this block as a hash value (all zeros to indicate the genesis block) + parent_id: HashValue, + /// T of the block (e.g. one or more transaction(s) + payload: T, + /// The round of a block is an internal monotonically increasing counter used by Consensus + /// protocol. + round: Round, + /// The height of a block is its position in the chain (block height = parent block height + 1) + height: Height, + /// The approximate physical time a block is proposed by a proposer. This timestamp is used + /// for + /// * Time-dependent logic in smart contracts (the current time of execution) + /// * Clients determining if they are relatively up-to-date with respect to the block chain. + /// + /// It makes the following guarantees: + /// 1. Time Monotonicity: Time is monotonically increasing in the block + /// chain. (i.e. If H1 < H2, H1.Time < H2.Time). + /// 2. If a block of transactions B is agreed on with timestamp T, then at least f+1 + /// honest replicas think that T is in the past. An honest replica will only vote + /// on a block when its own clock >= timestamp T. + /// 3. If a block of transactions B is agreed on with timestamp T, then at least f+1 honest + /// replicas saw the contents of B no later than T + delta for some delta. + /// If T = 3:00 PM and delta is 10 minutes, then an honest replica would not have + /// voted for B unless its clock was between 3:00 PM to 3:10 PM at the time the + /// proposal was received. After 3:10 PM, an honest replica would no longer vote + /// on B, noting it was too far in the past. + timestamp_usecs: u64, + /// Contains the quorum certified ancestor and whether the quorum certified ancestor was + /// voted on successfully + quorum_cert: QuorumCert, + /// Author of the block that can be validated by the author's public key and the signature + author: Author, + /// Signature that the hash of this block has been authored by the owner of the private key + signature: Signature, +} + +impl Display for Block { + fn fmt(&self, f: &mut Formatter) -> std::fmt::Result { + write!( + f, + "[id: {}, round: {:02}, parent_id: {}]", + self.id, self.round, self.parent_id + ) + } +} + +impl Block +where + T: Serialize + Default + CanonicalSerialize, +{ + // Make an empty genesis block + pub fn make_genesis_block() -> Self { + let ancestor_id = HashValue::zero(); + let genesis_validator_signer = ValidatorSigner::genesis(); + let state = ExecutedState::state_for_genesis(); + // Genesis carries a placeholder quorum certificate to its parent id with LedgerInfo + // carrying information about version `0`. + let genesis_quorum_cert = QuorumCert::new( + ancestor_id, + state, + 0, + LedgerInfoWithSignatures::new( + LedgerInfo::new( + 0, + state.state_id, + HashValue::zero(), + HashValue::zero(), + 0, + 0, + ), + HashMap::new(), + ), + ); + let genesis_id = *GENESIS_BLOCK_ID; + let signature = genesis_validator_signer + .sign_message(genesis_id) + .expect("Failed to sign genesis id."); + + Block { + id: genesis_id, + payload: T::default(), + parent_id: HashValue::zero(), + round: 0, + height: 0, + timestamp_usecs: 0, // The beginning of UNIX TIME + quorum_cert: genesis_quorum_cert, + author: genesis_validator_signer.author(), + signature, + } + } + + // Create a block directly. Most users should prefer make_block() as it ensures correct block + // chaining. This functionality should typically only be used for testing. + pub fn new_internal( + payload: T, + parent_id: HashValue, + round: Round, + height: Height, + timestamp_usecs: u64, + quorum_cert: QuorumCert, + validator_signer: &ValidatorSigner, + ) -> Self { + let block_internal = BlockSerializer { + parent_id, + payload: &payload, + round, + height, + timestamp_usecs, + quorum_cert: &quorum_cert, + author: validator_signer.author(), + }; + + let id = block_internal.hash(); + let signature = validator_signer + .sign_message(id) + .expect("Failed to sign message"); + + Block { + id, + payload, + parent_id, + round, + height, + timestamp_usecs, + quorum_cert, + author: validator_signer.author(), + signature, + } + } + + pub fn make_block( + parent_block: &Block, + payload: T, + round: Round, + timestamp_usecs: u64, + quorum_cert: QuorumCert, + validator_signer: &ValidatorSigner, + ) -> Self { + // A block must carry a QC to its parent. + checked_precondition_eq!(quorum_cert.certified_block_id(), parent_block.id()); + checked_precondition!(round > parent_block.round()); + Block::new_internal( + payload, + parent_block.id(), + round, + // Height is always parent's height + 1 because it's just the position in the chain. + parent_block.height() + 1, + timestamp_usecs, + quorum_cert, + validator_signer, + ) + } + + pub fn get_payload(&self) -> &T { + &self.payload + } + + pub fn verify( + &self, + validator: &ValidatorVerifier, + ) -> ::std::result::Result<(), BlockVerificationError> { + if self.is_genesis_block() { + return Ok(()); + } + validator + .verify_signature(self.author(), self.hash(), self.signature()) + .map_err(|_| BlockVerificationError::SigVerifyError)?; + self.quorum_cert + .verify(validator) + .map_err(BlockVerificationError::QCVerificationError) + } + + pub fn id(&self) -> HashValue { + self.id + } + + pub fn parent_id(&self) -> HashValue { + self.parent_id + } + + pub fn height(&self) -> Height { + self.height + } + + pub fn round(&self) -> Round { + self.round + } + + pub fn timestamp_usecs(&self) -> u64 { + self.timestamp_usecs + } + + pub fn quorum_cert(&self) -> &QuorumCert { + &self.quorum_cert + } + + pub fn author(&self) -> Author { + self.author + } + + pub fn signature(&self) -> &Signature { + &self.signature + } + + pub fn is_genesis_block(&self) -> bool { + self.id() == *GENESIS_BLOCK_ID + } +} + +impl CryptoHash for Block +where + T: canonical_serialization::CanonicalSerialize, +{ + type Hasher = BlockHasher; + + fn hash(&self) -> HashValue { + let block_internal = BlockSerializer { + parent_id: self.parent_id, + payload: &self.payload, + round: self.round, + height: self.height, + timestamp_usecs: self.timestamp_usecs, + quorum_cert: &self.quorum_cert, + author: self.author, + }; + block_internal.hash() + } +} + +// Internal use only. Contains all the fields in Block that contributes to the computation of +// Block Id +struct BlockSerializer<'a, T> { + parent_id: HashValue, + payload: &'a T, + round: Round, + height: Height, + timestamp_usecs: u64, + quorum_cert: &'a QuorumCert, + author: Author, +} + +impl<'a, T> CryptoHash for BlockSerializer<'a, T> +where + T: CanonicalSerialize, +{ + type Hasher = BlockHasher; + + fn hash(&self) -> HashValue { + let bytes = + SimpleSerializer::>::serialize(self).expect("block serialization failed"); + let mut state = Self::Hasher::default(); + state.write(bytes.as_ref()); + state.finish() + } +} + +impl<'a, T> CanonicalSerialize for BlockSerializer<'a, T> +where + T: CanonicalSerialize, +{ + fn serialize(&self, serializer: &mut impl CanonicalSerializer) -> Result<()> { + serializer + .encode_u64(self.timestamp_usecs)? + .encode_u64(self.round)? + .encode_u64(self.height)? + .encode_struct(self.payload)? + .encode_raw_bytes(self.parent_id.as_ref())? + .encode_raw_bytes(self.quorum_cert.certified_block_id().as_ref())? + .encode_struct(&self.author)?; + Ok(()) + } +} + +#[cfg(test)] +impl Block +where + T: Default + Serialize + CanonicalSerialize, +{ + // Is this block a parent of the parameter block? + pub fn is_parent_of(&self, block: &Self) -> bool { + block.parent_id == self.id + } +} + +impl IntoProto for Block +where + T: Serialize + Default + CanonicalSerialize, +{ + type ProtoType = ProtoBlock; + + fn into_proto(self) -> Self::ProtoType { + let mut proto = Self::ProtoType::new(); + proto.set_timestamp_usecs(self.timestamp_usecs); + proto.set_id(self.id().into()); + proto.set_parent_id(self.parent_id().into()); + proto.set_payload( + to_vec_named(self.get_payload()) + .expect("fail to serialize payload") + .into(), + ); + proto.set_round(self.round()); + proto.set_height(self.height()); + proto.set_quorum_cert(self.quorum_cert().clone().into_proto()); + proto.set_signature(self.signature().to_compact().as_ref().into()); + proto.set_author(self.author.into()); + proto + } +} + +impl FromProto for Block +where + T: DeserializeOwned + CanonicalDeserialize, +{ + type ProtoType = ProtoBlock; + + fn from_proto(mut object: Self::ProtoType) -> Result { + let id = HashValue::from_slice(object.get_id())?; + let parent_id = HashValue::from_slice(object.get_parent_id())?; + let payload = from_slice(object.get_payload())?; + let timestamp_usecs = object.get_timestamp_usecs(); + let round = object.get_round(); + let height = object.get_height(); + let quorum_cert = QuorumCert::from_proto(object.take_quorum_cert())?; + let author = Author::try_from(object.take_author())?; + let signature = Signature::from_compact(object.get_signature())?; + Ok(Block { + id, + parent_id, + payload, + round, + timestamp_usecs, + height, + quorum_cert, + author, + signature, + }) + } +} diff --git a/consensus/src/chained_bft/consensus_types/block_test.rs b/consensus/src/chained_bft/consensus_types/block_test.rs new file mode 100644 index 0000000000000..19a5dd95271e4 --- /dev/null +++ b/consensus/src/chained_bft/consensus_types/block_test.rs @@ -0,0 +1,276 @@ +// Copyright (c) The Libra Core Contributors +// SPDX-License-Identifier: Apache-2.0 + +use crate::chained_bft::{ + common::{Height, Round}, + consensus_types::{block::Block, quorum_cert::QuorumCert}, + test_utils::placeholder_certificate_for_block, +}; + +use crypto::{HashValue, PrivateKey, PublicKey}; +use proptest::prelude::*; +use std::{ + panic, + time::{Duration, SystemTime, UNIX_EPOCH}, +}; +use types::validator_signer::{self, ValidatorSigner}; + +type LinearizedBlockForest = Vec>; + +prop_compose! { + /// This strategy is a swiss-army tool to produce a low-level block + /// dependent on signer, round, parent and ancestor_id. + /// Note that the quorum certificate carried by this block is still placeholder: one will have + /// to generate it later on when adding to the tree. + pub fn make_block( + _ancestor_id: HashValue, + parent_id_strategy: impl Strategy, + round_strategy: impl Strategy, + height: Height, + signer_strategy: impl Strategy, + )( + parent_id in parent_id_strategy, + round in round_strategy, + payload in 0usize..10usize, + height in Just(height), + signer in signer_strategy, + ) -> Block> { + Block::new_internal( + vec![payload], + parent_id, + round, + height, + get_current_timestamp().as_micros() as u64, + QuorumCert::certificate_for_genesis(), + &signer, + ) + } +} + +/// This produces the genesis block +pub fn genesis_strategy() -> impl Strategy>> { + Just(Block::make_genesis_block()) +} + +prop_compose! { + /// This produces an unmoored block, with arbitrary parent & QC ancestor + pub fn unmoored_block(ancestor_id_strategy: impl Strategy)( + ancestor_id in ancestor_id_strategy, + )( + block in make_block( + ancestor_id, + HashValue::arbitrary(), + Round::arbitrary(), + 123, + validator_signer::arb_signer(), + ) + ) -> Block> { + block + } +} + +/// Offers the genesis block. +pub fn leaf_strategy() -> impl Strategy>> { + genesis_strategy().boxed() +} + +prop_compose! { + /// This produces a block with an invalid id (and therefore signature) + /// given a valid block + pub fn fake_id(block_strategy: impl Strategy>>) + (fake_id in HashValue::arbitrary(), + block in block_strategy) -> Block> { + Block { + timestamp_usecs: get_current_timestamp().as_micros() as u64, + id: fake_id, + payload: block.get_payload().clone(), + round: block.round(), + height: block.height(), + parent_id: block.parent_id(), + quorum_cert: block.quorum_cert().clone(), + author: block.author(), + signature: *block.signature(), + } + } +} + +prop_compose! { + fn bigger_round(initial_round: Round)( + increment in 2..8, + initial_round in Just(initial_round), + ) -> Round { + initial_round + increment as u64 + } +} + +/// This produces a round that is often higher than the parent, but not +/// too high +pub fn some_round(initial_round: Round) -> impl Strategy { + prop_oneof![ + 9 => Just(1 + initial_round), + 1 => bigger_round(initial_round), + ] +} + +prop_compose! { + /// This creates a child with a parent on its left, and a QC on the left + /// of the parent. This, depending on branching, does not require the + /// QC to always be an ancestor or the parent to always be the highest QC + fn child( + signer_strategy: impl Strategy, + block_forest_strategy: impl Strategy>>, + )( + signer in signer_strategy, + (forest_vec, parent_idx, qc_idx) in block_forest_strategy + .prop_flat_map(|forest_vec| { + let len = forest_vec.len(); + (Just(forest_vec), 0..len) + }) + .prop_flat_map(|(forest_vec, parent_idx)| { + (Just(forest_vec), Just(parent_idx), 0..=parent_idx) + }), + )( block in make_block( + // ancestor_id + forest_vec[qc_idx].id(), + // parent_id + Just(forest_vec[parent_idx].id()), + // round + some_round(forest_vec[parent_idx].round()), + // height, + forest_vec[parent_idx].height() + 1, + // signer + Just(signer), + ), mut forest in Just(forest_vec), + ) -> LinearizedBlockForest> { + forest.push(block); + forest + } +} + +/// This creates a block forest with keys extracted from a specific +/// vector +fn block_forest_from_keys( + depth: u32, + key_pairs: Vec<(PrivateKey, PublicKey)>, +) -> impl Strategy>> { + let leaf = leaf_strategy().prop_map(|block| vec![block]); + // Note that having `expected_branch_size` of 1 seems to generate significantly larger trees + // than desired (this is my understanding after reading the documentation: + // https://docs.rs/proptest/0.3.0/proptest/strategy/trait.Strategy.html#method.prop_recursive) + leaf.prop_recursive(depth, depth, 2, move |inner| { + child( + validator_signer::mostly_in_keypair_pool(key_pairs.clone()), + inner, + ) + }) +} + +/// This returns keys and a block forest created from them +pub fn block_forest_and_its_keys( + quorum_size: usize, + depth: u32, +) -> impl Strategy< + Value = ( + Vec<(PrivateKey, PublicKey)>, + LinearizedBlockForest>, + ), +> { + proptest::collection::vec(validator_signer::arb_keypair(), quorum_size).prop_flat_map( + move |key_pairs| { + ( + Just(key_pairs.clone()), + block_forest_from_keys(depth, key_pairs), + ) + }, + ) +} + +#[test] +fn test_genesis() { + // Test genesis and the next block + let genesis_block = Block::::make_genesis_block(); + assert_eq!(genesis_block.height(), 0); + assert_eq!(genesis_block.parent_id(), HashValue::zero()); + assert_ne!(genesis_block.id(), HashValue::zero()); + assert!(genesis_block.is_genesis_block()); +} + +#[test] +fn test_block_relation() { + let signer = ValidatorSigner::random(); + // Test genesis and the next block + let genesis_block = Block::make_genesis_block(); + let quorum_cert = QuorumCert::certificate_for_genesis(); + let payload = 101; + let next_block = Block::make_block( + &genesis_block, + payload, + 1, + get_current_timestamp().as_micros() as u64, + quorum_cert, + &signer, + ); + assert_eq!(next_block.round(), 1); + assert_eq!(next_block.height(), 1); + assert_eq!(genesis_block.is_parent_of(&next_block), true); + assert_eq!( + next_block.quorum_cert().certified_block_id(), + genesis_block.id() + ); + assert_eq!(next_block.get_payload(), &payload); + + let cloned_block = next_block.clone(); + assert_eq!(cloned_block.round(), next_block.round()); +} + +#[test] +fn test_block_qc() { + // Verify that it's impossible to create a block with QC that doesn't point to a parent. + let signer = ValidatorSigner::random(); + // Test genesis and the next block + let genesis_block = Block::make_genesis_block(); + let genesis_qc = QuorumCert::certificate_for_genesis(); + + let payload = 42; + let a1 = Block::make_block( + &genesis_block, + payload, + 1, + get_current_timestamp().as_micros() as u64, + genesis_qc.clone(), + &signer, + ); + let a1_qc = placeholder_certificate_for_block(vec![signer.clone()], a1.id(), a1.round()); + + let result = panic::catch_unwind(|| { + // should panic because qc does not point to parent + Block::make_block( + &a1, + payload, + 2, + get_current_timestamp().as_micros() as u64, + genesis_qc.clone(), + &signer, + ); + }); + assert!(result.is_err()); + + // once qc is correct, should not panic + let a2 = Block::make_block( + &a1, + payload, + 2, + get_current_timestamp().as_micros() as u64, + a1_qc.clone(), + &signer, + ); + assert_eq!(a2.height(), 2); +} + +// Using current_timestamp in this test +// because it's a bit hard to generate incremental timestamps in proptests +fn get_current_timestamp() -> Duration { + SystemTime::now() + .duration_since(UNIX_EPOCH) + .expect("Timestamp generated is before the UNIX_EPOCH!") +} diff --git a/consensus/src/chained_bft/consensus_types/mod.rs b/consensus/src/chained_bft/consensus_types/mod.rs new file mode 100644 index 0000000000000..57b1d38a1db70 --- /dev/null +++ b/consensus/src/chained_bft/consensus_types/mod.rs @@ -0,0 +1,5 @@ +// Copyright (c) The Libra Core Contributors +// SPDX-License-Identifier: Apache-2.0 + +pub(crate) mod block; +pub(crate) mod quorum_cert; diff --git a/consensus/src/chained_bft/consensus_types/quorum_cert.rs b/consensus/src/chained_bft/consensus_types/quorum_cert.rs new file mode 100644 index 0000000000000..8582946a6b50d --- /dev/null +++ b/consensus/src/chained_bft/consensus_types/quorum_cert.rs @@ -0,0 +1,182 @@ +// Copyright (c) The Libra Core Contributors +// SPDX-License-Identifier: Apache-2.0 + +use crate::{ + chained_bft::{ + common::Round, + safety::vote_msg::{VoteMsg, VoteMsgVerificationError}, + }, + state_replication::ExecutedState, +}; +use crypto::{ + hash::{CryptoHash, ACCUMULATOR_PLACEHOLDER_HASH, GENESIS_BLOCK_ID}, + HashValue, +}; +use failure::Result; +use network::proto::QuorumCert as ProtoQuorumCert; +use proto_conv::{FromProto, IntoProto}; +use serde::{Deserialize, Serialize}; +use std::{ + collections::HashMap, + fmt::{Display, Formatter}, +}; +use types::{ + ledger_info::{LedgerInfo, LedgerInfoWithSignatures}, + validator_signer::ValidatorSigner, + validator_verifier::ValidatorVerifier, +}; + +#[derive(Deserialize, Serialize, Clone, Debug, Eq, PartialEq)] +pub struct QuorumCert { + /// The id of a block that is certified by this QuorumCertificate. + certified_block_id: HashValue, + /// The execution state of the corresponding block. + certified_state: ExecutedState, + /// The round of a certified block. + certified_block_round: Round, + /// The signed LedgerInfo of a committed block that carries the data about the certified block. + signed_ledger_info: LedgerInfoWithSignatures, +} + +impl Display for QuorumCert { + fn fmt(&self, f: &mut Formatter) -> std::fmt::Result { + write!( + f, + "QC: [\n\ + \tCertified block id: {},\n\ + \tround: {:02},\n\ + \tledger info: {}\n\ + ]", + self.certified_block_id, self.certified_block_round, self.signed_ledger_info + ) + } +} + +#[allow(dead_code)] +impl QuorumCert { + pub fn new( + block_id: HashValue, + state: ExecutedState, + round: Round, + signed_ledger_info: LedgerInfoWithSignatures, + ) -> Self { + QuorumCert { + certified_block_id: block_id, + certified_state: state, + certified_block_round: round, + signed_ledger_info, + } + } + + pub fn certified_block_id(&self) -> HashValue { + self.certified_block_id + } + + pub fn certified_state(&self) -> ExecutedState { + self.certified_state + } + + pub fn certified_block_round(&self) -> Round { + self.certified_block_round + } + + pub fn ledger_info(&self) -> &LedgerInfoWithSignatures { + &self.signed_ledger_info + } + + pub fn committed_block_id(&self) -> Option { + let id = self.ledger_info().ledger_info().consensus_block_id(); + if id.is_zero() { + None + } else { + Some(id) + } + } + + /// QuorumCert for the genesis block: + /// - the ID of the block is predetermined by the `GENESIS_BLOCK_ID` constant. + /// - the accumulator root hash of the LedgerInfo is set to `ACCUMULATOR_PLACEHOLDER_HASH` + /// constant. + /// - the map of signatures is empty because genesis block is implicitly agreed. + pub fn certificate_for_genesis() -> QuorumCert { + let genesis_digest = + VoteMsg::vote_digest(*GENESIS_BLOCK_ID, ExecutedState::state_for_genesis(), 0); + let signer = ValidatorSigner::genesis(); + let li = LedgerInfo::new( + 0, + *ACCUMULATOR_PLACEHOLDER_HASH, + genesis_digest, + *GENESIS_BLOCK_ID, + 0, + 0, + ); + let signature = signer + .sign_message(li.hash()) + .expect("Fail to sign genesis ledger info"); + let mut signatures = HashMap::new(); + signatures.insert(signer.author(), signature); + QuorumCert::new( + *GENESIS_BLOCK_ID, + ExecutedState::state_for_genesis(), + 0, + LedgerInfoWithSignatures::new(li, signatures), + ) + } + + pub fn verify( + &self, + validator: &ValidatorVerifier, + ) -> ::std::result::Result<(), VoteMsgVerificationError> { + let vote_hash = VoteMsg::vote_digest( + self.certified_block_id, + self.certified_state, + self.certified_block_round, + ); + if self.ledger_info().ledger_info().consensus_data_hash() != vote_hash { + return Err(VoteMsgVerificationError::ConsensusDataMismatch); + } + // Genesis is implicitly agreed upon, it doesn't have real signatures. + if self.certified_block_round == 0 + && self.certified_block_id == *GENESIS_BLOCK_ID + && self.certified_state == ExecutedState::state_for_genesis() + { + return Ok(()); + } + self.ledger_info() + .verify(validator) + .map_err(VoteMsgVerificationError::SigVerifyError) + } +} + +impl IntoProto for QuorumCert { + type ProtoType = ProtoQuorumCert; + + fn into_proto(self) -> Self::ProtoType { + let mut proto = Self::ProtoType::new(); + proto.set_block_id(self.certified_block_id.into()); + proto.set_state_id(self.certified_state.state_id.into()); + proto.set_version(self.certified_state.version); + proto.set_round(self.certified_block_round); + proto.set_signed_ledger_info(self.signed_ledger_info.into_proto()); + proto + } +} + +impl FromProto for QuorumCert { + type ProtoType = ProtoQuorumCert; + + fn from_proto(object: Self::ProtoType) -> Result { + let certified_block_id = HashValue::from_slice(object.get_block_id())?; + let state_id = HashValue::from_slice(object.get_state_id())?; + let version = object.get_version(); + let certified_block_round = object.get_round(); + let signed_ledger_info = + LedgerInfoWithSignatures::from_proto(object.get_signed_ledger_info().clone())?; + Ok(QuorumCert { + certified_block_id, + certified_state: ExecutedState { state_id, version }, + certified_block_round, + signed_ledger_info, + }) + } +} diff --git a/consensus/src/chained_bft/consensusdb/consensusdb_test.rs b/consensus/src/chained_bft/consensusdb/consensusdb_test.rs new file mode 100644 index 0000000000000..77ee83332618b --- /dev/null +++ b/consensus/src/chained_bft/consensusdb/consensusdb_test.rs @@ -0,0 +1,55 @@ +// Copyright (c) The Libra Core Contributors +// SPDX-License-Identifier: Apache-2.0 + +use super::*; +use tempfile::tempdir; + +#[test] +fn test_put_get() { + let tmp_dir = tempdir().unwrap(); + let db = ConsensusDB::new(&tmp_dir); + + let block = Block::::make_genesis_block(); + let blocks = vec![block]; + + let old_blocks = db.get_blocks::().unwrap(); + assert!(db.get_state().unwrap().is_none()); + assert_eq!(old_blocks.len(), 0); + assert_eq!(db.get_quorum_certificates().unwrap().len(), 0); + + db.save_state(vec![0x01, 0x02, 0x03]).unwrap(); + + let qcs = vec![QuorumCert::certificate_for_genesis()]; + + db.save_blocks_and_quorum_certificates(blocks, qcs).unwrap(); + + assert_eq!(db.get_blocks::().unwrap().len(), 1); + assert_eq!(db.get_quorum_certificates().unwrap().len(), 1); + assert!(!db.get_state().unwrap().is_none()); +} + +#[test] +fn test_delete_block_and_qc() { + let tmp_dir = tempdir().unwrap(); + let db = ConsensusDB::new(&tmp_dir); + + assert!(db.get_state().unwrap().is_none()); + assert_eq!(db.get_blocks::().unwrap().len(), 0); + assert_eq!(db.get_quorum_certificates().unwrap().len(), 0); + + let blocks = vec![Block::::make_genesis_block()]; + let block_id = blocks[0].id(); + + let qcs = vec![QuorumCert::certificate_for_genesis()]; + let qc_id = qcs[0].certified_block_id(); + + db.save_blocks_and_quorum_certificates(blocks, qcs).unwrap(); + assert_eq!(db.get_blocks::().unwrap().len(), 1); + assert_eq!(db.get_quorum_certificates().unwrap().len(), 1); + + // Start to delete + db.delete_blocks_and_quorum_certificates::(vec![block_id, qc_id]) + .unwrap(); + assert_eq!(db.get_blocks::().unwrap().len(), 0); + assert_eq!(db.get_quorum_certificates().unwrap().len(), 0); +} diff --git a/consensus/src/chained_bft/consensusdb/mod.rs b/consensus/src/chained_bft/consensusdb/mod.rs new file mode 100644 index 0000000000000..e842234aae418 --- /dev/null +++ b/consensus/src/chained_bft/consensusdb/mod.rs @@ -0,0 +1,179 @@ +// Copyright (c) The Libra Core Contributors +// SPDX-License-Identifier: Apache-2.0 + +#[cfg(test)] +mod consensusdb_test; +mod schema; + +use crate::chained_bft::{ + common::Payload, + consensus_types::{block::Block, quorum_cert::QuorumCert}, + consensusdb::schema::{ + block::BlockSchema, + quorum_certificate::QCSchema, + single_entry::{SingleEntryKey, SingleEntrySchema}, + }, +}; +use crypto::HashValue; +use failure::prelude::*; +use logger::prelude::*; +use schema::{BLOCK_CF_NAME, QC_CF_NAME, SINGLE_ENTRY_CF_NAME}; +use schemadb::{ + ColumnFamilyOptions, ColumnFamilyOptionsMap, ReadOptions, SchemaBatch, DB, DEFAULT_CF_NAME, +}; +use std::{collections::HashMap, iter::Iterator, path::Path, time::Instant}; + +type HighestTimeoutCertificates = Vec; +type ConsensusStateData = Vec; + +pub struct ConsensusDB { + db: DB, +} + +impl ConsensusDB { + pub fn new + Clone>(db_root_path: P) -> Self { + let cf_opts_map: ColumnFamilyOptionsMap = [ + ( + /* UNUSED CF = */ DEFAULT_CF_NAME, + ColumnFamilyOptions::default(), + ), + (BLOCK_CF_NAME, ColumnFamilyOptions::default()), + (QC_CF_NAME, ColumnFamilyOptions::default()), + (SINGLE_ENTRY_CF_NAME, ColumnFamilyOptions::default()), + ] + .iter() + .cloned() + .collect(); + + let path = db_root_path.as_ref().join("consensusdb"); + let instant = Instant::now(); + let db = DB::open(path.clone(), cf_opts_map).unwrap_or_else(|e| { + panic!("ConsensusDB open failed due to {:?}, unable to continue", e) + }); + + info!( + "Opened ConsensusDB at {:?} in {} ms", + path, + instant.elapsed().as_millis() + ); + + Self { db } + } + + pub fn get_data( + &self, + ) -> Result<( + Option, + Option, + Vec>, + Vec, + )> { + let consensus_state = self.get_state()?; + let highest_timeout_certificates = self.get_highest_timeout_certificates()?; + self.db + .get::(&SingleEntryKey::ConsensusState)?; + let consensus_blocks = self + .get_blocks()? + .into_iter() + .map(|(_block_hash, block_content)| block_content) + .collect::>(); + let consensus_qcs = self + .get_quorum_certificates()? + .into_iter() + .map(|(_block_hash, qc)| qc) + .collect::>(); + Ok(( + consensus_state, + highest_timeout_certificates, + consensus_blocks, + consensus_qcs, + )) + } + + pub fn save_highest_timeout_certificates( + &self, + highest_timeout_certificates: HighestTimeoutCertificates, + ) -> Result<()> { + let mut batch = SchemaBatch::new(); + batch.put::( + &SingleEntryKey::HighestTimeoutCertificates, + &highest_timeout_certificates, + )?; + self.commit(batch) + } + + pub fn save_state(&self, state: ConsensusStateData) -> Result<()> { + let mut batch = SchemaBatch::new(); + batch.put::(&SingleEntryKey::ConsensusState, &state)?; + self.commit(batch) + } + + pub fn save_blocks_and_quorum_certificates( + &self, + block_data: Vec>, + qc_data: Vec, + ) -> Result<()> { + ensure!( + !block_data.is_empty() || !qc_data.is_empty(), + "Consensus block and qc data is empty!" + ); + let mut batch = SchemaBatch::new(); + block_data + .iter() + .map(|block| batch.put::>(&block.id(), block)) + .collect::>()?; + qc_data + .iter() + .map(|qc| batch.put::(&qc.certified_block_id(), qc)) + .collect::>()?; + self.commit(batch) + } + + pub fn delete_blocks_and_quorum_certificates( + &self, + block_ids: Vec, + ) -> Result<()> { + ensure!(!block_ids.is_empty(), "Consensus block ids is empty!"); + let mut batch = SchemaBatch::new(); + block_ids + .iter() + .map(|hash| { + batch.delete::>(hash)?; + batch.delete::(hash) + }) + .collect::>()?; + self.commit(batch) + } + + /// Write the whole schema batch including all data necessary to mutate the ledge + /// state of some transaction by leveraging rocksdb atomicity support. + fn commit(&self, batch: SchemaBatch) -> Result<()> { + self.db.write_schemas(batch) + } + + /// Get latest timeout certificates (we only store the latest highest timeout certificates). + fn get_highest_timeout_certificates(&self) -> Result>> { + self.db + .get::(&SingleEntryKey::HighestTimeoutCertificates) + } + + /// Get latest consensus state (we only store the latest state). + fn get_state(&self) -> Result>> { + self.db + .get::(&SingleEntryKey::ConsensusState) + } + + /// Get all consensus blocks. + fn get_blocks(&self) -> Result>> { + let mut iter = self.db.iter::>(ReadOptions::default())?; + iter.seek_to_first(); + iter.collect::>>>() + } + + /// Get all consensus QCs. + fn get_quorum_certificates(&self) -> Result> { + let mut iter = self.db.iter::(ReadOptions::default())?; + iter.seek_to_first(); + iter.collect::>>() + } +} diff --git a/consensus/src/chained_bft/consensusdb/schema/block/mod.rs b/consensus/src/chained_bft/consensusdb/schema/block/mod.rs new file mode 100644 index 0000000000000..8c42d9b68f0ef --- /dev/null +++ b/consensus/src/chained_bft/consensusdb/schema/block/mod.rs @@ -0,0 +1,51 @@ +// Copyright (c) The Libra Core Contributors +// SPDX-License-Identifier: Apache-2.0 + +//! This module defines physical storage schema for consensus block. +//! +//! Serialized block bytes identified by block_hash. +//! ```text +//! |<---key---->|<---value--->| +//! | block_hash | block | +//! ``` + +use super::BLOCK_CF_NAME; +use crate::chained_bft::{common::Payload, consensus_types::block::Block}; +use crypto::HashValue; +use failure::prelude::*; +use proto_conv::{FromProtoBytes, IntoProtoBytes}; +use schemadb::schema::{KeyCodec, Schema, ValueCodec}; +use std::marker::PhantomData; + +pub struct BlockSchema { + phantom: PhantomData, +} + +impl Schema for BlockSchema { + const COLUMN_FAMILY_NAME: schemadb::ColumnFamilyName = BLOCK_CF_NAME; + type Key = HashValue; + type Value = Block; +} + +impl KeyCodec> for HashValue { + fn encode_key(&self) -> Result> { + Ok(self.to_vec()) + } + + fn decode_key(data: &[u8]) -> Result { + Ok(HashValue::from_slice(data)?) + } +} + +impl ValueCodec> for Block { + fn encode_value(&self) -> Result> { + Ok(self.clone().into_proto_bytes()?) + } + + fn decode_value(data: &[u8]) -> Result { + Ok(Self::from_proto_bytes(data)?) + } +} + +#[cfg(test)] +mod test; diff --git a/consensus/src/chained_bft/consensusdb/schema/block/test.rs b/consensus/src/chained_bft/consensusdb/schema/block/test.rs new file mode 100644 index 0000000000000..d1517faaa097f --- /dev/null +++ b/consensus/src/chained_bft/consensusdb/schema/block/test.rs @@ -0,0 +1,10 @@ +// Copyright (c) The Libra Core Contributors +// SPDX-License-Identifier: Apache-2.0 + +use super::*; +use schemadb::schema::assert_encode_decode; + +#[test] +fn test_encode_decode() { + assert_encode_decode::>(&HashValue::random(), &Block::make_genesis_block()); +} diff --git a/consensus/src/chained_bft/consensusdb/schema/mod.rs b/consensus/src/chained_bft/consensusdb/schema/mod.rs new file mode 100644 index 0000000000000..44582f949a6ff --- /dev/null +++ b/consensus/src/chained_bft/consensusdb/schema/mod.rs @@ -0,0 +1,23 @@ +// Copyright (c) The Libra Core Contributors +// SPDX-License-Identifier: Apache-2.0 + +pub(crate) mod block; +pub(crate) mod quorum_certificate; +pub(crate) mod single_entry; + +use failure::prelude::*; +use schemadb::ColumnFamilyName; + +pub(super) const BLOCK_CF_NAME: ColumnFamilyName = "block"; +pub(super) const QC_CF_NAME: ColumnFamilyName = "quorum_certificate"; +pub(super) const SINGLE_ENTRY_CF_NAME: ColumnFamilyName = "single_entry"; + +fn ensure_slice_len_eq(data: &[u8], len: usize) -> Result<()> { + ensure!( + data.len() == len, + "Unexpected data len {}, expected {}.", + data.len(), + len, + ); + Ok(()) +} diff --git a/consensus/src/chained_bft/consensusdb/schema/quorum_certificate/mod.rs b/consensus/src/chained_bft/consensusdb/schema/quorum_certificate/mod.rs new file mode 100644 index 0000000000000..8dba66060e1ab --- /dev/null +++ b/consensus/src/chained_bft/consensusdb/schema/quorum_certificate/mod.rs @@ -0,0 +1,45 @@ +// Copyright (c) The Libra Core Contributors +// SPDX-License-Identifier: Apache-2.0 + +//! This module defines physical storage schema for consensus quorum certificate (of a block). +//! +//! Serialized quorum certificate bytes identified by block_hash. +//! ```text +//! |<---key---->|<----value--->| +//! | block_hash | QuorumCert | +//! ``` + +use super::QC_CF_NAME; +use crate::chained_bft::consensus_types::quorum_cert::QuorumCert; +use crypto::HashValue; +use failure::prelude::*; +use proto_conv::{FromProtoBytes, IntoProtoBytes}; +use schemadb::{ + define_schema, + schema::{KeyCodec, ValueCodec}, +}; + +define_schema!(QCSchema, HashValue, QuorumCert, QC_CF_NAME); + +impl KeyCodec for HashValue { + fn encode_key(&self) -> Result> { + Ok(self.to_vec()) + } + + fn decode_key(data: &[u8]) -> Result { + Ok(HashValue::from_slice(data)?) + } +} + +impl ValueCodec for QuorumCert { + fn encode_value(&self) -> Result> { + self.clone().into_proto_bytes() + } + + fn decode_value(data: &[u8]) -> Result { + Self::from_proto_bytes(data) + } +} + +#[cfg(test)] +mod test; diff --git a/consensus/src/chained_bft/consensusdb/schema/quorum_certificate/test.rs b/consensus/src/chained_bft/consensusdb/schema/quorum_certificate/test.rs new file mode 100644 index 0000000000000..b5110e1253b53 --- /dev/null +++ b/consensus/src/chained_bft/consensusdb/schema/quorum_certificate/test.rs @@ -0,0 +1,11 @@ +// Copyright (c) The Libra Core Contributors +// SPDX-License-Identifier: Apache-2.0 + +use super::*; +use schemadb::schema::assert_encode_decode; + +#[test] +fn test_encode_decode() { + let qc = QuorumCert::certificate_for_genesis(); + assert_encode_decode::(&qc.certified_block_id(), &qc); +} diff --git a/consensus/src/chained_bft/consensusdb/schema/single_entry/mod.rs b/consensus/src/chained_bft/consensusdb/schema/single_entry/mod.rs new file mode 100644 index 0000000000000..c99056470a4df --- /dev/null +++ b/consensus/src/chained_bft/consensusdb/schema/single_entry/mod.rs @@ -0,0 +1,65 @@ +// Copyright (c) The Libra Core Contributors +// SPDX-License-Identifier: Apache-2.0 + +//! This module defines physical storage schema for any single-entry data. +//! +//! There will be only one row in this column family for each type of data. +//! The key will be a serialized enum type designating the data type and should not have any meaning +//! and be used. ```text +//! |<-------key------->|<-----value----->| +//! | single entry key | raw value bytes | +//! ``` + +use super::{ensure_slice_len_eq, SINGLE_ENTRY_CF_NAME}; +use byteorder::ReadBytesExt; +use failure::prelude::*; +use num_derive::{FromPrimitive, ToPrimitive}; +use num_traits::{FromPrimitive, ToPrimitive}; +use schemadb::{ + define_schema, + schema::{KeyCodec, ValueCodec}, +}; +use std::mem::size_of; + +define_schema!( + SingleEntrySchema, + SingleEntryKey, + Vec, + SINGLE_ENTRY_CF_NAME +); + +#[derive(Debug, Eq, PartialEq, FromPrimitive, ToPrimitive)] +#[repr(u8)] +pub enum SingleEntryKey { + // Used to store ConsensusState + ConsensusState = 0, + // Used to store the highest timeout certificates + HighestTimeoutCertificates = 1, +} + +impl KeyCodec for SingleEntryKey { + fn encode_key(&self) -> Result> { + Ok(vec![self + .to_u8() + .ok_or_else(|| format_err!("ToPrimitive failed."))?]) + } + + fn decode_key(data: &[u8]) -> Result { + ensure_slice_len_eq(data, size_of::())?; + let key = (&data[..]).read_u8()?; + SingleEntryKey::from_u8(key).ok_or_else(|| format_err!("FromPrimitive failed.")) + } +} + +impl ValueCodec for Vec { + fn encode_value(&self) -> Result> { + Ok(self.clone()) + } + + fn decode_value(data: &[u8]) -> Result { + Ok(data.to_vec()) + } +} + +#[cfg(test)] +mod test; diff --git a/consensus/src/chained_bft/consensusdb/schema/single_entry/test.rs b/consensus/src/chained_bft/consensusdb/schema/single_entry/test.rs new file mode 100644 index 0000000000000..a96a21cfc3165 --- /dev/null +++ b/consensus/src/chained_bft/consensusdb/schema/single_entry/test.rs @@ -0,0 +1,13 @@ +// Copyright (c) The Libra Core Contributors +// SPDX-License-Identifier: Apache-2.0 + +use super::*; +use schemadb::schema::assert_encode_decode; + +#[test] +fn test_single_entry_schema() { + assert_encode_decode::( + &SingleEntryKey::ConsensusState, + &vec![1u8, 2u8, 3u8], + ); +} diff --git a/consensus/src/chained_bft/event_processor.rs b/consensus/src/chained_bft/event_processor.rs new file mode 100644 index 0000000000000..777ca4592be58 --- /dev/null +++ b/consensus/src/chained_bft/event_processor.rs @@ -0,0 +1,804 @@ +// Copyright (c) The Libra Core Contributors +// SPDX-License-Identifier: Apache-2.0 + +#[cfg(test)] +use crate::chained_bft::safety::safety_rules::ConsensusState; +use crate::{ + chained_bft::{ + block_storage::{BlockReader, BlockStore, NeedFetchResult, VoteReceptionResult}, + common::{Author, Payload, Round}, + consensus_types::block::Block, + liveness::{ + new_round_msg::{NewRoundMsg, PacemakerTimeout}, + pacemaker::{NewRoundEvent, NewRoundReason, Pacemaker, PacemakerEvent}, + proposal_generator::ProposalGenerator, + proposer_election::{ProposalInfo, ProposerElection, ProposerInfo}, + }, + network::{ + BlockRetrievalRequest, BlockRetrievalResponse, ChunkRetrievalRequest, + ConsensusNetworkImpl, + }, + persistent_storage::PersistentStorage, + safety::{safety_rules::SafetyRules, vote_msg::VoteMsg}, + sync_manager::{SyncInfo, SyncManager}, + }, + counters, + state_replication::{StateComputer, TxnManager}, + time_service::{ + duration_since_epoch, wait_if_possible, TimeService, WaitingError, WaitingSuccess, + }, +}; +use crypto::HashValue; +use futures::{channel::mpsc, SinkExt}; +use logger::prelude::*; +use network::proto::BlockRetrievalStatus; +use std::{ + sync::{Arc, RwLock}, + time::{Duration, Instant}, +}; +use termion::color::*; +use types::ledger_info::LedgerInfoWithSignatures; + +/// Result of initial proposal processing +/// NeedFetch means separate task mast be spawned for fetching block +/// Caller should call fetch_and_process_proposal in separate task when NeedFetch is returned +pub enum ProcessProposalResult { + Done, + NeedFetch(Instant, ProposalInfo), + NeedSync(Instant, ProposalInfo), +} + +/// Consensus SMR is working in an event based fashion: EventProcessor is responsible for +/// processing the individual events (e.g., process_new_round, process_proposal, process_vote, +/// etc.). It is exposing the async processing functions for each event type. +/// The caller is responsible for running the event loops and driving the execution via some +/// executors. +pub struct EventProcessor { + author: P, + block_store: Arc>, + pacemaker: Arc, + proposer_election: Arc + Send + Sync>, + pm_events_sender: mpsc::Sender, + proposal_candidates_sender: mpsc::Sender>, + proposal_generator: ProposalGenerator, + safety_rules: Arc>>, + state_computer: Arc>, + txn_manager: Arc>, + network: ConsensusNetworkImpl, + storage: Arc>, + sync_manager: SyncManager, + time_service: Arc, + enforce_increasing_timestamps: bool, +} + +impl EventProcessor { + pub fn new( + author: P, + block_store: Arc>, + pacemaker: Arc, + proposer_election: Arc + Send + Sync>, + pm_events_sender: mpsc::Sender, + proposal_candidates_sender: mpsc::Sender>, + proposal_generator: ProposalGenerator, + safety_rules: Arc>>, + state_computer: Arc>, + txn_manager: Arc>, + network: ConsensusNetworkImpl, + storage: Arc>, + time_service: Arc, + enforce_increasing_timestamps: bool, + ) -> Self { + let sync_manager = SyncManager::new( + Arc::clone(&block_store), + Arc::clone(&storage), + network.clone(), + Arc::clone(&state_computer), + ); + Self { + author, + block_store, + pacemaker, + proposer_election, + pm_events_sender, + proposal_candidates_sender, + proposal_generator, + safety_rules, + state_computer, + txn_manager, + network, + storage, + sync_manager, + time_service, + enforce_increasing_timestamps, + } + } + + /// Leader: + /// + /// This event is triggered by a new quorum certificate at the previous round or a + /// timeout certificate at the previous round. In either case, if this replica is the new + /// proposer for this round, it is ready to propose and guarantee that it can create a proposal + /// that all honest replicas can vote for. While this method should only be invoked at most + /// once per round, we ensure that only at most one proposal can get generated per round to + /// avoid accidental equivocation of proposals. + /// + /// Replica: + /// + /// Do nothing + pub async fn process_new_round_event(&self, new_round_event: NewRoundEvent) { + debug!("Processing {:?}", new_round_event); + counters::CURRENT_ROUND.set(new_round_event.round as i64); + counters::ROUND_TIMEOUT_MS.set(new_round_event.timeout.as_millis() as i64); + match new_round_event.reason { + NewRoundReason::QCReady => { + counters::QC_ROUNDS_COUNT.inc(); + } + NewRoundReason::Timeout { .. } => { + counters::TIMEOUT_ROUNDS_COUNT.inc(); + } + }; + let proposer_info = match self + .proposer_election + .is_valid_proposer(self.author, new_round_event.round) + { + Some(pi) => pi, + None => { + return; + } + }; + + // Proposal generator will ensure that at most one proposal is generated per round + let proposal = match self + .proposal_generator + .generate_proposal( + new_round_event.round, + self.pacemaker.current_round_deadline(), + ) + .await + { + Err(e) => { + error!("Error while generating proposal: {:?}", e); + return; + } + Ok(proposal) => proposal, + }; + let mut network = self.network.clone(); + debug!("Propose {}", proposal); + let timeout_certificate = match new_round_event.reason { + NewRoundReason::Timeout { cert } => Some(cert), + _ => None, + }; + let highest_ledger_info = (*self.block_store.highest_ledger_info()).clone(); + network + .broadcast_proposal(ProposalInfo { + proposal, + proposer_info, + timeout_certificate, + highest_ledger_info, + }) + .await; + counters::PROPOSALS_COUNT.inc(); + } + + /// The function is responsible for processing the incoming proposals and the Quorum + /// Certificate. 1. commit to the committed state the new QC carries + /// 2. fetch all the blocks from the committed state to the QC + /// 3. forwarding the proposals to the ProposerElection queue, + /// which is going to eventually trigger one winning proposal per round + /// (to be processed via a separate function). + /// The reason for separating `process_proposal` from `process_winning_proposal` is to + /// (a) asynchronously prefetch dependencies and + /// (b) allow the proposer election to choose one proposal out of many. + pub async fn process_proposal( + &self, + proposal: ProposalInfo, + ) -> ProcessProposalResult { + debug!("Receive proposal {}", proposal); + let qc = proposal.proposal.quorum_cert(); + + self.pacemaker + .process_certificates_from_proposal( + qc.certified_block_round(), + proposal.timeout_certificate.as_ref(), + ) + .await; + + if self.pacemaker.current_round() != proposal.proposal.round() { + if self.pacemaker.current_round() < proposal.proposal.round() { + warn!( + "Received proposal {} is ignored as it is from a future round {} and does not match the pacemaker round {}", + proposal, + proposal.proposal.round(), + self.pacemaker.current_round(), + ); + } else { + warn!( + "Received proposal {} is ignored as it is from a past round {} and does not match the pacemaker round {}", + proposal, + proposal.proposal.round(), + self.pacemaker.current_round(), + ); + } + return ProcessProposalResult::Done; + } + + if self + .proposer_election + .is_valid_proposer(proposal.proposer_info, proposal.proposal.round()) + .is_none() + { + warn!( + "Proposer {} for block {} is not a valid proposer for this round", + proposal.proposal.author(), + proposal.proposal + ); + return ProcessProposalResult::Done; + } + + let deadline = self.pacemaker.current_round_deadline(); + if let Some(committed_block_id) = proposal.highest_ledger_info.committed_block_id() { + if self + .block_store + .need_sync_for_quorum_cert(committed_block_id, &proposal.highest_ledger_info) + { + return ProcessProposalResult::NeedSync(deadline, proposal); + } + } else { + warn!("Highest ledger info {} has no committed block", proposal); + return ProcessProposalResult::Done; + } + + match self.block_store.need_fetch_for_quorum_cert(&qc) { + NeedFetchResult::NeedFetch => { + return ProcessProposalResult::NeedFetch(deadline, proposal) + } + NeedFetchResult::QCRoundBeforeRoot => { + warn!("Proposal {} has a highest quorum certificate with round older than root round {}", proposal, self.block_store.root().round()); + return ProcessProposalResult::Done; + } + NeedFetchResult::QCBlockExist => { + if let Err(e) = self.block_store.insert_single_quorum_cert(qc.clone()).await { + warn!( + "Quorum certificate for proposal {} could not be inserted to the block store: {:?}", + proposal, e + ); + return ProcessProposalResult::Done; + } + } + NeedFetchResult::QCAlreadyExist => (), + } + + self.finish_proposal_processing(proposal).await; + ProcessProposalResult::Done + } + + /// Finish proposal processing: note that multiple tasks can execute this function in parallel + /// so be careful with the updates. The safest thing to do is to pass the proposal further + /// to the proposal election. + async fn finish_proposal_processing(&self, proposal: ProposalInfo) { + let mut sender = self.proposal_candidates_sender.clone(); + if sender.send(proposal).await.is_err() { + error!("Error sending the received proposal to proposal election."); + } + } + + /// Fetches and completes processing proposal in dedicated task + pub async fn fetch_and_process_proposal( + &self, + deadline: Instant, + proposal: ProposalInfo, + ) { + if let Err(e) = self + .sync_manager + .fetch_quorum_cert( + proposal.proposal.quorum_cert().clone(), + proposal.proposer_info.get_author(), + deadline, + ) + .await + { + warn!( + "Quorum certificate for proposal {} could not be added to the block store: {:?}", + proposal, e + ); + return; + } + self.finish_proposal_processing(proposal).await; + } + + /// Takes mutable reference to avoid race with other processing and perform state + /// synchronization, then completes processing proposal in dedicated task + pub async fn sync_and_process_proposal( + &mut self, + deadline: Instant, + proposal: ProposalInfo, + ) { + // check if we still need sync + if let Err(e) = self + .sync_manager + .sync_to( + deadline, + SyncInfo { + highest_ledger_info: proposal.highest_ledger_info.clone(), + highest_quorum_cert: proposal.proposal.quorum_cert().clone(), + peer: proposal.proposer_info.get_author(), + }, + ) + .await + { + warn!( + "Quorum certificate for proposal {} could not be added to the block store: {:?}", + proposal, e + ); + return; + } + self.finish_proposal_processing(proposal).await; + } + + /// Upon receiving NewRoundMsg, ensure that any branches with higher quorum certificates are + /// populated to this replica prior to processing the pacemaker timeout. This ensures that when + /// a pacemaker timeout certificate is formed with 2f+1 timeouts, the next proposer will be + /// able to chain a proposal block to a highest quorum certificate such that all honest replicas + /// can vote for it. + pub async fn process_new_round_msg(&mut self, new_round_msg: NewRoundMsg) { + debug!( + "Received a new round msg for round {} from {}", + new_round_msg.pacemaker_timeout().round(), + new_round_msg.author() + ); + let deadline = self.pacemaker.current_round_deadline(); + let current_highest_quorum_cert_round = self + .block_store + .highest_quorum_cert() + .certified_block_round(); + let new_round_highest_quorum_cert_round = new_round_msg + .highest_quorum_certificate() + .certified_block_round(); + + if current_highest_quorum_cert_round >= new_round_highest_quorum_cert_round { + return; + } + + match self + .sync_manager + .sync_to( + deadline, + SyncInfo { + highest_ledger_info: new_round_msg.highest_ledger_info().clone(), + highest_quorum_cert: new_round_msg.highest_quorum_certificate().clone(), + peer: new_round_msg.author(), + }, + ) + .await + { + Ok(()) => debug!( + "Successfully added new highest quorum certificate at round {} from old round {}", + new_round_highest_quorum_cert_round, current_highest_quorum_cert_round + ), + Err(e) => warn!( + "Unable to insert new highest quorum certificate {} from old round {} due to {:?}", + new_round_msg.highest_quorum_certificate(), + current_highest_quorum_cert_round, + e + ), + } + } + + /// The replica stops voting for this round and saves its consensus state. Voting is halted + /// to ensure that the next proposer can make a proposal that can be voted on by all replicas. + /// Saving the consensus state ensures that on restart, the replicas will not waste time + /// on previous rounds. + pub async fn process_outgoing_pacemaker_timeout(&self, round: Round) -> Option { + // Stop voting at this round, persist the consensus state to support restarting from + // a recent round (i.e. > the last vote round) and then send the highest quorum + // certificate known + let consensus_state = self + .safety_rules + .write() + .unwrap() + .increase_last_vote_round(round); + if let Some(consensus_state) = consensus_state { + if let Err(e) = self.storage.save_consensus_state(consensus_state) { + error!("Failed to persist consensus state after increasing the last vote round due to {:?}", e); + return None; + } + } + debug!( + "Sending new round message at round {} due to timeout and will not vote at this round", + round + ); + + Some(NewRoundMsg::new( + self.block_store.highest_quorum_cert().as_ref().clone(), + self.block_store.highest_ledger_info().as_ref().clone(), + PacemakerTimeout::new(round, self.block_store.signer()), + self.block_store.signer(), + )) + } + + /// This function processes a proposal that was chosen as a representative of its round: + /// 1. Add it to a block store. + /// 2. Try to vote for it following the safety rules. + /// 3. In case a validator chooses to vote, send the vote to the representatives at the next + /// position. + pub async fn process_winning_proposal(&self, proposal: ProposalInfo) { + let qc = proposal.proposal.quorum_cert(); + let update_res = self.safety_rules.write().unwrap().update(qc); + if let Some(new_commit) = update_res { + let finality_proof = qc.ledger_info().clone(); + self.process_commit(new_commit, finality_proof).await; + } + + if let Some(time_to_receival) = duration_since_epoch() + .checked_sub(Duration::from_micros(proposal.proposal.timestamp_usecs())) + { + counters::CREATION_TO_RECEIVAL_MS.observe(time_to_receival.as_millis() as f64); + } + let block = match self + .block_store + .execute_and_insert_block(proposal.proposal) + .await + { + Err(e) => { + debug!( + "Block proposal could not be added to the block store: {:?}", + e + ); + return; + } + Ok(block) => block, + }; + + // Checking pacemaker round again, because multiple proposal can now race + // during async block retrieval + if self.pacemaker.current_round() != block.round() { + debug!( + "Skip voting for winning proposal {} rejected because round is incorrect. Pacemaker: {}, proposal: {}", + block, + self.pacemaker.current_round(), + block.round() + ); + return; + } + + let current_round_deadline = self.pacemaker.current_round_deadline(); + if self.enforce_increasing_timestamps { + match wait_if_possible( + self.time_service.as_ref(), + Duration::from_micros(block.timestamp_usecs()), + current_round_deadline, + ) + .await + { + Ok(waiting_success) => { + debug!("Success with {:?} for being able to vote", waiting_success); + + match waiting_success { + WaitingSuccess::WaitWasRequired { wait_duration, .. } => { + counters::VOTE_SUCCESS_WAIT_MS + .observe(wait_duration.as_millis() as f64); + counters::VOTE_WAIT_WAS_REQUIRED_COUNT.inc(); + } + WaitingSuccess::NoWaitRequired { .. } => { + counters::VOTE_SUCCESS_WAIT_MS.observe(0.0); + counters::VOTE_NO_WAIT_REQUIRED_COUNT.inc(); + } + } + } + Err(waiting_error) => { + match waiting_error { + WaitingError::MaxWaitExceeded => { + error!( + "Waiting until proposal block timestamp usecs {:?} would exceed the round duration {:?}, hence will not vote for this round", + block.timestamp_usecs(), + current_round_deadline); + counters::VOTE_FAILURE_WAIT_MS.observe(0.0); + counters::VOTE_MAX_WAIT_EXCEEDED_COUNT.inc(); + return; + } + WaitingError::WaitFailed { + current_duration_since_epoch, + wait_duration, + } => { + error!( + "Even after waiting for {:?}, proposal block timestamp usecs {:?} >= current timestamp usecs {:?}, will not vote for this round", + wait_duration, + block.timestamp_usecs(), + current_duration_since_epoch); + counters::VOTE_FAILURE_WAIT_MS + .observe(wait_duration.as_millis() as f64); + counters::VOTE_WAIT_FAILED_COUNT.inc(); + return; + } + }; + } + } + } + + let vote_info = match self + .safety_rules + .write() + .unwrap() + .voting_rule(Arc::clone(&block)) + { + Err(e) => { + debug!("{}Rejected{} {}: {:?}", Fg(Red), Fg(Reset), block, e); + return; + } + Ok(vote_info) => vote_info, + }; + if let Err(e) = self + .storage + .save_consensus_state(vote_info.consensus_state().clone()) + { + debug!("Fail to persist consensus state: {:?}", e); + return; + } + let proposal_id = vote_info.proposal_id(); + let executed_state = self + .block_store + .get_state_for_block(proposal_id) + .expect("Block proposal: no execution state found for inserted block."); + + let ledger_info_placeholder = self + .block_store + .ledger_info_placeholder(vote_info.potential_commit_id()); + let vote_msg = VoteMsg::new( + proposal_id, + executed_state, + block.round(), + self.author.get_author(), + ledger_info_placeholder, + self.block_store.signer(), + ); + + let recipients: Vec = self + .proposer_election + .get_valid_proposers(block.round() + 1) + .iter() + .map(ProposerInfo::get_author) + .collect(); + debug!( + "{}Voted for{} {}, potential commit {}", + Fg(Green), + Fg(Reset), + block, + vote_info + .potential_commit_id() + .unwrap_or_else(HashValue::zero) + ); + self.network.send_vote(vote_msg, recipients).await; + } + + /// Upon new vote: + /// 1. Filter out votes for rounds that should not be processed by this validator (to avoid + /// potential attacks). + /// 2. Add the vote to the store and check whether it finishes a QC. + /// 3. Once the QC successfully formed, notify the Pacemaker. + #[allow(clippy::collapsible_if)] // Collapsing here would make if look ugly + pub async fn process_vote(&self, vote: VoteMsg, quorum_size: usize) { + // Check whether this validator is a valid recipient of the vote. + let next_round = vote.round() + 1; + if self + .proposer_election + .is_valid_proposer(self.author, next_round) + .is_none() + { + debug!( + "Received {}, but I am not a valid proposer for round {}, ignore.", + vote, next_round + ); + security_log(SecurityEvent::InvalidConsensusVote) + .error("InvalidProposer") + .data(vote) + .data(next_round) + .log(); + return; + } + + let deadline = self.pacemaker.current_round_deadline(); + // TODO [Reconfiguration] Verify epoch of the vote message. + // Add the vote and check whether it completes a new QC. + match self + .block_store + .insert_vote(vote.clone(), quorum_size) + .await + { + VoteReceptionResult::DuplicateVote => { + // This should not happen in general. + security_log(SecurityEvent::DuplicateConsensusVote) + .error(VoteReceptionResult::DuplicateVote) + .data(vote) + .log(); + return; + } + VoteReceptionResult::NewQuorumCertificate(qc) => { + if self.block_store.need_fetch_for_quorum_cert(&qc) == NeedFetchResult::NeedFetch { + if let Err(e) = self + .sync_manager + .fetch_quorum_cert(qc.as_ref().clone(), vote.author(), deadline) + .await + { + error!("Error syncing to qc {}: {:?}", qc, e); + return; + } + } else { + if let Err(e) = self + .block_store + .insert_single_quorum_cert(qc.as_ref().clone()) + .await + { + error!("Error inserting qc {}: {:?}", qc, e); + return; + } + } + // Notify the Pacemaker. + let mut pm_events_sender = self.pm_events_sender.clone(); + if let Err(e) = pm_events_sender + .send(PacemakerEvent::QuorumCertified { + round: vote.round(), + }) + .await + { + error!("Delivering pacemaker event failed: {:?}", e); + } + } + // nothing interesting with votes arriving for the QC that has been formed + _ => { + return; + } + }; + } + + /// Upon new commit: + /// 1. Notify state computer with the finality proof. + /// 2. After the state is finalized, update the txn manager with the status of the committed + /// transactions. + /// 3. Prune the tree. + async fn process_commit( + &self, + committed_block: Arc>, + finality_proof: LedgerInfoWithSignatures, + ) { + // Verify that the ledger info is indeed for the block we're planning to + // commit. + assert_eq!( + finality_proof.ledger_info().consensus_block_id(), + committed_block.id() + ); + + // Update the pacemaker with the highest committed round so that on the next round + // duration it calculates, the initial round index is reset + self.pacemaker + .update_highest_committed_round(committed_block.round()); + + if let Err(e) = self.state_computer.commit(finality_proof).await { + // We assume that state computer cannot enter an inconsistent state that might + // violate safety of the protocol. Specifically, an executor service is going to panic + // if it fails to persist the commit requests, which would crash the whole process + // including consensus. + error!( + "Failed to persist commit, mempool will not be notified: {:?}", + e + ); + return; + } + // At this moment the new state is persisted and we can notify the clients. + // Multiple blocks might be committed at once: notify about all the transactions in the + // path from the old root to the new root. + for committed in self + .block_store + .path_from_root(Arc::clone(&committed_block)) + .unwrap_or_else(Vec::new) + { + if let Some(time_to_commit) = duration_since_epoch() + .checked_sub(Duration::from_micros(committed.timestamp_usecs())) + { + counters::CREATION_TO_COMMIT_MS.observe(time_to_commit.as_millis() as f64); + } + let compute_result = self + .block_store + .get_compute_result(committed.id()) + .expect("Compute result of a pending block is unknown"); + if let Err(e) = self + .txn_manager + .commit_txns( + committed.get_payload(), + compute_result.as_ref(), + committed.timestamp_usecs(), + ) + .await + { + error!("Failed to notify mempool: {:?}", e); + } + } + counters::LAST_COMMITTED_ROUND.set(committed_block.round() as i64); + debug!("{}Committed{} {}", Fg(Blue), Fg(Reset), *committed_block); + self.block_store.prune_tree(committed_block.id()).await; + } + + /// Retrieve a n chained blocks from the block store starting from + /// an initial parent id, returning with ) { + let mut blocks = vec![]; + let mut status = BlockRetrievalStatus::SUCCEEDED; + let mut id = request.block_id; + while (blocks.len() as u64) < request.num_blocks { + if let Some(block) = self.block_store.get_block(id) { + id = block.parent_id(); + blocks.push(Block::clone(block.as_ref())); + } else { + status = BlockRetrievalStatus::NOT_ENOUGH_BLOCKS; + break; + } + } + + if blocks.is_empty() { + status = BlockRetrievalStatus::ID_NOT_FOUND; + } + + if let Err(e) = request + .response_sender + .send(BlockRetrievalResponse { status, blocks }) + { + error!("Failed to return the requested block: {:?}", e); + } + } + + /// Retrieve the chunk from storage and send it back. + /// We'll also try to add the QuorumCert into block store if it's for a existing block and + /// potentially commit. + pub async fn process_chunk_retrieval(&self, request: ChunkRetrievalRequest) { + if self + .block_store + .block_exists(request.target.certified_block_id()) + && self + .block_store + .get_quorum_cert_for_block(request.target.certified_block_id()) + .is_none() + { + if let Err(e) = self + .block_store + .insert_single_quorum_cert(request.target.clone()) + .await + { + error!( + "Failed to insert QuorumCert {} from ChunkRetrievalRequest: {}", + request.target, e + ); + return; + } + let update_res = self + .safety_rules + .write() + .expect("[state synchronizer handler] unable to lock safety rules") + .process_ledger_info(&request.target.ledger_info()); + + if let Some(block) = update_res { + self.process_commit(block, request.target.ledger_info().clone()) + .await; + } + } + + let target_version = request.target.ledger_info().ledger_info().version(); + + let response = self + .sync_manager + .get_chunk(request.start_version, target_version, request.batch_size) + .await; + + if let Err(e) = request.response_sender.send(response) { + error!("Failed to return the requested chunk: {:?}", e); + } + } + + /// Inspect the current consensus state. + #[cfg(test)] + pub fn consensus_state(&self) -> ConsensusState { + self.safety_rules.read().unwrap().consensus_state() + } +} diff --git a/consensus/src/chained_bft/event_processor_test.rs b/consensus/src/chained_bft/event_processor_test.rs new file mode 100644 index 0000000000000..8230fb5898acd --- /dev/null +++ b/consensus/src/chained_bft/event_processor_test.rs @@ -0,0 +1,914 @@ +// Copyright (c) The Libra Core Contributors +// SPDX-License-Identifier: Apache-2.0 + +use crate::{ + chained_bft::{ + block_storage::{BlockReader, BlockStore}, + common::Author, + consensus_types::{block::Block, quorum_cert::QuorumCert}, + event_processor::EventProcessor, + liveness::{ + local_pacemaker::{ExponentialTimeInterval, LocalPacemaker}, + new_round_msg::{NewRoundMsg, PacemakerTimeout, PacemakerTimeoutCertificate}, + pacemaker::{NewRoundEvent, NewRoundReason, Pacemaker}, + pacemaker_timeout_manager::HighestTimeoutCertificates, + proposal_generator::ProposalGenerator, + proposer_election::{ProposalInfo, ProposerElection, ProposerInfo}, + rotating_proposer_election::RotatingProposer, + }, + network::{ + BlockRetrievalRequest, BlockRetrievalResponse, ChunkRetrievalRequest, + ConsensusNetworkImpl, + }, + network_tests::NetworkPlayground, + persistent_storage::{PersistentStorage, RecoveryData}, + safety::{ + safety_rules::{ConsensusState, SafetyRules}, + vote_msg::VoteMsg, + }, + test_utils::{ + consensus_runtime, placeholder_certificate_for_block, placeholder_ledger_info, + MockStateComputer, MockStorage, MockTransactionManager, TestPayload, TreeInserter, + }, + }, + state_replication::ExecutedState, + stream_utils::start_event_processing_loop, + time_service::{ClockTimeService, TimeService}, +}; +use channel; +use crypto::HashValue; +use futures::{ + channel::{mpsc, oneshot}, + compat::Future01CompatExt, + executor::block_on, + prelude::*, +}; +use network::{ + proto::BlockRetrievalStatus, + validator_network::{ConsensusNetworkEvents, ConsensusNetworkSender}, +}; +use proto_conv::FromProto; +use std::{ + collections::HashMap, + sync::{Arc, RwLock}, + time::Duration, +}; +use tokio::runtime::TaskExecutor; +use types::{ + account_address::AccountAddress, + ledger_info::{LedgerInfo, LedgerInfoWithSignatures}, + validator_signer::ValidatorSigner, + validator_verifier::ValidatorVerifier, +}; + +/// Auxiliary struct that is setting up node environment for the test. +#[allow(dead_code)] +struct NodeSetup { + author: Author, + block_store: Arc>, + event_processor: EventProcessor, + new_rounds_receiver: mpsc::Receiver, + proposal_winners_receiver: mpsc::Receiver>, + storage: Arc>, + signer: ValidatorSigner, + proposer_author: Author, + peers: Arc>, + pacemaker: Arc, + commit_cb_receiver: mpsc::UnboundedReceiver, +} + +impl NodeSetup { + fn build_empty_store( + signer: ValidatorSigner, + storage: Arc>, + initial_data: RecoveryData, + ) -> Arc> { + let (commit_cb_sender, _commit_cb_receiver) = mpsc::unbounded::(); + + Arc::new(block_on(BlockStore::new( + storage, + initial_data, + signer, + Arc::new(MockStateComputer::new(commit_cb_sender)), + true, + 10, // max pruned blocks in mem + ))) + } + + fn create_pacemaker(time_service: Arc) -> Arc { + let base_timeout = Duration::new(5, 0); + let time_interval = Box::new(ExponentialTimeInterval::fixed(base_timeout)); + let highest_certified_round = 0; + let (pacemaker_timeout_sender, _) = channel::new_test(1_024); + Arc::new(LocalPacemaker::new( + MockStorage::::start_for_testing() + .0 + .persistent_liveness_storage(), + time_interval, + 0, + highest_certified_round, + time_service, + pacemaker_timeout_sender, + 1, + HighestTimeoutCertificates::new(None, None), + )) + } + + fn create_proposer_election( + author: Author, + ) -> Arc + Send + Sync> { + Arc::new(RotatingProposer::new(vec![author], 1)) + } + + fn create_nodes( + playground: &mut NetworkPlayground, + executor: TaskExecutor, + num_nodes: usize, + ) -> Vec { + let mut signers = vec![]; + let mut peers = vec![]; + for _ in 0..num_nodes { + let signer = ValidatorSigner::random(); + peers.push(signer.author()); + signers.push(signer); + } + let proposer_author = peers[0]; + let peers_ref = Arc::new(peers); + let mut nodes = vec![]; + for signer in signers.iter().take(num_nodes) { + let (storage, initial_data) = MockStorage::::start_for_testing(); + nodes.push(Self::new( + playground, + executor.clone(), + signer.clone(), + proposer_author, + Arc::clone(&peers_ref), + storage, + initial_data, + )); + } + nodes + } + + fn new( + playground: &mut NetworkPlayground, + executor: TaskExecutor, + signer: ValidatorSigner, + proposer_author: Author, + peers: Arc>, + storage: Arc>, + initial_data: RecoveryData, + ) -> Self { + let (network_reqs_tx, network_reqs_rx) = channel::new_test(8); + let (consensus_tx, consensus_rx) = channel::new_test(8); + let network_sender = ConsensusNetworkSender::new(network_reqs_tx); + let network_events = ConsensusNetworkEvents::new(consensus_rx); + let author = signer.author(); + + playground.add_node(author, consensus_tx, network_reqs_rx); + let validator = ValidatorVerifier::new_single(signer.author(), signer.public_key()); + + let network = ConsensusNetworkImpl::new( + signer.author(), + network_sender, + network_events, + Arc::clone(&peers), + Arc::new(validator), + ); + let consensus_state = initial_data.state(); + + let block_store = Self::build_empty_store(signer.clone(), storage.clone(), initial_data); + let time_service = Arc::new(ClockTimeService::new(executor.clone())); + let proposal_generator = ProposalGenerator::new( + block_store.clone(), + Arc::new(MockTransactionManager::new()), + time_service.clone(), + 1, + true, + ); + let safety_rules = Arc::new(RwLock::new(SafetyRules::new( + block_store.clone(), + consensus_state, + ))); + + let mut pacemaker = Self::create_pacemaker(time_service.clone()); + let (pm_events_sender, new_rounds_receiver) = + start_event_processing_loop(&mut pacemaker, executor.clone()); + let mut proposer_election = Self::create_proposer_election(proposer_author); + let (proposal_candidates_sender, proposal_winners_receiver) = + start_event_processing_loop(&mut proposer_election, executor.clone()); + let (commit_cb_sender, commit_cb_receiver) = mpsc::unbounded::(); + let event_processor = EventProcessor::new( + author, + Arc::clone(&block_store), + Arc::clone(&pacemaker), + Arc::clone(&proposer_election), + pm_events_sender, + proposal_candidates_sender, + proposal_generator, + safety_rules, + Arc::new(MockStateComputer::new(commit_cb_sender)), + Arc::new(MockTransactionManager::new()), + network, + storage.clone(), + time_service, + true, + ); + Self { + author, + block_store, + event_processor, + new_rounds_receiver, + proposal_winners_receiver, + storage, + signer, + proposer_author, + peers, + pacemaker, + commit_cb_receiver, + } + } + + pub fn restart(self, playground: &mut NetworkPlayground, executor: TaskExecutor) -> Self { + let recover_data = self + .storage + .get_recovery_data() + .unwrap_or_else(|e| panic!("fail to restart due to: {}", e)); + Self::new( + playground, + executor, + self.signer, + self.proposer_author, + self.peers, + self.storage, + recover_data, + ) + } +} + +#[test] +fn basic_new_rank_event_test() { + let runtime = consensus_runtime(); + let mut playground = NetworkPlayground::new(runtime.executor()); + let nodes = NodeSetup::create_nodes(&mut playground, runtime.executor(), 2); + let node = &nodes[0]; + let genesis = node.block_store.root(); + let mut inserter = TreeInserter::new(node.block_store.clone()); + let a1 = + inserter.insert_block_with_qc(QuorumCert::certificate_for_genesis(), genesis.as_ref(), 1); + block_on(async move { + let new_round = 1; + node.event_processor + .process_new_round_event(NewRoundEvent { + round: new_round, + reason: NewRoundReason::QCReady, + timeout: Duration::new(5, 0), + }) + .await; + let pending_messages = playground + .wait_for_messages(1, NetworkPlayground::proposals_only) + .await; + let pending_proposals = pending_messages + .into_iter() + .filter(|m| m.1.has_proposal()) + .map(|mut m| { + ProposalInfo::::from_proto(m.1.take_proposal()).unwrap() + }) + .collect::>(); + assert_eq!(pending_proposals.len(), 1); + assert_eq!(pending_proposals[0].proposal.round(), new_round,); + assert_eq!( + pending_proposals[0] + .proposal + .quorum_cert() + .certified_block_id(), + genesis.id() + ); + assert_eq!(pending_proposals[0].proposer_info.get_author(), node.author); + + // Simulate a case with a1 receiving enough votes for a QC: a new proposal + // should be a child of a1 and carry its QC. + let vote_msg = VoteMsg::new( + a1.id(), + node.block_store.get_state_for_block(a1.id()).unwrap(), + a1.round(), + node.block_store.signer().author(), + placeholder_ledger_info(), + node.block_store.signer(), + ); + node.block_store.insert_vote_and_qc(vote_msg, 0).await; + node.event_processor + .process_new_round_event(NewRoundEvent { + round: 2, + reason: NewRoundReason::QCReady, + timeout: Duration::new(5, 0), + }) + .await; + let pending_messages = playground + .wait_for_messages(1, NetworkPlayground::proposals_only) + .await; + let pending_proposals = pending_messages + .into_iter() + .filter(|m| m.1.has_proposal()) + .map(|mut m| { + ProposalInfo::::from_proto(m.1.take_proposal()).unwrap() + }) + .collect::>(); + assert_eq!(pending_proposals.len(), 1); + assert_eq!(pending_proposals[0].proposal.round(), 2); + assert_eq!(pending_proposals[0].proposal.parent_id(), a1.id()); + assert_eq!(pending_proposals[0].proposal.height(), 2); + assert_eq!( + pending_proposals[0] + .proposal + .quorum_cert() + .certified_block_id(), + a1.id() + ); + }); +} + +#[test] +/// If the proposal is valid, a vote should be sent +fn process_successful_proposal_test() { + let runtime = consensus_runtime(); + let mut playground = NetworkPlayground::new(runtime.executor()); + // In order to observe the votes we're going to check proposal processing on the non-proposer + // node (which will send the votes to the proposer). + let nodes = NodeSetup::create_nodes(&mut playground, runtime.executor(), 2); + let node = &nodes[1]; + + let genesis = node.block_store.root(); + let genesis_qc = QuorumCert::certificate_for_genesis(); + block_on(async move { + let proposal_info = ProposalInfo:: { + proposal: Block::make_block( + genesis.as_ref(), + vec![1], + 1, + 1, + genesis_qc.clone(), + node.block_store.signer(), + ), + proposer_info: node.author, + timeout_certificate: None, + highest_ledger_info: genesis_qc.clone(), + }; + let proposal_id = proposal_info.proposal.id(); + node.event_processor + .process_winning_proposal(proposal_info) + .await; + let pending_messages = playground + .wait_for_messages(1, NetworkPlayground::votes_only) + .await; + let pending_for_proposer = pending_messages + .into_iter() + .filter(|m| m.1.has_vote() && m.0 == node.author) + .map(|mut m| VoteMsg::from_proto(m.1.take_vote()).unwrap()) + .collect::>(); + assert_eq!(pending_for_proposer.len(), 1); + assert_eq!(pending_for_proposer[0].author(), node.author); + assert_eq!(pending_for_proposer[0].proposed_block_id(), proposal_id); + assert_eq!( + *node.storage.shared_storage.state.lock().unwrap(), + ConsensusState::new(1, 0, 0), + ); + }); +} + +#[test] +/// If the proposal does not pass voting rules, +/// No votes are sent, but the block is still added to the block tree. +fn process_old_proposal_test() { + let runtime = consensus_runtime(); + let mut playground = NetworkPlayground::new(runtime.executor()); + // In order to observe the votes we're going to check proposal processing on the non-proposer + // node (which will send the votes to the proposer). + let nodes = NodeSetup::create_nodes(&mut playground, runtime.executor(), 2); + let node = &nodes[1]; + let genesis = node.block_store.root(); + let genesis_qc = QuorumCert::certificate_for_genesis(); + let new_block = Block::make_block( + genesis.as_ref(), + vec![1], + 1, + 1, + genesis_qc.clone(), + node.block_store.signer(), + ); + let new_block_id = new_block.id(); + let old_block = Block::make_block( + genesis.as_ref(), + vec![1], + 1, + 2, + genesis_qc.clone(), + node.block_store.signer(), + ); + let old_block_id = old_block.id(); + block_on(async move { + node.event_processor + .process_winning_proposal(ProposalInfo:: { + proposal: new_block, + proposer_info: node.author, + timeout_certificate: None, + highest_ledger_info: genesis_qc.clone(), + }) + .await; + node.event_processor + .process_winning_proposal(ProposalInfo:: { + proposal: old_block, + proposer_info: node.author, + timeout_certificate: None, + highest_ledger_info: genesis_qc.clone(), + }) + .await; + let pending_messages = playground + .wait_for_messages(1, NetworkPlayground::votes_only) + .await; + let pending_for_me = pending_messages + .into_iter() + .filter(|m| m.1.has_vote() && m.0 == node.author) + .map(|mut m| VoteMsg::from_proto(m.1.take_vote()).unwrap()) + .collect::>(); + // just the new one + assert_eq!(pending_for_me.len(), 1); + assert_eq!(pending_for_me[0].proposed_block_id(), new_block_id); + assert!(node.block_store.get_block(old_block_id).is_some()); + }); +} + +#[test] +/// We don't vote for proposals that 'skips' rounds +/// After that When we then receive proposal for correct round, we vote for it +/// Basically it checks that adversary can not send proposal and skip rounds violating pacemaker +/// rules +fn process_round_mismatch_test() { + let runtime = consensus_runtime(); + let mut playground = NetworkPlayground::new(runtime.executor()); + // In order to observe the votes we're going to check proposal processing on the non-proposer + // node (which will send the votes to the proposer). + let mut node = NodeSetup::create_nodes(&mut playground, runtime.executor(), 1) + .pop() + .unwrap(); + let genesis = node.block_store.root(); + let genesis_qc = QuorumCert::certificate_for_genesis(); + let correct_block = Block::make_block( + genesis.as_ref(), + vec![1], + 1, + 1, + genesis_qc.clone(), + node.block_store.signer(), + ); + let correct_block_id = correct_block.id(); + let block_skip_round = Block::make_block( + genesis.as_ref(), + vec![1], + 2, + 2, + genesis_qc.clone(), + node.block_store.signer(), + ); + block_on(async move { + node.event_processor + .process_proposal(ProposalInfo:: { + proposal: block_skip_round, + proposer_info: node.author, + timeout_certificate: None, + highest_ledger_info: genesis_qc.clone(), + }) + .await; + node.event_processor + .process_proposal(ProposalInfo:: { + proposal: correct_block, + proposer_info: node.author, + timeout_certificate: None, + highest_ledger_info: genesis_qc.clone(), + }) + .await; + + let winning = node + .proposal_winners_receiver + .next() + .await + .expect("No winning proposal"); + assert_eq!(winning.proposal.id(), correct_block_id); + }); +} + +#[test] +/// Ensure that after new round messages are sent that the receivers have the latest +/// quorum certificate +fn process_new_round_msg_test() { + let runtime = consensus_runtime(); + let mut playground = NetworkPlayground::new(runtime.executor()); + let mut nodes = NodeSetup::create_nodes(&mut playground, runtime.executor(), 2); + let non_proposer = nodes.pop().unwrap(); + let mut static_proposer = nodes.pop().unwrap(); + + let genesis = non_proposer.block_store.root(); + let block_0 = non_proposer + .block_store + .create_block(genesis, vec![1], 1, 1); + let block_0_id = block_0.id(); + block_on( + non_proposer + .block_store + .execute_and_insert_block(block_0.clone()), + ) + .unwrap(); + block_on( + static_proposer + .block_store + .execute_and_insert_block(block_0), + ) + .unwrap(); + + // Populate block_0 and a quorum certificate for block_0 on non_proposer + let block_0_quorum_cert = placeholder_certificate_for_block( + vec![static_proposer.signer.clone(), non_proposer.signer.clone()], + block_0_id, + 1, + ); + block_on( + non_proposer + .block_store + .insert_single_quorum_cert(block_0_quorum_cert.clone()), + ) + .unwrap(); + assert_eq!( + static_proposer + .block_store + .highest_quorum_cert() + .certified_block_round(), + 0 + ); + assert_eq!( + non_proposer + .block_store + .highest_quorum_cert() + .certified_block_round(), + 1 + ); + + // As the static proposer processes the new round message it should learn about + // block_0_quorum_cert at round 1. + block_on( + static_proposer + .event_processor + .process_new_round_msg(NewRoundMsg::new( + block_0_quorum_cert, + QuorumCert::certificate_for_genesis(), + PacemakerTimeout::new(2, &non_proposer.signer), + &non_proposer.signer, + )), + ); + assert_eq!( + static_proposer + .block_store + .highest_quorum_cert() + .certified_block_round(), + 1 + ); +} + +#[test] +/// We don't vote for proposals that comes from proposers that are not valid proposers for round +fn process_proposer_mismatch_test() { + let runtime = consensus_runtime(); + let mut playground = NetworkPlayground::new(runtime.executor()); + // In order to observe the votes we're going to check proposal processing on the non-proposer + // node (which will send the votes to the proposer). + let mut nodes = NodeSetup::create_nodes(&mut playground, runtime.executor(), 2); + let incorrect_proposer = nodes.pop().unwrap(); + let mut node = nodes.pop().unwrap(); + let genesis = node.block_store.root(); + let genesis_qc = QuorumCert::certificate_for_genesis(); + let correct_block = Block::make_block( + genesis.as_ref(), + vec![1], + 1, + 1, + genesis_qc.clone(), + node.block_store.signer(), + ); + let correct_block_id = correct_block.id(); + let block_incorrect_proposer = Block::make_block( + genesis.as_ref(), + vec![1], + 1, + 1, + genesis_qc.clone(), + incorrect_proposer.block_store.signer(), + ); + block_on(async move { + node.event_processor + .process_proposal(ProposalInfo:: { + proposal: block_incorrect_proposer, + proposer_info: incorrect_proposer.author, + timeout_certificate: None, + highest_ledger_info: genesis_qc.clone(), + }) + .await; + + node.event_processor + .process_proposal(ProposalInfo:: { + proposal: correct_block, + proposer_info: node.author, + timeout_certificate: None, + highest_ledger_info: genesis_qc.clone(), + }) + .await; + + let winning = node + .proposal_winners_receiver + .next() + .await + .expect("No winning proposal"); + assert_eq!(winning.proposal.id(), correct_block_id); + }); +} + +#[test] +/// We allow to 'skips' round if proposal carries timeout certificate for next round +fn process_timeout_certificate_test() { + let runtime = consensus_runtime(); + let mut playground = NetworkPlayground::new(runtime.executor()); + // In order to observe the votes we're going to check proposal processing on the non-proposer + // node (which will send the votes to the proposer). + let mut node = NodeSetup::create_nodes(&mut playground, runtime.executor(), 1) + .pop() + .unwrap(); + let genesis = node.block_store.root(); + let genesis_qc = QuorumCert::certificate_for_genesis(); + let correct_block = Block::make_block( + genesis.as_ref(), + vec![1], + 1, + 1, + genesis_qc.clone(), + node.block_store.signer(), + ); + let _correct_block_id = correct_block.id(); + let block_skip_round = Block::make_block( + genesis.as_ref(), + vec![1], + 2, + 2, + genesis_qc.clone(), + node.block_store.signer(), + ); + let block_skip_round_id = block_skip_round.id(); + let tc = PacemakerTimeoutCertificate::new(1, vec![PacemakerTimeout::new(1, &node.signer)]); + block_on(async move { + node.event_processor + .process_proposal(ProposalInfo:: { + proposal: block_skip_round, + proposer_info: node.author, + timeout_certificate: Some(tc), + highest_ledger_info: genesis_qc.clone(), + }) + .await; + node.event_processor + .process_proposal(ProposalInfo:: { + proposal: correct_block, + proposer_info: node.author, + timeout_certificate: None, + highest_ledger_info: genesis_qc.clone(), + }) + .await; + + let winning = node + .proposal_winners_receiver + .next() + .await + .expect("No winning proposal"); + assert_eq!(winning.proposal.id(), block_skip_round_id); + }); +} + +#[test] +/// Happy path for vote processing: +/// 1) if a new QC is formed and a block is present send a PM event +fn process_votes_basic_test() { + let runtime = consensus_runtime(); + let mut playground = NetworkPlayground::new(runtime.executor()); + let mut node = NodeSetup::create_nodes(&mut playground, runtime.executor(), 1) + .pop() + .unwrap(); + let genesis = node.block_store.root(); + let mut inserter = TreeInserter::new(node.block_store.clone()); + let a1 = + inserter.insert_block_with_qc(QuorumCert::certificate_for_genesis(), genesis.as_ref(), 1); + let vote_msg = VoteMsg::new( + a1.id(), + node.block_store.get_state_for_block(a1.id()).unwrap(), + a1.round(), + node.block_store.signer().author(), + placeholder_ledger_info(), + node.block_store.signer(), + ); + block_on(async move { + node.event_processor.process_vote(vote_msg, 1).await; + let new_round_event = node.new_rounds_receiver.next().await.unwrap(); + assert_eq!(new_round_event.reason, NewRoundReason::QCReady); + assert_eq!(new_round_event.round, a1.round()); + }); + block_on(runtime.shutdown_now().compat()).unwrap(); +} + +#[test] +fn process_chunk_retrieval() { + let runtime = consensus_runtime(); + let mut playground = NetworkPlayground::new(runtime.executor()); + let node = NodeSetup::create_nodes(&mut playground, runtime.executor(), 1) + .pop() + .unwrap(); + + let genesis = node.block_store.root(); + let genesis_qc = QuorumCert::certificate_for_genesis(); + + let block = Block::make_block( + genesis.as_ref(), + vec![1], + 1, + 1, + genesis_qc.clone(), + node.block_store.signer(), + ); + let proposal_info = ProposalInfo:: { + proposal: block.clone(), + proposer_info: node.author, + timeout_certificate: None, + highest_ledger_info: genesis_qc.clone(), + }; + node.pacemaker + .process_certificates_from_proposal(proposal_info.proposal.round() - 1, None); + + block_on(async move { + node.event_processor + .process_winning_proposal(proposal_info) + .await; + let ledger_info = + LedgerInfo::new(1, HashValue::zero(), HashValue::zero(), block.id(), 0, 0); + let target = QuorumCert::new( + block.id(), + ExecutedState::state_for_genesis(), + 0, + LedgerInfoWithSignatures::new(ledger_info, HashMap::new()), + ); + let req = ChunkRetrievalRequest { + start_version: 0, + target, + batch_size: 1, + response_sender: oneshot::channel().0, + }; + node.event_processor.process_chunk_retrieval(req).await; + assert_eq!(node.block_store.root().round(), 1); + }); +} + +#[test] +fn process_block_retrieval() { + let runtime = consensus_runtime(); + let mut playground = NetworkPlayground::new(runtime.executor()); + let node = NodeSetup::create_nodes(&mut playground, runtime.executor(), 1) + .pop() + .unwrap(); + + let genesis = node.block_store.root(); + let genesis_qc = QuorumCert::certificate_for_genesis(); + + let block = Block::make_block( + genesis.as_ref(), + vec![1], + 1, + 1, + genesis_qc.clone(), + node.block_store.signer(), + ); + let block_id = block.id(); + let proposal_info = ProposalInfo:: { + proposal: block.clone(), + proposer_info: node.author, + timeout_certificate: None, + highest_ledger_info: genesis_qc.clone(), + }; + node.pacemaker + .process_certificates_from_proposal(proposal_info.proposal.round() - 1, None); + + block_on(async move { + node.event_processor + .process_winning_proposal(proposal_info) + .await; + + // first verify that we can retrieve the block if it's in the tree + let (tx1, rx1) = oneshot::channel(); + let single_block_request = BlockRetrievalRequest { + block_id, + num_blocks: 1, + response_sender: tx1, + }; + node.event_processor + .process_block_retrieval(single_block_request) + .await; + match rx1.await { + Ok(BlockRetrievalResponse { status, blocks }) => { + assert_eq!(status, BlockRetrievalStatus::SUCCEEDED); + assert_eq!(block_id, blocks.get(0).unwrap().id()); + } + _ => panic!("block retrieval failure"), + } + + // verify that if a block is not there, return ID_NOT_FOUND + let (tx2, rx2) = oneshot::channel(); + let missing_block_request = BlockRetrievalRequest { + block_id: HashValue::random(), + num_blocks: 1, + response_sender: tx2, + }; + node.event_processor + .process_block_retrieval(missing_block_request) + .await; + match rx2.await { + Ok(BlockRetrievalResponse { status, blocks }) => { + assert_eq!(status, BlockRetrievalStatus::ID_NOT_FOUND); + assert!(blocks.is_empty()); + } + _ => panic!("block retrieval failure"), + } + + // if asked for many blocks, return NOT_ENOUGH_BLOCKS + let (tx3, rx3) = oneshot::channel(); + let many_block_request = BlockRetrievalRequest { + block_id, + num_blocks: 3, + response_sender: tx3, + }; + node.event_processor + .process_block_retrieval(many_block_request) + .await; + match rx3.await { + Ok(BlockRetrievalResponse { status, blocks }) => { + assert_eq!(status, BlockRetrievalStatus::NOT_ENOUGH_BLOCKS); + assert_eq!(block_id, blocks.get(0).unwrap().id()); + assert_eq!(node.block_store.root().id(), blocks.get(1).unwrap().id()); + } + _ => panic!("block retrieval failure"), + } + }); +} + +#[test] +/// rebuild a node from previous storage without violating safety guarantees. +fn basic_restart_test() { + let runtime = consensus_runtime(); + let mut playground = NetworkPlayground::new(runtime.executor()); + let mut node = NodeSetup::create_nodes(&mut playground, runtime.executor(), 1) + .pop() + .unwrap(); + let node_mut = &mut node; + + let genesis = node_mut.block_store.root(); + let genesis_qc = QuorumCert::certificate_for_genesis(); + let mut proposals = Vec::new(); + let proposals_mut = &mut proposals; + let num_proposals = 100; + // insert a few successful proposals + block_on(async move { + for i in 1..=num_proposals { + let proposal_info = ProposalInfo:: { + proposal: Block::make_block( + genesis.as_ref(), + vec![1], + i, + 1, + genesis_qc.clone(), + node_mut.block_store.signer(), + ), + proposer_info: node_mut.author, + timeout_certificate: None, + highest_ledger_info: genesis_qc.clone(), + }; + let proposal_id = proposal_info.proposal.id(); + proposals_mut.push(proposal_id); + node_mut + .pacemaker + .process_certificates_from_proposal(proposal_info.proposal.round() - 1, None); + node_mut + .event_processor + .process_winning_proposal(proposal_info) + .await; + } + }); + // verify after restart we recover the data + node = node.restart(&mut playground, runtime.executor()); + assert_eq!( + node.event_processor.consensus_state(), + ConsensusState::new(num_proposals, 0, 0,), + ); + for id in proposals { + assert_eq!(node.block_store.block_exists(id), true); + } +} diff --git a/consensus/src/chained_bft/liveness/local_pacemaker.rs b/consensus/src/chained_bft/liveness/local_pacemaker.rs new file mode 100644 index 0000000000000..612d1ca45345e --- /dev/null +++ b/consensus/src/chained_bft/liveness/local_pacemaker.rs @@ -0,0 +1,460 @@ +// Copyright (c) The Libra Core Contributors +// SPDX-License-Identifier: Apache-2.0 + +use crate::{ + chained_bft::{ + common::Round, + liveness::{ + new_round_msg::PacemakerTimeoutCertificate, + pacemaker::{NewRoundEvent, NewRoundReason, Pacemaker, PacemakerEvent}, + pacemaker_timeout_manager::{HighestTimeoutCertificates, PacemakerTimeoutManager}, + }, + persistent_storage::PersistentLivenessStorage, + }, + counters, + stream_utils::EventBasedActor, + time_service::{SendTask, TimeService}, +}; +use channel; +use futures::{channel::mpsc, Future, FutureExt, SinkExt}; +use logger::prelude::*; +use std::{ + cmp::{self, max}, + pin::Pin, + sync::{Arc, RwLock}, + time::{Duration, Instant}, +}; +use termion::color::*; + +/// Determines the maximum round duration based on the round difference between the current +/// round and the committed round +pub trait PacemakerTimeInterval: Send + Sync + 'static { + /// Use the index of the round after the highest quorum certificate to commit a block and + /// return the duration for this round + /// + /// Round indices start at 0 (round index = 0 is the first round after the round that led + /// to the highest committed round). Given that round r is the highest round to commit a + /// block, then round index 0 is round r+1. Note that for genesis does not follow the + /// 3-chain rule for commits, so round 1 has round index 0. For example, if one wants + /// to calculate the round duration of round 6 and the highest committed round is 3 (meaning + /// the highest round to commit a block is round 5, then the round index is 0. + fn get_round_duration(&self, round_index_after_committed_qc: usize) -> Duration; +} + +/// Round durations increase exponentially +/// Basically time interval is base * mul^power +/// Where power=max(rounds_since_qc, max_exponent) +#[derive(Clone)] +pub struct ExponentialTimeInterval { + // Initial time interval duration after a successful quorum commit. + base_ms: u64, + // By how much we increase interval every time + exponent_base: f64, + // Maximum time interval won't exceed base * mul^max_pow. + // Theoretically, setting it means + // that we rely on synchrony assumptions when the known max messaging delay is + // max_interval. Alternatively, we can consider using max_interval to meet partial synchrony + // assumptions where while delta is unknown, it is <= max_interval. + max_exponent: usize, +} + +impl ExponentialTimeInterval { + #[cfg(test)] + pub fn fixed(duration: Duration) -> Self { + Self::new(duration, 1.0, 0) + } + + pub fn new(base: Duration, exponent_base: f64, max_exponent: usize) -> Self { + assert!( + max_exponent < 32, + "max_exponent for PacemakerTimeInterval should be <32" + ); + assert!( + exponent_base.powf(max_exponent as f64).ceil() < f64::from(std::u32::MAX), + "Maximum interval multiplier should be less then u32::Max" + ); + ExponentialTimeInterval { + base_ms: base.as_millis() as u64, // any reasonable ms timeout fits u64 perfectly + exponent_base, + max_exponent, + } + } +} + +impl PacemakerTimeInterval for ExponentialTimeInterval { + fn get_round_duration(&self, round_index_after_committed_qc: usize) -> Duration { + let pow = round_index_after_committed_qc.min(self.max_exponent) as u32; + let base_multiplier = self.exponent_base.powf(f64::from(pow)); + let duration_ms = ((self.base_ms as f64) * base_multiplier).ceil() as u64; + Duration::from_millis(duration_ms) + } +} + +/// `LocalPacemakerInner` is a Pacemaker implementation that relies on increasing local timeouts +/// in order to eventually come up with the timeout that is large enough to guarantee overlap of the +/// "current round" of multiple participants. +/// +/// The protocol is as follows: +/// * `LocalPacemakerInner` manages the `highest_certified_round` that is keeping the round of the +/// highest certified block known to the validator. +/// * Once a new QC arrives with a round larger than that of `highest_certified_round`, +/// local pacemaker is going to increment a round with a default timeout. +/// * Upon every timeout `LocalPacemaker` increments a round and doubles the timeout. +/// +/// `LocalPacemakerInner` does not require clock synchronization to maintain the property of +/// liveness - although clock synchronization can improve the time necessary to get a large enough +/// timeout overlap. +/// It does rely on an assumption that when an honest replica receives a quorum certificate +/// indicating to move to the next round, all other honest replicas will move to the next round +/// within a bounded time. This can be guaranteed via all honest replicas gossiping their highest +/// QC to f+1 other replicas for instance. +struct LocalPacemakerInner { + // Determines the time interval for a round interval + time_interval: Box, + // Highest round that a block was committed + highest_committed_round: Round, + // Highest round known certified by QC. + highest_qc_round: Round, + // Current round (current_round - highest_qc_round determines the timeout). + // Current round is basically max(highest_qc_round, highest_received_tc, highest_local_tc) + 1 + // update_current_round take care of updating current_round and sending new round event if + // it changes + current_round: Round, + // Approximate deadline when current round ends + current_round_deadline: Option, + // Service for timer + time_service: Arc, + // To send timeout events to itself. + timeout_sender: Option>, + // To send new round events to the client. + new_round_sender: Option>, + // To send timeout events to other pacemakers + pacemaker_timeout_sender: channel::Sender, + // Manages the PacemakerTimeout and PacemakerTimeoutCertificate structs + pacemaker_timeout_manager: PacemakerTimeoutManager, +} + +impl LocalPacemakerInner { + pub fn new( + persistent_liveness_storage: Box, + time_interval: Box, + highest_committed_round: Round, + highest_qc_round: Round, + time_service: Arc, + pacemaker_timeout_sender: channel::Sender, + pacemaker_timeout_quorum_size: usize, + highest_timeout_certificates: HighestTimeoutCertificates, + ) -> Self { + assert!(pacemaker_timeout_quorum_size > 0); + // The starting round is maximum(highest quorum certificate, + // highest timeout certificate round) + 1. Note that it is possible this + // replica already voted at this round and will until a round timeout + // or another replica convinces it via a quorum certificate or a timeout + // certificate to advance to a higher round. + let current_round = { + match highest_timeout_certificates.highest_timeout_certificate() { + Some(highest_timeout_certificate) => { + cmp::max(highest_qc_round, highest_timeout_certificate.round()) + } + None => highest_qc_round, + } + } + 1; + // Our counters are initialized via lazy_static, so they're not going to appear in + // Prometheus if some conditions never happen. Invoking get() function enforces creation. + counters::QC_ROUNDS_COUNT.get(); + counters::TIMEOUT_ROUNDS_COUNT.get(); + counters::TIMEOUT_COUNT.get(); + Self { + time_interval, + highest_committed_round, + highest_qc_round, + current_round, + current_round_deadline: None, + time_service, + timeout_sender: None, + new_round_sender: None, + pacemaker_timeout_sender, + pacemaker_timeout_manager: PacemakerTimeoutManager::new( + pacemaker_timeout_quorum_size, + highest_timeout_certificates, + persistent_liveness_storage, + ), + } + } + + /// Trigger an event to create a new round interval and ignore any events from previous round + /// intervals. The reason for the event is given by the caller, the timeout is + /// deterministically determined by the reason and the internal state. + fn create_new_round_task(&mut self, reason: NewRoundReason) -> impl Future + Send { + let round = self.current_round; + let timeout = self.setup_timeout(); + let mut sender = self.new_round_sender.as_ref().unwrap().clone(); + async move { + if let Err(e) = sender + .send(NewRoundEvent { + round, + reason, + timeout, + }) + .await + { + debug!("Error in sending new round interval event: {:?}", e); + } + } + } + + /// Broadcasts timeout messages to validators + /// This task will also re-schedule a new timeout task for the same round, in order to ensure + /// that machines that are down will eventually get the timeouts and be able to form a timeout + /// certificate. + fn broadcast_timeout_task(&mut self) -> impl Future + Send { + self.setup_timeout(); + counters::TIMEOUT_COUNT.inc(); + let mut sender = self.pacemaker_timeout_sender.clone(); + let current_round = self.current_round; + async move { + if let Err(e) = sender.send(current_round).await { + warn!("Can't send pacemaker timeout message: {:?}", e) + } + } + } + + /// Setup the timeout task and return the duration of the current timeout + fn setup_timeout(&mut self) -> Duration { + let timeout_sender = self.timeout_sender.as_ref().unwrap().clone(); + let timeout = self.setup_deadline(); + // Note that the timeout should not be driven sequentially with any other events as it can + // become the head of the line blocker. + trace!( + "Scheduling to {} for round {}", + timeout.as_millis(), + self.current_round + ); + self.time_service.run_after( + timeout, + SendTask::make( + timeout_sender, + PacemakerEvent::Timeout { + round: self.current_round, + }, + ), + ); + timeout + } + + /// Setup the current round deadline and return the duration of the current round + fn setup_deadline(&mut self) -> Duration { + let round_index_after_committed_round = { + if self.highest_committed_round == 0 { + // Genesis doesn't require the 3-chain rule for commit, hence start the index at + // the round after genesis. + self.current_round - 1 + } else { + if self.current_round - self.highest_committed_round < 3 { + warn!("Finding a deadline for a round {} that should have already been completed since the highest committed round is {}", + self.current_round, + self.highest_committed_round); + } + + max(0, self.current_round - self.highest_committed_round - 3) + } + } as usize; + let timeout = self + .time_interval + .get_round_duration(round_index_after_committed_round); + self.current_round_deadline = Some(Instant::now() + timeout); + timeout + } + + /// Attempts to update highest_qc_certified_round when receiving QC for given round. + /// Returns true if highest_qc_certified_round of this pacemaker has changed + fn update_highest_qc_round(&mut self, round: Round) -> bool { + if round > self.highest_qc_round { + debug!( + "{}QuorumCertified at {}{}", + Fg(LightBlack), + round, + Fg(Reset) + ); + self.highest_qc_round = round; + return true; + } + false + } + + /// Combines highest_qc_certified_round, highest_local_tc and highest_received_tc into + /// effective round of this pacemaker. + /// Generates new_round event if effective round changes and ensures it is + /// monotonically increasing + fn update_current_round(&mut self) -> Pin + Send>> { + let (mut best_round, mut best_reason) = (self.highest_qc_round, NewRoundReason::QCReady); + if let Some(highest_timeout_certificate) = + self.pacemaker_timeout_manager.highest_timeout_certificate() + { + if highest_timeout_certificate.round() > best_round { + best_round = highest_timeout_certificate.round(); + best_reason = NewRoundReason::Timeout { + cert: highest_timeout_certificate.clone(), + }; + } + } + + let new_round = best_round + 1; + if self.current_round == new_round { + debug!( + "{}Round did not change: {}{}", + Fg(LightBlack), + new_round, + Fg(Reset) + ); + return async {}.boxed(); + } + assert!( + new_round > self.current_round, + "Round illegally decreased from {} to {}", + self.current_round, + new_round + ); + self.current_round = new_round; + self.create_new_round_task(best_reason).boxed() + } + + /// Validate timeout certificate and update local state if it's correct + fn check_and_update_highest_received_tc( + &mut self, + tc: Option<&PacemakerTimeoutCertificate>, + ) -> bool { + if let Some(tc) = tc { + return self + .pacemaker_timeout_manager + .update_highest_received_timeout_certificate(tc); + } + false + } +} + +/// `LocalPacemaker` is a wrapper to make the `LocalPacemakerInner` thread-safe. +pub struct LocalPacemaker { + inner: RwLock, +} + +impl LocalPacemaker { + pub fn new( + persistent_liveness_storage: Box, + time_interval: Box, + highest_committed_round: Round, + highest_qc_round: Round, + time_service: Arc, + pacemaker_timeout_sender: channel::Sender, + pacemaker_timeout_quorum_size: usize, + highest_timeout_certificates: HighestTimeoutCertificates, + ) -> Self { + LocalPacemaker { + inner: RwLock::new(LocalPacemakerInner::new( + persistent_liveness_storage, + time_interval, + highest_committed_round, + highest_qc_round, + time_service, + pacemaker_timeout_sender, + pacemaker_timeout_quorum_size, + highest_timeout_certificates, + )), + } + } +} + +impl EventBasedActor for LocalPacemaker { + type InputEvent = PacemakerEvent; + type OutputEvent = NewRoundEvent; + + fn init( + &mut self, + input_stream_sender: mpsc::Sender, + output_stream_sender: mpsc::Sender, + ) { + let mut guard = self.inner.write().unwrap(); + guard.timeout_sender = Some(input_stream_sender); + guard.new_round_sender = Some(output_stream_sender); + guard.setup_deadline(); + } + + fn on_startup(&self) -> Pin + Send>> { + // To jump start the execution return the new round event for the current round. + self.inner + .write() + .unwrap() + .create_new_round_task(NewRoundReason::QCReady) + .boxed() + } + + fn process_event(&self, event: Self::InputEvent) -> Pin + Send>> { + let mut guard = self.inner.write().unwrap(); + match event { + // Upon learning about a new quorum certificate, the pacemaker + // should advance to round r+1 if it's current round < r+1 + PacemakerEvent::QuorumCertified { round } => { + if guard.update_highest_qc_round(round) { + return guard.update_current_round(); + } + } + // Upon receiving a notification that a round has timed out, broadcast this event + // as a NewRoundMsg to all replicas. + PacemakerEvent::Timeout { round } => { + if round == guard.current_round { + warn!( + "Round {} has timed out, broadcasting new round message to all replicas", + round + ); + return guard.broadcast_timeout_task().boxed(); + } + } + // Upon receiving a pacemaker timeout from another replica, check to see if a + // timeout certificate is formed and the round should be advanced. + PacemakerEvent::RemoteTimeout { pacemaker_timeout } => { + if guard + .pacemaker_timeout_manager + .update_received_timeout(pacemaker_timeout) + { + return guard.update_current_round(); + } + } + } + async {}.boxed() + } +} + +impl Pacemaker for LocalPacemaker { + fn current_round_deadline(&self) -> Instant { + self.inner + .read() + .unwrap() + .current_round_deadline + .expect("Round deadline was not set") + } + + fn current_round(&self) -> Round { + self.inner.read().unwrap().current_round + } + + fn process_certificates_from_proposal( + &self, + qc_round: Round, + timeout_certificate: Option<&PacemakerTimeoutCertificate>, + ) -> Pin + Send>> { + let mut guard = self.inner.write().unwrap(); + let tc_round_updated = guard.check_and_update_highest_received_tc(timeout_certificate); + let qc_round_updated = guard.update_highest_qc_round(qc_round); + if tc_round_updated || qc_round_updated { + return guard.update_current_round(); + } + async {}.boxed() + } + + fn update_highest_committed_round(&self, highest_committed_round: Round) { + let mut guard = self.inner.write().unwrap(); + if guard.highest_committed_round < highest_committed_round { + guard.highest_committed_round = highest_committed_round; + } + } +} diff --git a/consensus/src/chained_bft/liveness/local_pacemaker_test.rs b/consensus/src/chained_bft/liveness/local_pacemaker_test.rs new file mode 100644 index 0000000000000..c155d02265743 --- /dev/null +++ b/consensus/src/chained_bft/liveness/local_pacemaker_test.rs @@ -0,0 +1,155 @@ +// Copyright (c) The Libra Core Contributors +// SPDX-License-Identifier: Apache-2.0 + +use crate::{ + chained_bft::{ + liveness::{ + local_pacemaker::{ExponentialTimeInterval, LocalPacemaker, PacemakerTimeInterval}, + new_round_msg::PacemakerTimeout, + pacemaker::{NewRoundEvent, NewRoundReason, PacemakerEvent}, + pacemaker_timeout_manager::HighestTimeoutCertificates, + }, + persistent_storage::PersistentStorage, + test_utils::{consensus_runtime, MockStorage, TestPayload}, + }, + mock_time_service::SimulatedTimeService, + stream_utils::start_event_processing_loop, +}; +use channel; +use futures::{channel::mpsc, executor::block_on, SinkExt, StreamExt}; +use std::{sync::Arc, time::Duration, u64}; +use tokio::runtime; +use types::validator_signer::ValidatorSigner; + +#[test] +fn test_pacemaker_time_interval() { + let interval = ExponentialTimeInterval::new(Duration::from_millis(3000), 1.5, 2); + assert_eq!(3000, interval.get_round_duration(0).as_millis()); + assert_eq!(4500, interval.get_round_duration(1).as_millis()); + assert_eq!( + 6750, /* 4500*1.5 */ + interval.get_round_duration(2).as_millis() + ); + // Test that there is no integer overflow + assert_eq!(6750, interval.get_round_duration(1000).as_millis()); +} + +#[test] +/// Verify that LocalPacemaker properly outputs PacemakerTimeoutMsg upon timeout +fn test_basic_timeout() { + let runtime = consensus_runtime(); + let time_interval = Box::new(ExponentialTimeInterval::fixed(Duration::from_millis(2))); + let highest_certified_round = 1; + let simulated_time = SimulatedTimeService::auto_advance_until(Duration::from_millis(4)); + let (pacemaker_timeout_tx, mut pacemaker_timeout_rx) = channel::new_test(1_024); + let mut pm = Arc::new(LocalPacemaker::new( + MockStorage::::start_for_testing() + .0 + .persistent_liveness_storage(), + time_interval, + 0, + highest_certified_round, + Arc::new(simulated_time.clone()), + pacemaker_timeout_tx, + 1, + HighestTimeoutCertificates::new(None, None), + )); + + start_event_processing_loop(&mut pm, runtime.executor()); + + block_on(async move { + for _ in 0..2 { + let round = pacemaker_timeout_rx.next().await.unwrap(); + // Here we just test timeout send retry, + // round for timeout is not changed as no timeout certificate was gathered at this point + assert_eq!(2, round); + } + }); +} + +#[test] +/// Verify that LocalPacemaker forms a timeout certificate on receiving sufficient timeout messages +fn test_timeout_certificate() { + let runtime = consensus_runtime(); + let rounds = 5; + let mut signers: Vec = vec![]; + for _round in 1..rounds { + let signer = ValidatorSigner::random(); + signers.push(signer); + } + let (mut tx, mut new_round_events_receiver) = make_pacemaker(&runtime); + + block_on(async move { + // Send timeout for rounds 1..5, each from a different author, so that they can be + // accumulated into single timeout certificate + for round in 1..rounds { + let signer = &signers[round - 1]; + let pacemaker_timeout = PacemakerTimeout::new(round as u64, signer); + tx.send(PacemakerEvent::RemoteTimeout { pacemaker_timeout }) + .await + .unwrap(); + } + // First event sent automatically on pacemaker startup + expect_qc(1, &mut new_round_events_receiver).await; + // Then timeout quorum for previous round (1,2,3) generates new round event for round 2 + expect_timeout(2, &mut new_round_events_receiver).await; + // Then timeout quorum for previous round (2,3,4) generates new round event for round 3 + expect_timeout(3, &mut new_round_events_receiver).await; + }); +} + +#[test] +fn test_basic_qc() { + let runtime = consensus_runtime(); + let (mut tx, mut new_round_intervals_receiver) = make_pacemaker(&runtime); + + block_on(async move { + tx.send(PacemakerEvent::QuorumCertified { round: 2 }) + .await + .unwrap(); + tx.send(PacemakerEvent::QuorumCertified { round: 3 }) + .await + .unwrap(); + + // The first event is just the initial round, the next two events are the new QCs. + expect_qc(1, &mut new_round_intervals_receiver).await; + expect_qc(3, &mut new_round_intervals_receiver).await; + expect_qc(4, &mut new_round_intervals_receiver).await; + }); +} + +fn make_pacemaker( + runtime: &runtime::Runtime, +) -> (mpsc::Sender, mpsc::Receiver) { + let time_interval = Box::new(ExponentialTimeInterval::fixed(Duration::from_millis(2))); + let simulated_time = SimulatedTimeService::new(); + let (pacemaker_timeout_tx, _) = channel::new_test(1_024); + let mut pm = Arc::new(LocalPacemaker::new( + MockStorage::::start_for_testing() + .0 + .persistent_liveness_storage(), + time_interval, + 0, + 0, + Arc::new(simulated_time.clone()), + pacemaker_timeout_tx, + 3, + HighestTimeoutCertificates::new(None, None), + )); + start_event_processing_loop(&mut pm, runtime.executor()) +} + +async fn expect_qc(round: u64, rx: &mut mpsc::Receiver) { + let event: NewRoundEvent = rx.next().await.unwrap(); + assert_eq!(round, event.round); + assert_eq!(NewRoundReason::QCReady, event.reason); +} + +async fn expect_timeout(round: u64, rx: &mut mpsc::Receiver) { + let event: NewRoundEvent = rx.next().await.unwrap(); + assert_eq!(round, event.round); + match event.reason { + NewRoundReason::Timeout { .. } => (), + x => panic!("Expected timeout for round {}, got {:?}", round, x), + }; +} diff --git a/consensus/src/chained_bft/liveness/mod.rs b/consensus/src/chained_bft/liveness/mod.rs new file mode 100644 index 0000000000000..0d9f2a1246a4a --- /dev/null +++ b/consensus/src/chained_bft/liveness/mod.rs @@ -0,0 +1,15 @@ +// Copyright (c) The Libra Core Contributors +// SPDX-License-Identifier: Apache-2.0 + +pub(crate) mod local_pacemaker; +pub(crate) mod new_round_msg; +pub(crate) mod pacemaker; +pub(crate) mod pacemaker_timeout_manager; +pub(crate) mod proposal_generator; +pub(crate) mod proposer_election; +pub(crate) mod rotating_proposer_election; + +#[cfg(test)] +mod local_pacemaker_test; +#[cfg(test)] +mod rotating_proposer_test; diff --git a/consensus/src/chained_bft/liveness/new_round_msg.rs b/consensus/src/chained_bft/liveness/new_round_msg.rs new file mode 100644 index 0000000000000..50a59491f059a --- /dev/null +++ b/consensus/src/chained_bft/liveness/new_round_msg.rs @@ -0,0 +1,392 @@ +// Copyright (c) The Libra Core Contributors +// SPDX-License-Identifier: Apache-2.0 + +use crate::chained_bft::{ + common::{Author, Round}, + consensus_types::quorum_cert::QuorumCert, + liveness::new_round_msg::PacemakerTimeoutCertificateVerificationError::*, +}; +use canonical_serialization::{CanonicalSerialize, CanonicalSerializer, SimpleSerializer}; +use crypto::{ + hash::{CryptoHash, CryptoHasher, NewRoundMsgHasher, PacemakerTimeoutHasher}, + HashValue, Signature, +}; +use network; +use proto_conv::{FromProto, IntoProto}; +use protobuf::RepeatedField; +use serde::{Deserialize, Serialize}; +use std::{collections::HashSet, convert::TryFrom, fmt, iter::FromIterator}; +use types::{ + account_address::AccountAddress, + validator_signer::ValidatorSigner, + validator_verifier::{ValidatorVerifier, VerifyError}, +}; + +// Internal use only. Contains all the fields in PaceMakerTimeout that contributes to the +// computation of its hash. +struct PacemakerTimeoutSerializer { + round: Round, + author: Author, +} + +impl CanonicalSerialize for PacemakerTimeoutSerializer { + fn serialize(&self, serializer: &mut impl CanonicalSerializer) -> failure::Result<()> { + serializer.encode_u64(self.round)?; + serializer.encode_struct(&self.author)?; + Ok(()) + } +} + +impl CryptoHash for PacemakerTimeoutSerializer { + type Hasher = PacemakerTimeoutHasher; + + fn hash(&self) -> HashValue { + let mut state = Self::Hasher::default(); + state.write(&SimpleSerializer::>::serialize(self).expect("Should serialize.")); + state.finish() + } +} + +/// This message will be broadcast by a pacemaker as part of NewRoundMsg when its local +/// timeout for a round is reached. Once f+1 PacemakerTimeout structs +/// from unique authors is gathered it forms a TimeoutCertificate. A TimeoutCertificate is +/// a proof that will cause a replica to advance to the minimum round in the TimeoutCertificate. +#[derive(Deserialize, Serialize, Clone, Debug, PartialEq, Eq)] +pub struct PacemakerTimeout { + round: Round, + author: Author, + signature: Signature, +} + +impl PacemakerTimeout { + /// Creates new PacemakerTimeoutMsg + pub fn new(round: Round, validator_signer: &ValidatorSigner) -> Self { + let author = validator_signer.author(); + let digest = PacemakerTimeoutSerializer { round, author }.hash(); + let signature = validator_signer + .sign_message(digest) + .expect("Failed to sign PacemakerTimeoutMsg"); + PacemakerTimeout { + round, + author, + signature, + } + } + + fn pacemaker_timeout_digest(author: AccountAddress, round: Round) -> HashValue { + PacemakerTimeoutSerializer { round, author }.hash() + } + + /// Calculates digest for this struct + pub fn digest(&self) -> HashValue { + Self::pacemaker_timeout_digest(self.author, self.round) + } + + pub fn round(&self) -> Round { + self.round + } + + /// Verifies that this message has valid signature + pub fn verify(&self, validator: &ValidatorVerifier) -> Result<(), VerifyError> { + validator.verify_signature(self.author, self.digest(), &self.signature) + } + + /// Returns the author of the timeout + pub fn author(&self) -> Author { + self.author + } + + /// Returns the signature of the author for this timeout + pub fn signature(&self) -> &Signature { + &self.signature + } +} + +impl IntoProto for PacemakerTimeout { + type ProtoType = network::proto::PacemakerTimeout; + + fn into_proto(self) -> Self::ProtoType { + let mut proto = Self::ProtoType::new(); + proto.set_round(self.round); + proto.set_author(self.author.into()); + proto.set_signature(self.signature.to_compact().as_ref().into()); + proto + } +} + +impl FromProto for PacemakerTimeout { + type ProtoType = network::proto::PacemakerTimeout; + + fn from_proto(mut object: Self::ProtoType) -> failure::Result { + let round = object.get_round(); + let author = Author::try_from(object.take_author())?; + let signature = Signature::from_compact(object.get_signature())?; + Ok(PacemakerTimeout { + round, + author, + signature, + }) + } +} + +// Internal use only. Contains all the fields in NewRoundMsg that contributes to the computation of +// its hash. +struct NewRoundMsgSerializer { + highest_quorum_certificate_block_id: HashValue, + pacemaker_timeout_digest: HashValue, +} + +impl CanonicalSerialize for NewRoundMsgSerializer { + fn serialize(&self, serializer: &mut impl CanonicalSerializer) -> failure::Result<()> { + serializer.encode_raw_bytes(self.highest_quorum_certificate_block_id.as_ref())?; + serializer.encode_raw_bytes(self.pacemaker_timeout_digest.as_ref())?; + Ok(()) + } +} + +impl CryptoHash for NewRoundMsgSerializer { + type Hasher = NewRoundMsgHasher; + + fn hash(&self) -> HashValue { + let mut state = Self::Hasher::default(); + state.write(&SimpleSerializer::>::serialize(self).expect("Should serialize.")); + state.finish() + } +} + +/// This message will be broadcast by a pacemaker when its local timeout for a round is reached. +/// Once the broadcasts start, retries will continue for every timeout until the round changes. +/// Retries are required since, say if a proposer for a round r was unresponsive, it might not +/// propose if it misses even only one PacemakerTimeoutMsg. +/// +/// The expected proposer will wait until n-f such messages are received before proposing to +/// ensure liveness (a next proposal has the highest quorum certificate across all replicas +/// as justification). If the expected proposer has a quorum certificate on round r-1, it need +/// not wait until n-f such messages are received and can make a proposal justified +/// by this quorum certificate. +#[derive(Deserialize, Serialize, Clone, Debug, PartialEq, Eq)] +pub struct NewRoundMsg { + highest_quorum_certificate: QuorumCert, + // Used for fast state synchronization. + highest_ledger_info: QuorumCert, + pacemaker_timeout: PacemakerTimeout, + author: Author, + signature: Signature, +} + +impl NewRoundMsg { + /// Creates new PacemakerTimeoutMsg + pub fn new( + highest_quorum_certificate: QuorumCert, + highest_ledger_info: QuorumCert, + pacemaker_timeout: PacemakerTimeout, + validator_signer: &ValidatorSigner, + ) -> NewRoundMsg { + let author = validator_signer.author(); + let digest = Self::new_round_digest( + highest_quorum_certificate.certified_block_id(), + pacemaker_timeout.digest(), + ); + let signature = validator_signer + .sign_message(digest) + .expect("Failed to sign PacemakerTimeoutMsg"); + NewRoundMsg { + highest_quorum_certificate, + highest_ledger_info, + pacemaker_timeout, + author, + signature, + } + } + + fn new_round_digest( + highest_quorum_certificate_block_id: HashValue, + pacemaker_timeout_digest: HashValue, + ) -> HashValue { + NewRoundMsgSerializer { + highest_quorum_certificate_block_id, + pacemaker_timeout_digest, + } + .hash() + } + + /// Calculates digest for this message + pub fn digest(&self) -> HashValue { + Self::new_round_digest( + self.highest_quorum_certificate.certified_block_id(), + self.pacemaker_timeout.digest(), + ) + } + + /// Highest QC carried by the new round message. + pub fn highest_quorum_certificate(&self) -> &QuorumCert { + &self.highest_quorum_certificate + } + + /// Returns a reference to a QuorumCert that has the highest round LedgerInfo + pub fn highest_ledger_info(&self) -> &QuorumCert { + &self.highest_ledger_info + } + + /// Returns a reference to the included PacemakerTimeout + pub fn pacemaker_timeout(&self) -> &PacemakerTimeout { + &self.pacemaker_timeout + } + + /// Verifies that this message has valid signature + pub fn verify(&self, validator: &ValidatorVerifier) -> Result<(), VerifyError> { + validator.verify_signature(self.author, self.digest(), &self.signature)?; + self.pacemaker_timeout.verify(validator) + } + + /// Returns the author of the NewRoundMsg + pub fn author(&self) -> Author { + self.author + } + + /// Returns a reference to the signature of the author + #[allow(dead_code)] + pub fn signature(&self) -> &Signature { + &self.signature + } +} + +impl IntoProto for NewRoundMsg { + type ProtoType = network::proto::NewRound; + + fn into_proto(self) -> Self::ProtoType { + let mut proto = Self::ProtoType::new(); + proto.set_highest_quorum_cert(self.highest_quorum_certificate.into_proto()); + proto.set_highest_ledger_info(self.highest_ledger_info.into_proto()); + proto.set_pacemaker_timeout(self.pacemaker_timeout.into_proto()); + proto.set_author(self.author.into()); + proto.set_signature(self.signature.to_compact().as_ref().into()); + proto + } +} + +impl FromProto for NewRoundMsg { + type ProtoType = network::proto::NewRound; + + fn from_proto(mut object: Self::ProtoType) -> failure::Result { + let highest_quorum_certificate = QuorumCert::from_proto(object.take_highest_quorum_cert())?; + let highest_ledger_info = QuorumCert::from_proto(object.take_highest_ledger_info())?; + let pacemaker_timeout = PacemakerTimeout::from_proto(object.take_pacemaker_timeout())?; + let author = Author::try_from(object.take_author())?; + let signature = Signature::from_compact(object.get_signature())?; + Ok(NewRoundMsg { + highest_quorum_certificate, + highest_ledger_info, + pacemaker_timeout, + author, + signature, + }) + } +} + +#[derive(Deserialize, Serialize, Clone, Debug, PartialEq, Eq)] +/// Proposal can include this timeout certificate as justification for switching to next round +pub struct PacemakerTimeoutCertificate { + round: Round, + timeouts: Vec, +} + +/// PacemakerTimeoutCertificate verification errors. +#[derive(Debug, PartialEq)] +pub enum PacemakerTimeoutCertificateVerificationError { + /// Number of signed timeouts is less then required quorum size + NoQuorum, + /// Round in message does not match calculated rounds based on signed timeouts + RoundMismatch { expected: Round }, + /// The signature on one of timeouts doesn't pass verification + SigVerifyError(Author, VerifyError), +} + +impl fmt::Display for PacemakerTimeoutCertificate { + fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result { + write!(f, "TimeoutCertificate", timeout.round())?; + if idx != self.timeouts.len() - 1 { + write!(f, ", ")?; + } + } + write!(f, "]") + } +} + +impl PacemakerTimeoutCertificate { + /// Creates new PacemakerTimeoutCertificate + pub fn new(round: Round, timeouts: Vec) -> PacemakerTimeoutCertificate { + PacemakerTimeoutCertificate { round, timeouts } + } + + /// Verifies that timeouts in message actually certify the round + pub fn verify( + &self, + validator: &ValidatorVerifier, + ) -> Result<(), PacemakerTimeoutCertificateVerificationError> { + let mut min_round: Option = None; + let mut unique_authors = HashSet::new(); + for timeout in &self.timeouts { + if let Err(e) = + validator.verify_signature(timeout.author(), timeout.digest(), timeout.signature()) + { + return Err(SigVerifyError(timeout.author(), e)); + } + unique_authors.insert(timeout.author()); + let timeout_round = timeout.round(); + min_round = Some(min_round.map_or(timeout_round, move |x| x.min(timeout_round))) + } + if unique_authors.len() < validator.quorum_size() { + return Err(NoQuorum); + } + if min_round == Some(self.round) { + Ok(()) + } else { + Err(RoundMismatch { + expected: min_round.unwrap_or(0), + }) + } + } + + /// Returns the round of the timeout + pub fn round(&self) -> Round { + self.round + } + + /// Returns the timeouts that certify the PacemakerTimeoutCertificate + #[allow(dead_code)] + pub fn timeouts(&self) -> &Vec { + &self.timeouts + } +} + +impl IntoProto for PacemakerTimeoutCertificate { + type ProtoType = network::proto::PacemakerTimeoutCertificate; + + fn into_proto(self) -> Self::ProtoType { + let mut proto = Self::ProtoType::new(); + proto.set_timeouts(RepeatedField::from_iter( + self.timeouts.into_iter().map(PacemakerTimeout::into_proto), + )); + proto.set_round(self.round); + proto + } +} + +impl FromProto for PacemakerTimeoutCertificate { + type ProtoType = network::proto::PacemakerTimeoutCertificate; + + fn from_proto(mut object: Self::ProtoType) -> failure::Result { + let timeouts = object + .take_timeouts() + .into_iter() + .map(PacemakerTimeout::from_proto) + .collect::>>()?; + Ok(PacemakerTimeoutCertificate::new( + object.get_round(), + timeouts, + )) + } +} diff --git a/consensus/src/chained_bft/liveness/pacemaker.rs b/consensus/src/chained_bft/liveness/pacemaker.rs new file mode 100644 index 0000000000000..d8fd52dbd08fd --- /dev/null +++ b/consensus/src/chained_bft/liveness/pacemaker.rs @@ -0,0 +1,81 @@ +// Copyright (c) The Libra Core Contributors +// SPDX-License-Identifier: Apache-2.0 + +use crate::{ + chained_bft::{ + common::Round, + liveness::new_round_msg::{PacemakerTimeout, PacemakerTimeoutCertificate}, + }, + stream_utils::EventBasedActor, +}; +use futures::Future; +use std::{ + pin::Pin, + time::{Duration, Instant}, +}; + +/// Pacemaker events are external signals provided by other components in order to facilitate +/// pacemaker functioning. For example, when there are enough votes to form a new QC, it can serve +/// as a signal to move to the new round. Different pacemaker implementations can choose to +/// consider / ignore different signals. +#[derive(Eq, Debug, PartialEq)] +pub enum PacemakerEvent { + // An event that denotes a block is quorum certified at a particular round. This occurs when + // 1) A replica aggregates multiple votes into a quorum certificate + // 2) A replica receives a quorum certificate from another replica (i.e. piggybacked on a + // proposal or a vote). + QuorumCertified { round: Round }, + // Used for timeouts: in the beginning of round R one can set up a timeout event for round + // R, which is going to be ignored if the overall state has progressed further. + Timeout { round: Round }, + // Used to handle pacemaker timeout information sent by pacemakers of other validators + RemoteTimeout { pacemaker_timeout: PacemakerTimeout }, +} + +/// A reason for starting a new round: introduced for monitoring / debug purposes. +#[derive(Eq, Debug, PartialEq)] +pub enum NewRoundReason { + QCReady, + Timeout { cert: PacemakerTimeoutCertificate }, +} + +/// NewRoundEvents produced by Pacemaker are guaranteed to be monotonically increasing. +/// NewRoundEvents are consumed by the rest of the system: they can cause sending new proposals +/// or voting for some proposals that wouldn't have been voted otherwise. +/// The duration is populated for debugging and testing +#[derive(Debug, PartialEq, Eq)] +pub struct NewRoundEvent { + pub round: Round, + pub reason: NewRoundReason, + pub timeout: Duration, +} + +/// Pacemaker is responsible for generating the new round events, which are driving the actions +/// of the rest of the system (e.g., for generating new proposals). +/// Ideal pacemaker provides an abstraction of a "shared clock". In reality pacemaker +/// implementations use external signals like receiving new votes / QCs plus internal +/// communication between other nodes' pacemaker instances in order to synchronize the logical +/// clocks. +/// The trait doesn't specify the starting conditions or the executor that is responsible for +/// driving the logic. +pub trait Pacemaker: + EventBasedActor + Send + Sync +{ + /// Returns deadline for current round + fn current_round_deadline(&self) -> Instant; + + /// Synchronous function to return the current round. + fn current_round(&self) -> Round; + + /// Function to update current round when proposal is received + /// Both round of latest received QC and timeout certificates are taken into account + /// This function guarantees to update pacemaker state when promise that it returns is fulfilled + fn process_certificates_from_proposal( + &self, + qc_round: Round, + timeout_certificate: Option<&PacemakerTimeoutCertificate>, + ) -> Pin + Send>>; + + /// Update the highest committed round + fn update_highest_committed_round(&self, highest_committed_round: Round); +} diff --git a/consensus/src/chained_bft/liveness/pacemaker_timeout_manager.rs b/consensus/src/chained_bft/liveness/pacemaker_timeout_manager.rs new file mode 100644 index 0000000000000..64d6faf33175e --- /dev/null +++ b/consensus/src/chained_bft/liveness/pacemaker_timeout_manager.rs @@ -0,0 +1,222 @@ +// Copyright (c) The Libra Core Contributors +// SPDX-License-Identifier: Apache-2.0 + +use crate::chained_bft::{ + common::Author, + liveness::new_round_msg::{PacemakerTimeout, PacemakerTimeoutCertificate}, + persistent_storage::PersistentLivenessStorage, +}; +use logger::prelude::*; +use serde::{Deserialize, Serialize}; +use std::collections::HashMap; + +#[cfg(test)] +#[path = "pacemaker_timeout_manager_test.rs"] +mod pacemaker_timeout_manager_test; + +/// Tracks the highest round known local and received timeout certificates +#[derive(Clone, Debug, Default, Serialize, Deserialize)] +pub struct HighestTimeoutCertificates { + // Highest timeout certificate gathered locally + highest_local_timeout_certificate: Option, + // Highest timeout certificate received from another replica + highest_received_timeout_certificate: Option, +} + +impl HighestTimeoutCertificates { + #[cfg(test)] + pub fn new( + highest_local_timeout_certificate: Option, + highest_received_timeout_certificate: Option, + ) -> Self { + Self { + highest_local_timeout_certificate, + highest_received_timeout_certificate, + } + } + + /// Return a optional reference to the highest timeout certificate (locally generated or + /// remotely received) + pub fn highest_timeout_certificate(&self) -> Option<&PacemakerTimeoutCertificate> { + if let Some(highest_received_timeout_certificate) = + self.highest_received_timeout_certificate.as_ref() + { + if let Some(highest_local_timeout_certificate) = &self.highest_local_timeout_certificate + { + if highest_local_timeout_certificate.round() + > highest_received_timeout_certificate.round() + { + self.highest_local_timeout_certificate.as_ref() + } else { + self.highest_received_timeout_certificate.as_ref() + } + } else { + self.highest_received_timeout_certificate.as_ref() + } + } else { + self.highest_local_timeout_certificate.as_ref() + } + } +} + +/// Manages the PacemakerTimeout structs received from replicas. +/// +/// A replica can generate and track TimeoutCertificates of the highest round (locally and received) +/// to allow a pacemaker to advance to the latest certificate round. +pub struct PacemakerTimeoutManager { + // The minimum quorum to generate a timeout certificate + timeout_certificate_quorum_size: usize, + // Track the PacemakerTimeoutMsg for highest timeout round received from this node + author_to_received_timeouts: HashMap, + // Highest timeout certificates + highest_timeout_certificates: HighestTimeoutCertificates, + // Used to persistently store the latest known timeout certificate + persistent_liveness_storage: Box, +} + +impl PacemakerTimeoutManager { + pub fn new( + timeout_certificate_quorum_size: usize, + highest_timeout_certificates: HighestTimeoutCertificates, + persistent_liveness_storage: Box, + ) -> Self { + // This struct maintains the invariant that the highest round timeout certificate + // that author_to_received_timeouts can generate is always equal to + // highest_timeout_certificates.highest_local_timeout_certificate. + let mut author_to_received_timeouts = HashMap::new(); + if let Some(tc) = &highest_timeout_certificates.highest_local_timeout_certificate { + author_to_received_timeouts = tc + .timeouts() + .iter() + .map(|t| (t.author(), t.clone())) + .collect(); + } + PacemakerTimeoutManager { + timeout_certificate_quorum_size, + author_to_received_timeouts, + highest_timeout_certificates, + persistent_liveness_storage, + } + } + + /// Returns the highest round PacemakerTimeoutCertificate from a map of author to + /// timeout messages or None if there are not enough timeout messages available. + /// A PacemakerTimeoutCertificate is made of the N highest timeout messages received where + /// N=timeout_quorum_size. The round of PacemakerTimeoutCertificate is determined as + /// the smallest of round of all messages used to generate this certificate. + /// + /// For example, if timeout_certificate_quorum_size=3 and we received unique author timeouts + /// for rounds (1,2,3,4), then rounds (2,3,4) would form PacemakerTimeoutCertificate with + /// round=2. + fn generate_timeout_certificate( + author_to_received_timeouts: &HashMap, + timeout_certificate_quorum_size: usize, + ) -> Option { + if author_to_received_timeouts.values().len() < timeout_certificate_quorum_size { + return None; + } + let mut values: Vec<&PacemakerTimeout> = author_to_received_timeouts.values().collect(); + values.sort_by(|x, y| y.round().cmp(&x.round())); + let slice = &values[..timeout_certificate_quorum_size]; + Some(PacemakerTimeoutCertificate::new( + // expect does not panic here because code above verifies values length + slice + .last() + .expect("Slice for timeout certificate is empty") + .round(), + slice.iter().map(|x| (*x).clone()).collect(), + )) + } + + /// Updates internal state according to received message from remote pacemaker and returns true + /// if round derived from highest PacemakerTimeoutCertificate has increased. + pub fn update_received_timeout(&mut self, pacemaker_timeout: PacemakerTimeout) -> bool { + let author = pacemaker_timeout.author(); + let prev_timeout = self.author_to_received_timeouts.get(&author).cloned(); + if let Some(prev_timeout) = &prev_timeout { + if prev_timeout.round() >= pacemaker_timeout.round() { + warn!("Received timeout message for previous round, ignoring. Author: {}, prev round: {}, received: {}", + author.short_str(), prev_timeout.round(), pacemaker_timeout.round()); + return false; + } + } + + self.author_to_received_timeouts + .insert(author, pacemaker_timeout.clone()); + let highest_timeout_certificate = Self::generate_timeout_certificate( + &self.author_to_received_timeouts, + self.timeout_certificate_quorum_size, + ); + let highest_round = match &highest_timeout_certificate { + Some(tc) => tc.round(), + None => return false, + }; + let prev_highest_round = self + .highest_timeout_certificates + .highest_local_timeout_certificate + .as_ref() + .map(PacemakerTimeoutCertificate::round); + assert!( + highest_round >= prev_highest_round.unwrap_or(0), + "Went down on highest timeout quorum round from {:?} to {:?}. + Received: {:?}, all: {:?}", + prev_highest_round, + highest_round, + pacemaker_timeout, + self.author_to_received_timeouts, + ); + self.highest_timeout_certificates + .highest_local_timeout_certificate = highest_timeout_certificate; + if let Err(e) = self + .persistent_liveness_storage + .save_highest_timeout_cert(self.highest_timeout_certificates.clone()) + { + warn!( + "Failed to persist local highest timeout certificate in round {} due to {}", + highest_round, e + ); + } + highest_round > prev_highest_round.unwrap_or(0) + } + + /// Attempts to update highest_received_timeout_certificate when receiving a new remote + /// timeout certificate. Returns true if highest_received_timeout_certificate has changed + pub fn update_highest_received_timeout_certificate( + &mut self, + timeout_certificate: &PacemakerTimeoutCertificate, + ) -> bool { + if timeout_certificate.round() + > self + .highest_timeout_certificates + .highest_received_timeout_certificate + .as_ref() + .map_or(0, PacemakerTimeoutCertificate::round) + { + debug!( + "Received remote timeout certificate at round {}", + timeout_certificate.round() + ); + self.highest_timeout_certificates + .highest_received_timeout_certificate = Some(timeout_certificate.clone()); + if let Err(e) = self + .persistent_liveness_storage + .save_highest_timeout_cert(self.highest_timeout_certificates.clone()) + { + warn!( + "Failed to persist received highest timeout certificate in round {} due to {}", + timeout_certificate.round(), + e + ); + } + return true; + } + false + } + + /// Return a optional reference to the highest timeout certificate (locally generated or + /// remotely received) + pub fn highest_timeout_certificate(&self) -> Option<&PacemakerTimeoutCertificate> { + self.highest_timeout_certificates + .highest_timeout_certificate() + } +} diff --git a/consensus/src/chained_bft/liveness/pacemaker_timeout_manager_test.rs b/consensus/src/chained_bft/liveness/pacemaker_timeout_manager_test.rs new file mode 100644 index 0000000000000..c4ff19cf55d71 --- /dev/null +++ b/consensus/src/chained_bft/liveness/pacemaker_timeout_manager_test.rs @@ -0,0 +1,133 @@ +// Copyright (c) The Libra Core Contributors +// SPDX-License-Identifier: Apache-2.0 + +use crate::chained_bft::{ + liveness::{ + new_round_msg::{PacemakerTimeout, PacemakerTimeoutCertificate}, + pacemaker_timeout_manager::{HighestTimeoutCertificates, PacemakerTimeoutManager}, + }, + persistent_storage::PersistentStorage, + test_utils::{MockStorage, TestPayload}, +}; +use types::validator_signer::ValidatorSigner; + +#[test] +fn test_basic() { + let mut timeout_manager = PacemakerTimeoutManager::new( + 2, + HighestTimeoutCertificates::new(None, None), + MockStorage::::start_for_testing() + .0 + .persistent_liveness_storage(), + ); + assert_eq!(timeout_manager.highest_timeout_certificate(), None); + let validator_signer1 = ValidatorSigner::random(); + let validator_signer2 = ValidatorSigner::random(); + + // No timeout certificate generated on adding 2 timeouts from the same author + let timeout_signer1_round1 = PacemakerTimeout::new(1, &validator_signer1); + assert_eq!( + timeout_manager.update_received_timeout(timeout_signer1_round1), + false + ); + assert_eq!(timeout_manager.highest_timeout_certificate(), None); + let timeout_signer1_round2 = PacemakerTimeout::new(2, &validator_signer1); + assert_eq!( + timeout_manager.update_received_timeout(timeout_signer1_round2), + false + ); + assert_eq!(timeout_manager.highest_timeout_certificate(), None); + + // Timeout certificate generated on adding a timeout from signer2 + let timeout_signer2_round1 = PacemakerTimeout::new(1, &validator_signer2); + assert_eq!( + timeout_manager.update_received_timeout(timeout_signer2_round1), + true + ); + assert_eq!( + timeout_manager + .highest_timeout_certificate() + .unwrap() + .round(), + 1 + ); + + // Timeout certificate increased when incrementing the round from signer 2 + let timeout_signer2_round2 = PacemakerTimeout::new(2, &validator_signer2); + assert_eq!( + timeout_manager.update_received_timeout(timeout_signer2_round2), + true + ); + assert_eq!( + timeout_manager + .highest_timeout_certificate() + .unwrap() + .round(), + 2 + ); + + // No timeout certificate generated since signer 1 is still on round 2 + let timeout_signer2_round3 = PacemakerTimeout::new(3, &validator_signer2); + assert_eq!( + timeout_manager.update_received_timeout(timeout_signer2_round3), + false + ); + assert_eq!( + timeout_manager + .highest_timeout_certificate() + .unwrap() + .round(), + 2 + ); + + // Simulate received a higher received timeout certificate + let received_timeout_certificate = PacemakerTimeoutCertificate::new( + 10, + vec![ + PacemakerTimeout::new(10, &validator_signer1), + PacemakerTimeout::new(11, &validator_signer2), + ], + ); + assert_eq!( + timeout_manager.update_highest_received_timeout_certificate(&received_timeout_certificate), + true + ); + assert_eq!( + timeout_manager + .highest_timeout_certificate() + .unwrap() + .round(), + 10 + ); +} + +#[test] +fn test_recovery_from_highest_timeout_certificate() { + let validator_signer1 = ValidatorSigner::random(); + let validator_signer2 = ValidatorSigner::random(); + + let timeout1 = PacemakerTimeout::new(10, &validator_signer1); + let timeout2 = PacemakerTimeout::new(11, &validator_signer2); + let tc = PacemakerTimeoutCertificate::new(10, vec![timeout1, timeout2]); + + let timeout_manager = PacemakerTimeoutManager::new( + 2, + HighestTimeoutCertificates::new(Some(tc), None), + MockStorage::::start_for_testing() + .0 + .persistent_liveness_storage(), + ); + + assert_eq!( + timeout_manager + .author_to_received_timeouts + .contains_key(&validator_signer1.author()), + true + ); + assert_eq!( + timeout_manager + .author_to_received_timeouts + .contains_key(&validator_signer2.author()), + true + ); +} diff --git a/consensus/src/chained_bft/liveness/proposal_generator.rs b/consensus/src/chained_bft/liveness/proposal_generator.rs new file mode 100644 index 0000000000000..b2e37185cfc42 --- /dev/null +++ b/consensus/src/chained_bft/liveness/proposal_generator.rs @@ -0,0 +1,219 @@ +// Copyright (c) The Libra Core Contributors +// SPDX-License-Identifier: Apache-2.0 + +use crate::chained_bft::{common::Round, consensus_types::block::Block}; + +use crate::{ + chained_bft::{block_storage::BlockReader, common::Payload}, + counters, + state_replication::TxnManager, + time_service::{wait_if_possible, TimeService, WaitingError, WaitingSuccess}, +}; +use logger::prelude::*; +use std::{ + sync::{Arc, Mutex}, + time::{Duration, Instant}, +}; + +#[cfg(test)] +#[path = "proposal_generator_test.rs"] +mod proposal_generator_test; + +#[derive(Clone, Debug, PartialEq, Fail)] +/// ProposalGeneration logical errors (e.g., given round number is low). +pub enum ProposalGenerationError { + /// The round of a certified block we'd like to extend is not lower than the provided round. + #[fail(display = "GivenRoundTooLow")] + GivenRoundTooLow(Round), + #[fail(display = "TxnRetrievalError")] + TxnRetrievalError, + /// Local clock waiting completed, but the timestamp is still not greater than its parent + #[fail(display = "CurrentTimeTooOld")] + CurrentTimeTooOld, + /// Local clock waiting would exceed round duration to allow the timestamp to be greater that + /// its parent + #[fail(display = "ExceedsMaxRoundDuration")] + ExceedsMaxRoundDuration, + /// Already proposed at this round (only a single proposal per round is allowed) + #[fail(display = "CurrentTimeTooOld")] + AlreadyProposed(Round), +} + +/// ProposalGenerator is responsible for generating the proposed block on demand: it's typically +/// used by a validator that believes it's a valid candidate for serving as a proposer at a given +/// round. +/// ProposalGenerator is the one choosing the branch to extend: +/// - height is determined as parent.height + 1, +/// - round is given by the caller (typically determined by Pacemaker). +/// The transactions for the proposed block are delivered by TxnManager. +/// +/// TxnManager should be aware of the pending transactions in the branch that it is extending, +/// such that it will filter them out to avoid transaction duplication. +pub struct ProposalGenerator { + // Block store is queried both for finding the branch to extend and for generating the + // proposed block. + block_store: Arc + Send + Sync>, + // Transaction manager is delivering the transactions. + txn_manager: Arc>, + // Time service to generate block timestamps + time_service: Arc, + // Max number of transactions to be added to a proposed block. + max_block_size: u64, + // Support increasing block timestamps + enforce_increasing_timestamps: bool, + // Last round that a proposal was generated + last_round_generated: Mutex, +} + +impl ProposalGenerator { + pub fn new( + block_store: Arc + Send + Sync>, + txn_manager: Arc>, + time_service: Arc, + max_block_size: u64, + enforce_increasing_timestamps: bool, + ) -> Self { + Self { + block_store, + txn_manager, + time_service, + max_block_size, + enforce_increasing_timestamps, + last_round_generated: Mutex::new(0), + } + } + + /// The function generates a new proposal block: the returned future is fulfilled when the + /// payload is delivered by the TxnManager implementation. At most one proposal can be + /// generated per round (no proposal equivocation allowed). + /// Errors returned by the TxnManager implementation are propagated to the caller. + /// The logic for choosing the branch to extend is as follows: + /// 1. The function gets the highest head of a one-chain from block tree. + /// The new proposal must extend hqc_block to ensure optimistic responsiveness. + /// 2. While the height is ultimately determined as the parent.height + 1, the round is provided + /// by the caller. + /// 3. In case a given round is not greater than the calculated parent, return an OldRound + /// error. + pub async fn generate_proposal( + &self, + round: Round, + round_deadline: Instant, + ) -> Result, ProposalGenerationError> { + { + let mut last_round_generated = self.last_round_generated.lock().unwrap(); + if *last_round_generated < round { + *last_round_generated = round; + } else { + return Err(ProposalGenerationError::AlreadyProposed(round)); + } + } + + let hqc_block = self.block_store.highest_certified_block(); + if hqc_block.round() >= round { + // The given round is too low. + return Err(ProposalGenerationError::GivenRoundTooLow(hqc_block.round())); + } + + // One needs to hold the blocks with the references to the payloads while get_block is + // being executed: pending blocks vector keeps all the pending ancestors of the extended + // branch. + let pending_blocks = match self.block_store.path_from_root(Arc::clone(&hqc_block)) { + Some(res) => res, + // In case the whole system moved forward between the check of a round and getting + // path from root. + None => { + return Err(ProposalGenerationError::GivenRoundTooLow(hqc_block.round())); + } + }; + //let pending_blocks = self.get_pending_blocks(Arc::clone(&hqc_block)); + // Exclude all the pending transactions: these are all the ancestors of + // parent (including) up to the root (excluding). + let exclude_payload = pending_blocks + .iter() + .map(|block| block.get_payload()) + .collect(); + + let block_timestamp = { + if self.enforce_increasing_timestamps { + match wait_if_possible( + self.time_service.as_ref(), + Duration::from_micros(hqc_block.timestamp_usecs()), + round_deadline, + ) + .await + { + Ok(waiting_success) => { + debug!( + "Success with {:?} for getting a valid timestamp for the next proposal", + waiting_success + ); + + match waiting_success { + WaitingSuccess::WaitWasRequired { + current_duration_since_epoch, + wait_duration, + } => { + counters::PROPOSAL_SUCCESS_WAIT_MS + .observe(wait_duration.as_millis() as f64); + counters::PROPOSAL_WAIT_WAS_REQUIRED_COUNT.inc(); + current_duration_since_epoch + } + WaitingSuccess::NoWaitRequired { + current_duration_since_epoch, + .. + } => { + counters::PROPOSAL_SUCCESS_WAIT_MS.observe(0.0); + counters::PROPOSAL_NO_WAIT_REQUIRED_COUNT.inc(); + current_duration_since_epoch + } + } + } + Err(waiting_error) => { + match waiting_error { + WaitingError::MaxWaitExceeded => { + error!( + "Waiting until parent block timestamp usecs {:?} would exceed the round duration {:?}, hence will not create a proposal for this round", + hqc_block.timestamp_usecs(), + round_deadline); + counters::PROPOSAL_FAILURE_WAIT_MS.observe(0.0); + counters::PROPOSAL_MAX_WAIT_EXCEEDED_COUNT.inc(); + return Err(ProposalGenerationError::ExceedsMaxRoundDuration); + } + WaitingError::WaitFailed { + current_duration_since_epoch, + wait_duration, + } => { + error!( + "Even after waiting for {:?}, parent block timestamp usecs {:?} >= current timestamp usecs {:?}, will not create a proposal for this round", + wait_duration, + hqc_block.timestamp_usecs(), + current_duration_since_epoch); + counters::PROPOSAL_FAILURE_WAIT_MS + .observe(wait_duration.as_millis() as f64); + counters::PROPOSAL_WAIT_FAILED_COUNT.inc(); + return Err(ProposalGenerationError::CurrentTimeTooOld); + } + }; + } + } + } else { + self.time_service.get_current_timestamp() + } + }; + + let block_store = Arc::clone(&self.block_store); + match self + .txn_manager + .pull_txns(self.max_block_size, exclude_payload) + .await + { + Ok(txns) => Ok(block_store.create_block( + hqc_block, + txns, + round, + block_timestamp.as_micros() as u64, + )), + Err(_) => Err(ProposalGenerationError::TxnRetrievalError), + } + } +} diff --git a/consensus/src/chained_bft/liveness/proposal_generator_test.rs b/consensus/src/chained_bft/liveness/proposal_generator_test.rs new file mode 100644 index 0000000000000..c50872aa16c23 --- /dev/null +++ b/consensus/src/chained_bft/liveness/proposal_generator_test.rs @@ -0,0 +1,140 @@ +// Copyright (c) The Libra Core Contributors +// SPDX-License-Identifier: Apache-2.0 + +use crate::{ + chained_bft::{ + block_storage::BlockReader, + liveness::proposal_generator::{ProposalGenerationError, ProposalGenerator}, + safety::vote_msg::VoteMsg, + test_utils::{ + build_empty_tree, placeholder_ledger_info, MockTransactionManager, TreeInserter, + }, + }, + mock_time_service::SimulatedTimeService, +}; +use futures::executor::block_on; +use std::{ + sync::Arc, + time::{Duration, Instant}, +}; + +fn minute_from_now() -> Instant { + Instant::now() + Duration::new(60, 0) +} + +#[test] +fn test_proposal_generation_empty_tree() { + let block_store = build_empty_tree(); + let proposal_generator = ProposalGenerator::new( + block_store.clone(), + Arc::new(MockTransactionManager::new()), + Arc::new(SimulatedTimeService::new()), + 1, + true, + ); + let genesis = block_store.root(); + + // Generate proposals for an empty tree. + let proposal = block_on(proposal_generator.generate_proposal(1, minute_from_now())).unwrap(); + assert_eq!(proposal.parent_id(), genesis.id()); + assert_eq!(proposal.round(), 1); + assert_eq!(proposal.height(), 1); + assert_eq!(proposal.quorum_cert().certified_block_id(), genesis.id()); + + // Duplicate proposals on the same round are not allowed + let proposal_err = block_on(proposal_generator.generate_proposal(1, minute_from_now())).err(); + assert_eq!( + proposal_err.unwrap(), + ProposalGenerationError::AlreadyProposed(1) + ); +} + +#[test] +fn test_proposal_generation_parent() { + let block_store = build_empty_tree(); + let mut inserter = TreeInserter::new(block_store.clone()); + let proposal_generator = ProposalGenerator::new( + block_store.clone(), + Arc::new(MockTransactionManager::new()), + Arc::new(SimulatedTimeService::new()), + 1, + true, + ); + let genesis = block_store.root(); + let a1 = inserter.insert_block(genesis.as_ref(), 1); + let b1 = inserter.insert_block(genesis.as_ref(), 2); + + // With no certifications the parent is genesis + // generate proposals for an empty tree. + assert_eq!( + block_on(proposal_generator.generate_proposal(10, minute_from_now())) + .unwrap() + .parent_id(), + genesis.id() + ); + + // Once a1 is certified, it should be the one to choose from + let vote_msg_a1 = VoteMsg::new( + a1.id(), + block_store.get_state_for_block(a1.id()).unwrap(), + a1.round(), + block_store.signer().author(), + placeholder_ledger_info(), + block_store.signer(), + ); + block_on(block_store.insert_vote_and_qc(vote_msg_a1, 1)); + let a1_child_res = + block_on(proposal_generator.generate_proposal(11, minute_from_now())).unwrap(); + assert_eq!(a1_child_res.parent_id(), a1.id()); + assert_eq!(a1_child_res.round(), 11); + assert_eq!(a1_child_res.height(), 2); + assert_eq!(a1_child_res.quorum_cert().certified_block_id(), a1.id()); + + // Once b1 is certified, it should be the one to choose from + let vote_msg_b1 = VoteMsg::new( + b1.id(), + block_store.get_state_for_block(b1.id()).unwrap(), + b1.round(), + block_store.signer().author(), + placeholder_ledger_info(), + block_store.signer(), + ); + + block_on(block_store.insert_vote_and_qc(vote_msg_b1, 1)); + let b1_child_res = + block_on(proposal_generator.generate_proposal(12, minute_from_now())).unwrap(); + assert_eq!(b1_child_res.parent_id(), b1.id()); + assert_eq!(b1_child_res.round(), 12); + assert_eq!(b1_child_res.height(), 2); + assert_eq!(b1_child_res.quorum_cert().certified_block_id(), b1.id()); +} + +#[test] +fn test_old_proposal_generation() { + let block_store = build_empty_tree(); + let mut inserter = TreeInserter::new(block_store.clone()); + let proposal_generator = ProposalGenerator::new( + block_store.clone(), + Arc::new(MockTransactionManager::new()), + Arc::new(SimulatedTimeService::new()), + 1, + true, + ); + let genesis = block_store.root(); + let a1 = inserter.insert_block(genesis.as_ref(), 1); + let vote_msg_a1 = VoteMsg::new( + a1.id(), + block_store.get_state_for_block(a1.id()).unwrap(), + a1.round(), + block_store.signer().author(), + placeholder_ledger_info(), + block_store.signer(), + ); + block_on(block_store.insert_vote_and_qc(vote_msg_a1, 1)); + + let proposal_err = block_on(proposal_generator.generate_proposal(1, minute_from_now())).err(); + assert_eq!( + proposal_err.unwrap(), + ProposalGenerationError::GivenRoundTooLow(1) + ); +} diff --git a/consensus/src/chained_bft/liveness/proposer_election.rs b/consensus/src/chained_bft/liveness/proposer_election.rs new file mode 100644 index 0000000000000..620afadbe4156 --- /dev/null +++ b/consensus/src/chained_bft/liveness/proposer_election.rs @@ -0,0 +1,136 @@ +// Copyright (c) The Libra Core Contributors +// SPDX-License-Identifier: Apache-2.0 + +use crate::{ + chained_bft::{ + common::{Author, Payload, Round}, + consensus_types::{block::Block, quorum_cert::QuorumCert}, + liveness::new_round_msg::PacemakerTimeoutCertificate, + }, + stream_utils::EventBasedActor, +}; +use failure::Result; +use network::proto::Proposal as ProtoProposal; +use proto_conv::{FromProto, IntoProto}; +use rmp_serde::{from_slice, to_vec_named}; +use serde::{de::DeserializeOwned, Serialize}; +use std::fmt; +use types::validator_verifier::ValidatorVerifier; + +/// ProposerInfo is a general trait that can include various proposer characteristics +/// relevant to a specific protocol implementation. The author is the only common thing for now. +pub trait ProposerInfo: + Send + Sync + Clone + Copy + fmt::Debug + DeserializeOwned + Serialize + 'static +{ + fn get_author(&self) -> Author; +} + +/// Trivial ProposerInfo implementation. +impl ProposerInfo for Author { + fn get_author(&self) -> Author { + *self + } +} + +/// ProposalInfo contains the required information for the proposer election protocol to make its +/// choice (typically depends on round and proposer info). +#[derive(Debug, Clone, PartialEq, Eq)] +pub struct ProposalInfo { + pub proposal: Block, + pub proposer_info: P, + pub timeout_certificate: Option, + // use to notify about last committed block and the receiver could decide to start + // a synchronization if it's behind + pub highest_ledger_info: QuorumCert, +} + +impl ProposalInfo { + pub fn verify(&self, validator: &ValidatorVerifier) -> Result<()> { + self.proposal + .verify(validator) + .map_err(|e| format_err!("{:?}", e))?; + if let Some(tc) = &self.timeout_certificate { + tc.verify(validator).map_err(|e| format_err!("{:?}", e))?; + } + if self.proposal.author() != self.proposer_info.get_author() { + return Err(format_err!("Proposal for {} has mismatching author of block and proposer info: block={}, proposer={}", self.proposal, + self.proposal.author(), self.proposer_info.get_author())); + } + self.highest_ledger_info + .verify(validator) + .map_err(|e| format_err!("{:?}", e))?; + + Ok(()) + } +} + +impl fmt::Display for ProposalInfo +where + P: ProposerInfo, +{ + fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result { + write!( + f, + "[block {} from {}]", + self.proposal, + self.proposer_info.get_author().short_str() + ) + } +} + +/// ProposerElection incorporates the logic of choosing a leader among multiple candidates. +/// We are open to a possibility for having multiple proposers per round, the ultimate choice +/// of a proposal is exposed by the election protocol via the stream of proposals. +pub trait ProposerElection: + EventBasedActor, OutputEvent = ProposalInfo> +{ + /// If a given author is a valid candidate for being a proposer, generate the info, + /// otherwise return None. + /// Note that this function is synchronous. + fn is_valid_proposer(&self, author: P, round: Round) -> Option

; + + /// Return all the possible valid proposers for a given round (this information can be + /// used by e.g., voters for choosing the destinations for sending their votes to). + fn get_valid_proposers(&self, round: Round) -> Vec

; +} + +impl IntoProto for ProposalInfo { + type ProtoType = ProtoProposal; + + fn into_proto(self) -> Self::ProtoType { + let mut proto = Self::ProtoType::new(); + let (block, proposer, hli) = (self.proposal, self.proposer_info, self.highest_ledger_info); + proto.set_proposed_block(block.into_proto()); + proto.set_proposer( + to_vec_named(&proposer) + .expect("fail to serialize proposer info") + .into(), + ); + if let Some(tc) = self.timeout_certificate { + proto.set_timeout_quorum_cert(tc.into_proto()); + } + proto.set_highest_ledger_info(hli.into_proto()); + proto + } +} + +impl FromProto for ProposalInfo { + type ProtoType = ProtoProposal; + + fn from_proto(mut object: Self::ProtoType) -> Result { + let proposal = Block::::from_proto(object.take_proposed_block())?; + let proposer_info = from_slice(object.get_proposer())?; + let highest_ledger_info = QuorumCert::from_proto(object.take_highest_ledger_info())?; + let timeout_certificate = if let Some(tc) = object.timeout_quorum_cert.into_option() { + Some(PacemakerTimeoutCertificate::from_proto(tc)?) + } else { + None + }; + Ok(ProposalInfo { + proposal, + proposer_info, + timeout_certificate, + highest_ledger_info, + }) + } +} diff --git a/consensus/src/chained_bft/liveness/rotating_proposer_election.rs b/consensus/src/chained_bft/liveness/rotating_proposer_election.rs new file mode 100644 index 0000000000000..14dc048adfa01 --- /dev/null +++ b/consensus/src/chained_bft/liveness/rotating_proposer_election.rs @@ -0,0 +1,82 @@ +// Copyright (c) The Libra Core Contributors +// SPDX-License-Identifier: Apache-2.0 + +use crate::{ + chained_bft::{ + common::{Payload, Round}, + liveness::proposer_election::{ProposalInfo, ProposerElection, ProposerInfo}, + }, + stream_utils::EventBasedActor, +}; +use futures::{channel::mpsc, Future, FutureExt, SinkExt}; +use logger::prelude::*; +use std::pin::Pin; + +/// The rotating proposer maps a round to an author according to a round-robin rotation. +/// A fixed proposer strategy loses liveness when the fixed proposer is down. Rotating proposers +/// won't gather quorum certificates to machine loss/byzantine behavior on f/n rounds. +pub struct RotatingProposer { + // Ordering of proposers to rotate through (all honest replicas must agree on this) + proposers: Vec

, + // Number of contiguous rounds (i.e. round numbers increase by 1) a proposer is active + // in a row + contiguous_rounds: u32, + // Output stream to send the chosen proposals + winning_proposals: Option>>, +} + +impl RotatingProposer { + /// With only one proposer in the vector, it behaves the same as a fixed proposer strategy. + pub fn new(proposers: Vec

, contiguous_rounds: u32) -> Self { + Self { + proposers, + contiguous_rounds, + winning_proposals: None, + } + } + + fn get_proposer(&self, round: Round) -> P { + self.proposers + [((round / u64::from(self.contiguous_rounds)) % self.proposers.len() as u64) as usize] + } +} + +impl ProposerElection for RotatingProposer { + fn is_valid_proposer(&self, author: P, round: Round) -> Option

{ + if self.get_proposer(round).get_author() == author.get_author() { + Some(author) + } else { + None + } + } + + fn get_valid_proposers(&self, round: Round) -> Vec

( + stk: &mut ExecutionStack<'alloc, 'txn, P>, + stack_state: StackState<'alloc>, + ) -> Bytecode + where + P: ModuleCache<'alloc>, + { + // Set the value stack + stk.set_stack(stack_state.stack); + + // Perform the frame transition (if there is any needed) + frame_transitions(stk, &stack_state.instr, stack_state.module_info); + + // Populate the locals of the frame + for (local_index, local) in stack_state.local_mapping.into_iter() { + assert_ok!(stk + .top_frame_mut() + .expect("[Stack Transition] Unable to get top frame on execution stack.") + .store_local(local_index, local)); + } + stack_state.instr + } +} + +impl<'alloc, 'txn> Iterator for RandomStackGenerator<'alloc, 'txn> +where + 'alloc: 'txn, +{ + type Item = StackState<'txn>; + fn next(&mut self) -> Option { + self.next_stack() + } +} diff --git a/language/vm/cost_synthesis/src/vm_runner.rs b/language/vm/cost_synthesis/src/vm_runner.rs new file mode 100644 index 0000000000000..0bee6feec8464 --- /dev/null +++ b/language/vm/cost_synthesis/src/vm_runner.rs @@ -0,0 +1,66 @@ +// Copyright (c) The Libra Core Contributors +// SPDX-License-Identifier: Apache-2.0 + +//! Defines the VM context for running instruction synthesis. +use types::access_path::AccessPath; +use vm::errors::VMInvariantViolation; +use vm_runtime::data_cache::RemoteCache; + +/// A fake data cache used to build a transaction processor. +/// +/// This is a simple fake data cache that doesn't cache anything. If we try `get`ing anything from +/// it then we return that we did not error, and that we did not find the data. +#[derive(Default)] +pub struct FakeDataCache; + +impl FakeDataCache { + /// Create a fake data cache. + pub fn new() -> Self { + FakeDataCache + } +} + +impl RemoteCache for FakeDataCache { + fn get(&self, _access_path: &AccessPath) -> Result>, VMInvariantViolation> { + Ok(None) + } +} + +/// Create a VM loaded with the modules defined by the module generator passed in. +/// +/// Returns back handles that can be used to reference the created VM, the root_module, and the +/// module cache of all loaded modules in the VM. +#[macro_export] +macro_rules! with_loaded_vm { + ($module_generator:expr => $vm:ident, $mod:ident, $module_cache:ident) => { + let mut modules = STDLIB_MODULES.clone(); + let mut generated_modules = $module_generator.collect(); + modules.append(&mut generated_modules); + // The last module is the root module based upon how we generate modules. + let root_module = modules + .last() + .expect("[VM Setup] Unable to get root module"); + let allocator = Arena::new(); + let module_id = root_module.self_code_key(); + let $module_cache = VMModuleCache::new(&allocator); + let entry_idx = FunctionDefinitionIndex::new(0); + let data_cache = FakeDataCache::new(); + $module_cache + .cache_module(root_module.clone()) + .expect("[Module Cache] Unable to cache root module."); + let $mod = $module_cache + .get_loaded_module(&module_id) + .expect("[Module Cache] Internal error encountered when fetching module.") + .expect("[Module Cache] Unable to find module in module cache."); + for m in modules { + $module_cache + .cache_module(m) + .expect("[Module Cache] Unable to cache module."); + } + let entry_func = FunctionRef::new(&$mod, entry_idx) + .expect("[Entry Function] Unable to build function reference for entry function."); + let mut $vm = + TransactionExecutor::new(&$module_cache, &data_cache, TransactionMetadata::default()); + $vm.execution_stack.push_frame(entry_func); + }; +} diff --git a/language/vm/coverage_report.sh b/language/vm/coverage_report.sh new file mode 100755 index 0000000000000..f46c6a3c00d58 --- /dev/null +++ b/language/vm/coverage_report.sh @@ -0,0 +1,81 @@ +#!/bin/bash + + +# Check that the report destination argument is provided +if [ $# -eq 0 ] +then + echo "Usage: coverage_report.sh path/to/report/destination" + exit 1 +fi + +# Set the directory to which the report will be saved +COVERAGE_DIR=$1 + +# Check that grcov is installed +if ! [ -x "$(command -v grcov)" ]; then + echo "Error: grcov is not installed." >&2 + read -p "Install grcov? [yY/*] " -n 1 -r + echo "" + if [[ ! $REPLY =~ ^[Yy]$ ]] + then + [[ "$0" = "$BASH_SOURCE" ]] && exit 1 || return 1 + fi + cargo install grcov +fi + +# Check that lcov is installed +if ! [ -x "$(command -v lcov)" ]; then + echo "Error: lcov is not installed." >&2 + echo "Assuming macOS and homebrew" + read -p "Install lcov? [yY/*] " -n 1 -r + echo "" + if [[ ! $REPLY =~ ^[Yy]$ ]] + then + [[ "$0" = "$BASH_SOURCE" ]] && exit 1 || return 1 + fi + brew install lcov +fi + +# Warn that cargo clean will happen +read -p "Generate coverage report? This will run cargo clean. [yY/*] " -n 1 -r +echo "" +if [[ ! $REPLY =~ ^[Yy]$ ]] +then + [[ "$0" = "$BASH_SOURCE" ]] && exit 1 || return 1 +fi + +# Remove existing coverage output +echo "Cleaning existing coverage info..." +find ../../target -type f -name "*.gcda" -delete +find ../../target -type f -name "*.gcno" -delete + +# Clean the project +echo "Cleaning project..." +cargo clean + +# Build with flags necessary for coverage output +echo "Building with coverage instrumentation..." +CARGO_INCREMENTAL=0 RUSTFLAGS="-Zprofile -Ccodegen-units=1 -Cinline-threshold=0 -Coverflow-checks=off -Zno-landing-pads" cargo build + +# Run tests +echo "Running tests..." +while read line; do + dirline=$(realpath $(dirname $line)); + (cd $dirline; CARGO_INCREMENTAL=0 RUSTFLAGS="-Zprofile -Ccodegen-units=1 -Cinline-threshold=0 -Coverflow-checks=off -Zno-landing-pads" cargo test) +done < <(find . -name 'Cargo.toml') + +# Make the coverage directory if it doesn't exist +if [ ! -d $COVERAGE_DIR ]; then + mkdir $COVERAGE_DIR; +fi + +# Generate lcov report +echo "Generating lcov report at ${COVERAGE_DIR}/lcov.info..." +grcov ../../target -t lcov --ignore-dir "/*" -o $COVERAGE_DIR/lcov.info + +# Generate HTML report +echo "Generating report at ${COVERAGE_DIR}..." +# Flag "--ignore-errors source" ignores missing source files +(cd ../../; genhtml -o $COVERAGE_DIR --show-details --highlight --ignore-errors source --legend $COVERAGE_DIR/lcov.info) + +echo "Done. Please view report at ${COVERAGE_DIR}/index.html" diff --git a/language/vm/src/access.rs b/language/vm/src/access.rs new file mode 100644 index 0000000000000..1dc64c4bb2009 --- /dev/null +++ b/language/vm/src/access.rs @@ -0,0 +1,236 @@ +// Copyright (c) The Libra Core Contributors +// SPDX-License-Identifier: Apache-2.0 + +//! Defines accessors for compiled modules. + +use std::slice; + +use types::{account_address::AccountAddress, byte_array::ByteArray, language_storage::CodeKey}; + +use crate::{ + errors::VMStaticViolation, + file_format::{ + AddressPoolIndex, ByteArrayPoolIndex, CompiledModule, CompiledScript, FieldDefinition, + FieldDefinitionIndex, FunctionDefinition, FunctionDefinitionIndex, FunctionHandle, + FunctionHandleIndex, FunctionSignature, FunctionSignatureIndex, LocalsSignature, + LocalsSignatureIndex, MemberCount, ModuleHandle, ModuleHandleIndex, StringPoolIndex, + StructDefinition, StructDefinitionIndex, StructHandle, StructHandleIndex, TypeSignature, + TypeSignatureIndex, + }, + internals::ModuleIndex, + IndexKind, +}; + +/// Represents accessors common to modules and scripts. +/// +/// This is done as a trait because in the future, we may be able to write an alternative impl for +/// bytecode that's already been checked for internal consistency. +pub trait BaseAccess: Sync { + fn module_handle_at(&self, idx: ModuleHandleIndex) -> &ModuleHandle; + fn struct_handle_at(&self, idx: StructHandleIndex) -> &StructHandle; + fn function_handle_at(&self, idx: FunctionHandleIndex) -> &FunctionHandle; + + fn type_signature_at(&self, idx: TypeSignatureIndex) -> &TypeSignature; + fn function_signature_at(&self, idx: FunctionSignatureIndex) -> &FunctionSignature; + fn locals_signature_at(&self, idx: LocalsSignatureIndex) -> &LocalsSignature; + + fn string_at(&self, idx: StringPoolIndex) -> &str; + fn byte_array_at(&self, idx: ByteArrayPoolIndex) -> &ByteArray; + fn address_at(&self, idx: AddressPoolIndex) -> &AccountAddress; + + // XXX is a partial range required here? + fn module_handles(&self) -> slice::Iter; + fn struct_handles(&self) -> slice::Iter; + fn function_handles(&self) -> slice::Iter; + + fn type_signatures(&self) -> slice::Iter; + fn function_signatures(&self) -> slice::Iter; + fn locals_signatures(&self) -> slice::Iter; + + fn byte_array_pool(&self) -> slice::Iter; + fn address_pool(&self) -> slice::Iter; + fn string_pool(&self) -> slice::Iter; +} + +/// Represents accessors for a compiled script. +/// +/// This is done as a trait because in the future, we may be able to write an alternative impl for a +/// script that's already been checked for internal consistency. +pub trait ScriptAccess: BaseAccess { + fn main(&self) -> &FunctionDefinition; +} + +/// Represents accessors for a compiled module. +/// +/// This is done as a trait because in the future, we may be able to write an alternative impl for a +/// module that's already been checked for internal consistency. +pub trait ModuleAccess: BaseAccess { + fn struct_def_at(&self, idx: StructDefinitionIndex) -> &StructDefinition; + fn field_def_at(&self, idx: FieldDefinitionIndex) -> &FieldDefinition; + fn function_def_at(&self, idx: FunctionDefinitionIndex) -> &FunctionDefinition; + + fn struct_defs(&self) -> slice::Iter; + fn field_defs(&self) -> slice::Iter; + fn function_defs(&self) -> slice::Iter; + + fn code_key_for_handle(&self, module_handle_idx: &ModuleHandle) -> CodeKey; + fn self_code_key(&self) -> CodeKey; + + fn field_def_range( + &self, + field_count: MemberCount, + first_field: FieldDefinitionIndex, + ) -> slice::Iter; +} + +macro_rules! impl_base_access { + ($ty:ty) => { + impl BaseAccess for $ty { + fn module_handle_at(&self, idx: ModuleHandleIndex) -> &ModuleHandle { + &self.module_handles[idx.into_index()] + } + + fn struct_handle_at(&self, idx: StructHandleIndex) -> &StructHandle { + &self.struct_handles[idx.into_index()] + } + + fn function_handle_at(&self, idx: FunctionHandleIndex) -> &FunctionHandle { + &self.function_handles[idx.into_index()] + } + + fn type_signature_at(&self, idx: TypeSignatureIndex) -> &TypeSignature { + &self.type_signatures[idx.into_index()] + } + + fn function_signature_at(&self, idx: FunctionSignatureIndex) -> &FunctionSignature { + &self.function_signatures[idx.into_index()] + } + + fn locals_signature_at(&self, idx: LocalsSignatureIndex) -> &LocalsSignature { + &self.locals_signatures[idx.into_index()] + } + + fn string_at(&self, idx: StringPoolIndex) -> &str { + self.string_pool[idx.into_index()].as_str() + } + + fn byte_array_at(&self, idx: ByteArrayPoolIndex) -> &ByteArray { + &self.byte_array_pool[idx.into_index()] + } + + fn address_at(&self, idx: AddressPoolIndex) -> &AccountAddress { + &self.address_pool[idx.into_index()] + } + + fn module_handles(&self) -> slice::Iter { + self.module_handles[..].iter() + } + fn struct_handles(&self) -> slice::Iter { + self.struct_handles[..].iter() + } + fn function_handles(&self) -> slice::Iter { + self.function_handles[..].iter() + } + + fn type_signatures(&self) -> slice::Iter { + self.type_signatures[..].iter() + } + fn function_signatures(&self) -> slice::Iter { + self.function_signatures[..].iter() + } + fn locals_signatures(&self) -> slice::Iter { + self.locals_signatures[..].iter() + } + + fn byte_array_pool(&self) -> slice::Iter { + self.byte_array_pool[..].iter() + } + fn address_pool(&self) -> slice::Iter { + self.address_pool[..].iter() + } + fn string_pool(&self) -> slice::Iter { + self.string_pool[..].iter() + } + } + }; +} + +impl_base_access!(CompiledModule); +impl_base_access!(CompiledScript); + +impl ModuleAccess for CompiledModule { + fn self_code_key(&self) -> CodeKey { + self.self_code_key() + } + + fn code_key_for_handle(&self, module_handle: &ModuleHandle) -> CodeKey { + self.code_key_for_handle(module_handle) + } + + fn struct_def_at(&self, idx: StructDefinitionIndex) -> &StructDefinition { + &self.struct_defs[idx.into_index()] + } + + fn field_def_at(&self, idx: FieldDefinitionIndex) -> &FieldDefinition { + &self.field_defs[idx.into_index()] + } + + fn function_def_at(&self, idx: FunctionDefinitionIndex) -> &FunctionDefinition { + &self.function_defs[idx.into_index()] + } + + fn struct_defs(&self) -> slice::Iter { + self.struct_defs[..].iter() + } + + fn field_defs(&self) -> slice::Iter { + self.field_defs[..].iter() + } + + fn function_defs(&self) -> slice::Iter { + self.function_defs[..].iter() + } + + fn field_def_range( + &self, + field_count: MemberCount, + first_field: FieldDefinitionIndex, + ) -> slice::Iter { + let first_field = first_field.0 as usize; + let field_count = field_count as usize; + let last_field = first_field + field_count; + self.field_defs[first_field..last_field].iter() + } +} + +impl ScriptAccess for CompiledScript { + fn main(&self) -> &FunctionDefinition { + &self.main + } +} + +impl CompiledModule { + #[inline] + pub(crate) fn check_field_range( + &self, + field_count: MemberCount, + first_field: FieldDefinitionIndex, + ) -> Option { + let first_field = first_field.into_index(); + let field_count = field_count as usize; + // Both first_field and field_count are u16 so this is guaranteed to not overflow. + // Note that last_field is exclusive, i.e. fields are in the range + // [first_field, last_field). + let last_field = first_field + field_count; + if last_field > self.field_defs.len() { + Some(VMStaticViolation::RangeOutOfBounds( + IndexKind::FieldDefinition, + self.field_defs.len(), + first_field, + last_field, + )) + } else { + None + } + } +} diff --git a/language/vm/src/checks.rs b/language/vm/src/checks.rs new file mode 100644 index 0000000000000..f29e32fd7dffd --- /dev/null +++ b/language/vm/src/checks.rs @@ -0,0 +1,8 @@ +// Copyright (c) The Libra Core Contributors +// SPDX-License-Identifier: Apache-2.0 + +pub mod bounds; +pub mod signature; + +pub use bounds::{BoundsCheck, BoundsChecker}; +pub use signature::SignatureCheck; diff --git a/language/vm/src/checks/bounds.rs b/language/vm/src/checks/bounds.rs new file mode 100644 index 0000000000000..b4b2695914a9b --- /dev/null +++ b/language/vm/src/checks/bounds.rs @@ -0,0 +1,337 @@ +// Copyright (c) The Libra Core Contributors +// SPDX-License-Identifier: Apache-2.0 + +use crate::{ + errors::{VMStaticViolation, VerificationError}, + file_format::{ + Bytecode, CompiledModule, FieldDefinition, FunctionDefinition, FunctionHandle, + FunctionSignature, LocalsSignature, ModuleHandle, SignatureToken, StructDefinition, + StructHandle, TypeSignature, + }, + internals::ModuleIndex, + IndexKind, +}; + +pub struct BoundsChecker<'a> { + module: &'a CompiledModule, +} + +impl<'a> BoundsChecker<'a> { + pub fn new(module: &'a CompiledModule) -> Self { + Self { module } + } + + pub fn verify(self) -> Vec { + let mut errors: Vec> = vec![]; + + // A module (or script) must always have at least one module handle. (For modules the first + // handle should be the same as the sender -- the bytecode verifier is unaware of + // transactions so it does not perform this check. + if self.module.module_handles.is_empty() { + errors.push(vec![VerificationError { + kind: IndexKind::ModuleHandle, + idx: 0, + err: VMStaticViolation::NoModuleHandles, + }]); + } + + errors.push(Self::verify_impl( + IndexKind::ModuleHandle, + self.module.module_handles.iter(), + self.module, + )); + errors.push(Self::verify_impl( + IndexKind::StructHandle, + self.module.struct_handles.iter(), + self.module, + )); + errors.push(Self::verify_impl( + IndexKind::FunctionHandle, + self.module.function_handles.iter(), + self.module, + )); + errors.push(Self::verify_impl( + IndexKind::StructDefinition, + self.module.struct_defs.iter(), + self.module, + )); + errors.push(Self::verify_impl( + IndexKind::FieldDefinition, + self.module.field_defs.iter(), + self.module, + )); + errors.push(Self::verify_impl( + IndexKind::FunctionDefinition, + self.module.function_defs.iter(), + self.module, + )); + errors.push(Self::verify_impl( + IndexKind::TypeSignature, + self.module.type_signatures.iter(), + self.module, + )); + errors.push(Self::verify_impl( + IndexKind::FunctionSignature, + self.module.function_signatures.iter(), + self.module, + )); + errors.push(Self::verify_impl( + IndexKind::LocalsSignature, + self.module.locals_signatures.iter(), + self.module, + )); + + let errors: Vec<_> = errors.into_iter().flatten().collect(); + if !errors.is_empty() { + return errors; + } + + // Code unit checking needs to be done once the rest of the module is validated. + self.module + .function_defs + .iter() + .enumerate() + .map(|(idx, elem)| { + elem.check_code_unit_bounds(self.module) + .into_iter() + .map(move |err| VerificationError { + kind: IndexKind::FunctionDefinition, + idx, + err, + }) + }) + .flatten() + .collect() + } + + #[inline] + fn verify_impl( + kind: IndexKind, + iter: impl Iterator, + module: &CompiledModule, + ) -> Vec { + iter.enumerate() + .map(move |(idx, elem)| { + elem.check_bounds(module) + .into_iter() + .map(move |err| VerificationError { kind, idx, err }) + }) + .flatten() + .collect() + } +} + +pub trait BoundsCheck { + fn check_bounds(&self, module: &CompiledModule) -> Vec; +} + +#[inline] +fn check_bounds_impl(pool: &[T], idx: I) -> Option +where + I: ModuleIndex, +{ + let idx = idx.into_index(); + let len = pool.len(); + if idx >= len { + Some(VMStaticViolation::IndexOutOfBounds(I::KIND, len, idx)) + } else { + None + } +} + +impl BoundsCheck for &ModuleHandle { + #[inline] + fn check_bounds(&self, module: &CompiledModule) -> Vec { + vec![ + check_bounds_impl(&module.address_pool, self.address), + check_bounds_impl(&module.string_pool, self.name), + ] + .into_iter() + .flatten() + .collect() + } +} + +impl BoundsCheck for &StructHandle { + #[inline] + fn check_bounds(&self, module: &CompiledModule) -> Vec { + vec![ + check_bounds_impl(&module.module_handles, self.module), + check_bounds_impl(&module.string_pool, self.name), + ] + .into_iter() + .flatten() + .collect() + } +} + +impl BoundsCheck for &FunctionHandle { + #[inline] + fn check_bounds(&self, module: &CompiledModule) -> Vec { + vec![ + check_bounds_impl(&module.module_handles, self.module), + check_bounds_impl(&module.string_pool, self.name), + check_bounds_impl(&module.function_signatures, self.signature), + ] + .into_iter() + .flatten() + .collect() + } +} + +impl BoundsCheck for &StructDefinition { + #[inline] + fn check_bounds(&self, module: &CompiledModule) -> Vec { + vec![ + check_bounds_impl(&module.struct_handles, self.struct_handle), + module.check_field_range(self.field_count, self.fields), + ] + .into_iter() + .flatten() + .collect() + } +} + +impl BoundsCheck for &FieldDefinition { + #[inline] + fn check_bounds(&self, module: &CompiledModule) -> Vec { + vec![ + check_bounds_impl(&module.struct_handles, self.struct_), + check_bounds_impl(&module.string_pool, self.name), + check_bounds_impl(&module.type_signatures, self.signature), + ] + .into_iter() + .flatten() + .collect() + } +} + +impl BoundsCheck for &FunctionDefinition { + #[inline] + fn check_bounds(&self, module: &CompiledModule) -> Vec { + vec![ + check_bounds_impl(&module.function_handles, self.function), + if self.is_native() { + None + } else { + check_bounds_impl(&module.locals_signatures, self.code.locals) + }, + ] + .into_iter() + .flatten() + .collect() + } +} + +impl BoundsCheck for &TypeSignature { + #[inline] + fn check_bounds(&self, module: &CompiledModule) -> Vec { + self.0.check_bounds(module).into_iter().collect() + } +} + +impl BoundsCheck for &FunctionSignature { + #[inline] + fn check_bounds(&self, module: &CompiledModule) -> Vec { + self.return_types + .iter() + .filter_map(|token| token.check_bounds(module)) + .chain( + self.arg_types + .iter() + .filter_map(|token| token.check_bounds(module)), + ) + .collect() + } +} + +impl BoundsCheck for &LocalsSignature { + #[inline] + fn check_bounds(&self, module: &CompiledModule) -> Vec { + self.0 + .iter() + .filter_map(|token| token.check_bounds(module)) + .collect() + } +} + +impl SignatureToken { + #[inline] + fn check_bounds(&self, module: &CompiledModule) -> Option { + match self.struct_index() { + Some(sh_idx) => check_bounds_impl(&module.struct_handles, sh_idx), + None => None, + } + } +} + +impl FunctionDefinition { + // This is implemented separately because it depends on the locals signature index being + // checked. + fn check_code_unit_bounds(&self, module: &CompiledModule) -> Vec { + if self.is_native() { + return vec![]; + } + + let locals_len = module.locals_signatures[self.code.locals.0 as usize] + .0 + .len(); + + let code = &self.code.code; + let code_len = code.len(); + + code.iter() + .filter_map(|bytecode| { + use self::Bytecode::*; + + match bytecode { + // Instructions that refer to other pools. + LdAddr(idx) => check_bounds_impl(&module.address_pool, *idx), + LdByteArray(idx) => check_bounds_impl(&module.byte_array_pool, *idx), + LdStr(idx) => check_bounds_impl(&module.string_pool, *idx), + BorrowField(idx) => check_bounds_impl(&module.field_defs, *idx), + Call(idx) => check_bounds_impl(&module.function_handles, *idx), + Pack(idx) | Unpack(idx) | Exists(idx) | BorrowGlobal(idx) | MoveFrom(idx) + | MoveToSender(idx) => check_bounds_impl(&module.struct_defs, *idx), + // Instructions that refer to this code block. + BrTrue(offset) | BrFalse(offset) | Branch(offset) => { + // XXX IndexOutOfBounds seems correct, but IndexKind::CodeDefinition + // (and LocalPool) feel wrong. Reconsider this at some point. + let offset = *offset as usize; + if offset >= code_len { + Some(VMStaticViolation::IndexOutOfBounds( + IndexKind::CodeDefinition, + code_len, + offset, + )) + } else { + None + } + } + // Instructions that refer to the locals. + CopyLoc(idx) | MoveLoc(idx) | StLoc(idx) | BorrowLoc(idx) => { + let idx = *idx as usize; + if idx >= locals_len { + Some(VMStaticViolation::IndexOutOfBounds( + IndexKind::LocalPool, + locals_len, + idx, + )) + } else { + None + } + } + + // List out the other options explicitly so there's a compile error if a new + // bytecode gets added. + FreezeRef | ReleaseRef | Pop | Ret | LdConst(_) | LdTrue | LdFalse + | ReadRef | WriteRef | Add | Sub | Mul | Mod | Div | BitOr | BitAnd | Xor + | Or | And | Not | Eq | Neq | Lt | Gt | Le | Ge | Assert + | GetTxnGasUnitPrice | GetTxnMaxGasUnits | GetGasRemaining + | GetTxnSenderAddress | CreateAccount | EmitEvent | GetTxnSequenceNumber + | GetTxnPublicKey => None, + } + }) + .collect() + } +} diff --git a/language/vm/src/checks/signature.rs b/language/vm/src/checks/signature.rs new file mode 100644 index 0000000000000..03dde45e84ae0 --- /dev/null +++ b/language/vm/src/checks/signature.rs @@ -0,0 +1,91 @@ +// Copyright (c) The Libra Core Contributors +// SPDX-License-Identifier: Apache-2.0 + +use crate::{ + access::ModuleAccess, + errors::VMStaticViolation, + file_format::{CompiledModule, SignatureToken}, + views::{ + FieldDefinitionView, FunctionSignatureView, LocalsSignatureView, SignatureTokenView, + TypeSignatureView, ViewInternals, + }, + SignatureTokenKind, +}; + +pub trait SignatureCheck { + fn check_signatures(&self) -> Vec; +} + +impl<'a, T: ModuleAccess> SignatureCheck for FunctionSignatureView<'a, T> { + fn check_signatures(&self) -> Vec { + self.return_tokens() + .filter_map(|token| token.check_structure()) + .chain( + self.arg_tokens() + .filter_map(|token| token.check_structure()), + ) + .collect() + } +} + +impl<'a, T: ModuleAccess> SignatureCheck for TypeSignatureView<'a, T> { + fn check_signatures(&self) -> Vec { + self.token().check_structure().into_iter().collect() + } +} + +impl<'a, T: ModuleAccess> SignatureCheck for LocalsSignatureView<'a, T> { + fn check_signatures(&self) -> Vec { + self.tokens() + .filter_map(|token| token.check_structure()) + .collect() + } +} + +impl<'a> FieldDefinitionView<'a, CompiledModule> { + /// Field definitions have additional constraints on signatures -- field signatures cannot be + /// references or mutable references. + pub fn check_signature_refs(&self) -> Option { + let type_signature = self.type_signature(); + let token = type_signature.token(); + let kind = token.kind(); + match kind { + SignatureTokenKind::Reference | SignatureTokenKind::MutableReference => Some( + VMStaticViolation::InvalidFieldDefReference(token.as_inner().clone(), kind), + ), + SignatureTokenKind::Value => None, + } + } +} + +impl<'a, T: ModuleAccess> SignatureTokenView<'a, T> { + /// Check that this token is structurally correct. + /// In particular, check that the token has a reference only at the top level. + #[inline] + pub fn check_structure(&self) -> Option { + self.as_inner().check_structure() + } +} + +impl SignatureToken { + // See SignatureTokenView::check_structure for more details. + pub(crate) fn check_structure(&self) -> Option { + use SignatureToken::*; + + let inner_token_opt = match self { + Reference(token) => Some(token), + MutableReference(token) => Some(token), + Bool | U64 | String | ByteArray | Address | Struct(_) => None, + }; + if let Some(inner_token) = inner_token_opt { + if inner_token.is_reference() { + return Some(VMStaticViolation::InvalidSignatureToken( + self.clone(), + self.kind(), + inner_token.kind(), + )); + } + } + None + } +} diff --git a/language/vm/src/deserializer.rs b/language/vm/src/deserializer.rs new file mode 100644 index 0000000000000..43692afeed84d --- /dev/null +++ b/language/vm/src/deserializer.rs @@ -0,0 +1,1040 @@ +// Copyright (c) The Libra Core Contributors +// SPDX-License-Identifier: Apache-2.0 + +use crate::{checks::BoundsChecker, errors::*, file_format::*, file_format_common::*}; +use byteorder::{LittleEndian, ReadBytesExt}; +use std::{ + collections::HashSet, + convert::TryInto, + io::{Cursor, Read}, + str::from_utf8, +}; +use types::{account_address::ADDRESS_LENGTH, byte_array::ByteArray}; + +impl CompiledScript { + /// Deserializes a &[u8] slice into a transaction script (`CompiledScript`) + pub fn deserialize(binary: &[u8]) -> BinaryLoaderResult { + let compiled_script = Self::deserialize_no_check_bounds(binary)?; + compiled_script.check_bounds() + } + + // exposed as a public function to enable testing the deserializer + pub fn deserialize_no_check_bounds(binary: &[u8]) -> BinaryLoaderResult { + deserialize_compiled_script(binary) + } + + /// Checks that all indexes are in bound in this `CompiledScript`. + pub fn check_bounds(self) -> BinaryLoaderResult { + let fake_module = self.into_module(); + if BoundsChecker::new(&fake_module).verify().is_empty() { + Ok(fake_module.into_script()) + } else { + Err(BinaryError::Malformed) + } + } +} + +impl CompiledModule { + /// Deserialize a &[u8] slice into a module (`CompiledModule`) + pub fn deserialize(binary: &[u8]) -> BinaryLoaderResult { + let compiled_module = Self::deserialize_no_check_bounds(binary)?; + compiled_module.check_bounds() + } + + // exposed as a public function to enable testing the deserializer + pub fn deserialize_no_check_bounds(binary: &[u8]) -> BinaryLoaderResult { + deserialize_compiled_module(binary) + } + + /// Checks that all indexes are in bound in this `CompiledModule`. + pub fn check_bounds(self) -> BinaryLoaderResult { + if BoundsChecker::new(&self).verify().is_empty() { + Ok(self) + } else { + Err(BinaryError::Malformed) + } + } +} + +/// Table info: table type, offset where the table content starts from, count of bytes for +/// the table content. +#[derive(Clone, Debug)] +struct Table { + kind: TableType, + offset: u32, + count: u32, +} + +impl Table { + fn new(kind: TableType, offset: u32, count: u32) -> Table { + Table { + kind, + offset, + count, + } + } +} + +/// Module internal function that manages deserialization of transactions. +fn deserialize_compiled_script(binary: &[u8]) -> BinaryLoaderResult { + let binary_len = binary.len() as u64; + let mut cursor = Cursor::new(binary); + let table_count = check_binary(&mut cursor)?; + let mut tables: Vec = Vec::new(); + read_tables(&mut cursor, table_count, &mut tables)?; + check_tables(&mut tables, cursor.position(), binary_len)?; + + build_compiled_script(binary, &tables) +} + +/// Module internal function that manages deserialization of modules. +fn deserialize_compiled_module(binary: &[u8]) -> BinaryLoaderResult { + let binary_len = binary.len() as u64; + let mut cursor = Cursor::new(binary); + let table_count = check_binary(&mut cursor)?; + let mut tables: Vec
= Vec::new(); + read_tables(&mut cursor, table_count, &mut tables)?; + check_tables(&mut tables, cursor.position(), binary_len)?; + + build_compiled_module(binary, &tables) +} + +/// Verifies the correctness of the "static" part of the binary's header. +/// +/// Returns the offset where the count of tables in the binary. +fn check_binary(cursor: &mut Cursor<&[u8]>) -> BinaryLoaderResult { + let mut magic = [0u8; BinaryConstants::LIBRA_MAGIC_SIZE]; + if let Ok(count) = cursor.read(&mut magic) { + if count != BinaryConstants::LIBRA_MAGIC_SIZE { + return Err(BinaryError::Malformed); + } else if magic != BinaryConstants::LIBRA_MAGIC { + return Err(BinaryError::BadMagic); + } + } else { + return Err(BinaryError::Malformed); + } + let major_ver = 1u8; + let minor_ver = 0u8; + if let Ok(ver) = cursor.read_u8() { + if ver != major_ver { + return Err(BinaryError::UnknownVersion); + } + } else { + return Err(BinaryError::Malformed); + } + if let Ok(ver) = cursor.read_u8() { + if ver != minor_ver { + return Err(BinaryError::UnknownVersion); + } + } else { + return Err(BinaryError::Malformed); + } + if let Ok(count) = cursor.read_u8() { + Ok(count) + } else { + return Err(BinaryError::Malformed); + } +} + +/// Reads all the table headers. +/// +/// Return a Vec
that contains all the table headers defined and checked. +fn read_tables( + cursor: &mut Cursor<&[u8]>, + table_count: u8, + tables: &mut Vec
, +) -> BinaryLoaderResult<()> { + for _count in 0..table_count { + tables.push(read_table(cursor)?); + } + Ok(()) +} + +/// Reads a table from a slice at a given offset. +/// If a table is not recognized an error is returned. +fn read_table(cursor: &mut Cursor<&[u8]>) -> BinaryLoaderResult
{ + if let Ok(kind) = cursor.read_u8() { + let table_offset = read_u32_internal(cursor)?; + let count = read_u32_internal(cursor)?; + Ok(Table::new(TableType::from_u8(kind)?, table_offset, count)) + } else { + Err(BinaryError::Malformed) + } +} + +/// Verify correctness of tables. +/// +/// Tables cannot have duplicates, must cover the entire blob and must be disjoint. +fn check_tables(tables: &mut Vec
, end_tables: u64, length: u64) -> BinaryLoaderResult<()> { + // there is no real reason to pass a mutable reference but we are sorting next line + tables.sort_by(|t1, t2| t1.offset.cmp(&t2.offset)); + + let mut current_offset = end_tables; + let mut table_types = HashSet::new(); + for table in tables { + let offset = u64::from(table.offset); + if offset != current_offset { + return Err(BinaryError::BadHeaderTable); + } + if table.count == 0 { + return Err(BinaryError::BadHeaderTable); + } + let count = u64::from(table.count); + if let Some(checked_offset) = current_offset.checked_add(count) { + current_offset = checked_offset; + } + if current_offset > length { + return Err(BinaryError::BadHeaderTable); + } + if !table_types.insert(table.kind) { + return Err(BinaryError::DuplicateTable); + } + } + if current_offset != length { + return Err(BinaryError::BadHeaderTable); + } + Ok(()) +} + +// +// Trait to read common tables from CompiledScript or CompiledModule +// + +trait CommonTables { + fn get_module_handles(&mut self) -> &mut Vec; + fn get_struct_handles(&mut self) -> &mut Vec; + fn get_function_handles(&mut self) -> &mut Vec; + + fn get_type_signatures(&mut self) -> &mut TypeSignaturePool; + fn get_function_signatures(&mut self) -> &mut FunctionSignaturePool; + fn get_locals_signatures(&mut self) -> &mut LocalsSignaturePool; + + fn get_string_pool(&mut self) -> &mut StringPool; + fn get_byte_array_pool(&mut self) -> &mut ByteArrayPool; + fn get_address_pool(&mut self) -> &mut AddressPool; +} + +impl CommonTables for CompiledScript { + fn get_module_handles(&mut self) -> &mut Vec { + &mut self.module_handles + } + + fn get_struct_handles(&mut self) -> &mut Vec { + &mut self.struct_handles + } + + fn get_function_handles(&mut self) -> &mut Vec { + &mut self.function_handles + } + + fn get_type_signatures(&mut self) -> &mut TypeSignaturePool { + &mut self.type_signatures + } + + fn get_function_signatures(&mut self) -> &mut FunctionSignaturePool { + &mut self.function_signatures + } + + fn get_locals_signatures(&mut self) -> &mut LocalsSignaturePool { + &mut self.locals_signatures + } + + fn get_string_pool(&mut self) -> &mut StringPool { + &mut self.string_pool + } + + fn get_byte_array_pool(&mut self) -> &mut ByteArrayPool { + &mut self.byte_array_pool + } + + fn get_address_pool(&mut self) -> &mut AddressPool { + &mut self.address_pool + } +} + +impl CommonTables for CompiledModule { + fn get_module_handles(&mut self) -> &mut Vec { + &mut self.module_handles + } + + fn get_struct_handles(&mut self) -> &mut Vec { + &mut self.struct_handles + } + + fn get_function_handles(&mut self) -> &mut Vec { + &mut self.function_handles + } + + fn get_type_signatures(&mut self) -> &mut TypeSignaturePool { + &mut self.type_signatures + } + + fn get_function_signatures(&mut self) -> &mut FunctionSignaturePool { + &mut self.function_signatures + } + + fn get_locals_signatures(&mut self) -> &mut LocalsSignaturePool { + &mut self.locals_signatures + } + + fn get_string_pool(&mut self) -> &mut StringPool { + &mut self.string_pool + } + + fn get_byte_array_pool(&mut self) -> &mut ByteArrayPool { + &mut self.byte_array_pool + } + + fn get_address_pool(&mut self) -> &mut AddressPool { + &mut self.address_pool + } +} + +/// Builds and returns a `CompiledScript`. +fn build_compiled_script(binary: &[u8], tables: &[Table]) -> BinaryLoaderResult { + let mut script = CompiledScript::default(); + build_common_tables(binary, tables, &mut script)?; + build_script_tables(binary, tables, &mut script)?; + Ok(script) +} + +/// Builds and returns a `CompiledModule`. +fn build_compiled_module(binary: &[u8], tables: &[Table]) -> BinaryLoaderResult { + let mut module = CompiledModule::default(); + build_common_tables(binary, tables, &mut module)?; + build_module_tables(binary, tables, &mut module)?; + Ok(module) +} + +/// Builds the common tables in a compiled unit. +fn build_common_tables( + binary: &[u8], + tables: &[Table], + common: &mut impl CommonTables, +) -> BinaryLoaderResult<()> { + for table in tables { + match table.kind { + TableType::MODULE_HANDLES => { + load_module_handles(binary, table, common.get_module_handles())?; + } + TableType::STRUCT_HANDLES => { + load_struct_handles(binary, table, common.get_struct_handles())?; + } + TableType::FUNCTION_HANDLES => { + load_function_handles(binary, table, common.get_function_handles())?; + } + TableType::ADDRESS_POOL => { + load_address_pool(binary, table, common.get_address_pool())?; + } + TableType::STRING_POOL => { + load_string_pool(binary, table, common.get_string_pool())?; + } + TableType::BYTE_ARRAY_POOL => { + load_byte_array_pool(binary, table, common.get_byte_array_pool())?; + } + TableType::TYPE_SIGNATURES => { + load_type_signatures(binary, table, common.get_type_signatures())?; + } + TableType::FUNCTION_SIGNATURES => { + load_function_signatures(binary, table, common.get_function_signatures())?; + } + TableType::LOCALS_SIGNATURES => { + load_locals_signatures(binary, table, common.get_locals_signatures())?; + } + TableType::FUNCTION_DEFS + | TableType::FIELD_DEFS + | TableType::STRUCT_DEFS + | TableType::MAIN => continue, + } + } + Ok(()) +} + +/// Builds tables related to a `CompiledModule`. +fn build_module_tables( + binary: &[u8], + tables: &[Table], + module: &mut CompiledModule, +) -> BinaryLoaderResult<()> { + for table in tables { + match table.kind { + TableType::STRUCT_DEFS => { + load_struct_defs(binary, table, &mut module.struct_defs)?; + } + TableType::FIELD_DEFS => { + load_field_defs(binary, table, &mut module.field_defs)?; + } + TableType::FUNCTION_DEFS => { + load_function_defs(binary, table, &mut module.function_defs)?; + } + TableType::MODULE_HANDLES + | TableType::STRUCT_HANDLES + | TableType::FUNCTION_HANDLES + | TableType::ADDRESS_POOL + | TableType::STRING_POOL + | TableType::BYTE_ARRAY_POOL + | TableType::TYPE_SIGNATURES + | TableType::FUNCTION_SIGNATURES + | TableType::LOCALS_SIGNATURES => { + continue; + } + TableType::MAIN => return Err(BinaryError::Malformed), + } + } + Ok(()) +} + +/// Builds tables related to a `CompiledScript`. +fn build_script_tables( + binary: &[u8], + tables: &[Table], + script: &mut CompiledScript, +) -> BinaryLoaderResult<()> { + for table in tables { + match table.kind { + TableType::MAIN => { + let start: usize = table.offset as usize; + let end: usize = start + table.count as usize; + let mut cursor = Cursor::new(&binary[start..end]); + let main = load_function_def(&mut cursor)?; + script.main = main; + } + TableType::MODULE_HANDLES + | TableType::STRUCT_HANDLES + | TableType::FUNCTION_HANDLES + | TableType::ADDRESS_POOL + | TableType::STRING_POOL + | TableType::BYTE_ARRAY_POOL + | TableType::TYPE_SIGNATURES + | TableType::FUNCTION_SIGNATURES + | TableType::LOCALS_SIGNATURES => { + continue; + } + TableType::STRUCT_DEFS | TableType::FIELD_DEFS | TableType::FUNCTION_DEFS => { + return Err(BinaryError::Malformed); + } + } + } + Ok(()) +} + +/// Builds the `ModuleHandle` table. +fn load_module_handles( + binary: &[u8], + table: &Table, + module_handles: &mut Vec, +) -> BinaryLoaderResult<()> { + let start = table.offset as usize; + let end = start + table.count as usize; + let mut cursor = Cursor::new(&binary[start..end]); + loop { + if cursor.position() == u64::from(table.count) { + break; + } + let address = read_uleb_u16_internal(&mut cursor)?; + let name = read_uleb_u16_internal(&mut cursor)?; + module_handles.push(ModuleHandle { + address: AddressPoolIndex(address), + name: StringPoolIndex(name), + }); + } + Ok(()) +} + +/// Builds the `StructHandle` table. +fn load_struct_handles( + binary: &[u8], + table: &Table, + struct_handles: &mut Vec, +) -> BinaryLoaderResult<()> { + let start = table.offset as usize; + let end = start + table.count as usize; + let mut cursor = Cursor::new(&binary[start..end]); + loop { + if cursor.position() == u64::from(table.count) { + break; + } + let module_handle = read_uleb_u16_internal(&mut cursor)?; + let name = read_uleb_u16_internal(&mut cursor)?; + if let Ok(is_resource) = cursor.read_u8() { + struct_handles.push(StructHandle { + module: ModuleHandleIndex(module_handle), + name: StringPoolIndex(name), + is_resource: is_resource != 0, + }); + } else { + return Err(BinaryError::Malformed); + } + } + Ok(()) +} + +/// Builds the `FunctionHandle` table. +fn load_function_handles( + binary: &[u8], + table: &Table, + function_handles: &mut Vec, +) -> BinaryLoaderResult<()> { + let start = table.offset as usize; + let end = start + table.count as usize; + let mut cursor = Cursor::new(&binary[start..end]); + loop { + if cursor.position() == u64::from(table.count) { + break; + } + let module_handle = read_uleb_u16_internal(&mut cursor)?; + let name = read_uleb_u16_internal(&mut cursor)?; + let signature = read_uleb_u16_internal(&mut cursor)?; + function_handles.push(FunctionHandle { + module: ModuleHandleIndex(module_handle), + name: StringPoolIndex(name), + signature: FunctionSignatureIndex(signature), + }); + } + Ok(()) +} + +/// Builds the `AddressPool`. +fn load_address_pool( + binary: &[u8], + table: &Table, + addresses: &mut AddressPool, +) -> BinaryLoaderResult<()> { + let mut start = table.offset as usize; + if table.count as usize % ADDRESS_LENGTH != 0 { + return Err(BinaryError::Malformed); + } + for _i in 0..table.count as usize / ADDRESS_LENGTH { + let end_addr = start + ADDRESS_LENGTH; + let address = (&binary[start..end_addr]).try_into(); + if address.is_err() { + return Err(BinaryError::Malformed); + } + start = end_addr; + + addresses.push(address.unwrap()); + } + Ok(()) +} + +/// Builds the `StringPool`. +fn load_string_pool( + binary: &[u8], + table: &Table, + strings: &mut StringPool, +) -> BinaryLoaderResult<()> { + let start = table.offset as usize; + let end = start + table.count as usize; + let mut cursor = Cursor::new(&binary[start..end]); + while cursor.position() < u64::from(table.count) { + let size = read_uleb_u32_internal(&mut cursor)? as usize; + if size > std::u16::MAX as usize { + return Err(BinaryError::Malformed); + } + let mut buffer: Vec = vec![0u8; size]; + if let Ok(count) = cursor.read(&mut buffer) { + if count != size { + return Err(BinaryError::Malformed); + } + let s = match from_utf8(&buffer) { + Ok(bytes) => bytes, + Err(_) => return Err(BinaryError::Malformed), + }; + + strings.push(String::from(s)); + } + } + Ok(()) +} + +/// Builds the `ByteArrayPool`. +fn load_byte_array_pool( + binary: &[u8], + table: &Table, + byte_arrays: &mut ByteArrayPool, +) -> BinaryLoaderResult<()> { + let start = table.offset as usize; + let end = start + table.count as usize; + let mut cursor = Cursor::new(&binary[start..end]); + while cursor.position() < u64::from(table.count) { + let size = read_uleb_u32_internal(&mut cursor)? as usize; + if size > std::u16::MAX as usize { + return Err(BinaryError::Malformed); + } + let mut byte_array: Vec = vec![0u8; size]; + if let Ok(count) = cursor.read(&mut byte_array) { + if count != size { + return Err(BinaryError::Malformed); + } + + byte_arrays.push(ByteArray::new(byte_array)); + } + } + Ok(()) +} + +/// Builds the `TypeSignaturePool`. +fn load_type_signatures( + binary: &[u8], + table: &Table, + type_signatures: &mut TypeSignaturePool, +) -> BinaryLoaderResult<()> { + let start = table.offset as usize; + let end = start + table.count as usize; + let mut cursor = Cursor::new(&binary[start..end]); + while cursor.position() < u64::from(table.count) { + if let Ok(byte) = cursor.read_u8() { + if byte != SignatureType::TYPE_SIGNATURE as u8 { + return Err(BinaryError::UnexpectedSignatureType); + } + } + let token = load_signature_token(&mut cursor)?; + type_signatures.push(TypeSignature(token)); + } + Ok(()) +} + +/// Builds the `FunctionSignaturePool`. +fn load_function_signatures( + binary: &[u8], + table: &Table, + function_signatures: &mut FunctionSignaturePool, +) -> BinaryLoaderResult<()> { + let start = table.offset as usize; + let end = start + table.count as usize; + let mut cursor = Cursor::new(&binary[start..end]); + while cursor.position() < u64::from(table.count) { + if let Ok(byte) = cursor.read_u8() { + if byte != SignatureType::FUNCTION_SIGNATURE as u8 { + return Err(BinaryError::UnexpectedSignatureType); + } + } + + // Return signature + let token_count = cursor.read_u8().map_err(|_| BinaryError::Malformed)?; + let mut returns_signature: Vec = Vec::new(); + for _i in 0..token_count { + let token = load_signature_token(&mut cursor)?; + returns_signature.push(token); + } + + // Arguments signature + let token_count = cursor.read_u8().map_err(|_| BinaryError::Malformed)?; + let mut args_signature: Vec = Vec::new(); + for _i in 0..token_count { + let token = load_signature_token(&mut cursor)?; + args_signature.push(token); + } + + function_signatures.push(FunctionSignature { + return_types: returns_signature, + arg_types: args_signature, + }); + } + Ok(()) +} + +/// Builds the `LocalsSignaturePool`. +fn load_locals_signatures( + binary: &[u8], + table: &Table, + locals_signatures: &mut LocalsSignaturePool, +) -> BinaryLoaderResult<()> { + let start = table.offset as usize; + let end = start + table.count as usize; + let mut cursor = Cursor::new(&binary[start..end]); + while cursor.position() < u64::from(table.count) { + if let Ok(byte) = cursor.read_u8() { + if byte != SignatureType::LOCAL_SIGNATURE as u8 { + return Err(BinaryError::UnexpectedSignatureType); + } + } + + let token_count = cursor.read_u8().map_err(|_| BinaryError::Malformed)?; + let mut local_signature: Vec = Vec::new(); + for _i in 0..token_count { + let token = load_signature_token(&mut cursor)?; + local_signature.push(token); + } + + locals_signatures.push(LocalsSignature(local_signature)); + } + Ok(()) +} + +/// Deserializes a `SignatureToken`. +fn load_signature_token(cursor: &mut Cursor<&[u8]>) -> BinaryLoaderResult { + if let Ok(byte) = cursor.read_u8() { + match SerializedType::from_u8(byte)? { + SerializedType::BOOL => Ok(SignatureToken::Bool), + SerializedType::INTEGER => Ok(SignatureToken::U64), + SerializedType::STRING => Ok(SignatureToken::String), + SerializedType::BYTEARRAY => Ok(SignatureToken::ByteArray), + SerializedType::ADDRESS => Ok(SignatureToken::Address), + SerializedType::REFERENCE => { + let ref_token = load_signature_token(cursor)?; + Ok(SignatureToken::Reference(Box::new(ref_token))) + } + SerializedType::MUTABLE_REFERENCE => { + let ref_token = load_signature_token(cursor)?; + Ok(SignatureToken::MutableReference(Box::new(ref_token))) + } + SerializedType::STRUCT => { + let sh_idx = read_uleb_u16_internal(cursor)?; + Ok(SignatureToken::Struct(StructHandleIndex(sh_idx))) + } + } + } else { + Err(BinaryError::Malformed) + } +} + +/// Builds the `StructDefinition` table. +fn load_struct_defs( + binary: &[u8], + table: &Table, + struct_defs: &mut Vec, +) -> BinaryLoaderResult<()> { + let start = table.offset as usize; + let end = start + table.count as usize; + let mut cursor = Cursor::new(&binary[start..end]); + while cursor.position() < u64::from(table.count) { + let struct_handle = read_uleb_u16_internal(&mut cursor)?; + let field_count = read_uleb_u16_internal(&mut cursor)?; + let fields = read_uleb_u16_internal(&mut cursor)?; + struct_defs.push(StructDefinition { + struct_handle: StructHandleIndex(struct_handle), + field_count, + fields: FieldDefinitionIndex(fields), + }); + } + Ok(()) +} + +/// Builds the `FieldDefinition` table. +fn load_field_defs( + binary: &[u8], + table: &Table, + field_defs: &mut Vec, +) -> BinaryLoaderResult<()> { + let start = table.offset as usize; + let end = start + table.count as usize; + let mut cursor = Cursor::new(&binary[start..end]); + while cursor.position() < u64::from(table.count) { + let struct_ = read_uleb_u16_internal(&mut cursor)?; + let name = read_uleb_u16_internal(&mut cursor)?; + let signature = read_uleb_u16_internal(&mut cursor)?; + field_defs.push(FieldDefinition { + struct_: StructHandleIndex(struct_), + name: StringPoolIndex(name), + signature: TypeSignatureIndex(signature), + }); + } + Ok(()) +} + +/// Builds the `FunctionDefinition` table. +fn load_function_defs( + binary: &[u8], + table: &Table, + func_defs: &mut Vec, +) -> BinaryLoaderResult<()> { + let start = table.offset as usize; + let end = start + table.count as usize; + let mut cursor = Cursor::new(&binary[start..end]); + while cursor.position() < u64::from(table.count) { + let func_def = load_function_def(&mut cursor)?; + func_defs.push(func_def); + } + Ok(()) +} + +/// Deserializes a `FunctionDefinition`. +fn load_function_def(cursor: &mut Cursor<&[u8]>) -> BinaryLoaderResult { + let function = read_uleb_u16_internal(cursor)?; + + let flags = cursor.read_u8().map_err(|_| BinaryError::Malformed)?; + let code_unit = load_code_unit(cursor)?; + Ok(FunctionDefinition { + function: FunctionHandleIndex(function), + flags, + code: code_unit, + }) +} + +/// Deserializes a `CodeUnit`. +fn load_code_unit(cursor: &mut Cursor<&[u8]>) -> BinaryLoaderResult { + let max_stack_size = read_uleb_u16_internal(cursor)?; + let locals = read_uleb_u16_internal(cursor)?; + + let mut code_unit = CodeUnit { + max_stack_size, + locals: LocalsSignatureIndex(locals), + code: vec![], + }; + + load_code(cursor, &mut code_unit.code)?; + Ok(code_unit) +} + +/// Deserializes a code stream (`Bytecode`s). +fn load_code(cursor: &mut Cursor<&[u8]>, code: &mut Vec) -> BinaryLoaderResult<()> { + let bytecode_count = read_u16_internal(cursor)?; + while code.len() < bytecode_count as usize { + let byte = cursor.read_u8().map_err(|_| BinaryError::Malformed)?; + let bytecode = match Opcodes::from_u8(byte)? { + Opcodes::POP => Bytecode::Pop, + Opcodes::RET => Bytecode::Ret, + Opcodes::BR_TRUE => { + let jump = read_u16_internal(cursor)?; + Bytecode::BrTrue(jump) + } + Opcodes::BR_FALSE => { + let jump = read_u16_internal(cursor)?; + Bytecode::BrFalse(jump) + } + Opcodes::BRANCH => { + let jump = read_u16_internal(cursor)?; + Bytecode::Branch(jump) + } + Opcodes::LD_CONST => { + let value = read_u64_internal(cursor)?; + Bytecode::LdConst(value) + } + Opcodes::LD_ADDR => { + let idx = read_uleb_u16_internal(cursor)?; + Bytecode::LdAddr(AddressPoolIndex(idx)) + } + Opcodes::LD_STR => { + let idx = read_uleb_u16_internal(cursor)?; + Bytecode::LdStr(StringPoolIndex(idx)) + } + Opcodes::LD_TRUE => Bytecode::LdTrue, + Opcodes::LD_FALSE => Bytecode::LdFalse, + Opcodes::COPY_LOC => { + let idx = cursor.read_u8().map_err(|_| BinaryError::Malformed)?; + Bytecode::CopyLoc(idx) + } + Opcodes::MOVE_LOC => { + let idx = cursor.read_u8().map_err(|_| BinaryError::Malformed)?; + Bytecode::MoveLoc(idx) + } + Opcodes::ST_LOC => { + let idx = cursor.read_u8().map_err(|_| BinaryError::Malformed)?; + Bytecode::StLoc(idx) + } + Opcodes::LD_REF_LOC => { + let idx = cursor.read_u8().map_err(|_| BinaryError::Malformed)?; + Bytecode::BorrowLoc(idx) + } + Opcodes::LD_REF_FIELD => { + let idx = read_uleb_u16_internal(cursor)?; + Bytecode::BorrowField(FieldDefinitionIndex(idx)) + } + Opcodes::LD_BYTEARRAY => { + let idx = read_uleb_u16_internal(cursor)?; + Bytecode::LdByteArray(ByteArrayPoolIndex(idx)) + } + Opcodes::CALL => { + let idx = read_uleb_u16_internal(cursor)?; + Bytecode::Call(FunctionHandleIndex(idx)) + } + Opcodes::PACK => { + let idx = read_uleb_u16_internal(cursor)?; + Bytecode::Pack(StructDefinitionIndex(idx)) + } + Opcodes::UNPACK => { + let idx = read_uleb_u16_internal(cursor)?; + Bytecode::Unpack(StructDefinitionIndex(idx)) + } + Opcodes::READ_REF => Bytecode::ReadRef, + Opcodes::WRITE_REF => Bytecode::WriteRef, + Opcodes::ADD => Bytecode::Add, + Opcodes::SUB => Bytecode::Sub, + Opcodes::MUL => Bytecode::Mul, + Opcodes::MOD => Bytecode::Mod, + Opcodes::DIV => Bytecode::Div, + Opcodes::BIT_OR => Bytecode::BitOr, + Opcodes::BIT_AND => Bytecode::BitAnd, + Opcodes::XOR => Bytecode::Xor, + Opcodes::OR => Bytecode::Or, + Opcodes::AND => Bytecode::And, + Opcodes::NOT => Bytecode::Not, + Opcodes::EQ => Bytecode::Eq, + Opcodes::NEQ => Bytecode::Neq, + Opcodes::LT => Bytecode::Lt, + Opcodes::GT => Bytecode::Gt, + Opcodes::LE => Bytecode::Le, + Opcodes::GE => Bytecode::Ge, + Opcodes::ASSERT => Bytecode::Assert, + Opcodes::GET_TXN_GAS_UNIT_PRICE => Bytecode::GetTxnGasUnitPrice, + Opcodes::GET_TXN_MAX_GAS_UNITS => Bytecode::GetTxnMaxGasUnits, + Opcodes::GET_GAS_REMAINING => Bytecode::GetGasRemaining, + Opcodes::GET_TXN_SENDER => Bytecode::GetTxnSenderAddress, + Opcodes::EXISTS => { + let idx = read_uleb_u16_internal(cursor)?; + Bytecode::Exists(StructDefinitionIndex(idx)) + } + Opcodes::BORROW_REF => { + let idx = read_uleb_u16_internal(cursor)?; + Bytecode::BorrowGlobal(StructDefinitionIndex(idx)) + } + Opcodes::RELEASE_REF => Bytecode::ReleaseRef, + Opcodes::MOVE_FROM => { + let idx = read_uleb_u16_internal(cursor)?; + Bytecode::MoveFrom(StructDefinitionIndex(idx)) + } + Opcodes::MOVE_TO => { + let idx = read_uleb_u16_internal(cursor)?; + Bytecode::MoveToSender(StructDefinitionIndex(idx)) + } + Opcodes::CREATE_ACCOUNT => Bytecode::CreateAccount, + Opcodes::EMIT_EVENT => Bytecode::EmitEvent, + Opcodes::GET_TXN_SEQUENCE_NUMBER => Bytecode::GetTxnSequenceNumber, + Opcodes::GET_TXN_PUBLIC_KEY => Bytecode::GetTxnPublicKey, + Opcodes::FREEZE_REF => Bytecode::FreezeRef, + }; + code.push(bytecode); + } + Ok(()) +} + +// +// Helpers to read uleb128 and uncompressed integers +// + +fn read_uleb_u16_internal(cursor: &mut Cursor<&[u8]>) -> BinaryLoaderResult { + read_uleb128_as_u16(cursor).map_err(|_| BinaryError::Malformed) +} + +fn read_uleb_u32_internal(cursor: &mut Cursor<&[u8]>) -> BinaryLoaderResult { + read_uleb128_as_u32(cursor).map_err(|_| BinaryError::Malformed) +} + +fn read_u16_internal(cursor: &mut Cursor<&[u8]>) -> BinaryLoaderResult { + cursor + .read_u16::() + .map_err(|_| BinaryError::Malformed) +} + +fn read_u32_internal(cursor: &mut Cursor<&[u8]>) -> BinaryLoaderResult { + cursor + .read_u32::() + .map_err(|_| BinaryError::Malformed) +} + +fn read_u64_internal(cursor: &mut Cursor<&[u8]>) -> BinaryLoaderResult { + cursor + .read_u64::() + .map_err(|_| BinaryError::Malformed) +} + +impl TableType { + fn from_u8(value: u8) -> BinaryLoaderResult { + match value { + 0x1 => Ok(TableType::MODULE_HANDLES), + 0x2 => Ok(TableType::STRUCT_HANDLES), + 0x3 => Ok(TableType::FUNCTION_HANDLES), + 0x4 => Ok(TableType::ADDRESS_POOL), + 0x5 => Ok(TableType::STRING_POOL), + 0x6 => Ok(TableType::BYTE_ARRAY_POOL), + 0x7 => Ok(TableType::MAIN), + 0x8 => Ok(TableType::STRUCT_DEFS), + 0x9 => Ok(TableType::FIELD_DEFS), + 0xA => Ok(TableType::FUNCTION_DEFS), + 0xB => Ok(TableType::TYPE_SIGNATURES), + 0xC => Ok(TableType::FUNCTION_SIGNATURES), + 0xD => Ok(TableType::LOCALS_SIGNATURES), + _ => Err(BinaryError::UnknownTableType), + } + } +} + +#[allow(dead_code)] +impl SignatureType { + fn from_u8(value: u8) -> BinaryLoaderResult { + match value { + 0x1 => Ok(SignatureType::TYPE_SIGNATURE), + 0x2 => Ok(SignatureType::FUNCTION_SIGNATURE), + 0x3 => Ok(SignatureType::LOCAL_SIGNATURE), + _ => Err(BinaryError::UnknownSignatureType), + } + } +} + +impl SerializedType { + fn from_u8(value: u8) -> BinaryLoaderResult { + match value { + 0x1 => Ok(SerializedType::BOOL), + 0x2 => Ok(SerializedType::INTEGER), + 0x3 => Ok(SerializedType::STRING), + 0x4 => Ok(SerializedType::ADDRESS), + 0x5 => Ok(SerializedType::REFERENCE), + 0x6 => Ok(SerializedType::MUTABLE_REFERENCE), + 0x7 => Ok(SerializedType::STRUCT), + 0x8 => Ok(SerializedType::BYTEARRAY), + _ => Err(BinaryError::UnknownSerializedType), + } + } +} + +impl Opcodes { + fn from_u8(value: u8) -> BinaryLoaderResult { + match value { + 0x01 => Ok(Opcodes::POP), + 0x02 => Ok(Opcodes::RET), + 0x03 => Ok(Opcodes::BR_TRUE), + 0x04 => Ok(Opcodes::BR_FALSE), + 0x05 => Ok(Opcodes::BRANCH), + 0x06 => Ok(Opcodes::LD_CONST), + 0x07 => Ok(Opcodes::LD_ADDR), + 0x08 => Ok(Opcodes::LD_STR), + 0x09 => Ok(Opcodes::LD_TRUE), + 0x0A => Ok(Opcodes::LD_FALSE), + 0x0B => Ok(Opcodes::COPY_LOC), + 0x0C => Ok(Opcodes::MOVE_LOC), + 0x0D => Ok(Opcodes::ST_LOC), + 0x0E => Ok(Opcodes::LD_REF_LOC), + 0x0F => Ok(Opcodes::LD_REF_FIELD), + 0x10 => Ok(Opcodes::LD_BYTEARRAY), + 0x11 => Ok(Opcodes::CALL), + 0x12 => Ok(Opcodes::PACK), + 0x13 => Ok(Opcodes::UNPACK), + 0x14 => Ok(Opcodes::READ_REF), + 0x15 => Ok(Opcodes::WRITE_REF), + 0x16 => Ok(Opcodes::ADD), + 0x17 => Ok(Opcodes::SUB), + 0x18 => Ok(Opcodes::MUL), + 0x19 => Ok(Opcodes::MOD), + 0x1A => Ok(Opcodes::DIV), + 0x1B => Ok(Opcodes::BIT_OR), + 0x1C => Ok(Opcodes::BIT_AND), + 0x1D => Ok(Opcodes::XOR), + 0x1E => Ok(Opcodes::OR), + 0x1F => Ok(Opcodes::AND), + 0x20 => Ok(Opcodes::NOT), + 0x21 => Ok(Opcodes::EQ), + 0x22 => Ok(Opcodes::NEQ), + 0x23 => Ok(Opcodes::LT), + 0x24 => Ok(Opcodes::GT), + 0x25 => Ok(Opcodes::LE), + 0x26 => Ok(Opcodes::GE), + 0x27 => Ok(Opcodes::ASSERT), + 0x28 => Ok(Opcodes::GET_TXN_GAS_UNIT_PRICE), + 0x29 => Ok(Opcodes::GET_TXN_MAX_GAS_UNITS), + 0x2A => Ok(Opcodes::GET_GAS_REMAINING), + 0x2B => Ok(Opcodes::GET_TXN_SENDER), + 0x2C => Ok(Opcodes::EXISTS), + 0x2D => Ok(Opcodes::BORROW_REF), + 0x2E => Ok(Opcodes::RELEASE_REF), + 0x2F => Ok(Opcodes::MOVE_FROM), + 0x30 => Ok(Opcodes::MOVE_TO), + 0x31 => Ok(Opcodes::CREATE_ACCOUNT), + 0x32 => Ok(Opcodes::EMIT_EVENT), + 0x33 => Ok(Opcodes::GET_TXN_SEQUENCE_NUMBER), + 0x34 => Ok(Opcodes::GET_TXN_PUBLIC_KEY), + 0x35 => Ok(Opcodes::FREEZE_REF), + _ => Err(BinaryError::UnknownOpcode), + } + } +} diff --git a/language/vm/src/errors.rs b/language/vm/src/errors.rs new file mode 100644 index 0000000000000..c6af4cce54707 --- /dev/null +++ b/language/vm/src/errors.rs @@ -0,0 +1,759 @@ +// Copyright (c) The Libra Core Contributors +// SPDX-License-Identifier: Apache-2.0 + +use crate::{file_format::SignatureToken, IndexKind, SignatureTokenKind}; +use failure::Fail; +use std::{fmt, iter::FromIterator}; +use types::{ + account_address::AccountAddress, + transaction::TransactionStatus, + vm_error::{VMStatus, VMValidationStatus, VMVerificationError, VMVerificationStatus}, +}; + +// We may want to eventually move this into the VM runtime since it is a semantic decision that +// need to be made by the VM. But for now, this will reside here. +pub fn vm_result_to_transaction_status(result: &VMResult) -> TransactionStatus { + // The decision as to whether or not a transaction should be dropped should be able to be + // determined solely by the VMStatus. This then means that we can audit/verify any decisions + // made by the VM externally on whether or not to discard or keep the transaction output by + // inspecting the contained VMStatus. + let vm_status = vm_status_of_result(result); + vm_status.into() +} + +#[derive(Debug)] +pub struct VMRuntimeError { + pub loc: Location, + pub err: VMErrorKind, +} + +// TODO: Fill in the details for Locations. Ideally it should be a unique handle into a function and +// a pc. +#[derive(Debug, Default)] +pub struct Location {} + +#[derive(Debug, PartialEq)] +pub enum VMErrorKind { + ArithmeticError, + TypeError, + AssertionFailure(u64), + OutOfGasError, + GlobalRefAlreadyReleased, + MissingReleaseRef, + GlobalAlreadyBorrowed, + MissingData, + DuplicateModuleName, + DataFormatError, + InvalidData, + RemoteDataError, + CannotWriteExistingResource, + ValueSerializerError, + ValueDeserializerError, + CodeSerializerError(BinaryError), + CodeDeserializerError(BinaryError), +} + +#[derive(Clone, Debug, Eq, Ord, PartialEq, PartialOrd)] +pub enum VerificationStatus { + /// A verification error was detected in a transaction script. + Script(VerificationError), + /// A verification error was detected in a module. The first element is the index of the module + /// in the transaction. + Module(u16, VerificationError), +} + +#[derive(Clone, Debug, Eq, Ord, PartialEq, PartialOrd)] +pub struct VerificationError { + /// Where the violation occurred. + pub kind: IndexKind, + /// The index where the violation occurred. + pub idx: usize, + /// The actual violation that occurred. + pub err: VMStaticViolation, +} + +impl fmt::Display for VerificationError { + fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result { + write!(f, "at '{}' index {}: {}", self.kind, self.idx, self.err) + } +} + +#[derive(Clone, Debug, Eq, Fail, Ord, PartialEq, PartialOrd)] +pub enum VMStaticViolation { + #[fail( + display = "Index out of bounds for '{}' (expected 0..{}, found {})", + _0, _1, _2 + )] + IndexOutOfBounds(IndexKind, usize, usize), + + #[fail( + display = "Range out of bounds for '{}' (expected 0..{}, found {}..{})", + _0, _1, _2, _3 + )] + RangeOutOfBounds(IndexKind, usize, usize, usize), + + #[fail(display = "Module must have at least one module handle")] + NoModuleHandles, + + #[fail(display = "Module address does not match sender")] + ModuleAddressDoesNotMatchSender, + + #[fail( + display = "Invalid signature token {:?}: '{} of {}' is invalid", + _0, _1, _2 + )] + InvalidSignatureToken(SignatureToken, SignatureTokenKind, SignatureTokenKind), + + #[fail(display = "Duplicate element")] + DuplicateElement, + + #[fail(display = "Invalid module handle")] + InvalidModuleHandle, + + #[fail(display = "Unimplemented struct or function handle")] + UnimplementedHandle, + + #[fail(display = "Inconsistent fields in struct definition")] + InconsistentFields, + + #[fail(display = "Unused fields")] + UnusedFields, + + #[fail(display = "Field definition has invalid type: {} ({:?})", _1, _0)] + InvalidFieldDefReference(SignatureToken, SignatureTokenKind), + + #[fail(display = "Recursive struct definition")] + RecursiveStructDef, + + #[fail(display = "Resource field in non-resource struct")] + InvalidResourceField, + + #[fail(display = "Invalid fall through")] + InvalidFallThrough, + + #[fail(display = "Failure to perform join at block {}", _0)] + JoinFailure(usize), + + #[fail(display = "Negative stack size at block {} and offset {}", _0, _1)] + NegativeStackSizeInsideBlock(usize, usize), + + #[fail(display = "Positive stack size at end of block {}", _0)] + PositiveStackSizeAtBlockEnd(usize), + + #[fail(display = "Invalid signature for main function in script")] + InvalidMainFunctionSignature, + + #[fail(display = "Lookup of struct or function handle failed in module dependency")] + LookupFailed, + + #[fail(display = "Visibility mismatch for function handle in module dependency")] + VisibilityMismatch, + + #[fail(display = "Type of function in module dependency could not be resolved")] + TypeResolutionFailure, + + #[fail(display = "Type mismatch for struct or function handle in module dependency")] + TypeMismatch, + + #[fail(display = "Missing module dependency")] + MissingDependency, + + #[fail(display = "Unable to verify Pop at offset {}", _0)] + PopReferenceError(usize), + + #[fail(display = "Unable to verify Pop at offset {}", _0)] + PopResourceError(usize), + + #[fail(display = "Unable to verify ReleaseRef at offset {}", _0)] + ReleaseRefTypeMismatchError(usize), + + #[fail(display = "Unable to verify BrTrue/BrFalse at offset {}", _0)] + BrTypeMismatchError(usize), + + #[fail(display = "Unable to verify Assert at offset {}", _0)] + AssertTypeMismatchError(usize), + + #[fail(display = "Unable to verify StLoc at offset {}", _0)] + StLocTypeMismatchError(usize), + + #[fail(display = "Unable to verify StLoc at offset {}", _0)] + StLocUnsafeToDestroyError(usize), + + #[fail(display = "Unable to verify Ret at offset {}", _0)] + RetUnsafeToDestroyError(usize), + + #[fail(display = "Unable to verify Ret at offset {}", _0)] + RetTypeMismatchError(usize), + + #[fail(display = "Unable to verify FreezeRef at offset {}", _0)] + FreezeRefTypeMismatchError(usize), + + #[fail(display = "Unable to verify FreezeRef at offset {}", _0)] + FreezeRefExistsMutableBorrowError(usize), + + #[fail(display = "Unable to verify BorrowField at offset {}", _0)] + BorrowFieldTypeMismatchError(usize), + + #[fail(display = "Unable to verify BorrowField at offset {}", _0)] + BorrowFieldBadFieldError(usize), + + #[fail(display = "Unable to verify BorrowField at offset {}", _0)] + BorrowFieldExistsMutableBorrowError(usize), + + #[fail(display = "Unable to verify CopyLoc at offset {}", _0)] + CopyLocUnavailableError(usize), + + #[fail(display = "Unable to verify CopyLoc at offset {}", _0)] + CopyLocResourceError(usize), + + #[fail(display = "Unable to verify CopyLoc at offset {}", _0)] + CopyLocExistsBorrowError(usize), + + #[fail(display = "Unable to verify MoveLoc at offset {}", _0)] + MoveLocUnavailableError(usize), + + #[fail(display = "Unable to verify MoveLoc at offset {}", _0)] + MoveLocExistsBorrowError(usize), + + #[fail(display = "Unable to verify BorrowLoc at offset {}", _0)] + BorrowLocReferenceError(usize), + + #[fail(display = "Unable to verify BorrowLoc at offset {}", _0)] + BorrowLocUnavailableError(usize), + + #[fail(display = "Unable to verify BorrowLoc at offset {}", _0)] + BorrowLocExistsBorrowError(usize), + + #[fail(display = "Unable to verify Call at offset {}", _0)] + CallTypeMismatchError(usize), + + #[fail(display = "Unable to verify BorrowLoc at offset {}", _0)] + CallBorrowedMutableReferenceError(usize), + + #[fail(display = "Unable to verify Pack at offset {}", _0)] + PackTypeMismatchError(usize), + + #[fail(display = "Unable to verify Unpack at offset {}", _0)] + UnpackTypeMismatchError(usize), + + #[fail(display = "Unable to verify ReadRef at offset {}", _0)] + ReadRefTypeMismatchError(usize), + + #[fail(display = "Unable to verify ReadRef at offset {}", _0)] + ReadRefResourceError(usize), + + #[fail(display = "Unable to verify ReadRef at offset {}", _0)] + ReadRefExistsMutableBorrowError(usize), + + #[fail(display = "Unable to verify WriteRef at offset {}", _0)] + WriteRefTypeMismatchError(usize), + + #[fail(display = "Unable to verify WriteRef at offset {}", _0)] + WriteRefResourceError(usize), + + #[fail(display = "Unable to verify WriteRef at offset {}", _0)] + WriteRefExistsBorrowError(usize), + + #[fail(display = "Unable to verify WriteRef at offset {}", _0)] + WriteRefNoMutableReferenceError(usize), + + #[fail(display = "Unable to verify integer operation at offset {}", _0)] + IntegerOpTypeMismatchError(usize), + + #[fail(display = "Unable to verify boolean operation at offset {}", _0)] + BooleanOpTypeMismatchError(usize), + + #[fail(display = "Unable to verify equality operation at offset {}", _0)] + EqualityOpTypeMismatchError(usize), + + #[fail(display = "Unable to verify Exists at offset {}", _0)] + ExistsResourceTypeMismatchError(usize), + + #[fail(display = "Unable to verify BorrowGlobal at offset {}", _0)] + BorrowGlobalTypeMismatchError(usize), + + #[fail(display = "Unable to verify BorrowGlobal at offset {}", _0)] + BorrowGlobalNoResourceError(usize), + + #[fail(display = "Unable to verify MoveFrom at offset {}", _0)] + MoveFromTypeMismatchError(usize), + + #[fail(display = "Unable to verify MoveFrom at offset {}", _0)] + MoveFromNoResourceError(usize), + + #[fail(display = "Unable to verify MoveToSender at offset {}", _0)] + MoveToSenderTypeMismatchError(usize), + + #[fail(display = "Unable to verify MoveToSender at offset {}", _0)] + MoveToSenderNoResourceError(usize), + + #[fail(display = "Unable to verify MoveToSender at offset {}", _0)] + CreateAccountTypeMismatchError(usize), +} + +#[derive(Clone, Debug, Eq, Fail, Ord, PartialEq, PartialOrd)] +pub enum VMInvariantViolation { + #[fail( + display = "Index out of bounds for '{}' (expected 0..{}, found {})", + _0, _1, _2 + )] + IndexOutOfBounds(IndexKind, usize, usize), + #[fail( + display = "Range out of bounds for '{}' (expected 0..{}, found {}..{})", + _0, _1, _2, _3 + )] + RangeOutOfBounds(IndexKind, usize, usize, usize), + #[fail(display = "Try to pop an empty value stack")] + EmptyValueStack, + #[fail(display = "Try to pop an empty call stack")] + EmptyCallStack, + #[fail(display = "Program Counter Overflows")] + ProgramCounterOverflow, + #[fail(display = "Linker can't find the destination code")] + LinkerError, + #[fail(display = "Owned value has multiple references")] + LocalReferenceError, + #[fail(display = "Failed to get response from storage")] + StorageError, +} + +/// Error codes that can be emitted by the prologue. These have special significance to the VM when +/// they are raised during the prologue. However, they can also be raised by user code during +/// execution of a transaction script. They have no significance to the VM in that case. +pub const EBAD_SIGNATURE: u64 = 1; // signature on transaction is invalid +pub const EBAD_ACCOUNT_AUTHENTICATION_KEY: u64 = 2; // auth key in transaction is invalid +pub const ESEQUENCE_NUMBER_TOO_OLD: u64 = 3; // transaction sequence number is too old +pub const ESEQUENCE_NUMBER_TOO_NEW: u64 = 4; // transaction sequence number is too new +pub const EACCOUNT_DOES_NOT_EXIST: u64 = 5; // transaction sender's account does not exist +pub const ECANT_PAY_GAS_DEPOSIT: u64 = 6; // insufficient balance to pay for gas deposit + +/// Generic error codes. These codes don't have any special meaning for the VM, but they are useful +/// conventions for debugging +pub const EINSUFFICIENT_BALANCE: u64 = 10; // withdrawing more than an account contains +pub const EASSERT_ERROR: u64 = 42; // catch-all error code for assert failures + +pub type VMRuntimeResult = ::std::result::Result; +pub type VMResult = ::std::result::Result, VMInvariantViolation>; + +impl Location { + pub fn new() -> Self { + Location {} + } +} + +pub type BinaryLoaderResult = ::std::result::Result; + +// TODO: This is an initial set of errors that needs to be expanded. +// Also it's not clear whether we should fold this into other error types +#[derive(Clone, Debug, Eq, Fail, PartialEq)] +pub enum BinaryError { + #[fail(display = "Malformed binary")] + Malformed, + #[fail(display = "Bad magic")] + BadMagic, + #[fail(display = "Unknown version")] + UnknownVersion, + #[fail(display = "Unknown table type")] + UnknownTableType, + #[fail(display = "Unknown signature type")] + UnknownSignatureType, + #[fail(display = "Unexpected signature type")] + UnexpectedSignatureType, + #[fail(display = "Unknown serialized type")] + UnknownSerializedType, + #[fail(display = "Unknown opcode")] + UnknownOpcode, + #[fail(display = "Wrong table header format (offset or count)")] + BadHeaderTable, + #[fail(display = "Duplicate table type")] + DuplicateTable, +} + +#[macro_export] +macro_rules! try_runtime { + ($e:expr) => { + match $e { + Ok(Ok(t)) => t, + Ok(Err(e)) => return Ok(Err(e)), + Err(e) => return Err(e), + } + }; +} + +#[macro_export] +macro_rules! assert_ok { + ($e:expr) => { + assert!(match $e { + Ok(Ok(t)) => true, + Ok(Err(e)) => { + println!("Unexpected Runtime Error: {:?}", e); + false + } + Err(e) => { + println!("Unexpected ICE: {:?}", e); + false + } + }) + }; +} + +//////////////////////////////////////////////////////////////////////////// +/// Conversion functions from internal VM statuses into external VM statuses +//////////////////////////////////////////////////////////////////////////// + +pub fn to_vm_status<'a, T, E>(result: &'a ::std::result::Result) -> VMStatus +where + VMStatus: From<&'a E>, + E: 'a, +{ + use types::vm_error::ExecutionStatus; + match result { + Ok(_) => VMStatus::Execution(ExecutionStatus::Executed), + Err(err) => err.into(), + } +} + +pub fn vm_status_of_result(result: &VMResult) -> VMStatus { + match result { + Ok(runtime_result) => to_vm_status(runtime_result), + Err(err) => err.into(), + } +} + +// FUTURE: At the moment we can't pass transaction metadata or the signed transaction due to +// restrictions in the two places that this function is called. We therefore just pass through what +// we need at the moment---the sender address---but we may want/need to pass more data later on. +pub fn convert_prologue_runtime_error( + err: &VMRuntimeError, + txn_sender: &AccountAddress, +) -> VMStatus { + use VMErrorKind::*; + match err.err { + // Invalid authentication key + AssertionFailure(EBAD_ACCOUNT_AUTHENTICATION_KEY) => { + VMStatus::Validation(VMValidationStatus::InvalidAuthKey) + } + // Sequence number too old + AssertionFailure(ESEQUENCE_NUMBER_TOO_OLD) => { + VMStatus::Validation(VMValidationStatus::SequenceNumberTooOld) + } + // Sequence number too new + AssertionFailure(ESEQUENCE_NUMBER_TOO_NEW) => { + VMStatus::Validation(VMValidationStatus::SequenceNumberTooNew) + } + // Sequence number too new + AssertionFailure(EACCOUNT_DOES_NOT_EXIST) => { + let error_msg = format!("sender address: {}", txn_sender); + VMStatus::Validation(VMValidationStatus::SendingAccountDoesNotExist(error_msg)) + } + // Can't pay for transaction gas deposit/fee + AssertionFailure(ECANT_PAY_GAS_DEPOSIT) => { + VMStatus::Validation(VMValidationStatus::InsufficientBalanceForTransactionFee) + } + _ => err.into(), + } +} + +/////////////////////////////////////////////////////////////////// +/// Conversion from internal VM statuses into external VM statuses +/////////////////////////////////////////////////////////////////// + +impl From<&BinaryError> for VMStatus { + fn from(error: &BinaryError) -> Self { + use types::vm_error::BinaryError as VMBinaryError; + let bin_err = match error { + BinaryError::Malformed => VMBinaryError::Malformed, + BinaryError::BadMagic => VMBinaryError::BadMagic, + BinaryError::UnknownVersion => VMBinaryError::UnknownVersion, + BinaryError::UnknownTableType => VMBinaryError::UnknownTableType, + BinaryError::UnknownSignatureType => VMBinaryError::UnknownSignatureType, + BinaryError::UnknownSerializedType => VMBinaryError::UnknownSerializedType, + BinaryError::UnknownOpcode => VMBinaryError::UnknownOpcode, + BinaryError::BadHeaderTable => VMBinaryError::BadHeaderTable, + BinaryError::DuplicateTable => VMBinaryError::DuplicateTable, + BinaryError::UnexpectedSignatureType => VMBinaryError::UnexpectedSignatureType, + }; + VMStatus::Deserialization(bin_err) + } +} + +impl From<&VMInvariantViolation> for VMStatus { + fn from(error: &VMInvariantViolation) -> Self { + use types::vm_error::VMInvariantViolationError; + let err = match error { + VMInvariantViolation::IndexOutOfBounds(_, _, _) => { + VMInvariantViolationError::OutOfBoundsIndex + } + VMInvariantViolation::RangeOutOfBounds(_, _, _, _) => { + VMInvariantViolationError::OutOfBoundsRange + } + VMInvariantViolation::EmptyValueStack => VMInvariantViolationError::EmptyValueStack, + VMInvariantViolation::EmptyCallStack => VMInvariantViolationError::EmptyCallStack, + VMInvariantViolation::ProgramCounterOverflow => VMInvariantViolationError::PCOverflow, + VMInvariantViolation::LinkerError => VMInvariantViolationError::LinkerError, + VMInvariantViolation::LocalReferenceError => { + VMInvariantViolationError::LocalReferenceError + } + VMInvariantViolation::StorageError => VMInvariantViolationError::StorageError, + }; + VMStatus::InvariantViolation(err) + } +} + +impl From<&VerificationError> for VMVerificationError { + fn from(error: &VerificationError) -> Self { + let message = format!("{}", error); + match error.err { + VMStaticViolation::IndexOutOfBounds(_, _, _) => { + VMVerificationError::IndexOutOfBounds(message) + } + VMStaticViolation::RangeOutOfBounds(_, _, _, _) => { + VMVerificationError::RangeOutOfBounds(message) + } + VMStaticViolation::NoModuleHandles => VMVerificationError::NoModuleHandles(message), + VMStaticViolation::ModuleAddressDoesNotMatchSender => { + VMVerificationError::ModuleAddressDoesNotMatchSender(message) + } + VMStaticViolation::InvalidSignatureToken(_, _, _) => { + VMVerificationError::InvalidSignatureToken(message) + } + VMStaticViolation::DuplicateElement => VMVerificationError::DuplicateElement(message), + VMStaticViolation::InvalidModuleHandle => { + VMVerificationError::InvalidModuleHandle(message) + } + VMStaticViolation::UnimplementedHandle => { + VMVerificationError::UnimplementedHandle(message) + } + VMStaticViolation::InconsistentFields => { + VMVerificationError::InconsistentFields(message) + } + VMStaticViolation::UnusedFields => VMVerificationError::UnusedFields(message), + VMStaticViolation::InvalidFieldDefReference(_, _) => { + VMVerificationError::InvalidFieldDefReference(message) + } + VMStaticViolation::RecursiveStructDef => { + VMVerificationError::RecursiveStructDefinition(message) + } + VMStaticViolation::InvalidResourceField => { + VMVerificationError::InvalidResourceField(message) + } + VMStaticViolation::InvalidFallThrough => { + VMVerificationError::InvalidFallThrough(message) + } + VMStaticViolation::JoinFailure(_) => VMVerificationError::JoinFailure(message), + VMStaticViolation::NegativeStackSizeInsideBlock(_, _) => { + VMVerificationError::NegativeStackSizeWithinBlock(message) + } + VMStaticViolation::PositiveStackSizeAtBlockEnd(_) => { + VMVerificationError::UnbalancedStack(message) + } + VMStaticViolation::InvalidMainFunctionSignature => { + VMVerificationError::InvalidMainFunctionSignature(message) + } + VMStaticViolation::LookupFailed => VMVerificationError::LookupFailed(message), + VMStaticViolation::VisibilityMismatch => { + VMVerificationError::VisibilityMismatch(message) + } + VMStaticViolation::TypeResolutionFailure => { + VMVerificationError::TypeResolutionFailure(message) + } + VMStaticViolation::TypeMismatch => VMVerificationError::TypeMismatch(message), + VMStaticViolation::MissingDependency => VMVerificationError::MissingDependency(message), + VMStaticViolation::PopReferenceError(_) => { + VMVerificationError::PopReferenceError(message) + } + VMStaticViolation::PopResourceError(_) => { + VMVerificationError::PopResourceError(message) + } + VMStaticViolation::ReleaseRefTypeMismatchError(_) => { + VMVerificationError::ReleaseRefTypeMismatchError(message) + } + VMStaticViolation::BrTypeMismatchError(_) => { + VMVerificationError::BrTypeMismatchError(message) + } + VMStaticViolation::AssertTypeMismatchError(_) => { + VMVerificationError::AssertTypeMismatchError(message) + } + VMStaticViolation::StLocTypeMismatchError(_) => { + VMVerificationError::StLocTypeMismatchError(message) + } + VMStaticViolation::StLocUnsafeToDestroyError(_) => { + VMVerificationError::StLocUnsafeToDestroyError(message) + } + VMStaticViolation::RetUnsafeToDestroyError(_) => { + VMVerificationError::RetUnsafeToDestroyError(message) + } + VMStaticViolation::RetTypeMismatchError(_) => { + VMVerificationError::RetTypeMismatchError(message) + } + VMStaticViolation::FreezeRefTypeMismatchError(_) => { + VMVerificationError::FreezeRefTypeMismatchError(message) + } + VMStaticViolation::FreezeRefExistsMutableBorrowError(_) => { + VMVerificationError::FreezeRefExistsMutableBorrowError(message) + } + VMStaticViolation::BorrowFieldTypeMismatchError(_) => { + VMVerificationError::BorrowFieldTypeMismatchError(message) + } + VMStaticViolation::BorrowFieldBadFieldError(_) => { + VMVerificationError::BorrowFieldBadFieldError(message) + } + VMStaticViolation::BorrowFieldExistsMutableBorrowError(_) => { + VMVerificationError::BorrowFieldExistsMutableBorrowError(message) + } + VMStaticViolation::CopyLocUnavailableError(_) => { + VMVerificationError::CopyLocUnavailableError(message) + } + VMStaticViolation::CopyLocResourceError(_) => { + VMVerificationError::CopyLocResourceError(message) + } + VMStaticViolation::CopyLocExistsBorrowError(_) => { + VMVerificationError::CopyLocExistsBorrowError(message) + } + VMStaticViolation::MoveLocUnavailableError(_) => { + VMVerificationError::MoveLocUnavailableError(message) + } + VMStaticViolation::MoveLocExistsBorrowError(_) => { + VMVerificationError::MoveLocExistsBorrowError(message) + } + VMStaticViolation::BorrowLocReferenceError(_) => { + VMVerificationError::BorrowLocReferenceError(message) + } + VMStaticViolation::BorrowLocUnavailableError(_) => { + VMVerificationError::BorrowLocUnavailableError(message) + } + VMStaticViolation::BorrowLocExistsBorrowError(_) => { + VMVerificationError::BorrowLocExistsBorrowError(message) + } + VMStaticViolation::CallTypeMismatchError(_) => { + VMVerificationError::CallTypeMismatchError(message) + } + VMStaticViolation::CallBorrowedMutableReferenceError(_) => { + VMVerificationError::CallBorrowedMutableReferenceError(message) + } + VMStaticViolation::PackTypeMismatchError(_) => { + VMVerificationError::PackTypeMismatchError(message) + } + VMStaticViolation::UnpackTypeMismatchError(_) => { + VMVerificationError::UnpackTypeMismatchError(message) + } + VMStaticViolation::ReadRefTypeMismatchError(_) => { + VMVerificationError::ReadRefTypeMismatchError(message) + } + VMStaticViolation::ReadRefResourceError(_) => { + VMVerificationError::ReadRefResourceError(message) + } + VMStaticViolation::ReadRefExistsMutableBorrowError(_) => { + VMVerificationError::ReadRefExistsMutableBorrowError(message) + } + VMStaticViolation::WriteRefTypeMismatchError(_) => { + VMVerificationError::WriteRefTypeMismatchError(message) + } + VMStaticViolation::WriteRefResourceError(_) => { + VMVerificationError::WriteRefResourceError(message) + } + VMStaticViolation::WriteRefExistsBorrowError(_) => { + VMVerificationError::WriteRefExistsBorrowError(message) + } + VMStaticViolation::WriteRefNoMutableReferenceError(_) => { + VMVerificationError::WriteRefNoMutableReferenceError(message) + } + VMStaticViolation::IntegerOpTypeMismatchError(_) => { + VMVerificationError::IntegerOpTypeMismatchError(message) + } + VMStaticViolation::BooleanOpTypeMismatchError(_) => { + VMVerificationError::BooleanOpTypeMismatchError(message) + } + VMStaticViolation::EqualityOpTypeMismatchError(_) => { + VMVerificationError::EqualityOpTypeMismatchError(message) + } + VMStaticViolation::ExistsResourceTypeMismatchError(_) => { + VMVerificationError::ExistsResourceTypeMismatchError(message) + } + VMStaticViolation::BorrowGlobalTypeMismatchError(_) => { + VMVerificationError::BorrowGlobalTypeMismatchError(message) + } + VMStaticViolation::BorrowGlobalNoResourceError(_) => { + VMVerificationError::BorrowGlobalNoResourceError(message) + } + VMStaticViolation::MoveFromTypeMismatchError(_) => { + VMVerificationError::MoveFromTypeMismatchError(message) + } + VMStaticViolation::MoveFromNoResourceError(_) => { + VMVerificationError::MoveFromNoResourceError(message) + } + VMStaticViolation::MoveToSenderTypeMismatchError(_) => { + VMVerificationError::MoveToSenderTypeMismatchError(message) + } + VMStaticViolation::MoveToSenderNoResourceError(_) => { + VMVerificationError::MoveToSenderNoResourceError(message) + } + VMStaticViolation::CreateAccountTypeMismatchError(_) => { + VMVerificationError::CreateAccountTypeMismatchError(message) + } + } + } +} + +impl From<&VerificationStatus> for VMVerificationStatus { + fn from(status: &VerificationStatus) -> Self { + match status { + VerificationStatus::Script(err) => VMVerificationStatus::Script(err.into()), + VerificationStatus::Module(module_idx, err) => { + VMVerificationStatus::Module(*module_idx, err.into()) + } + } + } +} + +impl<'a> FromIterator<&'a VerificationStatus> for VMStatus { + fn from_iter(iter: T) -> Self + where + T: IntoIterator, + { + let status_list = iter.into_iter().map(VMVerificationStatus::from).collect(); + VMStatus::Verification(status_list) + } +} + +impl From<&VMErrorKind> for VMStatus { + fn from(error: &VMErrorKind) -> Self { + use types::vm_error::{ArithmeticErrorType, DynamicReferenceErrorType, ExecutionStatus}; + let err = match error { + VMErrorKind::ArithmeticError => { + ExecutionStatus::ArithmeticError(ArithmeticErrorType::Underflow) + } + VMErrorKind::AssertionFailure(err_code) => ExecutionStatus::AssertionFailure(*err_code), + VMErrorKind::OutOfGasError => ExecutionStatus::OutOfGas, + VMErrorKind::TypeError => ExecutionStatus::TypeError, + VMErrorKind::GlobalRefAlreadyReleased => ExecutionStatus::DynamicReferenceError( + DynamicReferenceErrorType::GlobalRefAlreadyReleased, + ), + VMErrorKind::MissingReleaseRef => { + ExecutionStatus::DynamicReferenceError(DynamicReferenceErrorType::MissingReleaseRef) + } + VMErrorKind::GlobalAlreadyBorrowed => ExecutionStatus::DynamicReferenceError( + DynamicReferenceErrorType::GlobalAlreadyBorrowed, + ), + VMErrorKind::MissingData => ExecutionStatus::MissingData, + VMErrorKind::DataFormatError => ExecutionStatus::DataFormatError, + VMErrorKind::InvalidData => ExecutionStatus::InvalidData, + VMErrorKind::RemoteDataError => ExecutionStatus::RemoteDataError, + VMErrorKind::CannotWriteExistingResource => { + ExecutionStatus::CannotWriteExistingResource + } + VMErrorKind::ValueSerializerError => ExecutionStatus::ValueSerializationError, + VMErrorKind::ValueDeserializerError => ExecutionStatus::ValueDeserializationError, + VMErrorKind::DuplicateModuleName => ExecutionStatus::DuplicateModuleName, + VMErrorKind::CodeSerializerError(err) => return VMStatus::from(err), + VMErrorKind::CodeDeserializerError(err) => return VMStatus::from(err), + }; + VMStatus::Execution(err) + } +} + +impl From<&VMRuntimeError> for VMStatus { + fn from(error: &VMRuntimeError) -> Self { + VMStatus::from(&error.err) + } +} diff --git a/language/vm/src/file_format.rs b/language/vm/src/file_format.rs new file mode 100644 index 0000000000000..3db3965eeea14 --- /dev/null +++ b/language/vm/src/file_format.rs @@ -0,0 +1,1251 @@ +// Copyright (c) The Libra Core Contributors +// SPDX-License-Identifier: Apache-2.0 + +//! Binary format for transactions and modules. +//! +//! This module provides a simple Rust abstraction over the binary format. That is the format of +//! modules stored on chain or the format of the code section of a transaction. +//! +//! `file_format_common.rs` provides the constant values for entities in the binary format. +//! (*The binary format is evolving so please come back here in time to check evolutions.*) +//! +//! Overall the binary format is structured in a number of sections: +//! - **Header**: this must start at offset 0 in the binary. It contains a blob that starts every +//! Libra binary, followed by the version of the VM used to compile the code, and last is the +//! number of tables present in this binary. +//! - **Table Specification**: it's a number of tuple of the form +//! `(table type, starting_offset, byte_count)`. The number of entries is specified in the +//! header (last entry in header). There can only be a single entry per table type. The +//! `starting offset` is from the beginning of the binary. Tables must cover the entire size of +//! the binary blob and cannot overlap. +//! - **Table Content**: the serialized form of the specific entries in the table. Those roughly +//! map to the structs defined in this module. Entries in each table must be unique. +//! +//! We have two formats: one for modules here represented by `CompiledModule`, another +//! for transaction scripts which is `CompiledScript`. Building those tables and passing them +//! to the serializer (`serializer.rs`) generates a binary of the form described. Vectors in +//! those structs translate to tables and table specifications. + +use crate::{access::BaseAccess, internals::ModuleIndex, IndexKind, SignatureTokenKind}; +use proptest::{collection::vec, prelude::*, strategy::BoxedStrategy}; +use proptest_derive::Arbitrary; +use types::{account_address::AccountAddress, byte_array::ByteArray, language_storage::CodeKey}; + +/// Generic index into one of the tables in the binary format. +pub type TableIndex = u16; + +macro_rules! define_index { + { + name: $name: ident, + kind: $kind: ident, + doc: $comment: literal, + } => { + #[derive(Arbitrary, Clone, Copy, Default, Eq, Hash, Ord, PartialEq, PartialOrd)] + #[proptest(no_params)] + #[doc=$comment] + pub struct $name(pub TableIndex); + + /// Returns an instance of the given `Index`. + impl $name { + pub fn new(idx: TableIndex) -> Self { + Self(idx) + } + } + + impl ::std::fmt::Display for $name { + fn fmt(&self, f: &mut ::std::fmt::Formatter) -> ::std::fmt::Result { + write!(f, "{}", self.0) + } + } + + impl ::std::fmt::Debug for $name { + fn fmt(&self, f: &mut ::std::fmt::Formatter) -> ::std::fmt::Result { + write!(f, "{}({})", stringify!($name), self.0) + } + } + + impl ModuleIndex for $name { + const KIND: IndexKind = IndexKind::$kind; + + #[inline] + fn into_index(self) -> usize { + self.0 as usize + } + } + }; +} + +define_index! { + name: ModuleHandleIndex, + kind: ModuleHandle, + doc: "Index into the `ModuleHandle` table.", +} +define_index! { + name: StructHandleIndex, + kind: StructHandle, + doc: "Index into the `StructHandle` table.", +} +define_index! { + name: FunctionHandleIndex, + kind: FunctionHandle, + doc: "Index into the `FunctionHandle` table.", +} +define_index! { + name: StringPoolIndex, + kind: StringPool, + doc: "Index into the `StringPool` table.", +} +define_index! { + name: ByteArrayPoolIndex, + kind: ByteArrayPool, + doc: "Index into the `ByteArrayPool` table.", +} +define_index! { + name: AddressPoolIndex, + kind: AddressPool, + doc: "Index into the `AddressPool` table.", +} +define_index! { + name: TypeSignatureIndex, + kind: TypeSignature, + doc: "Index into the `TypeSignature` table.", +} +define_index! { + name: FunctionSignatureIndex, + kind: FunctionSignature, + doc: "Index into the `FunctionSignature` table.", +} +define_index! { + name: LocalsSignatureIndex, + kind: LocalsSignature, + doc: "Index into the `LocalsSignature` table.", +} +define_index! { + name: StructDefinitionIndex, + kind: StructDefinition, + doc: "Index into the `StructDefinition` table.", +} +define_index! { + name: FieldDefinitionIndex, + kind: FieldDefinition, + doc: "Index into the `FieldDefinition` table.", +} +define_index! { + name: FunctionDefinitionIndex, + kind: FunctionDefinition, + doc: "Index into the `FunctionDefinition` table.", +} + +/// Index of a local variable in a function. +/// +/// Bytecodes that operate on locals carry indexes to the locals of a function. +pub type LocalIndex = u8; +/// Max number of fields in a `StructDefinition`. +pub type MemberCount = u16; +/// Index into the code stream for a jump. The offset is relative to the beginning of +/// the instruction stream. +pub type CodeOffset = u16; + +/// The pool of identifiers and string literals. +pub type StringPool = Vec; +/// The pool of `ByteArray` literals. +pub type ByteArrayPool = Vec; +/// The pool of `AccountAddress` literals. +/// +/// Code references have a literal addresses in `ModuleHandle`s. Literal references to data in +/// the blockchain are also published here. +pub type AddressPool = Vec; +/// The pool of `TypeSignature` instances. Those are system and user types used and +/// their composition (e.g. &U64). +pub type TypeSignaturePool = Vec; +/// The pool of `FunctionSignature` instances. +pub type FunctionSignaturePool = Vec; +/// The pool of `LocalsSignature` instances. Every function definition must define the set of +/// locals used and their types. +pub type LocalsSignaturePool = Vec; + +/// Name of the placeholder module. Every compiled script has an entry that +/// refers to itself in its module handle list. This is the name of that script. +pub const SELF_MODULE_NAME: &str = ""; + +// HANDLES: +// Handles are structs that accompany opcodes that need references: a type reference, +// or a function reference (a field reference being available only within the module that +// defrines the field can be a definition). +// Handles refer to both internal and external "entities" and are embedded as indexes +// in the instruction stream. +// Handles define resolution. Resolution is assumed to be by (name, signature) + +/// A `ModuleHandle` is a reference to a MOVE module. It is composed by an `address` and a `name`. +/// +/// A `ModuleHandle` uniquely identifies a code resource in the blockchain. +/// The `address` is a reference to the account that holds the code and the `name` is used as a +/// key in order to load the module. +/// +/// Modules live in the *code* namespace of an LibraAccount. +/// +/// Modules introduce a scope made of all types defined in the module and all functions. +/// Type definitions (fields) are private to the module. Outside the module a +/// Type is an opaque handle. +#[derive(Arbitrary, Clone, Debug, Eq, Hash, PartialEq, PartialOrd, Ord)] +#[proptest(no_params)] +pub struct ModuleHandle { + /// Index into the `AddressPool`. Identifies the account that holds the module. + pub address: AddressPoolIndex, + /// The name of the module published in the code section for the account in `address`. + pub name: StringPoolIndex, +} + +/// A `StructHandle` is a reference to a user defined type. It is composed by a `ModuleHandle` +/// and the name of the type within that module. +/// +/// A type in a module is uniquely identified by its name and as such the name is enough +/// to perform resolution. +/// +/// The `StructHandle` also carries the type *kind* (resource/unrestricted) so that the verifier +/// can check resource semantic without having to load the referenced type. +/// At link time a check of the kind is performed and an error is reported if there is a +/// mismatch with the definition. +#[derive(Arbitrary, Clone, Debug, Eq, Hash, PartialEq, PartialOrd, Ord)] +#[proptest(no_params)] +pub struct StructHandle { + /// The module that defines the type. + pub module: ModuleHandleIndex, + /// The name of the type. + pub name: StringPoolIndex, + /// Whether the type is a resource or an unrestricted type. + pub is_resource: bool, +} + +/// A `FunctionHandle` is a reference to a function. It is composed by a +/// `ModuleHandle` and the name and signature of that function within the module. +/// +/// A function within a module is uniquely identified by its name. No overloading is allowed +/// and the verifier enforces that property. The signature of the function is used at link time to +/// ensure the function reference is valid and it is also used by the verifier to type check +/// function calls. +#[derive(Arbitrary, Clone, Debug, Eq, Hash, PartialEq)] +#[proptest(no_params)] +pub struct FunctionHandle { + /// The module that defines the function. + pub module: ModuleHandleIndex, + /// The name of the function. + pub name: StringPoolIndex, + /// The signature of the function. + pub signature: FunctionSignatureIndex, +} + +// DEFINITIONS: +// Definitions are the module code. So the set of types and functions in the module. + +/// A `StructDefinition` is a user type definition. It defines all the fields declared on the type. +#[derive(Arbitrary, Clone, Debug, Eq, PartialEq)] +#[proptest(no_params)] +pub struct StructDefinition { + /// The `StructHandle` for this `StructDefinition`. This has the name and the resource flag + /// for the type. + pub struct_handle: StructHandleIndex, + /// The number of fields in this type. + pub field_count: MemberCount, + /// The starting index for the fields of this type. `FieldDefinition`s for each type must + /// be consecutively stored in the `FieldDefinition` table. + pub fields: FieldDefinitionIndex, +} + +/// A `FieldDefinition` is the definition of a field: the type the field is defined on, +/// its name and the field type. +#[derive(Arbitrary, Clone, Debug, Eq, PartialEq)] +#[proptest(no_params)] +pub struct FieldDefinition { + /// The type (resource or unrestricted) the field is defined on. + pub struct_: StructHandleIndex, + /// The name of the field. + pub name: StringPoolIndex, + /// The type of the field. + pub signature: TypeSignatureIndex, +} + +/// A `FunctionDefinition` is the implementation of a function. It defines +/// the *prototype* of the function and the function body. +#[derive(Arbitrary, Clone, Debug, Default, Eq, PartialEq)] +#[proptest(params = "usize")] +pub struct FunctionDefinition { + /// The prototype of the function (module, name, signature). + pub function: FunctionHandleIndex, + /// Flags for this function (private, public, native, etc.) + pub flags: u8, + /// Code for this function. + #[proptest(strategy = "any_with::(params)")] + pub code: CodeUnit, +} + +impl FunctionDefinition { + /// Returns whether the FunctionDefinition is public. + pub fn is_public(&self) -> bool { + self.flags & CodeUnit::PUBLIC != 0 + } + /// Returns whether the FunctionDefinition is native. + pub fn is_native(&self) -> bool { + self.flags & CodeUnit::NATIVE != 0 + } +} + +// Signature definitions. +// A signature can be for a type (field, local) or for a function - return type: (arguments). +// They both go into the signature table so there is a marker that tags the signature. +// Signature usually don't carry a size and you have to read them to get to the end. + +/// A type definition. `SignatureToken` allows the definition of the set of known types and their +/// composition. +#[derive(Arbitrary, Clone, Debug, Eq, Hash, PartialEq)] +#[proptest(no_params)] +pub struct TypeSignature(pub SignatureToken); + +/// A `FunctionSignature` describes the arguments and the return types of a function. +#[derive(Arbitrary, Clone, Debug, Eq, Hash, PartialEq)] +#[proptest(params = "usize")] +pub struct FunctionSignature { + /// The list of return types. + #[proptest(strategy = "vec(any::(), 0..=params)")] + pub return_types: Vec, + /// The list of arguments to the function. + #[proptest(strategy = "vec(any::(), 0..=params)")] + pub arg_types: Vec, +} + +/// A `LocalsSignature` is the list of locals used by a function. +/// +/// Locals include the arguments to the function from position `0` to argument `count - 1`. +/// The remaining elements are the type of each local. +#[derive(Arbitrary, Clone, Debug, Default, Eq, Hash, PartialEq)] +#[proptest(params = "usize")] +pub struct LocalsSignature( + #[proptest(strategy = "vec(any::(), 0..=params)")] pub Vec, +); + +impl LocalsSignature { + /// Length of the `LocalsSignature`. + #[inline] + pub fn len(&self) -> usize { + self.0.len() + } + + /// Whether the function has no locals (both arguments or locals). + #[inline] + pub fn is_empty(&self) -> bool { + self.0.is_empty() + } +} + +/// A `SignatureToken` is a type declaration for a location. +/// +/// Any location in the system has a TypeSignature. +/// A TypeSignature is also used in composed signatures. +/// +/// A SignatureToken can express more types than the VM can handle safely, and correctness is +/// enforced by the verifier. +#[derive(Clone, Eq, Hash, Ord, PartialEq, PartialOrd)] +pub enum SignatureToken { + /// Boolean, `true` or `false`. + Bool, + /// Unsigned integers, 64 bits length. + U64, + /// Strings, immutable, utf8 representation. + String, + /// ByteArray, variable size, immutable byte array. + ByteArray, + /// Address, a 32 bytes immutable type. + Address, + /// MOVE user type, resource or unrestricted + Struct(StructHandleIndex), + /// Reference to a type. + Reference(Box), + /// Immutable reference to a type. + MutableReference(Box), +} + +/// `Arbitrary` for `SignatureToken` cannot be derived automatically as it's a recursive type. +impl Arbitrary for SignatureToken { + type Strategy = BoxedStrategy; + type Parameters = (); + + fn arbitrary_with(_params: Self::Parameters) -> Self::Strategy { + use SignatureToken::*; + + let leaf = prop_oneof![ + Just(Bool), + Just(U64), + Just(String), + Just(ByteArray), + Just(Address), + any::().prop_map(Struct), + ]; + leaf.prop_recursive( + 8, // levels deep + 16, // max size + 1, // items per collection + |inner| { + prop_oneof![ + inner.clone().prop_map(|token| Reference(Box::new(token))), + inner + .clone() + .prop_map(|token| MutableReference(Box::new(token))), + ] + }, + ) + .boxed() + } +} + +impl ::std::fmt::Debug for SignatureToken { + fn fmt(&self, f: &mut ::std::fmt::Formatter) -> ::std::fmt::Result { + match self { + SignatureToken::Bool => write!(f, "Bool"), + SignatureToken::U64 => write!(f, "U64"), + SignatureToken::String => write!(f, "String"), + SignatureToken::ByteArray => write!(f, "ByteArray"), + SignatureToken::Address => write!(f, "Address"), + SignatureToken::Struct(idx) => write!(f, "Struct({:?})", idx), + SignatureToken::Reference(boxed) => write!(f, "Reference({:?})", boxed), + SignatureToken::MutableReference(boxed) => write!(f, "MutableReference({:?})", boxed), + } + } +} + +impl SignatureToken { + /// If a `SignatureToken` is a reference it returns the underlying type of the reference (e.g. + /// U64 for &U64). + #[inline] + pub fn get_struct_handle_from_reference( + reference_signature: &SignatureToken, + ) -> Option { + match reference_signature { + SignatureToken::Reference(signature) => match **signature { + SignatureToken::Struct(idx) => Some(idx), + _ => None, + }, + SignatureToken::MutableReference(signature) => match **signature { + SignatureToken::Struct(idx) => Some(idx), + _ => None, + }, + _ => None, + } + } + + /// Returns the "kind" for the `SignatureToken` + #[inline] + pub fn kind(&self) -> SignatureTokenKind { + // TODO: SignatureTokenKind is out-dated. fix/update/remove SignatureTokenKind and see if + // this function needs to be cleaned up + use SignatureToken::*; + + match self { + Reference(_) => SignatureTokenKind::Reference, + MutableReference(_) => SignatureTokenKind::MutableReference, + Bool | U64 | ByteArray | String | Address | Struct(_) => SignatureTokenKind::Value, + } + } + + /// Returns the `StructHandleIndex` for a `SignatureToken` that contains a reference to a user + /// defined type (a resource or unrestricted type). + #[inline] + pub fn struct_index(&self) -> Option { + use SignatureToken::*; + + match self { + Struct(sh_idx) => Some(*sh_idx), + Reference(token) | MutableReference(token) => token.struct_index(), + Bool | U64 | ByteArray | String | Address => None, + } + } + + /// Returns `true` if the `SignatureToken` is a primitive type. + pub fn is_primitive(&self) -> bool { + use SignatureToken::*; + match self { + Bool | U64 | String | ByteArray | Address => true, + Struct(_) | Reference(_) | MutableReference(_) => false, + } + } + + /// Checks if the signature token is usable for Eq and Neq. + /// + /// Currently equality operations are only allowed on: + /// - Bool + /// - U64 + /// - String + /// - ByteArray + /// - Address + /// - Reference or Mutable reference to these types + pub fn allows_equality(&self) -> bool { + use SignatureToken::*; + match self { + Struct(_) => false, + Reference(token) | MutableReference(token) => token.is_primitive(), + token => token.is_primitive(), + } + } + + /// Returns true if the `SignatureToken` is any kind of reference (mutable and immutable). + pub fn is_reference(&self) -> bool { + use SignatureToken::*; + + match self { + Reference(_) | MutableReference(_) => true, + _ => false, + } + } + + /// Returns true if the `SignatureToken` is a mutable reference. + pub fn is_mutable_reference(&self) -> bool { + use SignatureToken::*; + + match self { + MutableReference(_) => true, + _ => false, + } + } + + /// Set the index to this one. Useful for random testing. + /// + /// Panics if this token doesn't contain a struct handle. + pub fn debug_set_sh_idx(&mut self, sh_idx: StructHandleIndex) { + match self { + SignatureToken::Struct(ref mut wrapped) => *wrapped = sh_idx, + SignatureToken::Reference(ref mut token) + | SignatureToken::MutableReference(ref mut token) => token.debug_set_sh_idx(sh_idx), + other => panic!( + "debug_set_sh_idx (to {}) called for non-struct token {:?}", + sh_idx, other + ), + } + } +} + +/// A `CodeUnit` is the body of a function. It has the function header and the instruction stream. +#[derive(Arbitrary, Clone, Debug, Default, Eq, PartialEq)] +#[proptest(params = "usize")] +pub struct CodeUnit { + /// Max stack size for the function - currently unused. + pub max_stack_size: u16, + /// List of locals type. All locals are typed. + pub locals: LocalsSignatureIndex, + /// Code stream, function body. + #[proptest(strategy = "vec(any::(), 0..=params)")] + pub code: Vec, +} + +/// Flags for `FunctionDeclaration`. +impl CodeUnit { + /// Function can be invoked outside of its declaring module. + pub const PUBLIC: u8 = 0x1; + /// A native function implemented in Rust. + pub const NATIVE: u8 = 0x2; +} + +/// `Bytecode` is a VM instruction of variable size. The type of the bytecode (opcode) defines +/// the size of the bytecode. +/// +/// Bytecodes operate on a stack machine and each bytecode has side effect on the stack and the +/// instruction stream. +#[derive(Arbitrary, Clone, Hash, Eq, PartialEq)] +#[proptest(no_params)] +pub enum Bytecode { + /// Pop and discard the value at the top of the stack. + /// The value on the stack must be an unrestricted type. + /// + /// Stack transition: + /// + /// ```..., value -> ...``` + Pop, + /// Return from function, possibly with values according to the return types in the + /// function signature. The returned values are pushed on the stack. + /// The function signature of the function being executed defines the semantic of + /// the Ret opcode. + /// + /// Stack transition: + /// + /// ```..., arg_val(1), ..., arg_val(n) -> ..., return_val(1), ..., return_val(n)``` + Ret, + /// Branch to the instruction at position `CodeOffset` if the value at the top of the stack + /// is true. Code offsets are relative to the start of the instruction stream. + /// + /// Stack transition: + /// + /// ```..., bool_value -> ...``` + BrTrue(CodeOffset), + /// Branch to the instruction at position `CodeOffset` if the value at the top of the stack + /// is false. Code offsets are relative to the start of the instruction stream. + /// + /// Stack transition: + /// + /// ```..., bool_value -> ...``` + BrFalse(CodeOffset), + /// Branch unconditionally to the instruction at position `CodeOffset`. Code offsets are + /// relative to the start of the instruction stream. + /// + /// Stack transition: none + Branch(CodeOffset), + /// Push integer constant onto the stack. + /// + /// Stack transition: + /// + /// ```... -> ..., u64_value``` + LdConst(u64), + /// Push a `string` literal onto the stack. The string is loaded from the `StringPool` via + /// `StringPoolIndex`. + /// + /// Stack transition: + /// + /// ```... -> ..., string_value``` + LdStr(StringPoolIndex), + /// Push a `ByteArray` literal onto the stack. The `ByteArray` is loaded from the + /// `ByteArrayPool` via `ByteArrayPoolIndex`. + /// + /// Stack transition: + /// + /// ```... -> ..., bytearray_value``` + LdByteArray(ByteArrayPoolIndex), + /// Push an 'Address' literal onto the stack. The address is loaded from the + /// `AddressPool` via `AddressPoolIndex`. + /// + /// Stack transition: + /// + /// ```... -> ..., address_value``` + LdAddr(AddressPoolIndex), + /// Push `true` onto the stack. + /// + /// Stack transition: + /// + /// ```... -> ..., true``` + LdTrue, + /// Push `false` onto the stack. + /// + /// Stack transition: + /// + /// ```... -> ..., false``` + LdFalse, + /// Push the local identified by `LocalIndex` onto the stack. The value is copied and the + /// local is still safe to use. + /// + /// Stack transition: + /// + /// ```... -> ..., value``` + CopyLoc(LocalIndex), + /// Push the local identified by `LocalIndex` onto the stack. The local is moved and it is + /// invalid to use from that point on, unless a store operation writes to the local before + /// any read to that local. + /// + /// Stack transition: + /// + /// ```... -> ..., value``` + MoveLoc(LocalIndex), + /// Pop value from the top of the stack and store it into the function locals at + /// position `LocalIndex`. + /// + /// Stack transition: + /// + /// ```..., value -> ...``` + StLoc(LocalIndex), + /// Call a function. The stack has the arguments pushed first to last. + /// The arguments are consumed and pushed to the locals of the function. + /// Return values are pushed on the stack and available to the caller. + /// + /// Stack transition: + /// + /// ```..., arg(1), arg(2), ..., arg(n) -> ..., return_value(1), return_value(2), ..., + /// return_value(k)``` + Call(FunctionHandleIndex), + /// Create an instance of the type specified via `StructHandleIndex` and push it on the stack. + /// The values of the fields of the struct, in the order they appear in the struct declaration, + /// must be pushed on the stack. All fields must be provided. + /// + /// A Pack instruction must fully initialize an instance. + /// + /// Stack transition: + /// + /// ```..., field(1)_value, field(2)_value, ..., field(n)_value -> ..., instance_value``` + Pack(StructDefinitionIndex), + /// Destroy an instance of a type and push the values bound to each field on the + /// stack. + /// + /// The values of the fields of the instance appear on the stack in the order defined + /// in the struct definition. + /// + /// This order makes Unpack the inverse of Pack. So `Unpack; Pack` is the identity + /// for struct T. + /// + /// Stack transition: + /// + /// ```..., instance_value -> ..., field(1)_value, field(2)_value, ..., field(n)_value``` + Unpack(StructDefinitionIndex), + /// Read a reference. The reference is on the stack, it is consumed and the value read is + /// pushed on the stack. + /// + /// Reading a reference performs a copy of the value referenced. As such + /// ReadRef cannot be used on a reference to a Resource. + /// + /// Stack transition: + /// + /// ```..., reference_value -> ..., value``` + ReadRef, + /// Write to a reference. The reference and the value are on the stack and are consumed. + /// + /// + /// The reference must be to an unrestricted type because Resources cannot be overwritten. + /// + /// Stack transition: + /// + /// ```..., value, reference_value -> ...``` + WriteRef, + /// Release a reference. The reference will become invalid and cannot be used after. + /// + /// All references must be consumed and ReleaseRef is a way to release references not + /// consumed by other opcodes. + /// + /// Stack transition: + /// + /// ```..., reference_value -> ...``` + ReleaseRef, + /// Convert a mutable reference to an immutable reference. + /// + /// Stack transition: + /// + /// ```..., reference_value -> ..., reference_value``` + FreezeRef, + /// Load a reference to a local identified by LocalIndex. + /// + /// The local must not be a reference. + /// + /// Stack transition: + /// + /// ```... -> ..., reference``` + BorrowLoc(LocalIndex), + /// Load a reference to a field identified by `FieldDefinitionIndex`. + /// The top of the stack must be a reference to a type that contains the field definition. + /// + /// Stack transition: + /// + /// ```..., reference -> ..., field_reference``` + BorrowField(FieldDefinitionIndex), + /// Return reference to an instance of type `StructDefinitionIndex` published at the address + /// passed as argument. Abort execution if such an object does not exist or if a reference + /// has already been handed out. + /// + /// Stack transition: + /// + /// ```..., address_value -> ..., reference_value``` + BorrowGlobal(StructDefinitionIndex), + /// Add the 2 u64 at the top of the stack and pushes the result on the stack. + /// The operation aborts the transaction in case of overflow. + /// + /// Stack transition: + /// + /// ```..., u64_value(1), u64_value(2) -> ..., u64_value``` + Add, + /// Subtract the 2 u64 at the top of the stack and pushes the result on the stack. + /// The operation aborts the transaction in case of underflow. + /// + /// Stack transition: + /// + /// ```..., u64_value(1), u64_value(2) -> ..., u64_value``` + Sub, + /// Multiply the 2 u64 at the top of the stack and pushes the result on the stack. + /// The operation aborts the transaction in case of overflow. + /// + /// Stack transition: + /// + /// ```..., u64_value(1), u64_value(2) -> ..., u64_value``` + Mul, + /// Perform a modulo operation on the 2 u64 at the top of the stack and pushes the + /// result on the stack. + /// + /// Stack transition: + /// + /// ```..., u64_value(1), u64_value(2) -> ..., u64_value``` + Mod, + /// Divide the 2 u64 at the top of the stack and pushes the result on the stack. + /// The operation aborts the transaction in case of "divide by 0". + /// + /// Stack transition: + /// + /// ```..., u64_value(1), u64_value(2) -> ..., u64_value``` + Div, + /// Bitwise OR the 2 u64 at the top of the stack and pushes the result on the stack. + /// + /// Stack transition: + /// + /// ```..., u64_value(1), u64_value(2) -> ..., u64_value``` + BitOr, + /// Bitwise AND the 2 u64 at the top of the stack and pushes the result on the stack. + /// + /// Stack transition: + /// + /// ```..., u64_value(1), u64_value(2) -> ..., u64_value``` + BitAnd, + /// Bitwise XOR the 2 u64 at the top of the stack and pushes the result on the stack. + /// + /// Stack transition: + /// + /// ```..., u64_value(1), u64_value(2) -> ..., u64_value``` + Xor, + /// Logical OR the 2 bool at the top of the stack and pushes the result on the stack. + /// + /// Stack transition: + /// + /// ```..., bool_value(1), bool_value(2) -> ..., bool_value``` + Or, + /// Logical AND the 2 bool at the top of the stack and pushes the result on the stack. + /// + /// Stack transition: + /// + /// ```..., bool_value(1), bool_value(2) -> ..., bool_value``` + And, + /// Logical NOT the bool at the top of the stack and pushes the result on the stack. + /// + /// Stack transition: + /// + /// ```..., bool_value -> ..., bool_value``` + Not, + /// Compare for equality the 2 value at the top of the stack and pushes the + /// result on the stack. + /// The values on the stack cannot be resources or they will be consumed and so destroyed. + /// + /// Stack transition: + /// + /// ```..., value(1), value(2) -> ..., bool_value``` + Eq, + /// Compare for inequality the 2 value at the top of the stack and pushes the + /// result on the stack. + /// The values on the stack cannot be resources or they will be consumed and so destroyed. + /// + /// Stack transition: + /// + /// ```..., value(1), value(2) -> ..., bool_value``` + Neq, + /// Perform a "less than" operation of the 2 u64 at the top of the stack and pushes the + /// result on the stack. + /// + /// Stack transition: + /// + /// ```..., u64_value(1), u64_value(2) -> ..., bool_value``` + Lt, + /// Perform a "greater than" operation of the 2 u64 at the top of the stack and pushes the + /// result on the stack. + /// + /// Stack transition: + /// + /// ```..., u64_value(1), u64_value(2) -> ..., bool_value``` + Gt, + /// Perform a "less than or equal" operation of the 2 u64 at the top of the stack and pushes + /// the result on the stack. + /// + /// Stack transition: + /// + /// ```..., u64_value(1), u64_value(2) -> ..., bool_value``` + Le, + /// Perform a "greater than or equal" than operation of the 2 u64 at the top of the stack + /// and pushes the result on the stack. + /// + /// Stack transition: + /// + /// ```..., u64_value(1), u64_value(2) -> ..., bool_value``` + Ge, + /// asserts that the value at the top of the stack is true. Abort execution with + /// errorcode otherwise. + /// + /// + /// Stack transition: + /// + /// ```..., bool_value, errorcode -> ...``` + Assert, + /// Get gas unit price from the transaction and pushes it on the stack. + /// + /// Stack transition: + /// + /// ```... -> ..., u64_value``` + GetTxnGasUnitPrice, + /// Get max gas units set in the transaction and pushes it on the stack. + /// + /// Stack transition: + /// + /// ```... -> ..., u64_value``` + GetTxnMaxGasUnits, + /// Get remaining gas for the given transaction at the point of execution of this bytecode. + /// The result is pushed on the stack. + /// + /// Stack transition: + /// + /// ```... -> ..., u64_value``` + GetGasRemaining, + /// Get the sender address from the transaction and pushes it on the stack. + /// + /// Stack transition: + /// + /// ```... -> ..., address_value``` + GetTxnSenderAddress, + /// Returns whether or not a given address has an object of type StructDefinitionIndex + /// published already + /// + /// Stack transition: + /// + /// ```..., address_value -> ..., bool_value``` + Exists(StructDefinitionIndex), + /// Move the instance of type StructDefinitionIndex, at the address at the top of the stack. + /// Abort execution if such an object does not exist. + /// + /// Stack transition: + /// + /// ```..., address_value -> ..., value``` + MoveFrom(StructDefinitionIndex), + /// Move the instance at the top of the stack to the address of the sender. + /// Abort execution if an object of type StructDefinitionIndex already exists in address. + /// + /// Stack transition: + /// + /// ```..., address_value -> ...``` + MoveToSender(StructDefinitionIndex), + /// Create an account at the address specified. Does not return anything. + /// + /// Stack transition: + /// + /// ```..., address_value -> ...``` + CreateAccount, + /// Emit a log message. + /// This bytecode is not fully specified yet. + /// + /// Stack transition: + /// + /// ```..., reference, key, value -> ...``` + EmitEvent, + /// Get the sequence number submitted with the transaction and pushes it on the stack. + /// + /// Stack transition: + /// + /// ```... -> ..., u64_value``` + GetTxnSequenceNumber, + /// Get the public key of the sender from the transaction and pushes it on the stack. + /// + /// Stack transition: + /// + /// ```..., -> ..., bytearray_value``` + GetTxnPublicKey, +} + +impl ::std::fmt::Debug for Bytecode { + fn fmt(&self, f: &mut ::std::fmt::Formatter) -> ::std::fmt::Result { + match self { + Bytecode::Pop => write!(f, "Pop"), + Bytecode::Ret => write!(f, "Ret"), + Bytecode::BrTrue(a) => write!(f, "BrTrue({})", a), + Bytecode::BrFalse(a) => write!(f, "BrFalse({})", a), + Bytecode::Branch(a) => write!(f, "Branch({})", a), + Bytecode::LdConst(a) => write!(f, "LdConst({})", a), + Bytecode::LdStr(a) => write!(f, "LdStr({})", a), + Bytecode::LdByteArray(a) => write!(f, "LdByteArray({})", a), + Bytecode::LdAddr(a) => write!(f, "LdAddr({})", a), + Bytecode::LdTrue => write!(f, "LdTrue"), + Bytecode::LdFalse => write!(f, "LdFalse"), + Bytecode::CopyLoc(a) => write!(f, "CopyLoc({})", a), + Bytecode::MoveLoc(a) => write!(f, "MoveLoc({})", a), + Bytecode::StLoc(a) => write!(f, "StLoc({})", a), + Bytecode::Call(a) => write!(f, "Call({})", a), + Bytecode::Pack(a) => write!(f, "Pack({})", a), + Bytecode::Unpack(a) => write!(f, "Unpack({})", a), + Bytecode::ReadRef => write!(f, "ReadRef"), + Bytecode::WriteRef => write!(f, "WriteRef"), + Bytecode::ReleaseRef => write!(f, "ReleaseRef"), + Bytecode::FreezeRef => write!(f, "FreezeRef"), + Bytecode::BorrowLoc(a) => write!(f, "BorrowLoc({})", a), + Bytecode::BorrowField(a) => write!(f, "BorrowField({})", a), + Bytecode::BorrowGlobal(a) => write!(f, "BorrowGlobal({})", a), + Bytecode::Add => write!(f, "Add"), + Bytecode::Sub => write!(f, "Sub"), + Bytecode::Mul => write!(f, "Mul"), + Bytecode::Mod => write!(f, "Mod"), + Bytecode::Div => write!(f, "Div"), + Bytecode::BitOr => write!(f, "BitOr"), + Bytecode::BitAnd => write!(f, "BitAnd"), + Bytecode::Xor => write!(f, "Xor"), + Bytecode::Or => write!(f, "Or"), + Bytecode::And => write!(f, "And"), + Bytecode::Not => write!(f, "Not"), + Bytecode::Eq => write!(f, "Eq"), + Bytecode::Neq => write!(f, "Neq"), + Bytecode::Lt => write!(f, "Lt"), + Bytecode::Gt => write!(f, "Gt"), + Bytecode::Le => write!(f, "Le"), + Bytecode::Ge => write!(f, "Ge"), + Bytecode::Assert => write!(f, "Assert"), + Bytecode::GetTxnGasUnitPrice => write!(f, "GetTxnGasUnitPrice"), + Bytecode::GetTxnMaxGasUnits => write!(f, "GetTxnMaxGasUnits"), + Bytecode::GetGasRemaining => write!(f, "GetGasRemaining"), + Bytecode::GetTxnSenderAddress => write!(f, "GetTxnSenderAddress"), + Bytecode::Exists(a) => write!(f, "Exists({})", a), + Bytecode::MoveFrom(a) => write!(f, "MoveFrom({})", a), + Bytecode::MoveToSender(a) => write!(f, "MoveToSender({})", a), + Bytecode::CreateAccount => write!(f, "CreateAccount"), + Bytecode::EmitEvent => write!(f, "EmitEvent"), + Bytecode::GetTxnSequenceNumber => write!(f, "GetTxnSequenceNumber"), + Bytecode::GetTxnPublicKey => write!(f, "GetTxnPublicKey"), + } + } +} + +/// A `CompiledProgram` defines the structure of a transaction to execute. +/// It has two parts: modules to be published and a transaction script. +#[derive(Clone, Default, Eq, PartialEq, Debug)] +pub struct CompiledProgram { + /// The modules to be published + pub modules: Vec, + /// The transaction script to execute + pub script: CompiledScript, +} + +impl CompiledProgram { + /// Creates a new compiled program from compiled modules and script + pub fn new(modules: Vec, script: CompiledScript) -> Self { + CompiledProgram { modules, script } + } +} + +/// A `CompiledScript` contains the main function to execute and its dependencies. +/// +/// A CompiledScript does not have definition tables because it can only have a `main(args)`. +/// A CompiledScript defines the constant pools (string, address, signatures, etc.), the handle +/// tables (external code references) and it has a `main` definition. +#[derive(Arbitrary, Clone, Default, Eq, PartialEq, Debug)] +#[proptest(params = "usize")] +pub struct CompiledScript { + /// Handles to all modules referenced. + #[proptest(strategy = "vec(any::(), 0..=params)")] + pub module_handles: Vec, + /// Handles to external/imported types. + #[proptest(strategy = "vec(any::(), 0..=params)")] + pub struct_handles: Vec, + /// Handles to external/imported functions. + #[proptest(strategy = "vec(any::(), 0..=params)")] + pub function_handles: Vec, + + /// Type pool. All external types referenced by the transaction. + #[proptest(strategy = "vec(any::(), 0..=params)")] + pub type_signatures: TypeSignaturePool, + /// Function signature pool. The signatures of the function referenced by the transaction. + #[proptest(strategy = "vec(any_with::(params), 0..=params)")] + pub function_signatures: FunctionSignaturePool, + /// Locals signature pool. The signature of the locals in `main`. + #[proptest(strategy = "vec(any_with::(params), 0..=params)")] + pub locals_signatures: LocalsSignaturePool, + + /// String pool. All literals and identifiers used in this transaction. + #[proptest(strategy = "vec(\".*\", 0..=params)")] + pub string_pool: StringPool, + /// ByteArray pool. The byte array literals used in the transaction. + #[proptest(strategy = "vec(any::(), 0..=params)")] + pub byte_array_pool: ByteArrayPool, + /// Address pool. The address literals used in the module. Those include literals for + /// code references (`ModuleHandle`). + #[proptest(strategy = "vec(any::(), 0..=params)")] + pub address_pool: AddressPool, + + /// The main (script) to execute. + #[proptest(strategy = "any_with::(params)")] + pub main: FunctionDefinition, +} + +impl CompiledScript { + /// Converts a `CompiledScript` to a `CompiledModule` for code that wants a uniform view of + /// both. + pub fn into_module(self) -> CompiledModule { + CompiledModule { + module_handles: self.module_handles, + struct_handles: self.struct_handles, + function_handles: self.function_handles, + + type_signatures: self.type_signatures, + function_signatures: self.function_signatures, + locals_signatures: self.locals_signatures, + + string_pool: self.string_pool, + byte_array_pool: self.byte_array_pool, + address_pool: self.address_pool, + + struct_defs: vec![], + field_defs: vec![], + function_defs: vec![self.main], + } + } +} + +/// A `CompiledModule` defines the structure of a module which is the unit of published code. +/// +/// A `CompiledModule` contains a definition of types (with their fields) and functions. +/// It is a unit of code that can be used by transactions or other modules. +/// +/// A module is published as a single entry and it is retrieved as a single blob. +#[derive(Clone, Default, Eq, PartialEq, Debug)] +pub struct CompiledModule { + /// Handles to external modules and self at position 0. + pub module_handles: Vec, + /// Handles to external and internal types. + pub struct_handles: Vec, + /// Handles to external and internal functions. + pub function_handles: Vec, + + /// Type pool. A definition for all types used in the module. + pub type_signatures: TypeSignaturePool, + /// Function signature pool. Represents all function signatures defined or used in + /// the module. + pub function_signatures: FunctionSignaturePool, + /// Locals signature pool. The signature for all locals of the functions defined in + /// the module. + pub locals_signatures: LocalsSignaturePool, + + /// String pool. All literals and identifiers used in the module. + pub string_pool: StringPool, + /// ByteArray pool. The byte array literals used in the module. + pub byte_array_pool: ByteArrayPool, + /// Address pool. The address literals used in the module. Those include literals for + /// code references (`ModuleHandle`). + pub address_pool: AddressPool, + + /// Types defined in this module. + pub struct_defs: Vec, + /// Fields defined on types in this module. + pub field_defs: Vec, + /// Function defined in this module. + pub function_defs: Vec, +} + +// Need a custom implementation of Arbitrary because as of proptest-derive 0.1.1, the derivation +// doesn't work for structs with more than 10 fields. +impl Arbitrary for CompiledModule { + type Strategy = BoxedStrategy; + /// The size of the compiled module. + type Parameters = usize; + + fn arbitrary_with(size: Self::Parameters) -> Self::Strategy { + ( + ( + vec(any::(), 0..=size), + vec(any::(), 0..=size), + vec(any::(), 0..=size), + ), + ( + vec(any::(), 0..=size), + vec(any_with::(size), 0..=size), + vec(any_with::(size), 0..=size), + ), + ( + vec(any::(), 0..=size), + vec(any::(), 0..=size), + vec(any::(), 0..=size), + ), + ( + vec(any::(), 0..=size), + vec(any::(), 0..=size), + vec(any_with::(size), 0..=size), + ), + ) + .prop_map( + |( + (module_handles, struct_handles, function_handles), + (type_signatures, function_signatures, locals_signatures), + (string_pool, byte_array_pool, address_pool), + (struct_defs, field_defs, function_defs), + )| { + CompiledModule { + module_handles, + struct_handles, + function_handles, + type_signatures, + function_signatures, + locals_signatures, + string_pool, + byte_array_pool, + address_pool, + struct_defs, + field_defs, + function_defs, + } + }, + ) + .boxed() + } +} + +impl CompiledModule { + /// By convention, the index of the module being implemented is 0. + pub const IMPLEMENTED_MODULE_INDEX: u16 = 0; + + fn self_handle(&self) -> &ModuleHandle { + &self.module_handles[Self::IMPLEMENTED_MODULE_INDEX as usize] + } + + /// Returns the name of the module. + pub fn name(&self) -> &str { + self.string_at(self.self_handle().name) + } + + /// Returns the address of the module. + pub fn address(&self) -> &AccountAddress { + self.address_at(self.self_handle().address) + } + + /// Returns the code key of `module_handle` + pub fn code_key_for_handle(&self, module_handle: &ModuleHandle) -> CodeKey { + CodeKey::new( + *self.address_at(module_handle.address), + self.string_at(module_handle.name).to_string(), + ) + } + + /// Returns the code key of `self` + pub fn self_code_key(&self) -> CodeKey { + self.code_key_for_handle(self.self_handle()) + } + + /// Returns the count of a specific `IndexKind` + pub fn kind_count(&self, kind: IndexKind) -> usize { + match kind { + IndexKind::ModuleHandle => self.module_handles.len(), + IndexKind::StructHandle => self.struct_handles.len(), + IndexKind::FunctionHandle => self.function_handles.len(), + IndexKind::StructDefinition => self.struct_defs.len(), + IndexKind::FieldDefinition => self.field_defs.len(), + IndexKind::FunctionDefinition => self.function_defs.len(), + IndexKind::TypeSignature => self.type_signatures.len(), + IndexKind::FunctionSignature => self.function_signatures.len(), + IndexKind::LocalsSignature => self.locals_signatures.len(), + IndexKind::StringPool => self.string_pool.len(), + IndexKind::ByteArrayPool => self.byte_array_pool.len(), + IndexKind::AddressPool => self.address_pool.len(), + // XXX these two don't seem to belong here + other @ IndexKind::LocalPool | other @ IndexKind::CodeDefinition => { + panic!("invalid kind for count: {:?}", other) + } + } + } + + /// This function should only be called on an instance of CompiledModule obtained by invoking + /// into_module on some instance of CompiledScript. This function is the inverse of + /// into_module, i.e., script.into_module().into_script() == script. + pub fn into_script(mut self) -> CompiledScript { + let main = self.function_defs.remove(0); + CompiledScript { + module_handles: self.module_handles, + struct_handles: self.struct_handles, + function_handles: self.function_handles, + + type_signatures: self.type_signatures, + function_signatures: self.function_signatures, + locals_signatures: self.locals_signatures, + + string_pool: self.string_pool, + byte_array_pool: self.byte_array_pool, + address_pool: self.address_pool, + + main, + } + } +} diff --git a/language/vm/src/file_format_common.rs b/language/vm/src/file_format_common.rs new file mode 100644 index 0000000000000..f31749179ce8e --- /dev/null +++ b/language/vm/src/file_format_common.rs @@ -0,0 +1,241 @@ +// Copyright (c) The Libra Core Contributors +// SPDX-License-Identifier: Apache-2.0 + +//! Constants for the binary format. +//! +//! Definition for the constants of the binary format, used by the serializer and the deserializer. +//! This module also offers helpers for the serialization and deserialization of certain +//! integer indexes. +//! +//! We use LEB128 for integer compression. LEB128 is a representation from the DWARF3 spec, +//! http://dwarfstd.org/Dwarf3Std.php or https://en.wikipedia.org/wiki/LEB128. +//! It's used to compress mostly indexes into the main binary tables. +use byteorder::ReadBytesExt; +use failure::*; +use std::{ + io::Cursor, + mem::{size_of, transmute}, +}; + +/// Constant values for the binary format header. +/// +/// The binary header is magic + version info + table count. +pub enum BinaryConstants {} +impl BinaryConstants { + /// The blob that must start a binary. + pub const LIBRA_MAGIC_SIZE: usize = 8; + pub const LIBRA_MAGIC: [u8; BinaryConstants::LIBRA_MAGIC_SIZE] = + [b'L', b'I', b'B', b'R', b'A', b'V', b'M', b'\n']; + /// The `LIBRA_MAGIC` size, 1 byte for major version, 1 byte for minor version and 1 byte + /// for table count. + pub const HEADER_SIZE: usize = BinaryConstants::LIBRA_MAGIC_SIZE + 3; + /// A (Table Type, Start Offset, Byte Count) size, which is 1 byte for the type and + /// 4 bytes for the offset/count. + pub const TABLE_HEADER_SIZE: u32 = size_of::() as u32 * 2 + 1; +} + +/// Constants for table types in the binary. +/// +/// The binary contains a subset of those tables. A table specification is a tuple (table type, +/// start offset, byte count) for a given table. +#[rustfmt::skip] +#[allow(non_camel_case_types)] +#[repr(u8)] +#[derive(Clone, Copy, Debug, Eq, Hash, PartialEq)] +pub enum TableType { + MODULE_HANDLES = 0x1, + STRUCT_HANDLES = 0x2, + FUNCTION_HANDLES = 0x3, + ADDRESS_POOL = 0x4, + STRING_POOL = 0x5, + BYTE_ARRAY_POOL = 0x6, + MAIN = 0x7, + STRUCT_DEFS = 0x8, + FIELD_DEFS = 0x9, + FUNCTION_DEFS = 0xA, + TYPE_SIGNATURES = 0xB, + FUNCTION_SIGNATURES = 0xC, + LOCALS_SIGNATURES = 0xD, +} + +/// Constants for signature kinds (type, function, locals). Those values start a signature blob. +#[rustfmt::skip] +#[allow(non_camel_case_types)] +#[repr(u8)] +#[derive(Clone, Copy, Debug)] +pub enum SignatureType { + TYPE_SIGNATURE = 0x1, + FUNCTION_SIGNATURE = 0x2, + LOCAL_SIGNATURE = 0x3, +} + +/// Constants for signature blob values. +#[rustfmt::skip] +#[allow(non_camel_case_types)] +#[repr(u8)] +#[derive(Clone, Copy, Debug)] +pub enum SerializedType { + BOOL = 0x1, + INTEGER = 0x2, + STRING = 0x3, + ADDRESS = 0x4, + REFERENCE = 0x5, + MUTABLE_REFERENCE = 0x6, + STRUCT = 0x7, + BYTEARRAY = 0x8, +} + +/// List of opcodes constants. +#[rustfmt::skip] +#[allow(non_camel_case_types)] +#[repr(u8)] +#[derive(Clone, Copy, Debug)] +pub enum Opcodes { + POP = 0x01, + RET = 0x02, + BR_TRUE = 0x03, + BR_FALSE = 0x04, + BRANCH = 0x05, + LD_CONST = 0x06, + LD_ADDR = 0x07, + LD_STR = 0x08, + LD_TRUE = 0x09, + LD_FALSE = 0x0A, + COPY_LOC = 0x0B, + MOVE_LOC = 0x0C, + ST_LOC = 0x0D, + LD_REF_LOC = 0x0E, + LD_REF_FIELD = 0x0F, + LD_BYTEARRAY = 0x10, + CALL = 0x11, + PACK = 0x12, + UNPACK = 0x13, + READ_REF = 0x14, + WRITE_REF = 0x15, + ADD = 0x16, + SUB = 0x17, + MUL = 0x18, + MOD = 0x19, + DIV = 0x1A, + BIT_OR = 0x1B, + BIT_AND = 0x1C, + XOR = 0x1D, + OR = 0x1E, + AND = 0x1F, + NOT = 0x20, + EQ = 0x21, + NEQ = 0x22, + LT = 0x23, + GT = 0x24, + LE = 0x25, + GE = 0x26, + ASSERT = 0x27, + GET_TXN_GAS_UNIT_PRICE = 0x28, + GET_TXN_MAX_GAS_UNITS = 0x29, + GET_GAS_REMAINING = 0x2A, + GET_TXN_SENDER = 0x2B, + EXISTS = 0x2C, + BORROW_REF = 0x2D, + RELEASE_REF = 0x2E, + MOVE_FROM = 0x2F, + MOVE_TO = 0x30, + CREATE_ACCOUNT = 0x31, + EMIT_EVENT = 0x32, + GET_TXN_SEQUENCE_NUMBER = 0x33, + GET_TXN_PUBLIC_KEY = 0x34, + FREEZE_REF = 0x35, +} + +/// Take a `Vec` and a value to write to that vector and applies LEB128 logic to +/// compress the u16. +pub fn write_u16_as_uleb128(binary: &mut Vec, value: u16) { + write_u32_as_uleb128(binary, u32::from(value)); +} + +/// Take a `Vec` and a value to write to that vector and applies LEB128 logic to +/// compress the u32. +pub fn write_u32_as_uleb128(binary: &mut Vec, value: u32) { + let mut val = value; + loop { + let v: u8 = (val & 0x7f) as u8; + if u32::from(v) != val { + binary.push(v | 0x80); + val >>= 7; + } else { + binary.push(v); + break; + } + } +} + +/// Write a `u16` in Little Endian format. +pub fn write_u16(binary: &mut Vec, value: u16) { + let bytes: [u8; 2] = unsafe { transmute(value.to_le()) }; + for byte in &bytes { + binary.push(*byte); + } +} + +/// Write a `u32` in Little Endian format. +pub fn write_u32(binary: &mut Vec, value: u32) { + let bytes: [u8; 4] = unsafe { transmute(value.to_le()) }; + for byte in &bytes { + binary.push(*byte); + } +} + +/// Write a `u64` in Little Endian format. +pub fn write_u64(binary: &mut Vec, value: u64) { + let bytes: [u8; 8] = unsafe { transmute(value.to_le()) }; + for byte in &bytes { + binary.push(*byte); + } +} + +/// Reads a `u16` in ULEB128 format from a `binary`. +/// +/// Takes a `&mut Cursor<&[u8]>` and returns a pair: +/// +/// u16 - value read +/// +/// Return an error on an invalid representation. +pub fn read_uleb128_as_u16(cursor: &mut Cursor<&[u8]>) -> Result { + let mut value: u16 = 0; + let mut shift: u8 = 0; + while let Ok(byte) = cursor.read_u8() { + let val = byte & 0x7f; + value |= u16::from(val) << shift; + if val == byte { + return Ok(value); + } + shift += 7; + if shift > 14 { + break; + } + } + bail!("invalid ULEB128 representation for u16") +} + +/// Reads a `u32` in ULEB128 format from a `binary`. +/// +/// Takes a `&mut Cursor<&[u8]>` and returns a pair: +/// +/// u32 - value read +/// +/// Return an error on an invalid representation. +pub fn read_uleb128_as_u32(cursor: &mut Cursor<&[u8]>) -> Result { + let mut value: u32 = 0; + let mut shift: u8 = 0; + while let Ok(byte) = cursor.read_u8() { + let val = byte & 0x7f; + value |= u32::from(val) << shift; + if val == byte { + return Ok(value); + } + shift += 7; + if shift > 28 { + break; + } + } + bail!("invalid ULEB128 representation for u32") +} diff --git a/language/vm/src/gas_schedule.rs b/language/vm/src/gas_schedule.rs new file mode 100644 index 0000000000000..700c15fcfc6c3 --- /dev/null +++ b/language/vm/src/gas_schedule.rs @@ -0,0 +1,434 @@ +// Copyright (c) The Libra Core Contributors +// SPDX-License-Identifier: Apache-2.0 + +//! This module lays out the basic abstract costing schedule for bytecode instructions. +//! +//! It is important to note that the cost schedule defined in this file does not track hashing +//! operations or other native operations; the cost of each native operation will be returned by the +//! native function itself. +use crate::file_format::Bytecode; +use std::{ops::Add, u64}; + +/// The underlying carrier for the gas cost +pub type GasUnits = u64; + +/// Type for representing the size of our memory +pub type AbstractMemorySize = u64; + +/// The units of gas that should be charged per byte for every transaction. +pub const INTRINSIC_GAS_PER_BYTE: GasUnits = 8; + +/// The minimum gas price that a transaction can be submitted with. +pub const MIN_PRICE_PER_GAS_UNIT: u64 = 0; + +/// The maximum gas unit price that a transaction can be submitted with. +pub const MAX_PRICE_PER_GAS_UNIT: u64 = 10_000; + +/// 1 nanosecond should equal one unit of computational gas. We bound the maximum +/// computational time of any given transaction at 10 milliseconds. We want this number and +/// `MAX_PRICE_PER_GAS_UNIT` to always satisfy the inequality that +/// MAXIMUM_NUMBER_OF_GAS_UNITS * MAX_PRICE_PER_GAS_UNIT < min(u64::MAX, GasUnits::MAX) +pub const MAXIMUM_NUMBER_OF_GAS_UNITS: GasUnits = 1_000_000; + +/// We charge one unit of gas per-byte for the first 600 bytes +pub const MIN_TRANSACTION_GAS_UNITS: GasUnits = 600; + +/// The word size that we charge by +pub const WORD_SIZE: AbstractMemorySize = 8; + +/// The size in words for a non-string or address constant on the stack +pub const CONST_SIZE: AbstractMemorySize = 1; + +/// The size in words for a reference on the stack +pub const REFERENCE_SIZE: AbstractMemorySize = 8; + +/// The size of a struct in words +pub const STRUCT_SIZE: AbstractMemorySize = 2; + +/// For V1 all accounts will be 32 words +pub const DEFAULT_ACCOUNT_SIZE: AbstractMemorySize = 32; + +/// Any transaction over this size will be charged `INTRINSIC_GAS_PER_BYTE` per byte +pub const LARGE_TRANSACTION_CUTOFF: AbstractMemorySize = 600; + +/// A number of instructions fall within three different tiers of gas usage. However some +/// instructions do not fall within these three categories and are separated out into their own gas +/// costs. +pub enum GasCostTable { + Low, + Mid, + High, + EmitEvent, + GetTxnSequenceNumber, + GetTxnPublicKey, + GetTxnGasUnitPrice, + GetTxnMaxGasUnits, + GetGasRemaining, + GetTxnSenderAddress, + Unpack, + Exists, + BorrowGlobal, + ReleaseRef, + MoveFrom, + MoveToSender, + CreateAccount, +} + +/// The `GasCost` tracks: +/// - instruction cost: how much time/computational power is needed to perform the instruction +/// - memory cost: how much memory is required for the instruction, and storage overhead +/// - stack cost: how large does the value stack grow or shrink because of this operation. Note that +/// this could in a sense have ``negative'' cost; an instruction can decrease the size of the +/// stack -- however in this case we view the instruction as having zero stack cost. +#[derive(Debug)] +pub struct GasCost { + pub instruction_gas: GasUnits, + pub memory_gas: GasUnits, + pub stack_gas: GasUnits, +} + +/// The implementation of the `GasCostTable` provides denotations for each abstract gas tier for +/// each of the three different resources that we track. +/// +/// Note that these denotations are filler (read: bogus) for now. The constants will be filled in +/// with the synthesized costs later on, or provided by a separate contract that we read in. +impl GasCostTable { + /// Return the instruction (computational) cost for the tier. + pub fn instr_cost(&self) -> GasUnits { + use GasCostTable::*; + match *self { + Low => 1, + Mid => 2, + High => 3, + EmitEvent => 4, + GetTxnSequenceNumber => 2, + GetTxnPublicKey => 2, + GetTxnGasUnitPrice => 2, + GetTxnMaxGasUnits => 2, + GetGasRemaining => 2, + GetTxnSenderAddress => 2, + Unpack => 5, + Exists => 3, + BorrowGlobal => 3, + ReleaseRef => 3, + MoveFrom => 5, + MoveToSender => 5, + CreateAccount => 5, + } + } + + /// Return the memory cost for the tier. + pub fn mem_cost(&self) -> GasUnits { + use GasCostTable::*; + match *self { + Low => 0, + Mid => 1, + High => 2, + EmitEvent => 3, + GetTxnSequenceNumber => 2, + GetTxnPublicKey => 2, + GetTxnGasUnitPrice => 4, + GetTxnMaxGasUnits => 4, + GetGasRemaining => 4, + GetTxnSenderAddress => 1, + Unpack => 2, + Exists => 3, + BorrowGlobal => 4, + ReleaseRef => 4, + MoveFrom => 5, + MoveToSender => 5, + CreateAccount => 5, + } + } + + /// Return the stack cost for the tier. + pub fn stack_cost(&self) -> GasUnits { + use GasCostTable::*; + match *self { + Low => 0, + Mid => 1, + High => 2, + EmitEvent => 0, + GetTxnSequenceNumber => 1, + GetTxnPublicKey => 1, + GetTxnGasUnitPrice => 1, + GetTxnMaxGasUnits => 1, + GetGasRemaining => 1, + GetTxnSenderAddress => 1, + Unpack => 1, + Exists => 1, + BorrowGlobal => 1, + ReleaseRef => 1, + MoveFrom => 0, + MoveToSender => 0, + CreateAccount => 1, + } + } +} + +// The general costing methodology that we will use for instruction cost +// determinations is as follows: +// * Each stack op (push, pop, ld, etc.) will have a gas cost of Low +// * Each primitive operation (+, -, *, jmp, etc.) will also have a gas cost of Low +// * Each function call will be charged a Mid cost +// * Each local memory operation will have a gas cost of Mid +// * Each global storage operation will have a gas cost of High +// * Examples: +// 1. gas_cost_instr(Add) = 4 * Low.value() +// 2 - one for each pop from the stack +// 1 - for the arithmetic op +// 1 - for the push of the resulting value onto the stack +// 2. gas_cost_instr(Branch(_)) = 1 * Low.value() +// 1 - A single jmp instruction, with no stack transition +// 3. gas_cost_instr(BrFalse(_)) = 2 * Low.value() +// 1 - perform one pop on the stack +// 1 - perform one jmp (possibly) +// -> NOTE: Will we want to charge a cost for possible lack of pipelining that +// we can perform with this instruction? +// `size_provider` provides the size for the costing, this can be e.g. the size of memory, or the +// number of arguments to the Call of Pack. +fn static_gas_cost_instr(instr: &Bytecode, size_provider: GasUnits) -> GasUnits { + use GasCostTable::*; + match instr { + // pop -> pop -> op -> push | all Low + Bytecode::Add + | Bytecode::Sub + | Bytecode::Mul + | Bytecode::Mod + | Bytecode::Div + | Bytecode::BitOr + | Bytecode::BitAnd + | Bytecode::Xor + | Bytecode::Or + | Bytecode::And + | Bytecode::Not + | Bytecode::Eq + | Bytecode::Neq + | Bytecode::Lt + | Bytecode::Gt + | Bytecode::Le + | Bytecode::Ge => 4 * Low.instr_cost(), + // push XOR pop XOR jmp | all Tier 0 + Bytecode::FreezeRef + | Bytecode::LdTrue + | Bytecode::LdFalse + | Bytecode::LdConst(_) + | Bytecode::BorrowLoc(_) + | Bytecode::Branch(_) + | Bytecode::Pop => Low.instr_cost(), + // Read from local/pool table -> push | tier 1 -> tier 0 + Bytecode::LdStr(_) | Bytecode::LdAddr(_) | Bytecode::LdByteArray(_) => { + Mid.instr_cost() + Low.instr_cost() + } + // pop -> push XOR pop -> op{assert, jmpeq} | all Low + Bytecode::Ret | Bytecode::Assert | Bytecode::BrTrue(_) | Bytecode::BrFalse(_) => { + 2 * Low.instr_cost() + } + // Load from global mem -> push, High -> Low + Bytecode::BorrowField(_) => High.instr_cost() + Low.instr_cost(), + // Load from local mem -> push XOR pop -> write local mem | Mid -> Low XOR Low -> Mid + Bytecode::MoveLoc(_) | Bytecode::StLoc(_) => Mid.instr_cost() + Low.instr_cost(), + // pop -> mem op -> push XOR pop -> pop -> mem op | Low -> High + // Since we charge for the BorrowGlobal, we don't need to charge that high of a cost for + // ReadRef. WriteRef on the other hand will incur a possible write to global memory since + // we can't determine if the reference is global or local at this time. + Bytecode::ReadRef => 2 * Low.instr_cost() + Mid.instr_cost(), + Bytecode::WriteRef => 2 * Low.instr_cost() + High.instr_cost(), + // size_provider gives us the number of bytes that need to be copied + // Copy bytes from locals -> push value onto stack | size_provider * Mid + Low + Bytecode::CopyLoc(_) => Low.instr_cost() + size_provider * Mid.instr_cost(), + // Allocate size_provider bytes for the new class -> push | size_provider * Mid + Low + // Question: Where will we get the class info to determine the layout of + // the object? We will need to include that in the cost estimate for + // this function. + Bytecode::Pack(_) => size_provider * Mid.instr_cost(), + // #size_provider pops -> #size_provider writes to local memory -> fn call + // | size_provider * Low + size_provider * Mid + Mid + Bytecode::Call(_) => { + size_provider * Low.instr_cost() + (size_provider + 1) * Mid.instr_cost() + } + Bytecode::Unpack(_) => size_provider * Unpack.instr_cost(), + Bytecode::CreateAccount => CreateAccount.instr_cost(), + Bytecode::EmitEvent => EmitEvent.instr_cost(), + Bytecode::GetTxnSenderAddress => GetTxnSenderAddress.instr_cost(), + Bytecode::GetTxnSequenceNumber => GetTxnSequenceNumber.instr_cost(), + Bytecode::GetTxnPublicKey => GetTxnPublicKey.instr_cost(), + Bytecode::GetTxnGasUnitPrice => GetTxnGasUnitPrice.instr_cost(), + Bytecode::GetTxnMaxGasUnits => GetTxnMaxGasUnits.instr_cost(), + Bytecode::GetGasRemaining => GetGasRemaining.instr_cost(), + Bytecode::Exists(_) => Exists.instr_cost(), + Bytecode::BorrowGlobal(_) => BorrowGlobal.instr_cost(), + Bytecode::ReleaseRef => ReleaseRef.instr_cost(), + Bytecode::MoveFrom(_) => MoveFrom.instr_cost(), + Bytecode::MoveToSender(_) => MoveToSender.instr_cost(), + } +} + +// Determine the cost to memory (in terms of size) for various operations. The +// tiers have the following meaning here: +// - Low: Don't touch any memory, just stack +// - Mid: Touch only local memory +// - High: Touch global storage +// Note that we _do not_ track the size of memory, and charge for expansion +// of this memory, nor do we track whether or not we are setting bits from zero +// or not (i.e. what Ethereum does). +fn static_gas_cost_mem(instr: &Bytecode, size_provider: GasUnits) -> GasUnits { + use GasCostTable::*; + match instr { + // All of these operations don't touch memory. So have Low memory cost + Bytecode::FreezeRef + | Bytecode::Pop + | Bytecode::Ret + | Bytecode::Add + | Bytecode::Sub + | Bytecode::Mul + | Bytecode::Mod + | Bytecode::Div + | Bytecode::BitOr + | Bytecode::BitAnd + | Bytecode::Xor + | Bytecode::Or + | Bytecode::And + | Bytecode::Not + | Bytecode::Eq + | Bytecode::Neq + | Bytecode::Lt + | Bytecode::Gt + | Bytecode::Le + | Bytecode::Ge + | Bytecode::LdTrue + | Bytecode::LdFalse + | Bytecode::LdConst(_) + | Bytecode::Assert + | Bytecode::BrTrue(_) + | Bytecode::BrFalse(_) + | Bytecode::Branch(_) + | Bytecode::CopyLoc(_) // Stored on the stack, so no overhead as such + | Bytecode::MoveLoc(_) // Moved, so stays the same + | Bytecode::BorrowLoc(_) => Low.mem_cost(), + // Call and Pack values (etc.) are allocated on the stack + Bytecode::LdByteArray(_) | Bytecode::LdAddr(_) | Bytecode::LdStr(_) | Bytecode::Pack(_) | Bytecode::Call(_) => size_provider * Low.mem_cost(), + // pop -> write local memory + Bytecode::StLoc(_) => size_provider * Mid.mem_cost() + Low.mem_cost(), + // One load from global, and push to the stack + Bytecode::BorrowField(_) => High.mem_cost() + Low.mem_cost(), + // We assume that all references are non-local + Bytecode::WriteRef | Bytecode::ReadRef => High.mem_cost(), + Bytecode::EmitEvent => size_provider * EmitEvent.mem_cost(), + Bytecode::GetTxnSenderAddress => GetTxnSenderAddress.mem_cost(), + Bytecode::Unpack(_) => size_provider * Unpack.mem_cost(), + Bytecode::CreateAccount => size_provider * CreateAccount.mem_cost(), + Bytecode::GetTxnSequenceNumber => GetTxnSequenceNumber.mem_cost(), + Bytecode::GetTxnPublicKey => GetTxnPublicKey.mem_cost(), + Bytecode::GetTxnGasUnitPrice => GetTxnGasUnitPrice.mem_cost(), + Bytecode::GetTxnMaxGasUnits => GetTxnMaxGasUnits.mem_cost(), + Bytecode::GetGasRemaining => GetGasRemaining.mem_cost(), + Bytecode::Exists(_) => Exists.mem_cost(), + Bytecode::BorrowGlobal(_) => size_provider * BorrowGlobal.mem_cost(), + Bytecode::ReleaseRef => ReleaseRef.mem_cost(), + Bytecode::MoveFrom(_) => size_provider * MoveFrom.mem_cost(), + Bytecode::MoveToSender(_) => size_provider * MoveToSender.mem_cost(), + } +} + +// We charge a stack cost based upon how much the given bytecode will effect the stack size. +// The tiers have the following meaning here: +// - Low: Reduction in stack size of numeric constant size +// - Mid: Stack size remains constant +// - High: Push to the stack of constant size +fn static_gas_cost_stack(instr: &Bytecode, size_provider: GasUnits) -> GasUnits { + use GasCostTable::*; + match instr { + Bytecode::FreezeRef + | Bytecode::Pop + | Bytecode::BrTrue(_) + | Bytecode::BrFalse(_) + | Bytecode::ReadRef + | Bytecode::Assert => Low.stack_cost(), + Bytecode::Ret | Bytecode::Branch(_) => Mid.stack_cost(), + Bytecode::BorrowLoc(_) + | Bytecode::BorrowField(_) + | Bytecode::Add + | Bytecode::Sub + | Bytecode::Mul + | Bytecode::Mod + | Bytecode::Div + | Bytecode::BitOr + | Bytecode::BitAnd + | Bytecode::Xor + | Bytecode::Or + | Bytecode::And + | Bytecode::Not + | Bytecode::Eq + | Bytecode::Neq + | Bytecode::Lt + | Bytecode::Gt + | Bytecode::Le + | Bytecode::Ge => 2 * Low.stack_cost() + Mid.stack_cost(), + Bytecode::LdTrue + | Bytecode::LdFalse + | Bytecode::LdAddr(_) + | Bytecode::LdByteArray(_) + | Bytecode::LdStr(_) + | Bytecode::LdConst(_) => High.stack_cost(), + Bytecode::CopyLoc(_) | Bytecode::MoveLoc(_) => size_provider * High.stack_cost(), + Bytecode::StLoc(_) => size_provider * Low.stack_cost(), + Bytecode::Pack(_) | Bytecode::Call(_) => (size_provider + 1) * High.mem_cost(), + Bytecode::WriteRef => 2 * Low.mem_cost(), + Bytecode::EmitEvent => EmitEvent.stack_cost(), + Bytecode::GetTxnSenderAddress => GetTxnSenderAddress.stack_cost(), + Bytecode::GetTxnSequenceNumber => GetTxnSequenceNumber.stack_cost(), + Bytecode::GetTxnPublicKey => GetTxnPublicKey.stack_cost(), + Bytecode::GetTxnGasUnitPrice => GetTxnGasUnitPrice.stack_cost(), + Bytecode::GetTxnMaxGasUnits => GetTxnMaxGasUnits.stack_cost(), + Bytecode::GetGasRemaining => GetGasRemaining.stack_cost(), + Bytecode::Unpack(_) => Unpack.stack_cost(), + Bytecode::Exists(_) => Exists.stack_cost(), + Bytecode::BorrowGlobal(_) => BorrowGlobal.stack_cost(), + Bytecode::ReleaseRef => ReleaseRef.stack_cost(), + Bytecode::MoveFrom(_) => MoveFrom.stack_cost(), + Bytecode::MoveToSender(_) => MoveToSender.stack_cost(), + Bytecode::CreateAccount => CreateAccount.stack_cost(), + } +} + +/// Statically cost a bytecode instruction. +/// +/// Don't take into account current stack or memory size. Don't track whether references are to +/// global or local storage. +pub fn static_cost_instr(instr: &Bytecode, size_provider: GasUnits) -> GasCost { + GasCost { + instruction_gas: static_gas_cost_instr(instr, size_provider), + memory_gas: static_gas_cost_mem(instr, size_provider), + stack_gas: static_gas_cost_stack(instr, size_provider), + } +} + +/// Computes the number of words rounded up +pub fn words_in(size: AbstractMemorySize) -> AbstractMemorySize { + // round-up div truncate + (size + (WORD_SIZE - 1)) / WORD_SIZE +} + +/// Calculate the intrinsic gas for the transaction based upon its size in bytes/words. +pub fn calculate_intrinsic_gas(transaction_size: u64) -> GasUnits { + let min_transaction_fee = MIN_TRANSACTION_GAS_UNITS; + + if transaction_size > LARGE_TRANSACTION_CUTOFF { + let excess = words_in(transaction_size - LARGE_TRANSACTION_CUTOFF); + min_transaction_fee + INTRINSIC_GAS_PER_BYTE * excess + } else { + min_transaction_fee + } +} + +impl Add for GasCost { + type Output = GasCost; + fn add(self, other: GasCost) -> GasCost { + GasCost { + instruction_gas: self.instruction_gas + other.instruction_gas, + memory_gas: self.memory_gas + other.memory_gas, + stack_gas: self.stack_gas + other.stack_gas, + } + } +} diff --git a/language/vm/src/internals.rs b/language/vm/src/internals.rs new file mode 100644 index 0000000000000..6bb5748f731c9 --- /dev/null +++ b/language/vm/src/internals.rs @@ -0,0 +1,14 @@ +// Copyright (c) The Libra Core Contributors +// SPDX-License-Identifier: Apache-2.0 + +//! Types meant for use by other parts of this crate, and by other crates that are designed to +//! work with the internals of these data structures. + +use crate::IndexKind; + +/// Represents a module index. +pub trait ModuleIndex { + const KIND: IndexKind; + + fn into_index(self) -> usize; +} diff --git a/language/vm/src/lib.rs b/language/vm/src/lib.rs new file mode 100644 index 0000000000000..6d352c8b0cdb7 --- /dev/null +++ b/language/vm/src/lib.rs @@ -0,0 +1,120 @@ +// Copyright (c) The Libra Core Contributors +// SPDX-License-Identifier: Apache-2.0 + +#![feature(never_type)] +#![feature(exhaustive_patterns)] + +use std::fmt; + +pub mod access; +pub mod checks; +#[macro_use] +pub mod errors; +pub mod deserializer; +pub mod file_format; +pub mod file_format_common; +pub mod gas_schedule; +pub mod internals; +pub mod printers; +pub mod proptest_types; +pub mod resolver; +pub mod serializer; +pub mod transaction_metadata; +pub mod views; + +#[cfg(test)] +mod unit_tests; + +pub use file_format::CompiledModule; + +/// Represents a kind of index -- useful for error messages. +#[derive(Copy, Clone, Debug, Eq, Hash, Ord, PartialEq, PartialOrd)] +pub enum IndexKind { + ModuleHandle, + StructHandle, + FunctionHandle, + StructDefinition, + FieldDefinition, + FunctionDefinition, + TypeSignature, + FunctionSignature, + LocalsSignature, + StringPool, + ByteArrayPool, + AddressPool, + LocalPool, + CodeDefinition, +} + +impl IndexKind { + pub fn variants() -> &'static [IndexKind] { + use IndexKind::*; + + // XXX ensure this list stays up to date! + &[ + ModuleHandle, + StructHandle, + FunctionHandle, + StructDefinition, + FieldDefinition, + FunctionDefinition, + TypeSignature, + FunctionSignature, + LocalsSignature, + StringPool, + AddressPool, + LocalPool, + CodeDefinition, + ] + } +} + +impl fmt::Display for IndexKind { + fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result { + use IndexKind::*; + + let desc = match self { + ModuleHandle => "module handle", + StructHandle => "struct handle", + FunctionHandle => "function handle", + StructDefinition => "struct definition", + FieldDefinition => "field definition", + FunctionDefinition => "function definition", + TypeSignature => "type signature", + FunctionSignature => "function signature", + LocalsSignature => "locals signature", + StringPool => "string pool", + ByteArrayPool => "byte_array pool", + AddressPool => "address pool", + LocalPool => "local pool", + CodeDefinition => "code definition pool", + }; + + f.write_str(desc) + } +} + +/// Represents the kind of a signature token. +#[derive(Copy, Clone, Debug, Eq, Ord, PartialEq, PartialOrd)] +pub enum SignatureTokenKind { + /// Any sort of owned value that isn't an array (Integer, Bool, Struct etc). + Value, + /// A reference. + Reference, + /// A mutable reference. + MutableReference, +} + +impl fmt::Display for SignatureTokenKind { + fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result { + use SignatureTokenKind::*; + + let desc = match self { + Value => "value", + Reference => "reference", + MutableReference => "mutable reference", + }; + + f.write_str(desc) + } +} diff --git a/language/vm/src/printers.rs b/language/vm/src/printers.rs new file mode 100644 index 0000000000000..18d9e82f33f39 --- /dev/null +++ b/language/vm/src/printers.rs @@ -0,0 +1,621 @@ +// Copyright (c) The Libra Core Contributors +// SPDX-License-Identifier: Apache-2.0 + +use crate::file_format::*; +use failure::*; +use hex; +use std::{collections::VecDeque, fmt}; +use types::{account_address::AccountAddress, byte_array::ByteArray}; + +// +// Display printing +// Display the top level compilation unit (CompiledScript and CompiledModule) in a more +// readable format. Essentially the printing resolves all table indexes and is a line by line +// for each table and with a reasonable indentation, e.g. +// ```text +// CompiledModule: { +// Struct Handles: [ +// ResourceBox@0x0000000000000000000000000000000000000000000000000000000000000000,] +// Field Handles: [ +// ResourceBox@0x0000000000000000000000000000000000000000000000000000000000000000.item: Value,] +// Function Handles: [ +// ResourceBox@0x0000000000000000000000000000000000000000000000000000000000000000.get(): Value, +// ResourceBox@0x0000000000000000000000000000000000000000000000000000000000000000.new(Value): ResourceBox@0x0000000000000000000000000000000000000000000000000000000000000000,] +// Struct Definitions: [ +// {public resource ResourceBox@0x0000000000000000000000000000000000000000000000000000000000000000 +// private ResourceBox@0x0000000000000000000000000000000000000000000000000000000000000000.item: Value +// public ResourceBox@0x0000000000000000000000000000000000000000000000000000000000000000.get(): Value +// static public ResourceBox@0x0000000000000000000000000000000000000000000000000000000000000000.new(Value): ResourceBox@0x0000000000000000000000000000000000000000000000000000000000000000},] +// Field Definitions: [ +// private ResourceBox@0x0000000000000000000000000000000000000000000000000000000000000000.item: Value,] +// Function Definitions: [ +// public ResourceBox@0x0000000000000000000000000000000000000000000000000000000000000000.get(): Value +// local(0): ResourceBox@0x0000000000000000000000000000000000000000000000000000000000000000, +// local(1): &Value, +// local(2): Value, +// CopyLoc(0) +// BorrowField(ResourceBox@0x0000000000000000000000000000000000000000000000000000000000000000.item: Value) +// StLoc(1) +// CopyLoc(1) +// ReadRef +// StLoc(2) +// MoveLoc(2) +// Ret, +// static public ResourceBox@0x0000000000000000000000000000000000000000000000000000000000000000.new(Value): ResourceBox@0x0000000000000000000000000000000000000000000000000000000000000000 +// local(0): Value, +// local(1): ResourceBox@0x0000000000000000000000000000000000000000000000000000000000000000, +// MoveLoc(0) +// Pack(ResourceBox@0x0000000000000000000000000000000000000000000000000000000000000000) +// StLoc(1) +// MoveLoc(1) +// Ret,] +// Signatures: [ +// Value, +// (): Value, +// (Value): ResourceBox@0x0000000000000000000000000000000000000000000000000000000000000000, +// ResourceBox@0x0000000000000000000000000000000000000000000000000000000000000000, +// &Value,] +// Strings: [ +// ResourceBox, +// item, +// get, +// new,] +// Addresses: [ +// 0x0000000000000000000000000000000000000000000000000000000000000000,] +// } +// ``` + +// Trait to access tables for both CompiledScript and CompiledModule. +// This is designed mainly for the printer -- public APIs should be based on the accessors in +// `access.rs`. +pub trait TableAccess { + fn get_field_def_at(&self, idx: FieldDefinitionIndex) -> Result<&FieldDefinition>; + + fn get_module_at(&self, idx: ModuleHandleIndex) -> Result<&ModuleHandle>; + fn get_struct_at(&self, idx: StructHandleIndex) -> Result<&StructHandle>; + fn get_function_at(&self, idx: FunctionHandleIndex) -> Result<&FunctionHandle>; + + fn get_string_at(&self, idx: StringPoolIndex) -> Result<&String>; + fn get_address_at(&self, idx: AddressPoolIndex) -> Result<&AccountAddress>; + fn get_type_signature_at(&self, idx: TypeSignatureIndex) -> Result<&TypeSignature>; + fn get_function_signature_at(&self, idx: FunctionSignatureIndex) -> Result<&FunctionSignature>; + fn get_locals_signature_at(&self, idx: LocalsSignatureIndex) -> Result<&LocalsSignature>; +} + +impl TableAccess for CompiledScript { + fn get_field_def_at(&self, _idx: FieldDefinitionIndex) -> Result<&FieldDefinition> { + bail!("no field definitions in scripts"); + } + + fn get_module_at(&self, idx: ModuleHandleIndex) -> Result<&ModuleHandle> { + match self.module_handles.get(idx.0 as usize) { + None => bail!("bad module handle index {}", idx), + Some(m) => Ok(m), + } + } + + fn get_struct_at(&self, idx: StructHandleIndex) -> Result<&StructHandle> { + match self.struct_handles.get(idx.0 as usize) { + None => bail!("bad struct handle index {}", idx), + Some(s) => Ok(s), + } + } + + fn get_function_at(&self, idx: FunctionHandleIndex) -> Result<&FunctionHandle> { + match self.function_handles.get(idx.0 as usize) { + None => bail!("bad function handle index {}", idx), + Some(m) => Ok(m), + } + } + + fn get_string_at(&self, idx: StringPoolIndex) -> Result<&String> { + match self.string_pool.get(idx.0 as usize) { + None => bail!("bad string index {}", idx), + Some(s) => Ok(s), + } + } + + fn get_address_at(&self, idx: AddressPoolIndex) -> Result<&AccountAddress> { + match self.address_pool.get(idx.0 as usize) { + None => bail!("bad address index {}", idx), + Some(addr) => Ok(addr), + } + } + + fn get_type_signature_at(&self, idx: TypeSignatureIndex) -> Result<&TypeSignature> { + match self.type_signatures.get(idx.0 as usize) { + None => bail!("bad signature index {}", idx), + Some(sig) => Ok(sig), + } + } + + fn get_function_signature_at(&self, idx: FunctionSignatureIndex) -> Result<&FunctionSignature> { + match self.function_signatures.get(idx.0 as usize) { + None => bail!("bad signature index {}", idx), + Some(sig) => Ok(sig), + } + } + + fn get_locals_signature_at(&self, idx: LocalsSignatureIndex) -> Result<&LocalsSignature> { + match self.locals_signatures.get(idx.0 as usize) { + None => bail!("bad signature index {}", idx), + Some(sig) => Ok(sig), + } + } +} + +impl TableAccess for CompiledModule { + fn get_field_def_at(&self, idx: FieldDefinitionIndex) -> Result<&FieldDefinition> { + match self.field_defs.get(idx.0 as usize) { + None => bail!("bad field definition index {}", idx), + Some(f) => Ok(f), + } + } + + fn get_module_at(&self, idx: ModuleHandleIndex) -> Result<&ModuleHandle> { + match self.module_handles.get(idx.0 as usize) { + None => bail!("bad module handle index {}", idx), + Some(m) => Ok(m), + } + } + + fn get_struct_at(&self, idx: StructHandleIndex) -> Result<&StructHandle> { + match self.struct_handles.get(idx.0 as usize) { + None => bail!("bad struct handle index {}", idx), + Some(s) => Ok(s), + } + } + + fn get_function_at(&self, idx: FunctionHandleIndex) -> Result<&FunctionHandle> { + match self.function_handles.get(idx.0 as usize) { + None => bail!("bad function handle index {}", idx), + Some(m) => Ok(m), + } + } + + fn get_string_at(&self, idx: StringPoolIndex) -> Result<&String> { + match self.string_pool.get(idx.0 as usize) { + None => bail!("bad string index {}", idx), + Some(s) => Ok(s), + } + } + + fn get_address_at(&self, idx: AddressPoolIndex) -> Result<&AccountAddress> { + match self.address_pool.get(idx.0 as usize) { + None => bail!("bad address index {}", idx), + Some(addr) => Ok(addr), + } + } + + fn get_type_signature_at(&self, idx: TypeSignatureIndex) -> Result<&TypeSignature> { + match self.type_signatures.get(idx.0 as usize) { + None => bail!("bad signature index {}", idx), + Some(sig) => Ok(sig), + } + } + + fn get_function_signature_at(&self, idx: FunctionSignatureIndex) -> Result<&FunctionSignature> { + match self.function_signatures.get(idx.0 as usize) { + None => bail!("bad signature index {}", idx), + Some(sig) => Ok(sig), + } + } + + fn get_locals_signature_at(&self, idx: LocalsSignatureIndex) -> Result<&LocalsSignature> { + match self.locals_signatures.get(idx.0 as usize) { + None => bail!("bad signature index {}", idx), + Some(sig) => Ok(sig), + } + } +} + +impl fmt::Display for CompiledProgram { + fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result { + write!(f, "CompiledProgram: {{\nModules: [\n")?; + for m in &self.modules { + writeln!(f, "{},", m)?; + } + write!(f, "],\nScript: {}\n}}", self.script) + } +} + +impl fmt::Display for CompiledScript { + fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result { + write!(f, "CompiledScript: {{\nMain:\n\t")?; + display_function_definition(&self.main, self, f)?; + display_code(&self.main.code, self, "\n\t\t", f)?; + write!(f, "\nStruct Handles: [")?; + for struct_handle in &self.struct_handles { + write!(f, "\n\t")?; + display_struct_handle(struct_handle, self, f)?; + write!(f, ",")?; + } + writeln!(f, "]")?; + write!(f, "Module Handles: [")?; + for module_handle in &self.module_handles { + write!(f, "\n\t")?; + display_module_handle(module_handle, self, f)?; + write!(f, ",")?; + } + writeln!(f, "]")?; + write!(f, "Function Handles: [")?; + for function_handle in &self.function_handles { + write!(f, "\n\t")?; + display_function_handle(function_handle, self, f)?; + write!(f, ",")?; + } + writeln!(f, "]")?; + write!(f, "Type Signatures: [")?; + for signature in &self.type_signatures { + write!(f, "\n\t")?; + display_type_signature(signature, self, f)?; + write!(f, ",")?; + } + writeln!(f, "]")?; + write!(f, "Function Signatures: [")?; + for signature in &self.function_signatures { + write!(f, "\n\t")?; + display_function_signature(signature, self, f)?; + write!(f, ",")?; + } + writeln!(f, "]")?; + write!(f, "Locals Signatures: [")?; + for signature in &self.locals_signatures { + write!(f, "\n\t")?; + display_locals_signature(signature, self, f)?; + write!(f, ",")?; + } + writeln!(f, "]")?; + write!(f, "Strings: [")?; + for string in &self.string_pool { + write!(f, "\n\t{},", string)?; + } + writeln!(f, "]")?; + write!(f, "ByteArrays: [")?; + for byte_array in &self.byte_array_pool { + write!(f, "\n\t")?; + display_byte_array(byte_array, f)?; + write!(f, ",")?; + } + writeln!(f, "]")?; + write!(f, "Addresses: [")?; + for address in &self.address_pool { + write!(f, "\n\t")?; + display_address(address, f)?; + write!(f, ",")?; + } + writeln!(f, "]")?; + writeln!(f, "}}") + } +} + +impl fmt::Display for CompiledModule { + fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result { + writeln!(f, "CompiledModule: {{")?; + write!(f, "Module Handles: [")?; + for module_handle in &self.module_handles { + write!(f, "\n\t")?; + display_module_handle(module_handle, self, f)?; + write!(f, ",")?; + } + writeln!(f, "]")?; + write!(f, "Struct Handles: [")?; + for struct_handle in &self.struct_handles { + write!(f, "\n\t")?; + display_struct_handle(struct_handle, self, f)?; + write!(f, ",")?; + } + writeln!(f, "]")?; + write!(f, "Function Handles: [")?; + for function_handle in &self.function_handles { + write!(f, "\n\t")?; + display_function_handle(function_handle, self, f)?; + write!(f, ",")?; + } + writeln!(f, "]")?; + write!(f, "Struct Definitions: [")?; + for struct_def in &self.struct_defs { + write!(f, "\n\t{{")?; + display_struct_definition(struct_def, self, f)?; + let f_start_idx = struct_def.fields; + let f_end_idx = f_start_idx.0 as u16 + struct_def.field_count; + for idx in f_start_idx.0 as u16..f_end_idx { + let field_def = match self.field_defs.get(idx as usize) { + None => panic!("bad field definition index {}", idx), + Some(f) => f, + }; + write!(f, "\n\t\t")?; + display_field_definition(field_def, self, f)?; + } + write!(f, "}},")?; + } + writeln!(f, "]")?; + write!(f, "Field Definitions: [")?; + for field_def in &self.field_defs { + write!(f, "\n\t")?; + display_field_definition(field_def, self, f)?; + write!(f, ",")?; + } + writeln!(f, "]")?; + write!(f, "Function Definitions: [")?; + for function_def in &self.function_defs { + write!(f, "\n\t")?; + display_function_definition(function_def, self, f)?; + if function_def.flags & CodeUnit::NATIVE == 0 { + display_code(&function_def.code, self, "\n\t\t", f)?; + } + write!(f, ",")?; + } + writeln!(f, "]")?; + write!(f, "Type Signatures: [")?; + for signature in &self.type_signatures { + write!(f, "\n\t")?; + display_type_signature(signature, self, f)?; + write!(f, ",")?; + } + writeln!(f, "]")?; + write!(f, "Function Signatures: [")?; + for signature in &self.function_signatures { + write!(f, "\n\t")?; + display_function_signature(signature, self, f)?; + write!(f, ",")?; + } + writeln!(f, "]")?; + write!(f, "Locals Signatures: [")?; + for signature in &self.locals_signatures { + write!(f, "\n\t")?; + display_locals_signature(signature, self, f)?; + write!(f, ",")?; + } + writeln!(f, "]")?; + write!(f, "Strings: [")?; + for string in &self.string_pool { + write!(f, "\n\t{},", string)?; + } + writeln!(f, "]")?; + write!(f, "ByteArrays: [")?; + for byte_array in &self.byte_array_pool { + write!(f, "\n\t")?; + display_byte_array(byte_array, f)?; + write!(f, ",")?; + } + writeln!(f, "]")?; + write!(f, "Addresses: [")?; + for address in &self.address_pool { + write!(f, "\n\t")?; + display_address(address, f)?; + write!(f, ",")?; + } + writeln!(f, "]")?; + writeln!(f, "}}") + } +} + +fn display_struct_handle( + struct_: &StructHandle, + tables: &T, + f: &mut fmt::Formatter, +) -> fmt::Result { + write!( + f, + "{} ", + if struct_.is_resource { + "resource" + } else { + "struct" + } + )?; + write!(f, "{}@", tables.get_string_at(struct_.name).unwrap())?; + display_module_handle(tables.get_module_at(struct_.module).unwrap(), tables, f) +} + +fn display_module_handle( + module: &ModuleHandle, + tables: &T, + f: &mut fmt::Formatter, +) -> fmt::Result { + display_address(tables.get_address_at(module.address).unwrap(), f)?; + write!(f, ".{}", tables.get_string_at(module.name).unwrap()) +} + +fn display_function_handle( + function: &FunctionHandle, + tables: &T, + f: &mut fmt::Formatter, +) -> fmt::Result { + display_module_handle(tables.get_module_at(function.module).unwrap(), tables, f)?; + write!(f, ".{}", tables.get_string_at(function.name).unwrap())?; + display_function_signature( + tables + .get_function_signature_at(function.signature) + .unwrap(), + tables, + f, + ) +} + +fn display_struct_definition( + struct_: &StructDefinition, + tables: &T, + f: &mut fmt::Formatter, +) -> fmt::Result { + display_struct_handle( + tables.get_struct_at(struct_.struct_handle).unwrap(), + tables, + f, + ) +} + +fn display_field_definition( + field: &FieldDefinition, + tables: &T, + f: &mut fmt::Formatter, +) -> fmt::Result { + display_struct_handle(tables.get_struct_at(field.struct_).unwrap(), tables, f)?; + write!(f, ".{}: ", tables.get_string_at(field.name).unwrap())?; + display_type_signature( + tables.get_type_signature_at(field.signature).unwrap(), + tables, + f, + ) +} + +fn display_function_definition( + function: &FunctionDefinition, + tables: &T, + f: &mut fmt::Formatter, +) -> fmt::Result { + display_function_flags(function.flags, f)?; + display_function_handle( + tables.get_function_at(function.function).unwrap(), + tables, + f, + ) +} + +fn display_code( + code: &CodeUnit, + tables: &T, + indentation: &str, + f: &mut fmt::Formatter, +) -> fmt::Result { + write!(f, "{}locals({}): ", indentation, code.locals,)?; + display_locals_signature( + tables.get_locals_signature_at(code.locals).unwrap(), + tables, + f, + )?; + write!(f, ",")?; + for bytecode in &code.code { + write!(f, "{}", indentation)?; + display_bytecode(bytecode, tables, f)?; + } + Ok(()) +} + +fn display_address(addr: &AccountAddress, f: &mut fmt::Formatter) -> fmt::Result { + let hex = format!("{:x}", addr); + let mut v: VecDeque = hex.chars().collect(); + while v.len() > 1 && v[0] == '0' { + v.pop_front(); + } + write!(f, "0x{}", v.into_iter().collect::()) +} + +// Clippy will complain about passing Vec<_> by reference; instead you should pass &[_] +// In order to keep the logic of abstracting ByteArray, I think it is alright to ignore the warning +#[allow(clippy::ptr_arg)] +fn display_byte_array(byte_array: &ByteArray, f: &mut fmt::Formatter) -> fmt::Result { + write!(f, "0x{}", hex::encode(&byte_array.as_bytes())) +} + +fn display_type_signature( + sig: &TypeSignature, + tables: &T, + f: &mut fmt::Formatter, +) -> fmt::Result { + display_signature_token(&sig.0, tables, f) +} + +fn display_function_signature( + sig: &FunctionSignature, + tables: &T, + f: &mut fmt::Formatter, +) -> fmt::Result { + let mut iter = sig.arg_types.iter().peekable(); + write!(f, "(")?; + while let Some(token) = iter.next() { + display_signature_token(token, tables, f)?; + if iter.peek().is_some() { + write!(f, ", ")?; + } + } + write!(f, "): ")?; + + let mut iter = sig.return_types.iter().peekable(); + write!(f, "(")?; + while let Some(token) = iter.next() { + display_signature_token(token, tables, f)?; + if iter.peek().is_some() { + write!(f, ", ")?; + } + } + write!(f, ")")?; + Ok(()) +} + +fn display_locals_signature( + sig: &LocalsSignature, + tables: &T, + f: &mut fmt::Formatter, +) -> fmt::Result { + let mut iter = sig.0.iter().peekable(); + while let Some(token) = iter.next() { + display_signature_token(token, tables, f)?; + if iter.peek().is_some() { + write!(f, ", ")?; + } + } + Ok(()) +} + +fn display_signature_token( + token: &SignatureToken, + tables: &T, + f: &mut fmt::Formatter, +) -> fmt::Result { + match token { + SignatureToken::Bool => write!(f, "Bool"), + SignatureToken::U64 => write!(f, "Integer"), + SignatureToken::String => write!(f, "String"), + SignatureToken::ByteArray => write!(f, "ByteArray"), + SignatureToken::Address => write!(f, "Address"), + SignatureToken::Struct(idx) => { + display_struct_handle(tables.get_struct_at(*idx).unwrap(), tables, f) + } + SignatureToken::Reference(token) => { + write!(f, "&")?; + display_signature_token(token, tables, f) + } + SignatureToken::MutableReference(token) => { + write!(f, "&mut ")?; + display_signature_token(token, tables, f) + } + } +} + +fn display_function_flags(flags: u8, f: &mut fmt::Formatter) -> fmt::Result { + if flags & CodeUnit::NATIVE != 0 { + write!(f, "native ")?; + } + if flags & CodeUnit::PUBLIC != 0 { + write!(f, "public ")?; + } + Ok(()) +} + +fn display_bytecode( + bytecode: &Bytecode, + tables: &T, + f: &mut fmt::Formatter, +) -> fmt::Result { + match bytecode { + Bytecode::LdAddr(idx) => { + write!(f, "LdAddr(")?; + display_address(tables.get_address_at(*idx).unwrap(), f)?; + write!(f, ")") + } + Bytecode::LdStr(idx) => write!(f, "LdStr({})", tables.get_string_at(*idx).unwrap()), + Bytecode::BorrowField(idx) => { + write!(f, "BorrowField(")?; + display_field_definition(tables.get_field_def_at(*idx).unwrap(), tables, f)?; + write!(f, ")") + } + Bytecode::Call(idx) => { + write!(f, "Call(")?; + display_function_handle(tables.get_function_at(*idx).unwrap(), tables, f)?; + write!(f, ")") + } + _ => write!(f, "{:?}", bytecode), + } +} diff --git a/language/vm/src/proptest_types.rs b/language/vm/src/proptest_types.rs new file mode 100644 index 0000000000000..6552a38382dc5 --- /dev/null +++ b/language/vm/src/proptest_types.rs @@ -0,0 +1,488 @@ +// Copyright (c) The Libra Core Contributors +// SPDX-License-Identifier: Apache-2.0 + +//! Utilities for property-based testing. + +use crate::file_format::{ + AddressPoolIndex, CompiledModule, FieldDefinition, FieldDefinitionIndex, FunctionHandle, + FunctionSignatureIndex, MemberCount, ModuleHandle, ModuleHandleIndex, SignatureToken, + StringPoolIndex, StructDefinition, StructHandle, StructHandleIndex, TableIndex, TypeSignature, + TypeSignatureIndex, +}; +use proptest::{ + collection::{vec, SizeRange}, + prelude::*, + sample::Index as PropIndex, +}; +use proptest_helpers::GrowingSubset; +use types::{account_address::AccountAddress, byte_array::ByteArray}; + +mod functions; +mod signature; + +use functions::{FnDefnMaterializeState, FunctionDefinitionGen}; +use signature::{FunctionSignatureGen, SignatureTokenGen}; + +/// Represents how large [`CompiledModule`] tables can be. +pub type TableSize = u16; + +impl CompiledModule { + /// Convenience wrapper around [`CompiledModuleStrategyGen`][CompiledModuleStrategyGen] that + /// generates valid modules with the given size. + pub fn valid_strategy(size: usize) -> impl Strategy { + CompiledModuleStrategyGen::new(size as TableSize).generate() + } +} + +/// Contains configuration to generate [`CompiledModule`] instances. +/// +/// If you don't care about customizing these parameters, see [`CompiledModule::valid_strategy`]. +/// +/// A `CompiledModule` can be looked at as a graph, with several kinds of nodes, and a nest of +/// pointers among those nodes. This graph has some properties: +/// +/// 1. The graph has cycles. Generating DAGs is often simpler, but is not an option in this case. +/// 2. The actual structure of the graph is well-defined in terms of the kinds of nodes and +/// pointers that exist. +/// +/// TODO: the graph also has pointers *out* of it, via address references to other modules. +/// This doesn't need to be handled when viewing modules in isolation, but some verification passes +/// will need to look at the entire set of modules. The work to make generating such modules +/// possible remains to be done. +/// +/// Intermediate types +/// ------------------ +/// +/// The pointers are represented as indexes into vectors of other kinds of nodes. One of the +/// bigger problems is that the number of types, functions etc isn't known upfront so it is +/// impossible to know what range to pick from for the index types (`ModuleHandleIndex`, +/// `StructHandleIndex`, etc). To deal with this, the code generates a bunch of intermediate +/// structures (sometimes tuples, other times more complicated structures with their own internal +/// constraints), with "holes" represented by [`Index`](proptest::sample::Index) instances. Once all +/// the lengths are known, there's a final "materialize" step at the end that "fills in" these +/// holes. +/// +/// One alternative would have been to generate lengths up front, then create vectors of those +/// lengths. This would have worked fine for generation but would have made shrinking take much +/// longer, because the shrinker would be less aware of the overall structure of the problem and +/// would have ended up redoing a lot of work. The approach taken here does end up being more +/// verbose but should perform optimally. +/// +/// See [`proptest` issue #130](https://github.com/AltSysrq/proptest/issues/130) for more discussion +/// about this. +#[derive(Clone, Debug)] +pub struct CompiledModuleStrategyGen { + size: usize, + /// Range of number of fields in a struct and number of arguments in a function to generate. + /// The default value is 0..4. + member_count: SizeRange, + /// Length of code units (function definition). XXX the unit might change here. + code_len: SizeRange, +} + +impl CompiledModuleStrategyGen { + /// Create a new configuration for randomly generating [`CompiledModule`] instances. + pub fn new(size: TableSize) -> Self { + Self { + size: size as usize, + member_count: (0..4).into(), + code_len: (0..50).into(), + } + } + + /// Set a new range for the number of fields in a struct or the number of arguments in a + /// function. + #[inline] + pub fn member_count(&mut self, count: impl Into) -> &mut Self { + self.member_count = count.into(); + self + } + + /// Create a `proptest` strategy for `CompiledModule` instances using this configuration. + pub fn generate(self) -> impl Strategy { + // Base data -- everything points to this eventually. + let address_pool_strat = vec(any::(), 1..=self.size); + // This ensures that there are no empty ByteArrays + // TODO: Should we enable empty ByteArrays in Move, e.g. let byte_array = b""; + let byte_array_pool_strat = vec(any::(), 1..=self.size); + let string_pool_strat = vec(".*", 1..=self.size); + + let type_signatures_strat = vec(SignatureTokenGen::strategy(), 1..=self.size); + // Ensure at least one owned non-struct type signature. + let owned_non_struct_strat = vec( + SignatureTokenGen::owned_non_struct_strategy(), + 1..=self.size, + ); + let owned_type_sigs_strat = vec(SignatureTokenGen::owned_strategy(), 1..=self.size); + let function_signatures_strat = vec( + FunctionSignatureGen::strategy(self.member_count.clone(), self.member_count.clone()), + 1..=self.size, + ); + + // The number of PropIndex instances in each tuple represents the number of pointers out + // from an instance of that particular kind of node. + let module_handles_strat = vec(any::<(PropIndex, PropIndex)>(), 1..=self.size); + let struct_handles_strat = vec(any::<(PropIndex, PropIndex, bool)>(), 1..=self.size); + let function_handles_strat = vec(any::<(PropIndex, PropIndex, PropIndex)>(), 1..=self.size); + let struct_defs_strat = vec( + StructDefinitionGen::strategy(self.member_count.clone()), + 1..=self.size, + ); + let function_defs_strat = vec( + FunctionDefinitionGen::strategy( + self.member_count.clone(), + self.member_count.clone(), + self.code_len, + ), + 1..=self.size, + ); + // Note that prop_test only allows a tuple of length up to ten + // therefore, we need to treat the last two items as a pair to + // ensure we have less than 10 elements in the tuple. + ( + address_pool_strat, + byte_array_pool_strat, + string_pool_strat, + type_signatures_strat, + owned_non_struct_strat, + owned_type_sigs_strat, + function_signatures_strat, + ( + module_handles_strat, + struct_handles_strat, + function_handles_strat, + ), + (struct_defs_strat, function_defs_strat), + ) + .prop_map( + |( + address_pool, + byte_array_pool, + string_pool, + type_signatures, + owned_non_structs, + owned_type_sigs, + function_signatures, + (module_handles, struct_handles, function_handles), + (struct_defs, function_defs), + )| { + let address_pool_len = address_pool.len(); + let string_pool_len = string_pool.len(); + let byte_array_pool_len = byte_array_pool.len(); + let module_handles_len = module_handles.len(); + // StDefnMaterializeState adds one new handle for each definition, so the total + // number of struct handles is the sum of the number of generated struct + // handles (i.e. representing structs in external modules) and the number of + // internal ones. + let struct_handles_len = struct_handles.len() + struct_defs.len(); + // XXX FnDefnMaterializeState below adds more function signatures. This line + // means that no signatures generated later will be used by handles generated + // earlier. + // + // Instead, one could use function_signatures.len() + function_defs.len() to + // use signatures from later. + let function_signatures_len = function_signatures.len(); + // FnDefnMaterializeState below adds function handles equal to the number of + // function definitions. + let function_handles_len = function_handles.len() + function_defs.len(); + + let owned_type_sigs: Vec<_> = + SignatureTokenGen::map_materialize(owned_non_structs, struct_handles_len) + .chain(SignatureTokenGen::map_materialize( + owned_type_sigs, + struct_handles_len, + )) + .map(TypeSignature) + .collect(); + let owned_type_indexes = type_indexes(&owned_type_sigs); + + // Put the owned type signatures first so they're in the range + // 0..owned_type_sigs.len(). These are the signatures that will be used to pick + // field definition sigs from. + // Note that this doesn't result in a distribution that's spread out -- it + // would be nice to achieve that. + let type_signatures: Vec<_> = owned_type_sigs + .into_iter() + .chain( + SignatureTokenGen::map_materialize(type_signatures, struct_handles_len) + .map(TypeSignature), + ) + .collect(); + let function_signatures = function_signatures + .into_iter() + .map(|sig| sig.materialize(struct_handles_len)) + .collect(); + + let module_handles: Vec<_> = module_handles + .into_iter() + .map(|(address_idx, name_idx)| ModuleHandle { + address: AddressPoolIndex::new( + address_idx.index(address_pool_len) as TableIndex + ), + name: StringPoolIndex::new( + name_idx.index(string_pool_len) as TableIndex + ), + }) + .collect(); + + let struct_handles: Vec<_> = struct_handles + .into_iter() + .map(|(module_idx, name_idx, is_resource)| StructHandle { + module: ModuleHandleIndex::new( + module_idx.index(module_handles_len) as TableIndex + ), + name: StringPoolIndex::new( + name_idx.index(string_pool_len) as TableIndex + ), + is_resource, + }) + .collect(); + + let function_handles: Vec<_> = function_handles + .into_iter() + .map(|(module_idx, name_idx, signature_idx)| FunctionHandle { + module: ModuleHandleIndex::new( + module_idx.index(module_handles_len) as TableIndex + ), + name: StringPoolIndex::new( + name_idx.index(string_pool_len) as TableIndex + ), + signature: FunctionSignatureIndex::new( + signature_idx.index(function_signatures_len) as TableIndex, + ), + }) + .collect(); + + // Struct definitions also generate field definitions. + let mut state = StDefnMaterializeState { + string_pool_len, + owned_type_indexes, + struct_handles, + type_signatures, + // field_defs will be filled out by StructDefinitionGen::materialize + field_defs: vec![], + }; + let struct_defs: Vec<_> = struct_defs + .into_iter() + .map(|def| def.materialize(&mut state)) + .collect(); + + let StDefnMaterializeState { + struct_handles, + type_signatures, + field_defs, + .. + } = state; + assert_eq!(struct_handles_len, struct_handles.len()); + + // Definitions get generated at the end. But some of the other pools need to be + // involved here, so temporarily give up ownership to the state accumulators. + let mut state = FnDefnMaterializeState { + struct_handles_len, + address_pool_len, + string_pool_len, + byte_array_pool_len, + function_handles_len, + type_signatures_len: type_signatures.len(), + field_defs_len: field_defs.len(), + struct_defs_len: struct_defs.len(), + function_defs_len: function_defs.len(), + function_signatures, + // locals will be filled out by FunctionDefinitionGen::materialize + locals_signatures: vec![], + function_handles, + }; + + let function_defs = function_defs + .into_iter() + .map(|def| def.materialize(&mut state)) + .collect(); + + let FnDefnMaterializeState { + function_signatures, + locals_signatures, + function_handles, + .. + } = state; + assert_eq!(function_handles_len, function_handles.len()); + + // Put it all together. + CompiledModule { + module_handles, + struct_handles, + function_handles, + + struct_defs, + field_defs, + function_defs, + + type_signatures, + function_signatures, + locals_signatures, + + string_pool, + byte_array_pool, + address_pool, + } + }, + ) + } +} + +#[derive(Debug)] +struct StDefnMaterializeState { + string_pool_len: usize, + // Struct definitions need to be nonrecursive -- this is ensured by only picking signatures + // that either have no struct handle (represented as None), or have a handle less than the + // one for the definition currently being added. + owned_type_indexes: GrowingSubset, TypeSignatureIndex>, + // These get mutated by StructDefinitionGen. + struct_handles: Vec, + field_defs: Vec, + type_signatures: Vec, +} + +impl StDefnMaterializeState { + fn next_struct_handle(&self) -> StructHandleIndex { + StructHandleIndex::new(self.struct_handles.len() as TableIndex) + } + + fn add_struct_handle(&mut self, handle: StructHandle) -> StructHandleIndex { + self.struct_handles.push(handle); + StructHandleIndex::new((self.struct_handles.len() - 1) as TableIndex) + } + + /// Adds field defs to the pool. Returns the number of fields added and the index of the first + /// field. + fn add_field_defs( + &mut self, + new_defs: impl IntoIterator, + ) -> (MemberCount, FieldDefinitionIndex) { + let old_len = self.field_defs.len(); + self.field_defs.extend(new_defs); + let new_len = self.field_defs.len(); + ( + (new_len - old_len) as MemberCount, + FieldDefinitionIndex::new(old_len as TableIndex), + ) + } + + fn is_resource(&self, signature: &SignatureToken) -> bool { + use SignatureToken::*; + + match signature { + Struct(struct_handle_index) => { + self.struct_handles[struct_handle_index.0 as usize].is_resource + } + Reference(token) | MutableReference(token) => self.is_resource(token), + Bool | U64 | ByteArray | String | Address => false, + } + } +} + +#[derive(Clone, Debug)] +struct StructDefinitionGen { + name_idx: PropIndex, + // the is_resource field of generated struct handle is set to true if + // either any of the fields is a resource or self.is_resource is true + is_resource: bool, + is_public: bool, + field_defs: Vec, +} + +impl StructDefinitionGen { + fn strategy(member_count: impl Into) -> impl Strategy { + ( + any::(), + any::(), + any::(), + // XXX 0..4 is the default member_count in CompiledModule -- is 0 (structs without + // fields) possible? + vec(FieldDefinitionGen::strategy(), member_count), + ) + .prop_map(|(name_idx, is_resource, is_public, field_defs)| Self { + name_idx, + is_resource, + is_public, + field_defs, + }) + } + + fn materialize(self, state: &mut StDefnMaterializeState) -> StructDefinition { + let sh_idx = state.next_struct_handle(); + state.owned_type_indexes.advance_to(&Some(sh_idx)); + + // Each struct defines one or more fields. The collect() is to work around the borrow + // checker -- it's annoying. + let field_defs: Vec<_> = self + .field_defs + .into_iter() + .map(|field| field.materialize(sh_idx, state)) + .collect(); + let is_resource = self.is_resource + || field_defs + .iter() + .any(|x| state.is_resource(&state.type_signatures[x.signature.0 as usize].0)); + + let (field_count, fields) = state.add_field_defs(field_defs); + + let handle = StructHandle { + // 0 represents the current module + module: ModuleHandleIndex::new(0), + name: StringPoolIndex::new(self.name_idx.index(state.string_pool_len) as TableIndex), + is_resource, + }; + state.add_struct_handle(handle); + + StructDefinition { + struct_handle: sh_idx, + field_count, + fields, + } + } +} + +#[derive(Clone, Debug)] +struct FieldDefinitionGen { + name_idx: PropIndex, + signature_idx: PropIndex, + // XXX flags? +} + +impl FieldDefinitionGen { + fn strategy() -> impl Strategy { + (any::(), any::()).prop_map(|(name_idx, signature_idx)| Self { + name_idx, + signature_idx, + }) + } + + fn materialize( + self, + sh_idx: StructHandleIndex, + state: &StDefnMaterializeState, + ) -> FieldDefinition { + FieldDefinition { + struct_: sh_idx, + name: StringPoolIndex::new(self.name_idx.index(state.string_pool_len) as TableIndex), + signature: *state.owned_type_indexes.pick_value(&self.signature_idx), + } + } +} + +fn type_indexes<'a>( + signatures: impl IntoIterator, +) -> GrowingSubset, TypeSignatureIndex> { + signatures + .into_iter() + .enumerate() + .map(|(idx, signature)| { + // Any signatures that don't have a struct handle in them can always be picked. + // None is less than Some(0) so set those to None. + ( + signature.0.struct_index(), + TypeSignatureIndex::new(idx as TableIndex), + ) + }) + .collect() +} diff --git a/language/vm/src/proptest_types/functions.rs b/language/vm/src/proptest_types/functions.rs new file mode 100644 index 0000000000000..68b072b65571d --- /dev/null +++ b/language/vm/src/proptest_types/functions.rs @@ -0,0 +1,370 @@ +// Copyright (c) The Libra Core Contributors +// SPDX-License-Identifier: Apache-2.0 + +use crate::{ + file_format::{ + AddressPoolIndex, ByteArrayPoolIndex, Bytecode, CodeOffset, CodeUnit, FieldDefinitionIndex, + FunctionDefinition, FunctionHandle, FunctionHandleIndex, FunctionSignature, + FunctionSignatureIndex, LocalIndex, LocalsSignature, LocalsSignatureIndex, + ModuleHandleIndex, StringPoolIndex, StructDefinitionIndex, TableIndex, + }, + proptest_types::signature::{FunctionSignatureGen, SignatureTokenGen}, +}; +use proptest::{ + collection::{vec, SizeRange}, + prelude::*, + sample::{select, Index as PropIndex}, +}; + +/// Represents state required to materialize final data structures for function definitions. +#[derive(Debug)] +pub struct FnDefnMaterializeState { + pub struct_handles_len: usize, + pub address_pool_len: usize, + pub string_pool_len: usize, + pub byte_array_pool_len: usize, + pub type_signatures_len: usize, + pub field_defs_len: usize, + pub struct_defs_len: usize, + pub function_defs_len: usize, + // This is the final length of function_handles, after all the definitions add their own + // handles. + pub function_handles_len: usize, + // These get mutated by `FunctionDefinitionGen`. + pub function_signatures: Vec, + pub locals_signatures: Vec, + pub function_handles: Vec, +} + +impl FnDefnMaterializeState { + #[inline] + fn add_function_signature(&mut self, sig: FunctionSignature) -> FunctionSignatureIndex { + self.function_signatures.push(sig); + FunctionSignatureIndex::new((self.function_signatures.len() - 1) as TableIndex) + } + + #[inline] + fn add_locals_signature(&mut self, sig: LocalsSignature) -> LocalsSignatureIndex { + self.locals_signatures.push(sig); + LocalsSignatureIndex::new((self.locals_signatures.len() - 1) as TableIndex) + } + + #[inline] + fn add_function_handle(&mut self, handle: FunctionHandle) -> FunctionHandleIndex { + self.function_handles.push(handle); + FunctionHandleIndex::new((self.function_handles.len() - 1) as TableIndex) + } +} + +#[derive(Clone, Debug)] +pub struct FunctionDefinitionGen { + name: PropIndex, + signature: FunctionSignatureGen, + is_public: bool, + code: CodeUnitGen, +} + +impl FunctionDefinitionGen { + pub fn strategy( + return_count: impl Into, + arg_count: impl Into, + code_len: impl Into, + ) -> impl Strategy { + let return_count = return_count.into(); + let arg_count = arg_count.into(); + ( + any::(), + FunctionSignatureGen::strategy(return_count.clone(), arg_count.clone()), + any::(), + CodeUnitGen::strategy(arg_count, code_len), + ) + .prop_map(|(name, signature, is_public, code)| Self { + name, + signature, + is_public, + code, + }) + } + + pub fn materialize(self, state: &mut FnDefnMaterializeState) -> FunctionDefinition { + let signature = self.signature.materialize(state.struct_handles_len); + + let handle = FunctionHandle { + // 0 represents the current module + module: ModuleHandleIndex::new(0), + // XXX need to guarantee uniqueness of names? + name: StringPoolIndex::new(self.name.index(state.string_pool_len) as TableIndex), + signature: state.add_function_signature(signature), + }; + let function_handle = state.add_function_handle(handle); + + FunctionDefinition { + function: function_handle, + // XXX is this even correct? + flags: if self.is_public { + CodeUnit::PUBLIC + } else { + // No qualifiers. + 0 + }, + code: self.code.materialize(state), + } + } +} + +#[derive(Clone, Debug)] +struct CodeUnitGen { + locals_signature: Vec, + code: Vec, +} + +impl CodeUnitGen { + fn strategy( + arg_count: impl Into, + code_len: impl Into, + ) -> impl Strategy { + ( + vec(SignatureTokenGen::strategy(), arg_count), + vec(BytecodeGen::garbage_strategy(), code_len), + ) + .prop_map(|(locals_signature, code)| Self { + locals_signature, + code, + }) + } + + fn materialize(self, state: &mut FnDefnMaterializeState) -> CodeUnit { + let locals_signature = LocalsSignature( + self.locals_signature + .into_iter() + .map(|sig| sig.materialize(state.struct_handles_len)) + .collect(), + ); + + // Not all bytecodes will be successfully materialized -- count how many will. + let code_len = self + .code + .iter() + .filter(|code| code.will_materialize(state, &locals_signature)) + .count(); + + let code = self + .code + .into_iter() + .filter_map(|code| code.materialize(state, code_len, &locals_signature)) + .collect(); + + CodeUnit { + max_stack_size: 0, + locals: state.add_locals_signature(locals_signature), + // XXX actually generate code + code, + } + } +} + +#[derive(Clone, Debug)] +enum BytecodeGen { + // "Simple" means this doesn't refer to any other indexes. + Simple(Bytecode), + // All of these refer to other indexes. + LdAddr(PropIndex), + LdStr(PropIndex), + LdByteArray(PropIndex), + BorrowField(PropIndex), + Call(PropIndex), + Pack(PropIndex), + Unpack(PropIndex), + Exists(PropIndex), + BorrowGlobal(PropIndex), + MoveFrom(PropIndex), + MoveToSender(PropIndex), + BrTrue(PropIndex), + BrFalse(PropIndex), + Branch(PropIndex), + CopyLoc(PropIndex), + MoveLoc(PropIndex), + StLoc(PropIndex), + BorrowLoc(PropIndex), +} + +impl BytecodeGen { + // This just generates nonsensical bytecodes. This will be cleaned up later as the generation + // model is refined. + fn garbage_strategy() -> impl Strategy { + use BytecodeGen::*; + + prop_oneof![ + Self::simple_bytecode_strategy().prop_map(Simple), + any::().prop_map(LdAddr), + any::().prop_map(LdStr), + any::().prop_map(LdByteArray), + any::().prop_map(BorrowField), + any::().prop_map(Call), + any::().prop_map(Pack), + any::().prop_map(Unpack), + any::().prop_map(Exists), + any::().prop_map(BorrowGlobal), + any::().prop_map(MoveFrom), + any::().prop_map(MoveToSender), + any::().prop_map(BrTrue), + any::().prop_map(BrFalse), + any::().prop_map(Branch), + any::().prop_map(CopyLoc), + any::().prop_map(MoveLoc), + any::().prop_map(StLoc), + any::().prop_map(BorrowLoc), + ] + } + + /// Whether this code will be materialized into a Some(bytecode). + fn will_materialize( + &self, + state: &FnDefnMaterializeState, + locals_signature: &LocalsSignature, + ) -> bool { + // This method should remain in sync with the `None` below. + use BytecodeGen::*; + + match self { + BorrowField(_) => state.field_defs_len != 0, + CopyLoc(_) | MoveLoc(_) | StLoc(_) | BorrowLoc(_) => !locals_signature.is_empty(), + _ => true, + } + } + + fn materialize( + self, + state: &FnDefnMaterializeState, + code_len: usize, + locals_signature: &LocalsSignature, + ) -> Option { + // This method returns an Option because some bytecodes cannot be represented if + // some tables are empty. + // + // Once more sensible function bodies are generated this will probably have to start using + // prop_flat_map anyway, so revisit this then. + + let bytecode = match self { + BytecodeGen::Simple(bytecode) => bytecode, + BytecodeGen::LdAddr(idx) => Bytecode::LdAddr(AddressPoolIndex::new( + idx.index(state.address_pool_len) as TableIndex, + )), + BytecodeGen::LdStr(idx) => Bytecode::LdStr(StringPoolIndex::new( + idx.index(state.string_pool_len) as TableIndex, + )), + BytecodeGen::LdByteArray(idx) => Bytecode::LdByteArray(ByteArrayPoolIndex::new( + idx.index(state.byte_array_pool_len) as TableIndex, + )), + BytecodeGen::BorrowField(idx) => { + // Again, once meaningful bytecodes are generated this won't actually be a + // possibility since it would be impossible to load a field from a struct that + // doesn't have any. + if state.field_defs_len == 0 { + return None; + } + Bytecode::BorrowField(FieldDefinitionIndex::new( + idx.index(state.field_defs_len) as TableIndex + )) + } + BytecodeGen::Call(idx) => Bytecode::Call(FunctionHandleIndex::new( + idx.index(state.function_handles_len) as TableIndex, + )), + BytecodeGen::Pack(idx) => Bytecode::Pack(StructDefinitionIndex::new( + idx.index(state.struct_defs_len) as TableIndex, + )), + BytecodeGen::Unpack(idx) => Bytecode::Unpack(StructDefinitionIndex::new( + idx.index(state.struct_defs_len) as TableIndex, + )), + BytecodeGen::Exists(idx) => Bytecode::Exists(StructDefinitionIndex::new( + idx.index(state.struct_defs_len) as TableIndex, + )), + BytecodeGen::BorrowGlobal(idx) => Bytecode::BorrowGlobal(StructDefinitionIndex::new( + idx.index(state.struct_defs_len) as TableIndex, + )), + BytecodeGen::MoveFrom(idx) => Bytecode::MoveFrom(StructDefinitionIndex::new( + idx.index(state.struct_defs_len) as TableIndex, + )), + BytecodeGen::MoveToSender(idx) => Bytecode::MoveToSender(StructDefinitionIndex::new( + idx.index(state.struct_defs_len) as TableIndex, + )), + BytecodeGen::BrTrue(idx) => Bytecode::BrTrue(idx.index(code_len) as CodeOffset), + BytecodeGen::BrFalse(idx) => Bytecode::BrFalse(idx.index(code_len) as CodeOffset), + BytecodeGen::Branch(idx) => Bytecode::Branch(idx.index(code_len) as CodeOffset), + BytecodeGen::CopyLoc(idx) => { + if locals_signature.is_empty() { + return None; + } + Bytecode::CopyLoc(idx.index(locals_signature.len()) as LocalIndex) + } + BytecodeGen::MoveLoc(idx) => { + if locals_signature.is_empty() { + return None; + } + Bytecode::MoveLoc(idx.index(locals_signature.len()) as LocalIndex) + } + BytecodeGen::StLoc(idx) => { + if locals_signature.is_empty() { + return None; + } + Bytecode::StLoc(idx.index(locals_signature.len()) as LocalIndex) + } + BytecodeGen::BorrowLoc(idx) => { + if locals_signature.is_empty() { + return None; + } + Bytecode::BorrowLoc(idx.index(locals_signature.len()) as LocalIndex) + } + }; + + Some(bytecode) + } + + fn simple_bytecode_strategy() -> impl Strategy { + prop_oneof![ + // The numbers are relative weights, somewhat arbitrarily picked. + 9 => Self::just_bytecode_strategy(), + 1 => any::().prop_map(Bytecode::LdConst), + ] + } + + fn just_bytecode_strategy() -> impl Strategy { + use Bytecode::*; + + static JUST_BYTECODES: &[Bytecode] = &[ + FreezeRef, + ReleaseRef, + Pop, + Ret, + LdTrue, + LdFalse, + ReadRef, + WriteRef, + Add, + Sub, + Mul, + Mod, + Div, + BitOr, + BitAnd, + Xor, + Or, + And, + Eq, + Neq, + Lt, + Gt, + Le, + Ge, + Assert, + GetTxnGasUnitPrice, + GetTxnMaxGasUnits, + GetTxnSenderAddress, + CreateAccount, + EmitEvent, + GetTxnSequenceNumber, + GetTxnPublicKey, + ]; + select(JUST_BYTECODES) + } +} diff --git a/language/vm/src/proptest_types/signature.rs b/language/vm/src/proptest_types/signature.rs new file mode 100644 index 0000000000000..af2f70cc22763 --- /dev/null +++ b/language/vm/src/proptest_types/signature.rs @@ -0,0 +1,131 @@ +// Copyright (c) The Libra Core Contributors +// SPDX-License-Identifier: Apache-2.0 + +use crate::file_format::{FunctionSignature, SignatureToken, StructHandleIndex, TableIndex}; +use proptest::{ + collection::{vec, SizeRange}, + prelude::*, + sample::{select, Index as PropIndex}, +}; + +#[derive(Clone, Debug)] +pub struct FunctionSignatureGen { + return_types: Vec, + arg_types: Vec, +} + +impl FunctionSignatureGen { + pub fn strategy( + return_count: impl Into, + arg_count: impl Into, + ) -> impl Strategy { + ( + vec(SignatureTokenGen::strategy(), return_count), + vec(SignatureTokenGen::strategy(), arg_count), + ) + .prop_map(|(return_types, arg_types)| Self { + return_types, + arg_types, + }) + } + + pub fn materialize(self, struct_handles_len: usize) -> FunctionSignature { + FunctionSignature { + return_types: SignatureTokenGen::map_materialize(self.return_types, struct_handles_len) + .collect(), + arg_types: SignatureTokenGen::map_materialize(self.arg_types, struct_handles_len) + .collect(), + } + } +} + +#[derive(Clone, Debug)] +pub enum SignatureTokenGen { + // Atomic signature tokens. + Bool, + Integer, + String, + ByteArray, + Address, + Struct(PropIndex), + + // Composite signature tokens. + Reference(Box), + MutableReference(Box), +} + +impl SignatureTokenGen { + pub fn strategy() -> impl Strategy { + prop::strategy::Union::new_weighted(vec![ + (5, Self::atom_strategy().boxed()), + (1, Self::reference_strategy().boxed()), + (1, Self::mutable_reference_strategy().boxed()), + ]) + } + + /// Generates a signature token for an owned (non-reference) type. + pub fn owned_strategy() -> impl Strategy { + prop::strategy::Union::new_weighted(vec![(3, Self::atom_strategy().boxed())]) + } + + pub fn atom_strategy() -> impl Strategy { + use SignatureTokenGen::*; + + prop_oneof![ + 9 => Self::owned_non_struct_strategy(), + 1 => any::().prop_map(Struct), + ] + } + + /// Generates a signature token for a non-struct owned type. + pub fn owned_non_struct_strategy() -> impl Strategy { + use SignatureTokenGen::*; + + static OWNED_NON_STRUCTS: &[SignatureTokenGen] = + &[Bool, Integer, String, ByteArray, Address]; + + select(OWNED_NON_STRUCTS) + } + + pub fn reference_strategy() -> impl Strategy { + // References to references are not supported. + Self::owned_strategy().prop_map(|atom| SignatureTokenGen::Reference(Box::new(atom))) + } + + pub fn mutable_reference_strategy() -> impl Strategy { + // References to references are not supported. + Self::owned_strategy().prop_map(|atom| SignatureTokenGen::MutableReference(Box::new(atom))) + } + + pub fn materialize(self, struct_handles_len: usize) -> SignatureToken { + use SignatureTokenGen::*; + + match self { + Bool => SignatureToken::Bool, + Integer => SignatureToken::U64, + String => SignatureToken::String, + ByteArray => SignatureToken::ByteArray, + Address => SignatureToken::Address, + Struct(idx) => SignatureToken::Struct(StructHandleIndex::new( + idx.index(struct_handles_len) as TableIndex, + )), + Reference(token) => { + SignatureToken::Reference(Box::new(token.materialize(struct_handles_len))) + } + MutableReference(token) => { + SignatureToken::MutableReference(Box::new(token.materialize(struct_handles_len))) + } + } + } + + /// Convenience function to materialize many tokens. + #[inline] + pub fn map_materialize( + tokens: impl IntoIterator, + struct_handles_len: usize, + ) -> impl Iterator { + tokens + .into_iter() + .map(move |token| token.materialize(struct_handles_len)) + } +} diff --git a/language/vm/src/resolver.rs b/language/vm/src/resolver.rs new file mode 100644 index 0000000000000..4c04affbfa443 --- /dev/null +++ b/language/vm/src/resolver.rs @@ -0,0 +1,130 @@ +// Copyright (c) The Libra Core Contributors +// SPDX-License-Identifier: Apache-2.0 + +//! This module implements a resolver for importing a SignatureToken defined in one module into +//! another. This functionaliy is used in verify_module_dependencies and verify_script_dependencies. +use crate::{ + access::BaseAccess, + errors::VMStaticViolation, + file_format::{ + AddressPoolIndex, CompiledModule, FunctionSignature, ModuleHandle, ModuleHandleIndex, + SignatureToken, StringPoolIndex, StructHandle, StructHandleIndex, + }, +}; +use std::collections::BTreeMap; +use types::account_address::AccountAddress; + +/// Resolution context for importing types +pub struct Resolver { + address_map: BTreeMap, + string_map: BTreeMap, + module_handle_map: BTreeMap, + struct_handle_map: BTreeMap, +} + +impl Resolver { + /// create a new instance of Resolver for module + pub fn new(module: &CompiledModule) -> Self { + let mut address_map = BTreeMap::new(); + for (idx, address) in module.address_pool().enumerate() { + address_map.insert(address.clone(), AddressPoolIndex(idx as u16)); + } + let mut string_map = BTreeMap::new(); + for (idx, name) in module.string_pool().enumerate() { + string_map.insert(name.clone(), StringPoolIndex(idx as u16)); + } + let mut module_handle_map = BTreeMap::new(); + for (idx, module_hadndle) in module.module_handles().enumerate() { + module_handle_map.insert(module_hadndle.clone(), ModuleHandleIndex(idx as u16)); + } + let mut struct_handle_map = BTreeMap::new(); + for (idx, struct_handle) in module.struct_handles().enumerate() { + struct_handle_map.insert(struct_handle.clone(), StructHandleIndex(idx as u16)); + } + Self { + address_map, + string_map, + module_handle_map, + struct_handle_map, + } + } + + /// given a signature token in dependency, construct an equivalent signature token in the + /// context of this resolver and return it; return an error if resolution fails + pub fn import_signature_token( + &self, + dependency: &CompiledModule, + sig_token: &SignatureToken, + ) -> Result { + match sig_token { + SignatureToken::Bool + | SignatureToken::U64 + | SignatureToken::String + | SignatureToken::ByteArray + | SignatureToken::Address => Ok(sig_token.clone()), + SignatureToken::Struct(sh_idx) => { + let struct_handle = dependency.struct_handle_at(*sh_idx); + let defining_module_handle = dependency.module_handle_at(struct_handle.module); + let defining_module_address = dependency.address_at(defining_module_handle.address); + let defining_module_name = dependency.string_at(defining_module_handle.name); + let local_module_handle = ModuleHandle { + address: *self + .address_map + .get(defining_module_address) + .ok_or(VMStaticViolation::TypeResolutionFailure)?, + name: *self + .string_map + .get(defining_module_name) + .ok_or(VMStaticViolation::TypeResolutionFailure)?, + }; + let struct_name = dependency.string_at(struct_handle.name); + let local_struct_handle = StructHandle { + module: *self + .module_handle_map + .get(&local_module_handle) + .ok_or(VMStaticViolation::TypeResolutionFailure)?, + name: *self + .string_map + .get(struct_name) + .ok_or(VMStaticViolation::TypeResolutionFailure)?, + is_resource: struct_handle.is_resource, + }; + Ok(SignatureToken::Struct( + *self + .struct_handle_map + .get(&local_struct_handle) + .ok_or(VMStaticViolation::TypeResolutionFailure)?, + )) + } + SignatureToken::Reference(sub_sig_token) => Ok(SignatureToken::Reference(Box::new( + self.import_signature_token(dependency, sub_sig_token)?, + ))), + SignatureToken::MutableReference(sub_sig_token) => { + Ok(SignatureToken::MutableReference(Box::new( + self.import_signature_token(dependency, sub_sig_token)?, + ))) + } + } + } + + /// given a function signature in dependency, construct an equivalent function signature in the + /// context of this resolver and return it; return an error if resolution fails + pub fn import_function_signature( + &self, + dependency: &CompiledModule, + func_sig: &FunctionSignature, + ) -> Result { + let mut return_types = Vec::::new(); + let mut arg_types = Vec::::new(); + for e in &func_sig.return_types { + return_types.push(self.import_signature_token(dependency, e)?); + } + for e in &func_sig.arg_types { + arg_types.push(self.import_signature_token(dependency, e)?); + } + Ok(FunctionSignature { + return_types, + arg_types, + }) + } +} diff --git a/language/vm/src/serializer.rs b/language/vm/src/serializer.rs new file mode 100644 index 0000000000000..71f6bc51c938e --- /dev/null +++ b/language/vm/src/serializer.rs @@ -0,0 +1,921 @@ +// Copyright (c) The Libra Core Contributors +// SPDX-License-Identifier: Apache-2.0 + +//! Serialization of transactions and modules. +//! +//! This module exposes two entry points for serialization of `CompiledScript` and +//! `CompiledModule`. The entry points are exposed on the main structs `CompiledScript` and +//! `CompiledModule`. + +use crate::{file_format::*, file_format_common::*}; +use failure::*; +use std::ops::Deref; +use types::{account_address::AccountAddress, byte_array::ByteArray}; + +impl CompiledScript { + /// Serializes a `CompiledScript` into a binary. The mutable `Vec` will contain the + /// binary blob on return. + pub fn serialize(&self, binary: &mut Vec) -> Result<()> { + let mut ser = ScriptSerializer::new(1, 0); + let mut temp: Vec = Vec::new(); + ser.serialize(&mut temp, self)?; + ser.serialize_header(binary)?; + binary.append(&mut temp); + Ok(()) + } +} + +impl CompiledModule { + /// Serializes a `CompiledModule` into a binary. The mutable `Vec` will contain the + /// binary blob on return. + pub fn serialize(&self, binary: &mut Vec) -> Result<()> { + let mut ser = ModuleSerializer::new(1, 0); + let mut temp: Vec = Vec::new(); + ser.serialize(&mut temp, self)?; + ser.serialize_header(binary)?; + binary.append(&mut temp); + Ok(()) + } +} + +/// Holds data to compute the header of a generic binary. +/// +/// A binary header contains information about the tables serialized. +/// The serializer needs to serialize the tables in order to compute the offset and size +/// of each table. +/// `CommonSerializer` keeps track of the tables common to `CompiledScript` and +/// `CompiledModule`. +#[derive(Debug)] +struct CommonSerializer { + major_version: u8, + minor_version: u8, + table_count: u8, + module_handles: (u32, u32), + struct_handles: (u32, u32), + function_handles: (u32, u32), + type_signatures: (u32, u32), + function_signatures: (u32, u32), + locals_signatures: (u32, u32), + string_pool: (u32, u32), + address_pool: (u32, u32), + byte_array_pool: (u32, u32), +} + +/// Holds data to compute the header of a module binary. +#[derive(Debug)] +struct ModuleSerializer { + common: CommonSerializer, + struct_defs: (u32, u32), + field_defs: (u32, u32), + function_defs: (u32, u32), +} + +/// Holds data to compute the header of a transaction script binary. +#[derive(Debug)] +struct ScriptSerializer { + common: CommonSerializer, + main: (u32, u32), +} + +// +// Helpers +// +fn check_index_in_binary(index: usize) -> Result { + if index > u32::max_value() as usize { + bail!( + "Compilation unit too big ({}) cannot exceed {}", + index, + u32::max_value() + ) + } + Ok(index as u32) +} + +fn serialize_table(binary: &mut Vec, kind: TableType, offset: u32, count: u32) { + if count != 0 { + binary.push(kind as u8); + write_u32(binary, offset); + write_u32(binary, count); + } +} + +fn serialize_magic(binary: &mut Vec) { + for byte in &BinaryConstants::LIBRA_MAGIC { + binary.push(*byte); + } +} + +/// Trait to access tables for both `CompiledScript` and `CompiledModule`, +/// used by `CommonSerializer`. +trait CommonTables { + fn get_module_handles(&self) -> &[ModuleHandle]; + fn get_struct_handles(&self) -> &[StructHandle]; + fn get_function_handles(&self) -> &[FunctionHandle]; + fn get_string_pool(&self) -> &[String]; + fn get_address_pool(&self) -> &[AccountAddress]; + fn get_byte_array_pool(&self) -> &[ByteArray]; + fn get_type_signatures(&self) -> &[TypeSignature]; + fn get_function_signatures(&self) -> &[FunctionSignature]; + fn get_locals_signatures(&self) -> &[LocalsSignature]; +} + +impl CommonTables for CompiledScript { + fn get_module_handles(&self) -> &[ModuleHandle] { + &self.module_handles + } + + fn get_struct_handles(&self) -> &[StructHandle] { + &self.struct_handles + } + + fn get_function_handles(&self) -> &[FunctionHandle] { + &self.function_handles + } + + fn get_string_pool(&self) -> &[String] { + &self.string_pool + } + + fn get_address_pool(&self) -> &[AccountAddress] { + &self.address_pool + } + + fn get_byte_array_pool(&self) -> &[ByteArray] { + &self.byte_array_pool + } + + fn get_type_signatures(&self) -> &[TypeSignature] { + &self.type_signatures + } + + fn get_function_signatures(&self) -> &[FunctionSignature] { + &self.function_signatures + } + + fn get_locals_signatures(&self) -> &[LocalsSignature] { + &self.locals_signatures + } +} + +impl CommonTables for CompiledModule { + fn get_module_handles(&self) -> &[ModuleHandle] { + &self.module_handles + } + + fn get_struct_handles(&self) -> &[StructHandle] { + &self.struct_handles + } + + fn get_function_handles(&self) -> &[FunctionHandle] { + &self.function_handles + } + + fn get_string_pool(&self) -> &[String] { + &self.string_pool + } + + fn get_address_pool(&self) -> &[AccountAddress] { + &self.address_pool + } + + fn get_byte_array_pool(&self) -> &[ByteArray] { + &self.byte_array_pool + } + + fn get_type_signatures(&self) -> &[TypeSignature] { + &self.type_signatures + } + + fn get_function_signatures(&self) -> &[FunctionSignature] { + &self.function_signatures + } + + fn get_locals_signatures(&self) -> &[LocalsSignature] { + &self.locals_signatures + } +} + +/// Serializes a `ModuleHandle`. +/// +/// A `ModuleHandle` gets serialized as follows: +/// - `ModuleHandle.address` as a ULEB128 (index into the `AddressPool`) +/// - `ModuleHandle.name` as a ULEB128 (index into the `StringPool`) +fn serialize_module_handle(binary: &mut Vec, module_handle: &ModuleHandle) { + write_u16_as_uleb128(binary, module_handle.address.0); + write_u16_as_uleb128(binary, module_handle.name.0); +} + +/// Serializes a `StructHandle`. +/// +/// A `StructHandle` gets serialized as follows: +/// - `StructHandle.module` as a ULEB128 (index into the `ModuleHandle` table) +/// - `StructHandle.name` as a ULEB128 (index into the `StringPool`) +/// - `StructHandle.is_resource` as a 1 byte boolean (0 for false, 1 for true) +fn serialize_struct_handle(binary: &mut Vec, struct_handle: &StructHandle) { + write_u16_as_uleb128(binary, struct_handle.module.0); + write_u16_as_uleb128(binary, struct_handle.name.0); + if struct_handle.is_resource { + binary.push(1); + } else { + binary.push(0); + } +} + +/// Serializes a `FunctionHandle`. +/// +/// A `FunctionHandle` gets serialized as follows: +/// - `FunctionHandle.module` as a ULEB128 (index into the `ModuleHandle` table) +/// - `FunctionHandle.name` as a ULEB128 (index into the `StringPool`) +/// - `FunctionHandle.signature` as a ULEB128 (index into the `FunctionSignaturePool`) +fn serialize_function_handle(binary: &mut Vec, function_handle: &FunctionHandle) { + write_u16_as_uleb128(binary, function_handle.module.0); + write_u16_as_uleb128(binary, function_handle.name.0); + write_u16_as_uleb128(binary, function_handle.signature.0); +} + +/// Serializes a `String`. +/// +/// A `String` gets serialized as follows: +/// - `String` size as a ULEB128 +/// - `String` bytes - *exact format to be defined, Rust utf8 right now* +fn serialize_string(binary: &mut Vec, string: &str) -> Result<()> { + let bytes = string.as_bytes(); + let len = bytes.len(); + if len > u32::max_value() as usize { + bail!("string size ({}) cannot exceed {}", len, u32::max_value()) + } + write_u32_as_uleb128(binary, len as u32); + for byte in bytes { + binary.push(*byte); + } + Ok(()) +} + +/// Serializes a `ByteArray`. +/// +/// A `ByteArray` gets serialized as follows: +/// - `ByteArray` size as a ULEB128 +/// - `ByteArray` bytes in increasing index order +fn serialize_byte_array(binary: &mut Vec, byte_array: &ByteArray) -> Result<()> { + let bytes = byte_array.as_bytes(); + let len = bytes.len(); + if len > u32::max_value() as usize { + bail!( + "byte arrays size ({}) cannot exceed {}", + len, + u32::max_value() + ) + } + write_u32_as_uleb128(binary, len as u32); + for byte in bytes { + binary.push(*byte); + } + Ok(()) +} + +/// Serializes an `AccountAddress`. +/// +/// A `AccountAddress` gets serialized as follows: +/// - 32 bytes in increasing index order +fn serialize_address(binary: &mut Vec, address: &AccountAddress) -> Result<()> { + for byte in address.as_ref() { + binary.push(*byte); + } + Ok(()) +} + +/// Serializes a `StructDefinition`. +/// +/// A `StructDefinition` gets serialized as follows: +/// - `StructDefinition.handle` as a ULEB128 (index into the `ModuleHandle` table) +/// - `StructDefinition.field_count` as a ULEB128 (number of fields defined in the type) +/// - `StructDefinition.fields` as a ULEB128 (index into the `FieldDefinition` table) +fn serialize_struct_definition(binary: &mut Vec, struct_definition: &StructDefinition) { + write_u16_as_uleb128(binary, struct_definition.struct_handle.0); + write_u16_as_uleb128(binary, struct_definition.field_count); + write_u16_as_uleb128(binary, struct_definition.fields.0); +} + +/// Serializes a `FieldDefinition`. +/// +/// A `FieldDefinition` gets serialized as follows: +/// - `FieldDefinition.struct_` as a ULEB128 (index into the `StructHandle` table) +/// - `StructDefinition.name` as a ULEB128 (index into the `StringPool` table) +/// - `StructDefinition.signature` as a ULEB128 (index into the `TypeSignaturePool`) +fn serialize_field_definition(binary: &mut Vec, field_definition: &FieldDefinition) { + write_u16_as_uleb128(binary, field_definition.struct_.0); + write_u16_as_uleb128(binary, field_definition.name.0); + write_u16_as_uleb128(binary, field_definition.signature.0); +} + +/// Serializes a `FunctionDefinition`. +/// +/// A `FunctionDefinition` gets serialized as follows: +/// - `FunctionDefinition.function` as a ULEB128 (index into the `FunctionHandle` table) +/// - `StructDefinition.flags` 1 byte for the flags of the function +/// - `StructDefinition.code` a variable size stream for the `CodeUnit` +fn serialize_function_definition( + binary: &mut Vec, + function_definition: &FunctionDefinition, +) -> Result<()> { + write_u16_as_uleb128(binary, function_definition.function.0); + binary.push(function_definition.flags); + serialize_code_unit(binary, &function_definition.code) +} + +/// Serializes a `TypeSignature`. +/// +/// A `TypeSignature` gets serialized as follows: +/// - `SignatureType::TYPE_SIGNATURE` as 1 byte +/// - The `SignatureToken` as a blob +fn serialize_type_signature(binary: &mut Vec, signature: &TypeSignature) -> Result<()> { + binary.push(SignatureType::TYPE_SIGNATURE as u8); + serialize_signature_token(binary, &signature.0) +} + +/// Serializes a `FunctionSignature`. +/// +/// A `FunctionSignature` gets serialized as follows: +/// - `SignatureType::FUNCTION_SIGNATURE` as 1 byte +/// - The vector of `SignatureToken`s for the return values +/// - The vector of `SignatureToken`s for the arguments +fn serialize_function_signature(binary: &mut Vec, signature: &FunctionSignature) -> Result<()> { + binary.push(SignatureType::FUNCTION_SIGNATURE as u8); + serialize_signature_tokens(binary, &signature.return_types)?; + serialize_signature_tokens(binary, &signature.arg_types) +} + +/// Serializes a `LocalsSignature`. +/// +/// A `LocalsSignature` gets serialized as follows: +/// - `SignatureType::LOCAL_SIGNATURE` as 1 byte +/// - The vector of `SignatureToken`s for locals +fn serialize_locals_signature(binary: &mut Vec, signature: &LocalsSignature) -> Result<()> { + binary.push(SignatureType::LOCAL_SIGNATURE as u8); + serialize_signature_tokens(binary, &signature.0) +} + +/// Serializes a slice of `SignatureToken`s. +fn serialize_signature_tokens(binary: &mut Vec, tokens: &[SignatureToken]) -> Result<()> { + let len = tokens.len(); + if len > u8::max_value() as usize { + bail!( + "arguments/locals size ({}) cannot exceed {}", + len, + u8::max_value(), + ) + } + binary.push(len as u8); + for token in tokens { + serialize_signature_token(binary, token)?; + } + Ok(()) +} + +/// Serializes a `SignatureToken`. +/// +/// A `SignatureToken` gets serialized as a variable size blob depending on composition. +/// Values for types are defined in `SerializedType`. +fn serialize_signature_token(binary: &mut Vec, token: &SignatureToken) -> Result<()> { + match token { + SignatureToken::Bool => binary.push(SerializedType::BOOL as u8), + SignatureToken::U64 => binary.push(SerializedType::INTEGER as u8), + SignatureToken::String => binary.push(SerializedType::STRING as u8), + SignatureToken::ByteArray => binary.push(SerializedType::BYTEARRAY as u8), + SignatureToken::Address => binary.push(SerializedType::ADDRESS as u8), + SignatureToken::Struct(idx) => { + binary.push(SerializedType::STRUCT as u8); + write_u16_as_uleb128(binary, idx.0); + } + SignatureToken::Reference(boxed_token) => { + binary.push(SerializedType::REFERENCE as u8); + serialize_signature_token(binary, boxed_token.deref())? + } + SignatureToken::MutableReference(boxed_token) => { + binary.push(SerializedType::MUTABLE_REFERENCE as u8); + serialize_signature_token(binary, boxed_token.deref())? + } + } + Ok(()) +} + +/// Serializes a `CodeUnit`. +/// +/// A `CodeUnit` is serialized as the lst firld of a `FunctionDefinition`. +/// A `CodeUnit` gets serialized as follows: +/// - `CodeUnit.max_stack_size` as a ULEB128 +/// - `CodeUnit.locals` as a ULEB128 (index into the `LocalSignaturePool`) +/// - `CodeUnit.code` as variable size byte stream for the bytecode +fn serialize_code_unit(binary: &mut Vec, code: &CodeUnit) -> Result<()> { + write_u16_as_uleb128(binary, code.max_stack_size); + write_u16_as_uleb128(binary, code.locals.0); + serialize_code(binary, &code.code) +} + +/// Serializes a `Bytecode` stream. Serialization of the function body. +fn serialize_code(binary: &mut Vec, code: &[Bytecode]) -> Result<()> { + let code_size = code.len(); + if code_size > u16::max_value() as usize { + bail!( + "code size ({}) cannot exceed {}", + code_size, + u16::max_value(), + ) + } + write_u16(binary, code_size as u16); + for opcode in code { + match opcode { + Bytecode::FreezeRef => binary.push(Opcodes::FREEZE_REF as u8), + Bytecode::Pop => binary.push(Opcodes::POP as u8), + Bytecode::Ret => binary.push(Opcodes::RET as u8), + Bytecode::BrTrue(code_offset) => { + binary.push(Opcodes::BR_TRUE as u8); + write_u16(binary, *code_offset); + } + Bytecode::BrFalse(code_offset) => { + binary.push(Opcodes::BR_FALSE as u8); + write_u16(binary, *code_offset); + } + Bytecode::Branch(code_offset) => { + binary.push(Opcodes::BRANCH as u8); + write_u16(binary, *code_offset); + } + Bytecode::LdConst(value) => { + binary.push(Opcodes::LD_CONST as u8); + write_u64(binary, *value); + } + Bytecode::LdAddr(address_idx) => { + binary.push(Opcodes::LD_ADDR as u8); + write_u16_as_uleb128(binary, address_idx.0); + } + Bytecode::LdByteArray(byte_array_idx) => { + binary.push(Opcodes::LD_BYTEARRAY as u8); + write_u16_as_uleb128(binary, byte_array_idx.0); + } + Bytecode::LdStr(string_idx) => { + binary.push(Opcodes::LD_STR as u8); + write_u16_as_uleb128(binary, string_idx.0); + } + Bytecode::LdTrue => binary.push(Opcodes::LD_TRUE as u8), + Bytecode::LdFalse => binary.push(Opcodes::LD_FALSE as u8), + Bytecode::CopyLoc(local_idx) => { + binary.push(Opcodes::COPY_LOC as u8); + binary.push(*local_idx); + } + Bytecode::MoveLoc(local_idx) => { + binary.push(Opcodes::MOVE_LOC as u8); + binary.push(*local_idx); + } + Bytecode::StLoc(local_idx) => { + binary.push(Opcodes::ST_LOC as u8); + binary.push(*local_idx); + } + Bytecode::BorrowLoc(local_idx) => { + binary.push(Opcodes::LD_REF_LOC as u8); + binary.push(*local_idx); + } + Bytecode::BorrowField(field_idx) => { + binary.push(Opcodes::LD_REF_FIELD as u8); + write_u16_as_uleb128(binary, field_idx.0); + } + Bytecode::Call(method_idx) => { + binary.push(Opcodes::CALL as u8); + write_u16_as_uleb128(binary, method_idx.0); + } + Bytecode::Pack(class_idx) => { + binary.push(Opcodes::PACK as u8); + write_u16_as_uleb128(binary, class_idx.0); + } + Bytecode::Unpack(class_idx) => { + binary.push(Opcodes::UNPACK as u8); + write_u16_as_uleb128(binary, class_idx.0); + } + Bytecode::ReadRef => binary.push(Opcodes::READ_REF as u8), + Bytecode::WriteRef => binary.push(Opcodes::WRITE_REF as u8), + Bytecode::Add => binary.push(Opcodes::ADD as u8), + Bytecode::Sub => binary.push(Opcodes::SUB as u8), + Bytecode::Mul => binary.push(Opcodes::MUL as u8), + Bytecode::Mod => binary.push(Opcodes::MOD as u8), + Bytecode::Div => binary.push(Opcodes::DIV as u8), + Bytecode::BitOr => binary.push(Opcodes::BIT_OR as u8), + Bytecode::BitAnd => binary.push(Opcodes::BIT_AND as u8), + Bytecode::Xor => binary.push(Opcodes::XOR as u8), + Bytecode::Or => binary.push(Opcodes::OR as u8), + Bytecode::And => binary.push(Opcodes::AND as u8), + Bytecode::Not => binary.push(Opcodes::NOT as u8), + Bytecode::Eq => binary.push(Opcodes::EQ as u8), + Bytecode::Neq => binary.push(Opcodes::NEQ as u8), + Bytecode::Lt => binary.push(Opcodes::LT as u8), + Bytecode::Gt => binary.push(Opcodes::GT as u8), + Bytecode::Le => binary.push(Opcodes::LE as u8), + Bytecode::Ge => binary.push(Opcodes::GE as u8), + Bytecode::Assert => binary.push(Opcodes::ASSERT as u8), + Bytecode::GetTxnGasUnitPrice => binary.push(Opcodes::GET_TXN_GAS_UNIT_PRICE as u8), + Bytecode::GetTxnMaxGasUnits => binary.push(Opcodes::GET_TXN_MAX_GAS_UNITS as u8), + Bytecode::GetGasRemaining => binary.push(Opcodes::GET_GAS_REMAINING as u8), + Bytecode::GetTxnSenderAddress => binary.push(Opcodes::GET_TXN_SENDER as u8), + Bytecode::Exists(class_idx) => { + binary.push(Opcodes::EXISTS as u8); + write_u16_as_uleb128(binary, class_idx.0); + } + Bytecode::BorrowGlobal(class_idx) => { + binary.push(Opcodes::BORROW_REF as u8); + write_u16_as_uleb128(binary, class_idx.0); + } + Bytecode::ReleaseRef => binary.push(Opcodes::RELEASE_REF as u8), + Bytecode::MoveFrom(class_idx) => { + binary.push(Opcodes::MOVE_FROM as u8); + write_u16_as_uleb128(binary, class_idx.0); + } + Bytecode::MoveToSender(class_idx) => { + binary.push(Opcodes::MOVE_TO as u8); + write_u16_as_uleb128(binary, class_idx.0); + } + Bytecode::CreateAccount => binary.push(Opcodes::CREATE_ACCOUNT as u8), + Bytecode::EmitEvent => binary.push(Opcodes::EMIT_EVENT as u8), + Bytecode::GetTxnSequenceNumber => binary.push(Opcodes::GET_TXN_SEQUENCE_NUMBER as u8), + Bytecode::GetTxnPublicKey => binary.push(Opcodes::GET_TXN_PUBLIC_KEY as u8), + } + } + Ok(()) +} + +impl CommonSerializer { + pub fn new(major_version: u8, minor_version: u8) -> CommonSerializer { + CommonSerializer { + major_version, + minor_version, + table_count: 0, + module_handles: (0, 0), + struct_handles: (0, 0), + function_handles: (0, 0), + type_signatures: (0, 0), + function_signatures: (0, 0), + locals_signatures: (0, 0), + string_pool: (0, 0), + address_pool: (0, 0), + byte_array_pool: (0, 0), + } + } + + /// Common binary header serialization. + fn serialize_header(&mut self, binary: &mut Vec) -> Result { + serialize_magic(binary); + binary.push(self.major_version); + binary.push(self.minor_version); + binary.push(self.table_count); + + let start_offset = check_index_in_binary(binary.len())? + u32::from(self.table_count) * 9; + + serialize_table( + binary, + TableType::MODULE_HANDLES, + self.module_handles.0 + start_offset, + self.module_handles.1, + ); + serialize_table( + binary, + TableType::STRUCT_HANDLES, + self.struct_handles.0 + start_offset, + self.struct_handles.1, + ); + serialize_table( + binary, + TableType::FUNCTION_HANDLES, + self.function_handles.0 + start_offset, + self.function_handles.1, + ); + serialize_table( + binary, + TableType::TYPE_SIGNATURES, + self.type_signatures.0 + start_offset, + self.type_signatures.1, + ); + serialize_table( + binary, + TableType::FUNCTION_SIGNATURES, + self.function_signatures.0 + start_offset, + self.function_signatures.1, + ); + serialize_table( + binary, + TableType::LOCALS_SIGNATURES, + self.locals_signatures.0 + start_offset, + self.locals_signatures.1, + ); + serialize_table( + binary, + TableType::STRING_POOL, + self.string_pool.0 + start_offset, + self.string_pool.1, + ); + serialize_table( + binary, + TableType::ADDRESS_POOL, + self.address_pool.0 + start_offset, + self.address_pool.1, + ); + serialize_table( + binary, + TableType::BYTE_ARRAY_POOL, + self.byte_array_pool.0 + start_offset, + self.byte_array_pool.1, + ); + Ok(start_offset) + } + + fn serialize_common( + &mut self, + binary: &mut Vec, + tables: &T, + ) -> Result<()> { + self.serialize_module_handles(binary, tables.get_module_handles())?; + self.serialize_struct_handles(binary, tables.get_struct_handles())?; + self.serialize_function_handles(binary, tables.get_function_handles())?; + self.serialize_type_signatures(binary, tables.get_type_signatures())?; + self.serialize_function_signatures(binary, tables.get_function_signatures())?; + self.serialize_locals_signatures(binary, tables.get_locals_signatures())?; + self.serialize_strings(binary, tables.get_string_pool())?; + self.serialize_addresses(binary, tables.get_address_pool())?; + self.serialize_byte_arrays(binary, tables.get_byte_array_pool())?; + Ok(()) + } + + /// Serializes `ModuleHandle` table. + fn serialize_module_handles( + &mut self, + binary: &mut Vec, + module_handles: &[ModuleHandle], + ) -> Result<()> { + if !module_handles.is_empty() { + self.table_count += 1; + self.module_handles.0 = check_index_in_binary(binary.len())?; + for module_handle in module_handles { + serialize_module_handle(binary, module_handle); + } + self.module_handles.1 = check_index_in_binary(binary.len())? - self.module_handles.0; + } + Ok(()) + } + + /// Serializes `StructHandle` table. + fn serialize_struct_handles( + &mut self, + binary: &mut Vec, + struct_handles: &[StructHandle], + ) -> Result<()> { + if !struct_handles.is_empty() { + self.table_count += 1; + self.struct_handles.0 = check_index_in_binary(binary.len())?; + for struct_handle in struct_handles { + serialize_struct_handle(binary, struct_handle); + } + self.struct_handles.1 = check_index_in_binary(binary.len())? - self.struct_handles.0; + } + Ok(()) + } + + /// Serializes `FunctionHandle` table. + fn serialize_function_handles( + &mut self, + binary: &mut Vec, + function_handles: &[FunctionHandle], + ) -> Result<()> { + if !function_handles.is_empty() { + self.table_count += 1; + self.function_handles.0 = check_index_in_binary(binary.len())?; + for function_handle in function_handles { + serialize_function_handle(binary, function_handle); + } + self.function_handles.1 = + check_index_in_binary(binary.len())? - self.function_handles.0; + } + Ok(()) + } + + /// Serializes `StringPool`. + fn serialize_strings(&mut self, binary: &mut Vec, strings: &[String]) -> Result<()> { + if !strings.is_empty() { + self.table_count += 1; + self.string_pool.0 = check_index_in_binary(binary.len())?; + for string in strings { + serialize_string(binary, string)?; + } + self.string_pool.1 = check_index_in_binary(binary.len())? - self.string_pool.0; + } + Ok(()) + } + + /// Serializes `ByteArrayPool`. + fn serialize_byte_arrays( + &mut self, + binary: &mut Vec, + byte_arrays: &[ByteArray], + ) -> Result<()> { + if !byte_arrays.is_empty() { + self.table_count += 1; + self.byte_array_pool.0 = check_index_in_binary(binary.len())?; + for byte_array in byte_arrays { + serialize_byte_array(binary, byte_array)?; + } + self.byte_array_pool.1 = check_index_in_binary(binary.len())? - self.byte_array_pool.0; + } + Ok(()) + } + + /// Serializes `AddressPool`. + fn serialize_addresses( + &mut self, + binary: &mut Vec, + addresses: &[AccountAddress], + ) -> Result<()> { + if !addresses.is_empty() { + self.table_count += 1; + self.address_pool.0 = check_index_in_binary(binary.len())?; + for address in addresses { + serialize_address(binary, address)?; + } + self.address_pool.1 = check_index_in_binary(binary.len())? - self.address_pool.0; + } + Ok(()) + } + + /// Serializes `TypeSignaturePool` table. + fn serialize_type_signatures( + &mut self, + binary: &mut Vec, + signatures: &[TypeSignature], + ) -> Result<()> { + if !signatures.is_empty() { + self.table_count += 1; + self.type_signatures.0 = check_index_in_binary(binary.len())?; + for signature in signatures { + serialize_type_signature(binary, signature)?; + } + self.type_signatures.1 = check_index_in_binary(binary.len())? - self.type_signatures.0; + } + Ok(()) + } + + /// Serializes `FunctionSignaturePool` table. + fn serialize_function_signatures( + &mut self, + binary: &mut Vec, + signatures: &[FunctionSignature], + ) -> Result<()> { + if !signatures.is_empty() { + self.table_count += 1; + self.function_signatures.0 = check_index_in_binary(binary.len())?; + for signature in signatures { + serialize_function_signature(binary, signature)?; + } + self.function_signatures.1 = + check_index_in_binary(binary.len())? - self.function_signatures.0; + } + Ok(()) + } + + /// Serializes `LocalSignaturePool` table. + fn serialize_locals_signatures( + &mut self, + binary: &mut Vec, + signatures: &[LocalsSignature], + ) -> Result<()> { + if !signatures.is_empty() { + self.table_count += 1; + self.locals_signatures.0 = check_index_in_binary(binary.len())?; + for signature in signatures { + serialize_locals_signature(binary, signature)?; + } + self.locals_signatures.1 = + check_index_in_binary(binary.len())? - self.locals_signatures.0; + } + Ok(()) + } +} + +impl ModuleSerializer { + fn new(major_version: u8, minor_version: u8) -> ModuleSerializer { + ModuleSerializer { + common: CommonSerializer::new(major_version, minor_version), + struct_defs: (0, 0), + field_defs: (0, 0), + function_defs: (0, 0), + } + } + + fn serialize(&mut self, binary: &mut Vec, module: &CompiledModule) -> Result<()> { + self.common.serialize_common(binary, module)?; + self.serialize_struct_definitions(binary, &module.struct_defs)?; + self.serialize_field_definitions(binary, &module.field_defs)?; + self.serialize_function_definitions(binary, &module.function_defs) + } + + fn serialize_header(&mut self, binary: &mut Vec) -> Result<()> { + let start_offset = self.common.serialize_header(binary)?; + serialize_table( + binary, + TableType::STRUCT_DEFS, + self.struct_defs.0 + start_offset, + self.struct_defs.1, + ); + serialize_table( + binary, + TableType::FIELD_DEFS, + self.field_defs.0 + start_offset, + self.field_defs.1, + ); + serialize_table( + binary, + TableType::FUNCTION_DEFS, + self.function_defs.0 + start_offset, + self.function_defs.1, + ); + Ok(()) + } + + /// Serializes `StructDefinition` table. + fn serialize_struct_definitions( + &mut self, + binary: &mut Vec, + struct_definitions: &[StructDefinition], + ) -> Result<()> { + if !struct_definitions.is_empty() { + self.common.table_count += 1; + self.struct_defs.0 = check_index_in_binary(binary.len())?; + for struct_definition in struct_definitions { + serialize_struct_definition(binary, struct_definition); + } + self.struct_defs.1 = check_index_in_binary(binary.len())? - self.struct_defs.0; + } + Ok(()) + } + + /// Serializes `FieldDefinition` table. + fn serialize_field_definitions( + &mut self, + binary: &mut Vec, + field_definitions: &[FieldDefinition], + ) -> Result<()> { + if !field_definitions.is_empty() { + self.common.table_count += 1; + self.field_defs.0 = check_index_in_binary(binary.len())?; + for field_definition in field_definitions { + serialize_field_definition(binary, field_definition); + } + self.field_defs.1 = check_index_in_binary(binary.len())? - self.field_defs.0; + } + Ok(()) + } + + /// Serializes `FunctionDefinition` table. + fn serialize_function_definitions( + &mut self, + binary: &mut Vec, + function_definitions: &[FunctionDefinition], + ) -> Result<()> { + if !function_definitions.is_empty() { + self.common.table_count += 1; + self.function_defs.0 = check_index_in_binary(binary.len())?; + for function_definition in function_definitions { + serialize_function_definition(binary, function_definition)?; + } + self.function_defs.1 = check_index_in_binary(binary.len())? - self.function_defs.0; + } + Ok(()) + } +} + +impl ScriptSerializer { + fn new(major_version: u8, minor_version: u8) -> ScriptSerializer { + ScriptSerializer { + common: CommonSerializer::new(major_version, minor_version), + main: (0, 0), + } + } + + fn serialize(&mut self, binary: &mut Vec, script: &CompiledScript) -> Result<()> { + self.common.serialize_common(binary, script)?; + self.serialize_main(binary, &script.main) + } + + fn serialize_header(&mut self, binary: &mut Vec) -> Result<()> { + let start_offset = self.common.serialize_header(binary)?; + serialize_table( + binary, + TableType::MAIN, + self.main.0 + start_offset, + self.main.1, + ); + Ok(()) + } + + /// Serializes `CompiledScript` main. + fn serialize_main(&mut self, binary: &mut Vec, main: &FunctionDefinition) -> Result<()> { + self.common.table_count += 1; + self.main.0 = check_index_in_binary(binary.len())?; + serialize_function_definition(binary, main)?; + self.main.1 = check_index_in_binary(binary.len())? - self.main.0; + Ok(()) + } +} diff --git a/language/vm/src/transaction_metadata.rs b/language/vm/src/transaction_metadata.rs new file mode 100644 index 0000000000000..81a5b5e0878a0 --- /dev/null +++ b/language/vm/src/transaction_metadata.rs @@ -0,0 +1,65 @@ +// Copyright (c) The Libra Core Contributors +// SPDX-License-Identifier: Apache-2.0 + +use crypto::{signing::generate_genesis_keypair, PublicKey}; +use types::{account_address::AccountAddress, transaction::SignedTransaction}; + +pub struct TransactionMetadata { + pub sender: AccountAddress, + pub public_key: PublicKey, + pub sequence_number: u64, + pub max_gas_amount: u64, + pub gas_unit_price: u64, + pub transaction_size: u64, +} + +impl TransactionMetadata { + pub fn new(txn: &SignedTransaction) -> Self { + Self { + sender: txn.sender(), + public_key: txn.public_key(), + sequence_number: txn.sequence_number(), + max_gas_amount: txn.max_gas_amount(), + gas_unit_price: txn.gas_unit_price(), + transaction_size: txn.raw_txn_bytes_len() as u64, + } + } + + pub fn max_gas_amount(&self) -> u64 { + self.max_gas_amount + } + + pub fn gas_unit_price(&self) -> u64 { + self.gas_unit_price + } + + pub fn sender(&self) -> AccountAddress { + self.sender.to_owned() + } + + pub fn public_key(&self) -> &PublicKey { + &self.public_key + } + + pub fn sequence_number(&self) -> u64 { + self.sequence_number + } + + pub fn transaction_size(&self) -> u64 { + self.transaction_size + } +} + +impl Default for TransactionMetadata { + fn default() -> Self { + let (_, public_key) = generate_genesis_keypair(); + TransactionMetadata { + sender: AccountAddress::default(), + public_key, + sequence_number: 0, + max_gas_amount: 100_000_000, + gas_unit_price: 0, + transaction_size: 0, + } + } +} diff --git a/language/vm/src/unit_tests/checks.rs b/language/vm/src/unit_tests/checks.rs new file mode 100644 index 0000000000000..9dfc4c7dc487b --- /dev/null +++ b/language/vm/src/unit_tests/checks.rs @@ -0,0 +1,4 @@ +// Copyright (c) The Libra Core Contributors +// SPDX-License-Identifier: Apache-2.0 + +mod signature_tests; diff --git a/language/vm/src/unit_tests/checks/signature_tests.rs b/language/vm/src/unit_tests/checks/signature_tests.rs new file mode 100644 index 0000000000000..13532deb0fc87 --- /dev/null +++ b/language/vm/src/unit_tests/checks/signature_tests.rs @@ -0,0 +1,41 @@ +// Copyright (c) The Libra Core Contributors +// SPDX-License-Identifier: Apache-2.0 + +use crate::{ + errors::VMStaticViolation, + file_format::{SignatureToken, StructHandleIndex}, + SignatureTokenKind, +}; + +#[test] +fn test_sig_token_structure() { + // Valid cases. + let bool_token = SignatureToken::Bool; + assert_eq!(bool_token.check_structure(), None); + let struct_token = SignatureToken::Struct(StructHandleIndex::new(0)); + assert_eq!(struct_token.check_structure(), None); + let ref_token = SignatureToken::Reference(Box::new(struct_token.clone())); + assert_eq!(ref_token.check_structure(), None); + let mut_ref_token = SignatureToken::MutableReference(Box::new(struct_token.clone())); + assert_eq!(mut_ref_token.check_structure(), None); + + // Invalid cases. + let ref_ref_token = SignatureToken::Reference(Box::new(ref_token.clone())); + assert_eq!( + ref_ref_token.check_structure(), + Some(VMStaticViolation::InvalidSignatureToken( + ref_ref_token.clone(), + SignatureTokenKind::Reference, + SignatureTokenKind::Reference, + )) + ); + let ref_mut_ref_token = SignatureToken::Reference(Box::new(mut_ref_token.clone())); + assert_eq!( + ref_mut_ref_token.check_structure(), + Some(VMStaticViolation::InvalidSignatureToken( + ref_mut_ref_token.clone(), + SignatureTokenKind::Reference, + SignatureTokenKind::MutableReference, + )) + ); +} diff --git a/language/vm/src/unit_tests/deserializer_tests.rs b/language/vm/src/unit_tests/deserializer_tests.rs new file mode 100644 index 0000000000000..a73321b2f6347 --- /dev/null +++ b/language/vm/src/unit_tests/deserializer_tests.rs @@ -0,0 +1,64 @@ +// Copyright (c) The Libra Core Contributors +// SPDX-License-Identifier: Apache-2.0 + +use crate::{ + errors::*, + file_format::{CompiledModule, CompiledScript}, + file_format_common::*, +}; + +#[test] +fn malformed_simple() { + // empty binary + let mut binary = vec![]; + let mut res = CompiledScript::deserialize(&binary); + assert_eq!( + res.expect_err("Expected malformed binary"), + BinaryError::Malformed + ); + + // under-sized binary + binary = vec![0u8, 0u8, 0u8]; + res = CompiledScript::deserialize(&binary); + assert_eq!( + res.expect_err("Expected malformed binary"), + BinaryError::Malformed + ); + + // bad magic + binary = vec![0u8; 15]; + res = CompiledScript::deserialize(&binary); + assert_eq!(res.expect_err("Expected bad magic"), BinaryError::BadMagic); + + // only magic + binary = BinaryConstants::LIBRA_MAGIC.to_vec(); + res = CompiledScript::deserialize(&binary); + assert_eq!( + res.expect_err("Expected malformed binary"), + BinaryError::Malformed + ); + + // bad major version + binary = BinaryConstants::LIBRA_MAGIC.to_vec(); + binary.push(2); // major version + binary.push(0); // minor version + binary.push(10); // table count + binary.push(0); // rest of binary ;) + res = CompiledScript::deserialize(&binary); + assert_eq!( + res.expect_err("Expected unknown version"), + BinaryError::UnknownVersion + ); + + // bad minor version + binary = BinaryConstants::LIBRA_MAGIC.to_vec(); + binary.push(1); // major version + binary.push(1); // minor version + binary.push(10); // table count + binary.push(0); // rest of binary ;) + let res1 = CompiledModule::deserialize(&binary); + assert_eq!( + res1.expect_err("Expected unknown version"), + BinaryError::UnknownVersion + ); +} diff --git a/language/vm/src/unit_tests/fixture_tests.rs b/language/vm/src/unit_tests/fixture_tests.rs new file mode 100644 index 0000000000000..ace6ba8e72f29 --- /dev/null +++ b/language/vm/src/unit_tests/fixture_tests.rs @@ -0,0 +1,16 @@ +// Copyright (c) The Libra Core Contributors +// SPDX-License-Identifier: Apache-2.0 + +use crate::file_format::{CompiledModule, CompiledScript}; +use types::test_helpers::transaction_test_helpers::placeholder_script; + +// Ensure that the placeholder_script fixture deserializes properly, i.e. is kept up to date. +#[test] +fn placeholder_script_deserialize() { + let placeholder_program = placeholder_script(); + CompiledScript::deserialize(&placeholder_program.code()) + .expect("script should deserialize properly"); + for module in placeholder_program.modules() { + CompiledModule::deserialize(module).expect("module should deserialize properly"); + } +} diff --git a/language/vm/src/unit_tests/mod.rs b/language/vm/src/unit_tests/mod.rs new file mode 100644 index 0000000000000..1824b979c6571 --- /dev/null +++ b/language/vm/src/unit_tests/mod.rs @@ -0,0 +1,7 @@ +// Copyright (c) The Libra Core Contributors +// SPDX-License-Identifier: Apache-2.0 + +mod checks; +mod deserializer_tests; +mod fixture_tests; +mod number_tests; diff --git a/language/vm/src/unit_tests/number_tests.rs b/language/vm/src/unit_tests/number_tests.rs new file mode 100644 index 0000000000000..5cc44cbf7d3d8 --- /dev/null +++ b/language/vm/src/unit_tests/number_tests.rs @@ -0,0 +1,142 @@ +// Copyright (c) The Libra Core Contributors +// SPDX-License-Identifier: Apache-2.0 + +use crate::file_format_common::*; +use byteorder::{LittleEndian, ReadBytesExt}; +use proptest::prelude::*; +use std::io::Cursor; + +// verify all bytes in the vector have the high bit set except the last one +fn check_vector(buf: &[u8]) { + let mut last_byte: bool = false; + for byte in buf { + assert!(!last_byte); + if *byte & 0x80 == 0 { + last_byte = true; + } + if !last_byte { + assert!(*byte & 0x80 > 0, "{} & 0x80", *byte); + } + } + assert!(last_byte); +} + +fn test_u16(value: u16, expected_bytes: usize) { + let mut buf: Vec = Vec::new(); + write_u16_as_uleb128(&mut buf, value); + assert_eq!(buf.len(), expected_bytes); + check_vector(&buf); + let mut cursor = Cursor::new(&buf[..]); + let val = read_uleb128_as_u16(&mut cursor).unwrap(); + assert_eq!(value, val); +} + +fn test_u32(value: u32, expected_bytes: usize) { + let mut buf: Vec = Vec::new(); + write_u32_as_uleb128(&mut buf, value); + assert_eq!(buf.len(), expected_bytes); + check_vector(&buf); + let mut cursor = Cursor::new(&buf[..]); + let val = read_uleb128_as_u32(&mut cursor).unwrap(); + assert_eq!(value, val); +} + +#[test] +fn lab128_u16_test() { + test_u16(0, 1); + test_u16(16, 1); + test_u16(2u16.pow(7) - 1, 1); + test_u16(2u16.pow(7), 2); + test_u16(2u16.pow(7) + 1, 2); + test_u16(2u16.pow(14) - 1, 2); + test_u16(2u16.pow(14), 3); + test_u16(2u16.pow(14) + 1, 3); + test_u16(u16::max_value() - 2, 3); + test_u16(u16::max_value() - 1, 3); + test_u16(u16::max_value(), 3); +} + +#[test] +fn lab128_u32_test() { + test_u32(0, 1); + test_u32(16, 1); + test_u32(2u32.pow(7) - 1, 1); + test_u32(2u32.pow(7), 2); + test_u32(2u32.pow(7) + 1, 2); + test_u32(2u32.pow(14) - 1, 2); + test_u32(2u32.pow(14), 3); + test_u32(2u32.pow(14) + 1, 3); + test_u32(2u32.pow(21) - 1, 3); + test_u32(2u32.pow(21), 4); + test_u32(2u32.pow(21) + 1, 4); + test_u32(2u32.pow(28) - 1, 4); + test_u32(2u32.pow(28), 5); + test_u32(2u32.pow(28) + 1, 5); + test_u32(u32::max_value() - 2, 5); + test_u32(u32::max_value() - 1, 5); + test_u32(u32::max_value(), 5); +} + +#[test] +fn lab128_malformed_test() { + assert!(read_uleb128_as_u16(&mut Cursor::new(&[])).is_err()); + assert!(read_uleb128_as_u16(&mut Cursor::new(&[0x80, 0x80])).is_err()); + assert!(read_uleb128_as_u16(&mut Cursor::new(&[0x80])).is_err()); + assert!(read_uleb128_as_u16(&mut Cursor::new(&[0x80, 0x80])).is_err()); + assert!(read_uleb128_as_u16(&mut Cursor::new(&[0x80, 0x80, 0x80, 0x80])).is_err()); + assert!(read_uleb128_as_u16(&mut Cursor::new(&[0x80, 0x80, 0x80, 0x2])).is_err()); + + assert!(read_uleb128_as_u32(&mut Cursor::new(&[])).is_err()); + assert!(read_uleb128_as_u32(&mut Cursor::new(&[0x80, 0x80])).is_err()); + assert!(read_uleb128_as_u32(&mut Cursor::new(&[0x80])).is_err()); + assert!(read_uleb128_as_u32(&mut Cursor::new(&[0x80, 0x80])).is_err()); + assert!(read_uleb128_as_u32(&mut Cursor::new(&[0x80, 0x80, 0x80, 0x80])).is_err()); + assert!(read_uleb128_as_u32(&mut Cursor::new(&[0x80, 0x80, 0x80, 0x80, 0x80, 0x2])).is_err()); +} + +proptest! { + #[test] + fn u16_uleb128_roundtrip(input in any::()) { + let mut serialized = vec![]; + write_u16_as_uleb128(&mut serialized, input); + let mut cursor = Cursor::new(&serialized[..]); + let output = read_uleb128_as_u16(&mut cursor).expect("deserialization should work"); + prop_assert_eq!(input, output); + } + + #[test] + fn u32_uleb128_roundtrip(input in any::()) { + let mut serialized = vec![]; + write_u32_as_uleb128(&mut serialized, input); + let mut cursor = Cursor::new(&serialized[..]); + let output = read_uleb128_as_u32(&mut cursor).expect("deserialization should work"); + prop_assert_eq!(input, output); + } + + #[test] + fn u16_roundtrip(input in any::()) { + let mut serialized = vec![]; + write_u16(&mut serialized, input); + let mut cursor = Cursor::new(&serialized[..]); + let output = cursor.read_u16::().expect("deserialization should work"); + prop_assert_eq!(input, output); + } + + #[test] + fn u32_roundtrip(input in any::()) { + let mut serialized = vec![]; + write_u32(&mut serialized, input); + let mut cursor = Cursor::new(&serialized[..]); + let output = cursor.read_u32::().expect("deserialization should work"); + prop_assert_eq!(input, output); + } + + #[test] + fn u64_roundtrip(input in any::()) { + let mut serialized = vec![]; + write_u64(&mut serialized, input); + let mut cursor = Cursor::new(&serialized[..]); + let output = cursor.read_u64::().expect("deserialization should work"); + prop_assert_eq!(input, output); + } +} diff --git a/language/vm/src/views.rs b/language/vm/src/views.rs new file mode 100644 index 0000000000000..c59f09803e737 --- /dev/null +++ b/language/vm/src/views.rs @@ -0,0 +1,542 @@ +// Copyright (c) The Libra Core Contributors +// SPDX-License-Identifier: Apache-2.0 + +//! An alternate representation of the file format built on top of the existing format. +//! +//! Some general notes: +//! +//! * These views are not meant to be set in stone. Feel free to change the views exposed as the +//! format and our understanding evolves. +//! * The typical use for these views would be to materialize all the lazily evaluated data +//! immediately -- the views are a convenience to make that simpler. They've been written as lazy +//! iterators to aid understanding of the file format and to make it easy to generate views. + +use std::iter::DoubleEndedIterator; + +use crate::{ + access::ModuleAccess, + file_format::{ + CodeUnit, FieldDefinition, FunctionDefinition, FunctionHandle, FunctionSignature, + LocalIndex, LocalsSignature, ModuleHandle, SignatureToken, StructDefinition, StructHandle, + StructHandleIndex, TypeSignature, + }, + SignatureTokenKind, +}; + +use types::language_storage::CodeKey; + +use std::collections::BTreeMap; + +/// Represents a lazily evaluated abstraction over a module. +/// +/// `T` here is any sort of `ModuleAccess`. See the documentation in access.rs for more. +pub struct ModuleView<'a, T> { + module: &'a T, + name_to_function_definition_view: BTreeMap<&'a str, FunctionDefinitionView<'a, T>>, + name_to_struct_definition_view: BTreeMap<&'a str, StructDefinitionView<'a, T>>, +} + +impl<'a, T: ModuleAccess> ModuleView<'a, T> { + pub fn new(module: &'a T) -> Self { + let mut name_to_function_definition_view = BTreeMap::new(); + for function_def in module.function_defs() { + let view = FunctionDefinitionView::new(module, function_def); + name_to_function_definition_view.insert(view.name(), view); + } + let mut name_to_struct_definition_view = BTreeMap::new(); + for struct_def in module.struct_defs() { + let view = StructDefinitionView::new(module, struct_def); + name_to_struct_definition_view.insert(view.name(), view); + } + Self { + module, + name_to_function_definition_view, + name_to_struct_definition_view, + } + } + + pub fn module_handles( + &self, + ) -> impl DoubleEndedIterator> + Send { + let module = self.module; + module + .module_handles() + .map(move |module_handle| ModuleHandleView::new(module, module_handle)) + } + + pub fn struct_handles( + &self, + ) -> impl DoubleEndedIterator> + Send { + let module = self.module; + module + .struct_handles() + .map(move |struct_handle| StructHandleView::new(module, struct_handle)) + } + + pub fn function_handles( + &self, + ) -> impl DoubleEndedIterator> + Send { + let module = self.module; + module + .function_handles() + .map(move |function_handle| FunctionHandleView::new(module, function_handle)) + } + + pub fn structs(&self) -> impl DoubleEndedIterator> + Send { + let module = self.module; + module + .struct_defs() + .map(move |struct_def| StructDefinitionView::new(module, struct_def)) + } + + pub fn fields(&self) -> impl DoubleEndedIterator> + Send { + let module = self.module; + module + .field_defs() + .map(move |field_def| FieldDefinitionView::new(module, field_def)) + } + + pub fn functions( + &self, + ) -> impl DoubleEndedIterator> + Send { + let module = self.module; + module + .function_defs() + .map(move |function_def| FunctionDefinitionView::new(module, function_def)) + } + + pub fn type_signatures( + &self, + ) -> impl DoubleEndedIterator> + Send { + let module = self.module; + module + .type_signatures() + .map(move |type_signature| TypeSignatureView::new(module, type_signature)) + } + + pub fn function_signatures( + &self, + ) -> impl DoubleEndedIterator> + Send { + let module = self.module; + module + .function_signatures() + .map(move |function_signature| FunctionSignatureView::new(module, function_signature)) + } + + pub fn locals_signatures( + &self, + ) -> impl DoubleEndedIterator> + Send { + let module = self.module; + module + .locals_signatures() + .map(move |locals_signature| LocalsSignatureView::new(module, locals_signature)) + } + + pub fn function_definition(&self, name: &'a str) -> Option<&FunctionDefinitionView<'a, T>> { + self.name_to_function_definition_view.get(name) + } + + pub fn struct_definition(&self, name: &'a str) -> Option<&StructDefinitionView<'a, T>> { + self.name_to_struct_definition_view.get(name) + } +} + +pub struct ModuleHandleView<'a, T> { + module: &'a T, + module_handle: &'a ModuleHandle, +} + +impl<'a, T: ModuleAccess> ModuleHandleView<'a, T> { + pub fn new(module: &'a T, module_handle: &'a ModuleHandle) -> Self { + Self { + module, + module_handle, + } + } + + pub fn module_code_key(&self) -> CodeKey { + self.module.code_key_for_handle(self.module_handle) + } +} + +pub struct StructHandleView<'a, T> { + module: &'a T, + struct_handle: &'a StructHandle, +} + +impl<'a, T: ModuleAccess> StructHandleView<'a, T> { + pub fn new(module: &'a T, struct_handle: &'a StructHandle) -> Self { + Self { + module, + struct_handle, + } + } + + pub fn is_resource(&self) -> bool { + self.struct_handle.is_resource + } + + pub fn definition(&self) -> StructDefinitionView<'a, T> { + unimplemented!("this requires linking") + } + + pub fn module_handle(&self) -> &ModuleHandle { + self.module.module_handle_at(self.struct_handle.module) + } + + pub fn name(&self) -> &'a str { + self.module.string_at(self.struct_handle.name) + } + + pub fn module_code_key(&self) -> CodeKey { + self.module.code_key_for_handle(self.module_handle()) + } +} + +pub struct FunctionHandleView<'a, T> { + module: &'a T, + function_handle: &'a FunctionHandle, +} + +impl<'a, T: ModuleAccess> FunctionHandleView<'a, T> { + pub fn new(module: &'a T, function_handle: &'a FunctionHandle) -> Self { + Self { + module, + function_handle, + } + } + + pub fn module_handle(&self) -> &ModuleHandle { + self.module.module_handle_at(self.function_handle.module) + } + + pub fn name(&self) -> &'a str { + self.module.string_at(self.function_handle.name) + } + + pub fn signature(&self) -> FunctionSignatureView<'a, T> { + let function_signature = self + .module + .function_signature_at(self.function_handle.signature); + FunctionSignatureView::new(self.module, function_signature) + } + + pub fn module_code_key(&self) -> CodeKey { + self.module.code_key_for_handle(self.module_handle()) + } +} + +pub struct StructDefinitionView<'a, T> { + module: &'a T, + struct_def: &'a StructDefinition, + struct_handle_view: StructHandleView<'a, T>, +} + +impl<'a, T: ModuleAccess> StructDefinitionView<'a, T> { + pub fn new(module: &'a T, struct_def: &'a StructDefinition) -> Self { + let struct_handle = module.struct_handle_at(struct_def.struct_handle); + let struct_handle_view = StructHandleView::new(module, struct_handle); + Self { + module, + struct_def, + struct_handle_view, + } + } + + pub fn is_resource(&self) -> bool { + self.struct_handle_view.is_resource() + } + + pub fn fields(&self) -> impl DoubleEndedIterator> + Send { + let module = self.module; + module + .field_def_range(self.struct_def.field_count, self.struct_def.fields) + .map(move |field_def| FieldDefinitionView::new(module, field_def)) + } + + pub fn name(&self) -> &'a str { + self.struct_handle_view.name() + } +} + +pub struct FieldDefinitionView<'a, T> { + module: &'a T, + field_def: &'a FieldDefinition, +} + +impl<'a, T: ModuleAccess> FieldDefinitionView<'a, T> { + pub fn new(module: &'a T, field_def: &'a FieldDefinition) -> Self { + Self { module, field_def } + } + + pub fn name(&self) -> &'a str { + self.module.string_at(self.field_def.name) + } + + pub fn type_signature(&self) -> TypeSignatureView<'a, T> { + let type_signature = self.module.type_signature_at(self.field_def.signature); + TypeSignatureView::new(self.module, type_signature) + } + + // Field definitions are always private. + + /// The struct this field is defined in. + pub fn member_of(&self) -> StructHandleView<'a, T> { + let struct_handle = self.module.struct_handle_at(self.field_def.struct_); + StructHandleView::new(self.module, struct_handle) + } +} + +pub struct FunctionDefinitionView<'a, T> { + module: &'a T, + function_def: &'a FunctionDefinition, + function_handle_view: FunctionHandleView<'a, T>, +} + +impl<'a, T: ModuleAccess> FunctionDefinitionView<'a, T> { + pub fn new(module: &'a T, function_def: &'a FunctionDefinition) -> Self { + let function_handle = module.function_handle_at(function_def.function); + let function_handle_view = FunctionHandleView::new(module, function_handle); + Self { + module, + function_def, + function_handle_view, + } + } + + pub fn is_public(&self) -> bool { + self.function_def.is_public() + } + + pub fn is_native(&self) -> bool { + self.function_def.is_native() + } + + pub fn locals_signature(&self) -> LocalsSignatureView<'a, T> { + let locals_signature = self + .module + .locals_signature_at(self.function_def.code.locals); + LocalsSignatureView::new(self.module, locals_signature) + } + + pub fn name(&self) -> &'a str { + self.function_handle_view.name() + } + + pub fn signature(&self) -> FunctionSignatureView<'a, T> { + self.function_handle_view.signature() + } + + pub fn code(&self) -> &'a CodeUnit { + &self.function_def.code + } +} + +pub struct TypeSignatureView<'a, T> { + module: &'a T, + type_signature: &'a TypeSignature, +} + +impl<'a, T: ModuleAccess> TypeSignatureView<'a, T> { + #[inline] + pub fn new(module: &'a T, type_signature: &'a TypeSignature) -> Self { + Self { + module, + type_signature, + } + } + + #[inline] + pub fn token(&self) -> SignatureTokenView<'a, T> { + SignatureTokenView::new(self.module, &self.type_signature.0) + } + + #[inline] + pub fn is_resource(&self) -> bool { + self.token().is_resource() + } +} + +pub struct FunctionSignatureView<'a, T> { + module: &'a T, + function_signature: &'a FunctionSignature, +} + +impl<'a, T: ModuleAccess> FunctionSignatureView<'a, T> { + #[inline] + pub fn new(module: &'a T, function_signature: &'a FunctionSignature) -> Self { + Self { + module, + function_signature, + } + } + + #[inline] + pub fn return_tokens(&self) -> impl DoubleEndedIterator> + 'a { + let module = self.module; + self.function_signature + .return_types + .iter() + .map(move |token| SignatureTokenView::new(module, token)) + } + + #[inline] + pub fn arg_tokens(&self) -> impl DoubleEndedIterator> + 'a { + let module = self.module; + self.function_signature + .arg_types + .iter() + .map(move |token| SignatureTokenView::new(module, token)) + } + + pub fn return_count(&self) -> usize { + self.function_signature.return_types.len() + } + + pub fn arg_count(&self) -> usize { + self.function_signature.arg_types.len() + } +} + +pub struct LocalsSignatureView<'a, T> { + module: &'a T, + locals_signature: &'a LocalsSignature, +} + +impl<'a, T: ModuleAccess> LocalsSignatureView<'a, T> { + #[inline] + pub fn new(module: &'a T, locals_signature: &'a LocalsSignature) -> Self { + Self { + module, + locals_signature, + } + } + + #[inline] + pub fn len(&self) -> usize { + self.locals_signature.0.len() + } + + #[inline] + pub fn is_empty(&self) -> bool { + self.len() == 0 + } + + #[inline] + pub fn tokens(&self) -> impl DoubleEndedIterator> + 'a { + let module = self.module; + self.locals_signature + .0 + .iter() + .map(move |token| SignatureTokenView::new(module, token)) + } + + pub fn token_at(&self, index: LocalIndex) -> SignatureTokenView<'a, T> { + SignatureTokenView::new(self.module, &self.locals_signature.0[index as usize]) + } +} + +pub struct SignatureTokenView<'a, T> { + module: &'a T, + token: &'a SignatureToken, +} + +impl<'a, T: ModuleAccess> SignatureTokenView<'a, T> { + #[inline] + pub fn new(module: &'a T, token: &'a SignatureToken) -> Self { + Self { module, token } + } + + #[inline] + pub fn struct_handle(&self) -> Option> { + self.struct_index() + .map(|sh_idx| StructHandleView::new(self.module, self.module.struct_handle_at(sh_idx))) + } + + #[inline] + pub fn kind(&self) -> SignatureTokenKind { + self.token.kind() + } + + #[inline] + pub fn is_resource(&self) -> bool { + match self.token { + SignatureToken::Struct(sh_idx) => self.module.struct_handle_at(*sh_idx).is_resource, + SignatureToken::Reference(_) + | SignatureToken::MutableReference(_) + | SignatureToken::Bool + | SignatureToken::U64 + | SignatureToken::String + | SignatureToken::ByteArray + | SignatureToken::Address => false, + } + } + + #[inline] + pub fn is_reference(&self) -> bool { + self.token.is_reference() + } + + #[inline] + pub fn is_mutable_reference(&self) -> bool { + self.token.is_mutable_reference() + } + + #[inline] + pub fn struct_index(&self) -> Option { + self.token.struct_index() + } +} + +/// This is used to expose some view internals to checks and other areas. This might be exposed +/// to external code in the future. +pub trait ViewInternals { + type ModuleType; + type Inner; + + fn module(&self) -> Self::ModuleType; + fn as_inner(&self) -> Self::Inner; +} + +macro_rules! impl_view_internals { + ($view_type:ident, $inner_type:ty, $inner_var:ident) => { + impl<'a, T: ModuleAccess> ViewInternals for $view_type<'a, T> { + type ModuleType = &'a T; + type Inner = &'a $inner_type; + + #[inline] + fn module(&self) -> Self::ModuleType { + &self.module + } + + #[inline] + fn as_inner(&self) -> Self::Inner { + &self.$inner_var + } + } + }; +} + +impl<'a, T: ModuleAccess> ViewInternals for ModuleView<'a, T> { + type ModuleType = &'a T; + type Inner = &'a T; + + fn module(&self) -> Self::ModuleType { + self.module + } + + fn as_inner(&self) -> Self::Inner { + self.module + } +} + +impl_view_internals!(ModuleHandleView, ModuleHandle, module_handle); +impl_view_internals!(StructHandleView, StructHandle, struct_handle); +impl_view_internals!(FunctionHandleView, FunctionHandle, function_handle); +impl_view_internals!(StructDefinitionView, StructDefinition, struct_def); +impl_view_internals!(FunctionDefinitionView, FunctionDefinition, function_def); +impl_view_internals!(FieldDefinitionView, FieldDefinition, field_def); +impl_view_internals!(TypeSignatureView, TypeSignature, type_signature); +impl_view_internals!(FunctionSignatureView, FunctionSignature, function_signature); +impl_view_internals!(LocalsSignatureView, LocalsSignature, locals_signature); +impl_view_internals!(SignatureTokenView, SignatureToken, token); diff --git a/language/vm/tests/serializer_tests.proptest-regressions b/language/vm/tests/serializer_tests.proptest-regressions new file mode 100644 index 0000000000000..2d82fd56459df --- /dev/null +++ b/language/vm/tests/serializer_tests.proptest-regressions @@ -0,0 +1,9 @@ +# Seeds for failure cases proptest has generated in the past. It is +# automatically read and these particular cases re-run before any +# novel cases are generated. +# +# It is recommended to check this file in to source control so that +# everyone who runs the test benefits from these saved cases. +cc 3d00be1cbcb9f344e7bce080015d7a755f6e69acd37dbde62e449732af226fe4 # shrinks to module = CompiledModule: { module_handles: [ ModuleHandle { address: AddressPoolIndex(0), name: StringPoolIndex(0) },] struct_handles: [ StructHandle { module: ModuleHandleIndex(0), name: StringPoolIndex(0), is_resource: false }, StructHandle { module: ModuleHandleIndex(0), name: StringPoolIndex(0), is_resource: false },] function_handles: [ FunctionHandle { module: ModuleHandleIndex(0), name: StringPoolIndex(0), signature: FunctionSignatureIndex(0) }, FunctionHandle { module: ModuleHandleIndex(0), name: StringPoolIndex(0), signature: FunctionSignatureIndex(1) },] struct_defs: [ StructDefinition { struct_handle: 1, field_count: 0, fields: 0 },] field_defs: [] function_defs: [ FunctionDefinition { function: 1, flags: 0x0, code: CodeUnit { max_stack_size: 0, locals: 0 code: [] } },] type_signatures: [ TypeSignature(Unit), TypeSignature(Unit),] function_signatures: [ FunctionSignature { return_type: Unit, arg_types: [] }, FunctionSignature { return_type: Unit, arg_types: [] },] locals_signatures: [ LocalsSignature([]),] string_pool: [ "",] address_pool: [ Address([0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0]),] } +cc 7b1bb969b87bfcdbb0f635eb46212f8437d21bcd1ba754de84d66bb552e6aec2 # shrinks to module = CompiledModule { module_handles: [], struct_handles: [], function_handles: [], type_signatures: [], function_signatures: [], locals_signatures: [], string_pool: [], byte_array_pool: [], address_pool: [0000000000000000000000000000000000000000000000000000000000000000, 0000000000000000000000000000000000000000000000000000000000000000], struct_defs: [], field_defs: [], function_defs: [] } +cc 4118fc247fb7d48382876de931d47a8999a6e42658bbecc93afff9245ade141b # shrinks to module = CompiledModule { module_handles: [], struct_handles: [], function_handles: [], type_signatures: [], function_signatures: [], locals_signatures: [], string_pool: [], byte_array_pool: [], address_pool: [], struct_defs: [], field_defs: [FieldDefinition { struct_: StructHandleIndex(0), name: StringPoolIndex(0), signature: TypeSignatureIndex(0) }], function_defs: [] } diff --git a/language/vm/tests/serializer_tests.rs b/language/vm/tests/serializer_tests.rs new file mode 100644 index 0000000000000..1059c54d2a45f --- /dev/null +++ b/language/vm/tests/serializer_tests.rs @@ -0,0 +1,34 @@ +// Copyright (c) The Libra Core Contributors +// SPDX-License-Identifier: Apache-2.0 + +use proptest::prelude::*; +use vm::file_format::CompiledModule; + +proptest! { + #[test] + fn serializer_roundtrip(module in CompiledModule::valid_strategy(20)) { + let mut serialized = Vec::with_capacity(2048); + module.serialize(&mut serialized).expect("serialization should work"); + + let deserialized_module = CompiledModule::deserialize(&serialized) + .expect("deserialization should work"); + prop_assert_eq!(module, deserialized_module); + } +} + +proptest! { + // Generating arbitrary compiled modules is really slow, possibly because of + // https://github.com/AltSysrq/proptest/issues/143. + #![proptest_config(ProptestConfig::with_cases(16))] + + /// Make sure that garbage inputs don't crash the serializer and deserializer. + #[test] + fn garbage_inputs(module in any_with::(16)) { + let mut serialized = Vec::with_capacity(65536); + module.serialize(&mut serialized).expect("serialization should work"); + + let deserialized_module = CompiledModule::deserialize_no_check_bounds(&serialized) + .expect("deserialization should work"); + prop_assert_eq!(module, deserialized_module); + } +} diff --git a/language/vm/vm_genesis/Cargo.toml b/language/vm/vm_genesis/Cargo.toml new file mode 100644 index 0000000000000..28d8dc1e6b681 --- /dev/null +++ b/language/vm/vm_genesis/Cargo.toml @@ -0,0 +1,31 @@ +[package] +name = "vm_genesis" +version = "0.1.0" +edition = "2018" +authors = ["Libra Association "] +license = "Apache-2.0" +publish = false + +[dependencies] +config = { path = "../../../config" } +crypto = { path = "../../../crypto/legacy_crypto" } +failure = { path = "../../../common/failure_ext", package = "failure_ext" } +compiler = { path = "../../compiler"} +stdlib = { path = "../../stdlib" } +proto_conv = { path = "../../../common/proto_conv", features = ["derive"] } +state_view = { path = "../../../storage/state_view" } +types = { path = "../../../types" } +vm = { path = "../" } +vm_cache_map = { path = "../vm_runtime/vm_cache_map"} +vm_runtime = { path = "../vm_runtime" } +hex = "0.3.2" +lazy_static = "1.3.0" +rand = "0.6.5" +tiny-keccak = "1.4.2" +toml = "0.4" + +[dev-dependencies] +canonical_serialization = { path = "../../../common/canonical_serialization" } +proptest = "0.9.3" +proptest-derive = "0.1.1" +proptest_helpers = { path = "../../../common/proptest_helpers" } diff --git a/language/vm/vm_genesis/genesis/genesis.blob b/language/vm/vm_genesis/genesis/genesis.blob new file mode 100644 index 0000000000000..90126abea58ac Binary files /dev/null and b/language/vm/vm_genesis/genesis/genesis.blob differ diff --git a/language/vm/vm_genesis/genesis/vm_config.toml b/language/vm/vm_genesis/genesis/vm_config.toml new file mode 100644 index 0000000000000..7c1bdedb5be20 --- /dev/null +++ b/language/vm/vm_genesis/genesis/vm_config.toml @@ -0,0 +1,3 @@ +[publishing_options] +type = "Locked" +whitelist = ["d3493756a00b7a9e4d9ca8482e80fd055411ce53882bdcb08fec97d42eef0bde", "88c0c64595f6cec7d0c0bfe29e1be1886c736ec3d26888d049e30909f7a72836", "2bb3828f55bc640a85b17d9c6e120e84f8c068c9fd850e1a1d61d2f91ed295fd", "ee31d65b559ad5a300e6a508ff3edb2d23f1589ef68d0ead124d8f0374073d84"] diff --git a/language/vm/vm_genesis/src/lib.rs b/language/vm/vm_genesis/src/lib.rs new file mode 100644 index 0000000000000..9c678fdebd265 --- /dev/null +++ b/language/vm/vm_genesis/src/lib.rs @@ -0,0 +1,444 @@ +// Copyright (c) The Libra Core Contributors +// SPDX-License-Identifier: Apache-2.0 + +use ::compiler::{compiler, parser::ast}; +use config::config::{VMConfig, VMPublishingOption}; +use crypto::{signing, PrivateKey, PublicKey}; +use failure::prelude::*; +use lazy_static::lazy_static; +use rand::{rngs::StdRng, SeedableRng}; +use state_view::StateView; +use std::{collections::HashSet, iter::FromIterator, time::Duration}; +use stdlib::{ + stdlib::*, + transaction_scripts::{ + CREATE_ACCOUNT_TXN_BODY, MINT_TXN_BODY, PEER_TO_PEER_TRANSFER_TXN_BODY, + ROTATE_AUTHENTICATION_KEY_TXN_BODY, + }, +}; +use tiny_keccak::Keccak; +use types::{ + access_path::AccessPath, + account_address::AccountAddress, + account_config, + byte_array::ByteArray, + language_storage::CodeKey, + transaction::{ + Program, RawTransaction, SignedTransaction, TransactionArgument, SCRIPT_HASH_LENGTH, + }, + validator_public_keys::ValidatorPublicKeys, +}; +use vm::{file_format::CompiledModule, transaction_metadata::TransactionMetadata}; +use vm_cache_map::Arena; +use vm_runtime::{ + code_cache::{ + module_adapter::FakeFetcher, + module_cache::{BlockModuleCache, VMModuleCache}, + }, + data_cache::BlockDataCache, + txn_executor::{TransactionExecutor, ACCOUNT_MODULE, COIN_MODULE}, + value::Local, +}; + +#[cfg(test)] +mod tests; + +// The seed is arbitrarily picked to produce a consistent key. XXX make this more formal? +const GENESIS_SEED: [u8; 32] = [42; 32]; +// Max size of the validator set +const VALIDATOR_SIZE_LIMIT: usize = 10; + +lazy_static! { + pub static ref GENESIS_KEYPAIR: (PrivateKey, PublicKey) = { + let mut rng: StdRng = SeedableRng::from_seed(GENESIS_SEED); + signing::generate_keypair_for_testing(&mut rng) + }; +} + +pub fn sign_genesis_transaction(raw_txn: RawTransaction) -> Result { + let (private_key, public_key) = &*GENESIS_KEYPAIR; + raw_txn.sign(private_key, *public_key) +} + +#[derive(Debug, Clone)] +pub struct Account { + pub addr: AccountAddress, + pub privkey: PrivateKey, + pub pubkey: PublicKey, +} + +impl Account { + pub fn new(rng: &mut StdRng) -> Self { + let (privkey, pubkey) = crypto::signing::generate_keypair_for_testing(rng); + let addr = pubkey.into(); + Account { + addr, + privkey, + pubkey, + } + } +} + +pub struct Accounts { + accounts: Vec, + pub randomness_source: StdRng, +} + +impl Default for Accounts { + fn default() -> Self { + let mut accounts = Accounts::empty(); + for _i in 0..Self::NUMBER_OF_ACCOUNTS { + accounts.new_account(); + } + accounts + } +} + +impl Accounts { + const NUMBER_OF_ACCOUNTS: i64 = 10; + + pub fn empty() -> Self { + let mut seed = [0u8; 32]; + seed[..4].copy_from_slice(&[1, 2, 3, 4]); + let rng: StdRng = StdRng::from_seed(seed); + Accounts { + accounts: vec![], + randomness_source: rng, + } + } + + pub fn fresh_account(&mut self) -> Account { + Account::new(&mut self.randomness_source) + } + + pub fn new_account(&mut self) -> usize { + self.accounts + .push(Account::new(&mut self.randomness_source)); + self.accounts.len() - 1 + } + + pub fn get_address(&self, account: usize) -> AccountAddress { + self.accounts[account].addr + } + + pub fn get_account(&self, account: usize) -> Account { + self.accounts[account].clone() + } + + pub fn get_public_key(&self, account: usize) -> PublicKey { + self.accounts[account].pubkey + } + + pub fn create_signed_txn_with_args( + &self, + program: Vec, + args: Vec, + sender: AccountAddress, + sender_account: Account, + sequence_number: u64, + max_gas_amount: u64, + gas_unit_price: u64, + ) -> SignedTransaction { + RawTransaction::new( + sender, + sequence_number, + Program::new(program, vec![], args), + max_gas_amount, + gas_unit_price, + Duration::from_secs(u64::max_value()), + ) + .sign(&sender_account.privkey, sender_account.pubkey) + .unwrap() + } + + pub fn get_addresses(&self) -> Vec { + self.accounts.iter().map(|account| account.addr).collect() + } + + pub fn accounts(&self) -> &[Account] { + &self.accounts + } +} + +lazy_static! { + pub static ref STDLIB_ADDRESS: AccountAddress = { account_config::core_code_address() }; + pub static ref STDLIB_MODULES: Vec = { + let mut modules: Vec = vec![]; + let stdlib = vec![coin_module(), native_hash_module(), account_module(), signature_module(), validator_set_module()]; + for m in stdlib.iter() { + let (compiled_module, verification_errors) = + compiler::compile_and_verify_module(&STDLIB_ADDRESS, m, &modules).unwrap(); + + // Fail if the module doesn't verify + for e in &verification_errors { + println!("{:?}", e); + } + assert!(verification_errors.is_empty()); + + modules.push(compiled_module); + } + modules + }; + static ref PEER_TO_PEER_TXN: Vec = { compile_script(&PEER_TO_PEER_TRANSFER_TXN_BODY) }; + static ref CREATE_ACCOUNT_TXN: Vec = { compile_script(&CREATE_ACCOUNT_TXN_BODY) }; + static ref ROTATE_AUTHENTICATION_KEY_TXN: Vec = + { compile_script(&ROTATE_AUTHENTICATION_KEY_TXN_BODY) }; + static ref MINT_TXN: Vec = { compile_script(&MINT_TXN_BODY) }; + static ref GENESIS_ACCOUNT: Accounts = { + let mut account = Accounts::empty(); + account.new_account(); + account + }; +} + +fn compile_script(body: &ast::Program) -> Vec { + let compiled_program = + compiler::compile_program(&AccountAddress::default(), body, &STDLIB_MODULES.clone()) + .unwrap(); + let mut script_bytes = vec![]; + compiled_program + .script + .serialize(&mut script_bytes) + .unwrap(); + script_bytes +} + +/// Encode a program transferring `amount` coins from `sender` to `recipient`. Fails if there is no +/// account at the recipient address or if the sender's balance is lower than `amount`. +pub fn encode_transfer_program(recipient: &AccountAddress, amount: u64) -> Program { + Program::new( + PEER_TO_PEER_TXN.clone(), + vec![], + vec![ + TransactionArgument::Address(*recipient), + TransactionArgument::U64(amount), + ], + ) +} + +/// Encode a program creating a fresh account at `account_address` with `initial_balance` coins +/// transferred from the sender's account balance. Fails if there is already an account at +/// `account_address` or if the sender's balance is lower than `initial_balance`. +pub fn encode_create_account_program( + account_address: &AccountAddress, + initial_balance: u64, +) -> Program { + Program::new( + CREATE_ACCOUNT_TXN.clone(), + vec![], + vec![ + TransactionArgument::Address(*account_address), + TransactionArgument::U64(initial_balance), + ], + ) +} + +/// Encode a program that rotates the sender's authentication key to `new_key`. +pub fn rotate_authentication_key_program(new_key: AccountAddress) -> Program { + Program::new( + ROTATE_AUTHENTICATION_KEY_TXN.clone(), + vec![], + vec![TransactionArgument::ByteArray(ByteArray::new( + new_key.as_ref().to_vec(), + ))], + ) +} + +// TODO: this should go away once we are no longer using it in tests +/// Encode a program creating `amount` coins for sender +pub fn encode_mint_program(sender: &AccountAddress, amount: u64) -> Program { + Program::new( + MINT_TXN.clone(), + vec![], + vec![ + TransactionArgument::Address(*sender), + TransactionArgument::U64(amount), + ], + ) +} + +/// Returns a user friendly mnemonic for the transaction type if the transaction is +/// for a known, white listed, transaction. +pub fn get_transaction_name(code: &[u8]) -> String { + if code == &PEER_TO_PEER_TXN[..] { + return "peer_to_peer_transaction".to_string(); + } else if code == &CREATE_ACCOUNT_TXN[..] { + return "create_account_transaction".to_string(); + } else if code == &MINT_TXN[..] { + return "mint_transaction".to_string(); + } else if code == &ROTATE_AUTHENTICATION_KEY_TXN[..] { + return "rotate_authentication_key_transaction".to_string(); + } + "".to_string() +} + +pub fn allowing_script_hashes() -> Vec<[u8; SCRIPT_HASH_LENGTH]> { + vec![ + MINT_TXN.clone(), + PEER_TO_PEER_TXN.clone(), + ROTATE_AUTHENTICATION_KEY_TXN.clone(), + CREATE_ACCOUNT_TXN.clone(), + ] + .into_iter() + .map(|s| { + let mut hash = [0u8; SCRIPT_HASH_LENGTH]; + let mut keccak = Keccak::new_sha3_256(); + + keccak.update(&s); + keccak.finalize(&mut hash); + hash + }) + .collect() +} + +pub fn default_config() -> VMConfig { + VMConfig { + publishing_options: VMPublishingOption::Locked(HashSet::from_iter( + allowing_script_hashes().into_iter(), + )), + } +} + +struct FakeStateView; + +impl StateView for FakeStateView { + fn get(&self, _access_path: &AccessPath) -> Result>> { + Ok(None) + } + + fn multi_get(&self, _access_paths: &[AccessPath]) -> Result>>> { + unimplemented!() + } + + fn is_genesis(&self) -> bool { + true + } +} + +pub fn encode_genesis_transaction( + private_key: &PrivateKey, + public_key: PublicKey, +) -> SignedTransaction { + encode_genesis_transaction_with_validator(private_key, public_key, vec![]) +} + +pub fn encode_genesis_transaction_with_validator( + private_key: &PrivateKey, + public_key: PublicKey, + validator_set: Vec, +) -> SignedTransaction { + assert!(validator_set.len() <= VALIDATOR_SIZE_LIMIT); + const INIT_BALANCE: u64 = 1_000_000_000; + + // Compile the needed stdlib modules. + let modules = STDLIB_MODULES.clone(); + let arena = Arena::new(); + let state_view = FakeStateView; + let vm_cache = VMModuleCache::new(&arena); + let genesis_addr = account_config::association_address(); + let genesis_auth_key = ByteArray::new(AccountAddress::from(public_key).to_vec()); + + let genesis_write_set = { + let fake_fetcher = FakeFetcher::new(modules.clone()); + let data_cache = BlockDataCache::new(&state_view); + let block_cache = BlockModuleCache::new(&vm_cache, fake_fetcher); + { + let mut txn_data = TransactionMetadata::default(); + txn_data.sender = genesis_addr; + let validator_set_key = CodeKey::new( + account_config::core_code_address(), + "ValidatorSet".to_string(), + ); + + let mut txn_executor = TransactionExecutor::new(&block_cache, &data_cache, txn_data); + txn_executor.create_account(genesis_addr).unwrap().unwrap(); + txn_executor + .execute_function(&COIN_MODULE, "grant_mint_capability", vec![]) + .unwrap() + .unwrap(); + + txn_executor + .execute_function( + &ACCOUNT_MODULE, + "mint_to_address", + vec![Local::address(genesis_addr), Local::u64(INIT_BALANCE)], + ) + .unwrap() + .unwrap(); + + txn_executor + .execute_function( + &ACCOUNT_MODULE, + "rotate_authentication_key", + vec![Local::bytearray(genesis_auth_key)], + ) + .unwrap() + .unwrap(); + + let mut validator_args = vec![Local::u64(validator_set.len() as u64)]; + for key in validator_set.iter() { + txn_executor + .execute_function( + &validator_set_key, + "make_new_validator_key", + vec![ + Local::address(*key.account_address()), + Local::bytearray(ByteArray::new( + key.consensus_public_key().to_slice().to_vec(), + )), + Local::bytearray(ByteArray::new( + key.network_signing_public_key().to_slice().to_vec(), + )), + Local::bytearray(ByteArray::new( + key.network_identity_public_key().to_slice().to_vec(), + )), + ], + ) + .unwrap() + .unwrap(); + validator_args.push(txn_executor.pop_stack().unwrap()); + } + let placeholder = { + txn_executor + .execute_function( + &validator_set_key, + "make_new_validator_key", + vec![ + Local::address(AccountAddress::default()), + Local::bytearray(ByteArray::new(vec![])), + Local::bytearray(ByteArray::new(vec![])), + Local::bytearray(ByteArray::new(vec![])), + ], + ) + .unwrap() + .unwrap(); + txn_executor.pop_stack().unwrap() + }; + validator_args.resize(VALIDATOR_SIZE_LIMIT + 1, placeholder); + + txn_executor + .execute_function(&validator_set_key, "publish_validator_set", validator_args) + .unwrap() + .unwrap(); + + let stdlib_modules = modules + .into_iter() + .map(|m| { + let mut module_vec = vec![]; + m.serialize(&mut module_vec).unwrap(); + (m.self_code_key(), module_vec) + }) + .collect(); + + txn_executor + .make_write_set(stdlib_modules, Ok(Ok(()))) + .unwrap() + .write_set() + .clone() + .into_mut() + } + }; + let transaction = + RawTransaction::new_write_set(genesis_addr, 0, genesis_write_set.freeze().unwrap()); + transaction.sign(private_key, public_key).unwrap() +} diff --git a/language/vm/vm_genesis/src/main.rs b/language/vm/vm_genesis/src/main.rs new file mode 100644 index 0000000000000..40031da371ba9 --- /dev/null +++ b/language/vm/vm_genesis/src/main.rs @@ -0,0 +1,25 @@ +// Copyright (c) The Libra Core Contributors +// SPDX-License-Identifier: Apache-2.0 + +use std::{fs::File, io::prelude::*}; +use vm_genesis::{default_config, encode_genesis_transaction, GENESIS_KEYPAIR}; + +const CONFIG_LOCATION: &str = "genesis/vm_config.toml"; +const GENESIS_LOCATION: &str = "genesis/genesis.blob"; + +use proto_conv::IntoProtoBytes; + +fn main() { + println!( + "Creating genesis binary blob at {} from configuration file {}", + GENESIS_LOCATION, CONFIG_LOCATION + ); + let config = default_config(); + config.save_config(CONFIG_LOCATION); + + // Generate a genesis blob used for vm tests. + let genesis_txn = encode_genesis_transaction(&GENESIS_KEYPAIR.0, GENESIS_KEYPAIR.1); + let mut file = File::create(GENESIS_LOCATION).unwrap(); + file.write_all(&genesis_txn.into_proto_bytes().unwrap()) + .unwrap(); +} diff --git a/language/vm/vm_genesis/src/tests.rs b/language/vm/vm_genesis/src/tests.rs new file mode 100644 index 0000000000000..c635b5b88be74 --- /dev/null +++ b/language/vm/vm_genesis/src/tests.rs @@ -0,0 +1,4 @@ +// Copyright (c) The Libra Core Contributors +// SPDX-License-Identifier: Apache-2.0 + +mod genesis_test; diff --git a/language/vm/vm_genesis/src/tests/genesis_test.rs b/language/vm/vm_genesis/src/tests/genesis_test.rs new file mode 100644 index 0000000000000..becfd6bd67a62 --- /dev/null +++ b/language/vm/vm_genesis/src/tests/genesis_test.rs @@ -0,0 +1,33 @@ +// Copyright (c) The Libra Core Contributors +// SPDX-License-Identifier: Apache-2.0 + +use crate::encode_genesis_transaction_with_validator; +use canonical_serialization::SimpleDeserializer; +use crypto::signing::generate_keypair; +use proptest::{collection::vec, prelude::*}; +use types::{ + access_path::VALIDATOR_SET_ACCESS_PATH, transaction::TransactionPayload, + validator_public_keys::ValidatorPublicKeys, validator_set::ValidatorSet, write_set::WriteOp, +}; + +proptest! { + #[test] + fn test_validator_set_roundtrip(keys in vec(any::(), 0..10)) { + let (priv_key, pub_key) = generate_keypair(); + let writeset = match encode_genesis_transaction_with_validator(&priv_key, pub_key, keys.clone()).payload() { + TransactionPayload::WriteSet(ws) => ws.clone(), + _ => panic!("Unexpected Transaction"), + }; + let (_, validator_entry) = writeset.iter().find( + |(ap, _)| *ap == *VALIDATOR_SET_ACCESS_PATH + ).cloned().unwrap(); + let validator_set_bytes = match validator_entry { + WriteOp::Value(blob) => blob, + _ => panic!("Unexpected WriteOp"), + }; + let validator_set: ValidatorSet = + SimpleDeserializer::deserialize(&validator_set_bytes).unwrap(); + + prop_assert_eq!(validator_set.payload(), keys.as_slice()); + } +} diff --git a/language/vm/vm_runtime/Cargo.toml b/language/vm/vm_runtime/Cargo.toml new file mode 100644 index 0000000000000..5edd679830846 --- /dev/null +++ b/language/vm/vm_runtime/Cargo.toml @@ -0,0 +1,38 @@ +[package] +name = "vm_runtime" +version = "0.1.0" +authors = ["Libra Association "] +license = "Apache-2.0" +publish = false +edition = "2018" + +[dependencies] +rental = "0.5.3" +tiny-keccak = "1.4.2" +proptest = "0.9" + +bytecode_verifier = { path = "../../bytecode_verifier" } +canonical_serialization = { path = "../../../common/canonical_serialization" } +crypto = { path = "../../../crypto/legacy_crypto"} +failure = { path = "../../../common/failure_ext", package = "failure_ext" } +metrics = { path = "../../../common/metrics" } +state_view = { path = "../../../storage/state_view" } +types = { path = "../../../types" } +vm = { path = "../" } +vm_cache_map = { path = "vm_cache_map" } +lazy_static = "1.3.0" +move_ir_natives = { path = "../../stdlib/natives" } +hex = "0.3.2" +config = { path = "../../../config"} +logger = { path = "../../../common/logger" } + +[dev-dependencies] +compiler = { path = "../../compiler"} + +[dependencies.prometheus] +version = "0.4.2" +default-features = false +features = ["nightly", "push"] + +[features] +instruction_synthesis = [] diff --git a/language/vm/vm_runtime/README.md b/language/vm/vm_runtime/README.md new file mode 100644 index 0000000000000..59a4d9e45876f --- /dev/null +++ b/language/vm/vm_runtime/README.md @@ -0,0 +1,107 @@ +--- +id: vm-runtime +title: MoveVM Runtime +custom_edit_url: https://github.com/libra/libra/edit/master/language/vm/vm_runtime/README.md +--- + +# MoveVM Runtime + +The MoveVM runtime is the verification and execution engine for the Move +bytecode format. The runtime is imported and loaded in 2 modes: +verification mode (by the [admission control](../../../admission_control) +and [mempool](../../../mempool) components) and execution mode (by the +[execution](../../../execution) component). + +## Overview + +The MoveVM runtime is a stack machine. The VM runtime receives as input a +*block* which is a list of *transaction scripts* and a *data view*. The +data view is a **read only** snapshot of the data and code in the blockchain at +a given version (i.e., block height). At the time of startup, the runtime +does not have any code or data loaded. It is effectively *β€œempty”*. + +Every transaction executes within the context of a [Libra +account](../../stdlib/modules/libra_account.mvir)---specifically the transaction +submitter's account. The execution of every transaction consists of three +parts: the account prologue, the transaction itself, and the account +epilogue. This is the only transaction flow known to the runtime, and it is +the only flow the runtime executes. The runtime is responsible to load the +individual transaction from the block and execute the transaction flow: + +1. ***Transaction Prologue*** - in verification mode the runtime runs the + bytecode verifier over the transaction script and executes the + prologue defined in the [Libra account + module](../../stdlib/modules/libra_account.mvir). The prologue is responsible + for checking the structure of the transaction and + rejecting obviously bad transactions. In verification mode, the runtime + returns a status of either `success` or `failure` depending upon the + result of running the prologue. No updates to the blockchain state are + ever performed by the prologue. +2. ***Transaction Execution*** - in execution mode, and after verification, + the runtime starts executing transaction-specific/client code. A typical + code performs updates to data in the blockchain. Execution of the + transaction by the VM runtime produces a write set that acts as an + atomic state change from the current state of the blockchain---received + via the data view---to a new version that is the result of applying the + write set. Importantly, on-chain data is _never_ changed during the + execution of the transaction. Further, while the write set is produced as the + result of executing the bytecode, the changes are not applied to the global + blockchain state by the VM---this is the responsibility of the + [execution module](../../../execution/). +3. ***Transaction Epilogue*** - in execution mode the epilogue defined in + the [Libra account module](../../stdlib/modules/libra_account.mvir) is + executed to perform actions based upon the result of the execution of + the user-submitted transaction. One example of such an action is + debiting the gas fee for the transaction from the submitting account's + balance. + +During execution, the runtime resolves references to code by loading the +referenced code via the data view. One can think of this process as similar +to linking. Then, within the context of a block of transactions---a list of +transactions coupled with a data view---the runtime caches code and +linked and imported modules across transactions within the block. +The runtime tracks state changes (data updates) from one transaction +to the next within each block of transactions; the semantics of the +execution of a block specify that transactions are sequentially executed +and, as a consequence, state changes of previous transactions must be +visible to subsequent transactions within each block. + +## Implementation Details + +* The runtime top level structs are in `runtime` and `libra vm` related + code. +* The transaction flow is implemented in the [`process_txn`](./src/process_txn.rs) + module. +* The interpreter is implemented within the [transaction + executor](./src/txn_executor.rs). +* Code caching logic and policies are defined under the [code + cache](./src/code_cache/) directory. +* Runtime loaded code and the type system view for the runtime is defined + under the [loaded data](./src/loaded_data/) directory. +* The data representation of values, and logic for write set generation can + be found under the [value](./src/value.rs) and [data + cache](./src/data_cache.rs) files. +* Test cases and infrastructure for building proptests for the VM runtime + can be found under the [`vm_runtime_tests`](./vm_runtime_tests/) + directory. + +## Folder Structure + +``` +. +β”œβ”€β”€ src # VM Runtime files +β”‚Β Β  β”œβ”€β”€ code_cache # VM Runtime code cache +β”‚Β Β  β”œβ”€β”€ loaded_data # VM Runtime loaded data types, runtime caches over code +β”‚Β Β  β”œβ”€β”€ unit_tests # unit tests +β”œβ”€β”€ vm_cache_map # abstractions for the code cache +└── vm_runtime_tests # test infrastructure for the runtime, proptest and simple tests +``` + +## This Module Interacts With + +This crate is mainly used in two parts: AC and mempool use it to determine +if it should accept a transaction or not; the Executor runs the MoveVM +runtime to execute the program field in a SignedTransaction and convert +it into a TransactionOutput, which contains a writeset that the +executor need to patch to the blockchain as a side effect of this +transaction. diff --git a/language/vm/vm_runtime/src/block_processor.rs b/language/vm/vm_runtime/src/block_processor.rs new file mode 100644 index 0000000000000..d5e2c5cdb282d --- /dev/null +++ b/language/vm/vm_runtime/src/block_processor.rs @@ -0,0 +1,140 @@ +// Copyright (c) The Libra Core Contributors +// SPDX-License-Identifier: Apache-2.0 + +use crate::{ + code_cache::{ + module_adapter::ModuleFetcherImpl, + module_cache::{BlockModuleCache, ModuleCache, VMModuleCache}, + script_cache::ScriptCache, + }, + counters, + data_cache::BlockDataCache, + process_txn::{execute::ExecutedTransaction, validate::ValidationMode, ProcessTransaction}, +}; +use config::config::VMPublishingOption; +use logger::prelude::*; +use state_view::StateView; +use types::{ + transaction::{SignedTransaction, TransactionOutput, TransactionStatus}, + vm_error::{ExecutionStatus, VMStatus, VMValidationStatus}, + write_set::WriteSet, +}; +use vm_cache_map::Arena; + +pub fn execute_block<'alloc>( + txn_block: Vec, + code_cache: &VMModuleCache<'alloc>, + script_cache: &ScriptCache<'alloc>, + data_view: &dyn StateView, + publishing_option: &VMPublishingOption, +) -> Vec { + trace!("[VM] Execute block, transaction count: {}", txn_block.len()); + + let mode = if data_view.is_genesis() { + // The genesis transaction must be in a block of its own. + if txn_block.len() != 1 { + // XXX Need a way to return that an entire block failed. + return txn_block + .iter() + .map(|_| { + TransactionOutput::new( + WriteSet::default(), + vec![], + 0, + TransactionStatus::from(VMStatus::Validation( + VMValidationStatus::RejectedWriteSet, + )), + ) + }) + .collect(); + } else { + ValidationMode::Genesis + } + } else { + ValidationMode::Executing + }; + + let module_cache = BlockModuleCache::new(code_cache, ModuleFetcherImpl::new(data_view)); + let mut data_cache = BlockDataCache::new(data_view); + let mut result = vec![]; + for txn in txn_block.into_iter() { + let output = transaction_flow( + txn, + &module_cache, + script_cache, + &data_cache, + mode, + publishing_option, + ); + data_cache.push_write_set(&output.write_set()); + result.push(output); + } + trace!("[VM] Execute block finished"); + result +} + +/// Process a transaction and emit a TransactionOutput. +/// +/// A successful execution will have `TransactionStatus::Keep` in the TransactionOutput and a +/// non-empty writeset. There are two possibilities for a failed transaction. If a verification or +/// runtime error occurs, the TransactionOutput will have `TransactionStatus::Keep` and a writeset +/// that only contains the charged gas of this transaction. If a validation or `InvariantViolation` +/// error occurs, the TransactionOutput will have `TransactionStatus::Discard` and an empty +/// writeset. +/// +/// Note that this function DO HAVE side effect. If a transaction tries to publish some module, +/// and this transaction is executed successfully, this function will update `module_cache` to +/// include those newly published modules. This function will also update the `script_cache` to +/// cache this `txn` +fn transaction_flow<'alloc, P>( + txn: SignedTransaction, + module_cache: P, + script_cache: &ScriptCache<'alloc>, + data_cache: &BlockDataCache<'_>, + mode: ValidationMode, + publishing_option: &VMPublishingOption, +) -> TransactionOutput +where + P: ModuleCache<'alloc>, +{ + let arena = Arena::new(); + let process_txn = ProcessTransaction::new(txn, &module_cache, data_cache, &arena); + + let validated_txn = match process_txn.validate(mode, publishing_option) { + Ok(validated_txn) => validated_txn, + Err(vm_status) => { + counters::FAILED_TRANSACTION.inc(); + return ExecutedTransaction::discard_error_output(vm_status); + } + }; + let verified_txn = match validated_txn.verify() { + Ok(verified_txn) => verified_txn, + Err(vm_status) => { + counters::FAILED_TRANSACTION.inc(); + return ExecutedTransaction::discard_error_output(vm_status); + } + }; + let executed_txn = verified_txn.execute(script_cache); + + // On success, publish the modules into the cache so that future transactions can refer to them + // directly. + let output = executed_txn.into_output(); + match output.status() { + TransactionStatus::Keep(VMStatus::Execution(ExecutionStatus::Executed)) => { + match module_cache.reclaim_cached_module(arena.into_vec()) { + Ok(_) => { + counters::SUCCESSFUL_TRANSACTION.inc(); + output + } + Err(err) => { + counters::FAILED_TRANSACTION.inc(); + ExecutedTransaction::discard_error_output(&err) + } + } + } + _ => { + counters::FAILED_TRANSACTION.inc(); + output + } + } +} diff --git a/language/vm/vm_runtime/src/code_cache/mod.rs b/language/vm/vm_runtime/src/code_cache/mod.rs new file mode 100644 index 0000000000000..0f9a583cbc913 --- /dev/null +++ b/language/vm/vm_runtime/src/code_cache/mod.rs @@ -0,0 +1,7 @@ +// Copyright (c) The Libra Core Contributors +// SPDX-License-Identifier: Apache-2.0 +//! Caches for code data stored on chain. + +pub mod module_adapter; +pub mod module_cache; +pub mod script_cache; diff --git a/language/vm/vm_runtime/src/code_cache/module_adapter.rs b/language/vm/vm_runtime/src/code_cache/module_adapter.rs new file mode 100644 index 0000000000000..f540f9420faa3 --- /dev/null +++ b/language/vm/vm_runtime/src/code_cache/module_adapter.rs @@ -0,0 +1,87 @@ +// Copyright (c) The Libra Core Contributors +// SPDX-License-Identifier: Apache-2.0 +//! Fetches code data from the blockchain. + +use logger::prelude::*; +use state_view::StateView; +use std::collections::HashMap; +use types::language_storage::CodeKey; +use vm::file_format::CompiledModule; + +/// Trait that describes how the VM expects code data to be stored. +pub trait ModuleFetcher { + /// `CodeKey` is the fully qualified name for the module we are trying to fetch. + fn get_module(&self, key: &CodeKey) -> Option; +} + +/// A wrapper around State Store database for fetching code data stored on chain. +pub struct ModuleFetcherImpl<'a>(&'a dyn StateView); + +impl<'a> ModuleFetcherImpl<'a> { + /// Creates a new Fetcher instance with a `StateView` reference. + pub fn new(storage: &'a dyn StateView) -> Self { + ModuleFetcherImpl(storage) + } +} + +impl<'a> ModuleFetcher for ModuleFetcherImpl<'a> { + fn get_module(&self, key: &CodeKey) -> Option { + let access_path = key.into(); + match self.0.get(&access_path) { + Ok(opt_module_blob) => match opt_module_blob { + Some(module_blob) => match CompiledModule::deserialize(&module_blob) { + Ok(module) => Some(module), + Err(_) => { + crit!( + "[VM] Storage contains a malformed module with key {:?}", + key + ); + None + } + }, + None => { + crit!("[VM] Storage returned None for module with key {:?}", key); + None + } + }, + Err(_) => { + crit!("[VM] Error fetching module with key {:?}", key); + None + } + } + } +} + +/// A wrapper for an empty state with no code data stored. +pub struct NullFetcher(); + +impl ModuleFetcher for NullFetcher { + fn get_module(&self, _key: &CodeKey) -> Option { + None + } +} + +/// A wrapper for a state with a list of pre-compiled modules. +pub struct FakeFetcher(HashMap); + +impl FakeFetcher { + /// Create a FakeFetcher instance with a vector of pre-compiled modules. + pub fn new(modules: Vec) -> Self { + let mut map = HashMap::new(); + for m in modules.into_iter() { + map.insert(m.self_code_key(), m); + } + FakeFetcher(map) + } + + /// Remove all modules stored in the fetcher. + pub fn clear(&mut self) { + self.0 = HashMap::new(); + } +} + +impl ModuleFetcher for FakeFetcher { + fn get_module(&self, key: &CodeKey) -> Option { + self.0.get(key).cloned() + } +} diff --git a/language/vm/vm_runtime/src/code_cache/module_cache.rs b/language/vm/vm_runtime/src/code_cache/module_cache.rs new file mode 100644 index 0000000000000..1eef17f83e5a2 --- /dev/null +++ b/language/vm/vm_runtime/src/code_cache/module_cache.rs @@ -0,0 +1,504 @@ +// Copyright (c) The Libra Core Contributors +// SPDX-License-Identifier: Apache-2.0 +//! Cache for modules published on chain. + +use crate::{ + code_cache::module_adapter::{ModuleFetcher, NullFetcher}, + gas_meter::GasMeter, + loaded_data::{ + function::{FunctionRef, FunctionReference}, + loaded_module::LoadedModule, + struct_def::StructDef, + types::Type, + }, +}; +use std::marker::PhantomData; +use types::language_storage::CodeKey; +use vm::{ + access::{BaseAccess, ModuleAccess}, + errors::*, + file_format::{ + CompiledModule, CompiledScript, FunctionDefinitionIndex, FunctionHandleIndex, + SignatureToken, StructDefinitionIndex, StructHandleIndex, + }, + views::{FunctionHandleView, StructHandleView}, +}; +use vm_cache_map::{Arena, CacheRefMap}; + +#[cfg(test)] +#[path = "../unit_tests/module_cache_tests.rs"] +mod module_cache_tests; + +/// Trait that describe a cache for modules. The idea is that this trait will in charge of +/// loading resolving all dependencies of needed module from the storage. +pub trait ModuleCache<'alloc> { + /// Given a function handle index, resolve that handle into an internal representation of + /// move function. Return value can be one of the three following cases: + /// 1. `Ok(Some(FunctionRef))` if such function exists. + /// 2. `Ok(None)` if such function doesn't exists. + /// 3. `Err` if the module we are referring to has some internal consistency error + fn resolve_function_ref( + &self, + caller_module: &LoadedModule, + idx: FunctionHandleIndex, + ) -> Result>, VMInvariantViolation>; + + /// Resolve a StructDefinitionIndex into a StructDef. This process will be recursive so we may + /// charge gas on each recursive step. Return value can be one of the following cases: + /// 1. `Ok(Some(StructDef))` if such struct exists. + /// 2. `Ok(None)` if such function doesn't exists. + /// 4. `Err(VMInvariantViolation)` if the module we are referring to has some internal + /// consistency error + /// 5. `Err(LinkerError)` if some fields contains an unknown struct. + /// 6. `Err(OutOfGas)` if the recursive resolution is costing too much gas + fn resolve_struct_def( + &self, + module: &LoadedModule, + idx: StructDefinitionIndex, + gas_meter: &GasMeter, + ) -> VMResult>; + + /// Resolve a CodeKey into a LoadedModule if the module has been cached already. Return value + /// can be one of the three following cases: + /// 1. `Ok(Some(LoadedModule))` if such module exists. + /// 2. `Ok(None)` if such module doesn't exists. + /// 3. `Err` if the module we are referring to has some internal consistency error + fn get_loaded_module( + &self, + id: &CodeKey, + ) -> Result, VMInvariantViolation>; + + fn cache_module(&self, module: CompiledModule) -> Result<(), VMInvariantViolation>; + + /// Recache the list of previously resolved modules. Think of the cache as a generational + /// cache and we need to move modules across generations. + fn reclaim_cached_module(&self, v: Vec) -> Result<(), VMInvariantViolation>; +} + +/// `ModuleCache` is also implemented for references. +impl<'alloc, P> ModuleCache<'alloc> for &P +where + P: ModuleCache<'alloc>, +{ + fn resolve_function_ref( + &self, + caller_module: &LoadedModule, + idx: FunctionHandleIndex, + ) -> Result>, VMInvariantViolation> { + (*self).resolve_function_ref(caller_module, idx) + } + + fn resolve_struct_def( + &self, + module: &LoadedModule, + idx: StructDefinitionIndex, + gas_meter: &GasMeter, + ) -> VMResult> { + (*self).resolve_struct_def(module, idx, gas_meter) + } + + fn get_loaded_module( + &self, + id: &CodeKey, + ) -> Result, VMInvariantViolation> { + (*self).get_loaded_module(id) + } + + fn cache_module(&self, module: CompiledModule) -> Result<(), VMInvariantViolation> { + (*self).cache_module(module) + } + + fn reclaim_cached_module(&self, v: Vec) -> Result<(), VMInvariantViolation> { + (*self).reclaim_cached_module(v) + } +} + +/// Cache for modules that resides in a VM. It is an internally mutable map from module +/// identifier to a reference to loaded module, where the actual module is owned by the Arena +/// allocator so that it will guarantee to outlive the lifetime of the transaction. +pub struct VMModuleCache<'alloc> { + map: CacheRefMap<'alloc, CodeKey, LoadedModule>, +} + +/// Convert a CompiledScript into a CompiledModule. +pub fn create_fake_module(script: CompiledScript) -> (CompiledModule, FunctionDefinitionIndex) { + ( + CompiledModule { + module_handles: script.module_handles, + struct_handles: script.struct_handles, + function_handles: script.function_handles, + + struct_defs: vec![], + field_defs: vec![], + function_defs: vec![script.main], + type_signatures: script.type_signatures, + function_signatures: script.function_signatures, + locals_signatures: script.locals_signatures, + string_pool: script.string_pool, + byte_array_pool: script.byte_array_pool, + address_pool: script.address_pool, + }, + FunctionDefinitionIndex::new(0), + ) +} + +impl<'alloc> VMModuleCache<'alloc> { + /// In order + /// to get a cleaner lifetime, the loaded program trait will take an input parameter of Arena + /// allocator to store so that every allocated element in the loaded program can have the same + /// lifetime. + pub fn new(allocator: &'alloc Arena) -> Self { + VMModuleCache { + map: CacheRefMap::new(allocator), + } + } + + /// Resolve a CodeKey into a LoadedModule. If there is a cache miss, try to fetch the module + /// from the `fetcher` and insert it into the cache if found. If nothing is found, it will + /// return Ok(None). + pub fn get_loaded_module_with_fetcher( + &self, + id: &CodeKey, + fetcher: &F, + ) -> Result, VMInvariantViolation> { + // Currently it is still possible for a script to invoke a nonsense module id function. + // However, once we have the verifier that checks the well-formedness of the all the linked + // module id, we should get rid of that ok_or case here. + if let Some(m) = self.map.get(id) { + return Ok(Some(&*m)); + } + Ok(fetcher + .get_module(id) + .map(LoadedModule::new) + .transpose()? + .map(|m| self.map.or_insert(id.clone(), m))) + } + + #[cfg(test)] + pub fn new_from_module( + module: CompiledModule, + allocator: &'alloc Arena, + ) -> Result { + let module_id = module.self_code_key(); + let map = CacheRefMap::new(allocator); + let loaded_module = LoadedModule::new(module)?; + map.or_insert(module_id, loaded_module); + Ok(VMModuleCache { map }) + } + + /// Resolve a FunctionHandleIndex into a FunctionRef in either the cache or the `fetcher`. + /// An Ok(None) will be returned if no such function is found. + pub fn resolve_function_ref_with_fetcher( + &self, + caller_module: &LoadedModule, + idx: FunctionHandleIndex, + fetcher: &F, + ) -> Result>, VMInvariantViolation> + where + F: ModuleFetcher, + { + let function_handle = caller_module.module.function_handle_at(idx); + let callee_name = caller_module.string_at(function_handle.name); + let callee_module_id = + FunctionHandleView::new(&caller_module.module, function_handle).module_code_key(); + self.get_loaded_module_with_fetcher(&callee_module_id, fetcher) + .and_then(|callee_module_opt| { + if let Some(callee_module) = callee_module_opt { + let callee_func_id = callee_module + .function_defs_table + .get(callee_name) + .ok_or(VMInvariantViolation::LinkerError)?; + Ok(Some(FunctionRef::new(callee_module, *callee_func_id)?)) + } else { + Ok(None) + } + }) + } + + /// Resolve a StructHandle into a StructDef recursively in either the cache or the `fetcher`. + pub fn resolve_struct_handle_with_fetcher( + &self, + module: &LoadedModule, + idx: StructHandleIndex, + gas_meter: &GasMeter, + fetcher: &F, + ) -> VMResult> { + let struct_handle = module.module.struct_handle_at(idx); + let struct_name = module.module.string_at(struct_handle.name); + let struct_def_module_id = + StructHandleView::new(&module.module, struct_handle).module_code_key(); + let defined_module = self.get_loaded_module_with_fetcher(&struct_def_module_id, fetcher)?; + if let Some(m) = defined_module { + let struct_def_idx = m + .struct_defs_table + .get(struct_name) + .ok_or(VMInvariantViolation::LinkerError)?; + self.resolve_struct_def_with_fetcher(m, *struct_def_idx, gas_meter, fetcher) + } else { + Ok(Ok(None)) + } + } + + /// Resolve a SignatureToken into a Type recursively in either the cache or the `fetcher`. + pub fn resolve_signature_token_with_fetcher<'txn, F: ModuleFetcher>( + &'txn self, + module: &LoadedModule, + tok: &SignatureToken, + gas_meter: &GasMeter, + fetcher: &F, + ) -> VMResult> { + match tok { + SignatureToken::Bool => Ok(Ok(Some(Type::Bool))), + SignatureToken::U64 => Ok(Ok(Some(Type::U64))), + SignatureToken::String => Ok(Ok(Some(Type::String))), + SignatureToken::ByteArray => Ok(Ok(Some(Type::ByteArray))), + SignatureToken::Address => Ok(Ok(Some(Type::Address))), + SignatureToken::Struct(sh_idx) => { + let struct_def = + try_runtime!(self + .resolve_struct_handle_with_fetcher(module, *sh_idx, gas_meter, fetcher)); + Ok(Ok(struct_def.map(Type::Struct))) + } + SignatureToken::Reference(sub_tok) => { + let inner_ty = + try_runtime!(self + .resolve_signature_token_with_fetcher(module, sub_tok, gas_meter, fetcher)); + Ok(Ok(inner_ty.map(|t| Type::Reference(Box::new(t))))) + } + SignatureToken::MutableReference(sub_tok) => { + let inner_ty = + try_runtime!(self + .resolve_signature_token_with_fetcher(module, sub_tok, gas_meter, fetcher)); + Ok(Ok(inner_ty.map(|t| Type::MutableReference(Box::new(t))))) + } + } + } + + /// Resolve a StructDefinition into a StructDef recursively in either the cache or the + /// `fetcher`. + pub fn resolve_struct_def_with_fetcher<'txn, F: ModuleFetcher>( + &'txn self, + module: &LoadedModule, + idx: StructDefinitionIndex, + gas_meter: &GasMeter, + fetcher: &F, + ) -> VMResult> { + if let Some(def) = module.cached_struct_def_at(idx) { + return Ok(Ok(Some(def))); + } + let def = { + let struct_def = module.module.struct_def_at(idx); + let mut field_types = vec![]; + for field in module + .module + .field_def_range(struct_def.field_count, struct_def.fields) + { + let ty = try_runtime!(self.resolve_signature_token_with_fetcher( + module, + &module.module.type_signature_at(field.signature).0, + gas_meter, + fetcher + )); + if let Some(t) = ty { + field_types.push(t); + } else { + return Ok(Ok(None)); + } + } + StructDef::new(field_types) + }; + // If multiple writers write to def at the same time, the last one will win. It's possible + // to have multiple copies of a struct def floating around, but that probably isn't going + // to be a big deal. + module.cache_struct_def(idx, def.clone()); + Ok(Ok(Some(def))) + } +} + +impl<'alloc> ModuleCache<'alloc> for VMModuleCache<'alloc> { + fn resolve_function_ref( + &self, + caller_module: &LoadedModule, + idx: FunctionHandleIndex, + ) -> Result>, VMInvariantViolation> { + self.resolve_function_ref_with_fetcher(caller_module, idx, &NullFetcher()) + } + + fn resolve_struct_def( + &self, + module: &LoadedModule, + idx: StructDefinitionIndex, + gas_meter: &GasMeter, + ) -> VMResult> { + self.resolve_struct_def_with_fetcher(module, idx, gas_meter, &NullFetcher()) + } + + fn get_loaded_module( + &self, + id: &CodeKey, + ) -> Result, VMInvariantViolation> { + // Currently it is still possible for a script to invoke a nonsense module id function. + // However, once we have the verifier that checks the well-formedness of the all the linked + // module id, we should get rid of that ok_or case here. + Ok(self.map.get(id)) + } + + fn cache_module(&self, module: CompiledModule) -> Result<(), VMInvariantViolation> { + let module_id = module.self_code_key(); + // TODO: Check CodeKey duplication in statedb + let loaded_module = LoadedModule::new(module)?; + self.map.or_insert(module_id, loaded_module); + Ok(()) + } + + fn reclaim_cached_module(&self, v: Vec) -> Result<(), VMInvariantViolation> { + for m in v.into_iter() { + let module_id = m.module.self_code_key(); + self.map.or_insert(module_id, m); + } + Ok(()) + } +} + +/// A cache for all modules stored on chain. `vm_cache` holds the local cached modules whereas +/// `storage` should implement trait ModuleFetcher that can fetch the modules that aren't in the +/// cache yet. In production, it will usually provide a connection to the StateStore client to fetch +/// the needed data. `alloc` is the lifetime for the entire VM and `blk` is the lifetime for the +/// current block we are executing. +pub struct BlockModuleCache<'alloc, 'blk, F> +where + 'alloc: 'blk, + F: ModuleFetcher, +{ + vm_cache: &'blk VMModuleCache<'alloc>, + storage: F, +} + +impl<'alloc, 'blk, F> BlockModuleCache<'alloc, 'blk, F> +where + 'alloc: 'blk, + F: ModuleFetcher, +{ + pub fn new(vm_cache: &'blk VMModuleCache<'alloc>, module_fetcher: F) -> Self { + BlockModuleCache { + vm_cache, + storage: module_fetcher, + } + } +} + +impl<'alloc, 'blk, F: ModuleFetcher> ModuleCache<'alloc> for BlockModuleCache<'alloc, 'blk, F> { + fn resolve_function_ref( + &self, + caller_module: &LoadedModule, + idx: FunctionHandleIndex, + ) -> Result>, VMInvariantViolation> { + self.vm_cache + .resolve_function_ref_with_fetcher(caller_module, idx, &self.storage) + } + + fn resolve_struct_def( + &self, + module: &LoadedModule, + idx: StructDefinitionIndex, + gas_meter: &GasMeter, + ) -> VMResult> { + self.vm_cache + .resolve_struct_def_with_fetcher(module, idx, gas_meter, &self.storage) + } + + fn get_loaded_module( + &self, + id: &CodeKey, + ) -> Result, VMInvariantViolation> { + self.vm_cache + .get_loaded_module_with_fetcher(id, &self.storage) + } + + fn cache_module(&self, module: CompiledModule) -> Result<(), VMInvariantViolation> { + self.vm_cache.cache_module(module) + } + + fn reclaim_cached_module(&self, v: Vec) -> Result<(), VMInvariantViolation> { + self.vm_cache.reclaim_cached_module(v) + } +} + +/// A temporary cache for module published by a single transaction. This cache allows the +/// transaction script to refer to either those newly published modules in `local_cache` or those +/// existing on chain modules in `block_cache`. VM can choose to discard those newly published +/// modules if there is an error during execution. +pub struct TransactionModuleCache<'alloc, 'txn, P> +where + 'alloc: 'txn, + P: ModuleCache<'alloc>, +{ + block_cache: P, + local_cache: VMModuleCache<'txn>, + + phantom: PhantomData<&'alloc ()>, +} + +impl<'alloc, 'txn, P> TransactionModuleCache<'alloc, 'txn, P> +where + 'alloc: 'txn, + P: ModuleCache<'alloc>, +{ + pub fn new(block_cache: P, allocator: &'txn Arena) -> Self { + TransactionModuleCache { + block_cache, + local_cache: VMModuleCache::new(allocator), + phantom: PhantomData, + } + } +} + +impl<'alloc, 'txn, P> ModuleCache<'txn> for TransactionModuleCache<'alloc, 'txn, P> +where + 'alloc: 'txn, + P: ModuleCache<'alloc>, +{ + fn resolve_function_ref( + &self, + caller_module: &LoadedModule, + idx: FunctionHandleIndex, + ) -> Result>, VMInvariantViolation> { + if let Some(f) = self.local_cache.resolve_function_ref(caller_module, idx)? { + Ok(Some(f)) + } else { + self.block_cache.resolve_function_ref(caller_module, idx) + } + } + + fn resolve_struct_def( + &self, + module: &LoadedModule, + idx: StructDefinitionIndex, + gas_meter: &GasMeter, + ) -> VMResult> { + if let Some(f) = try_runtime!(self.local_cache.resolve_struct_def(module, idx, gas_meter)) { + Ok(Ok(Some(f))) + } else { + self.block_cache.resolve_struct_def(module, idx, gas_meter) + } + } + + fn get_loaded_module( + &self, + id: &CodeKey, + ) -> Result, VMInvariantViolation> { + if let Some(m) = self.local_cache.get_loaded_module(id)? { + Ok(Some(m)) + } else { + self.block_cache.get_loaded_module(id) + } + } + + fn cache_module(&self, module: CompiledModule) -> Result<(), VMInvariantViolation> { + self.local_cache.cache_module(module) + } + + fn reclaim_cached_module(&self, _v: Vec) -> Result<(), VMInvariantViolation> { + Err(VMInvariantViolation::LinkerError) + } +} diff --git a/language/vm/vm_runtime/src/code_cache/script_cache.rs b/language/vm/vm_runtime/src/code_cache/script_cache.rs new file mode 100644 index 0000000000000..82ed5fa2e2760 --- /dev/null +++ b/language/vm/vm_runtime/src/code_cache/script_cache.rs @@ -0,0 +1,60 @@ +// Copyright (c) The Libra Core Contributors +// SPDX-License-Identifier: Apache-2.0 +//! Cache for commonly executed scripts + +use crate::{ + code_cache::module_cache::create_fake_module, + loaded_data::{ + function::{FunctionRef, FunctionReference}, + loaded_module::LoadedModule, + }, +}; +use logger::prelude::*; +use tiny_keccak::Keccak; +use types::transaction::SCRIPT_HASH_LENGTH; +use vm::{errors::VMResult, file_format::CompiledScript}; +use vm_cache_map::{Arena, CacheMap}; + +/// The cache for commonly executed scripts. Currently there's no eviction policy, and it maps +/// hash of script bytes into `FunctionRef`. +pub struct ScriptCache<'alloc> { + map: CacheMap<'alloc, [u8; SCRIPT_HASH_LENGTH], LoadedModule, FunctionRef<'alloc>>, +} + +impl<'alloc> ScriptCache<'alloc> { + /// Create a new ScriptCache. + pub fn new(allocator: &'alloc Arena) -> Self { + ScriptCache { + map: CacheMap::new(allocator), + } + } + + /// Cache and resolve `script` into a `FunctionRef` that can be executed + pub fn cache_script( + &self, + script: CompiledScript, + raw_bytes: &[u8], + ) -> VMResult> { + let mut hash = [0u8; SCRIPT_HASH_LENGTH]; + let mut keccak = Keccak::new_sha3_256(); + + keccak.update(raw_bytes); + keccak.finalize(&mut hash); + + if let Some(f) = self.map.get(&hash) { + trace!("[VM] Script cache hit"); + Ok(Ok(f)) + } else { + trace!("[VM] Script cache miss"); + let (fake_module, idx) = create_fake_module(script); + let loaded_module = LoadedModule::new(fake_module)?; + self.map + .or_insert_with_try_transform( + hash, + move || loaded_module, + |module_ref| FunctionRef::new(module_ref, idx), + ) + .map(Ok) + } + } +} diff --git a/language/vm/vm_runtime/src/counters.rs b/language/vm/vm_runtime/src/counters.rs new file mode 100644 index 0000000000000..7d84c3031ab62 --- /dev/null +++ b/language/vm/vm_runtime/src/counters.rs @@ -0,0 +1,24 @@ +// Copyright (c) The Libra Core Contributors +// SPDX-License-Identifier: Apache-2.0 + +use lazy_static; +use metrics::OpMetrics; +use prometheus::IntCounter; + +lazy_static::lazy_static! { + pub static ref VM_COUNTERS: OpMetrics = OpMetrics::new_and_registered("move_vm"); +} + +lazy_static::lazy_static! { +/// Counter of the successfully executed transactions. +pub static ref SUCCESSFUL_TRANSACTION: IntCounter = VM_COUNTERS.counter("txn.execution.success"); + +/// Counter of the transactions that failed to execute. +pub static ref FAILED_TRANSACTION: IntCounter = VM_COUNTERS.counter("txn.execution.fail"); + +/// Counter of the successfully verified transactions. +pub static ref VERIFIED_TRANSACTION: IntCounter = VM_COUNTERS.counter("txn.verification.success"); + +/// Counter of the transactions that failed to verify. +pub static ref UNVERIFIED_TRANSACTION: IntCounter = VM_COUNTERS.counter("txn.verification.fail"); +} diff --git a/language/vm/vm_runtime/src/data_cache.rs b/language/vm/vm_runtime/src/data_cache.rs new file mode 100644 index 0000000000000..b3db17efc6db6 --- /dev/null +++ b/language/vm/vm_runtime/src/data_cache.rs @@ -0,0 +1,256 @@ +// Copyright (c) The Libra Core Contributors +// SPDX-License-Identifier: Apache-2.0 +//! Scratchpad for on chain values during the execution. + +use crate::{ + loaded_data::struct_def::StructDef, + value::{GlobalRef, Local, MutVal, Reference, Value}, +}; +use logger::prelude::*; +use state_view::StateView; +use std::{collections::btree_map::BTreeMap, mem::replace}; +use types::{ + access_path::AccessPath, + language_storage::CodeKey, + write_set::{WriteOp, WriteSet, WriteSetMut}, +}; +use vm::{errors::*, gas_schedule::AbstractMemorySize}; + +/// The wrapper around the StateVersionView for the block. +/// It keeps track of the value that have been changed during execution of a block. +/// It's effectively the write set for the block. +pub struct BlockDataCache<'block> { + data_view: &'block dyn StateView, + // TODO: an AccessPath corresponds to a top level resource but that may not be the + // case moving forward, so we need to review this. + // Also need to relate this to a ResourceKey. + data_map: BTreeMap>, +} + +impl<'block> BlockDataCache<'block> { + pub fn new(data_view: &'block dyn StateView) -> Self { + BlockDataCache { + data_view, + data_map: BTreeMap::new(), + } + } + + pub fn get(&self, access_path: &AccessPath) -> Result>, VMInvariantViolation> { + match self.data_map.get(access_path) { + Some(data) => Ok(Some(data.clone())), + None => match self.data_view.get(&access_path) { + Ok(remote_data) => Ok(remote_data), + // TODO: should we forward some error info? + Err(_) => { + crit!("[VM] Error getting data from storage for {:?}", access_path); + Err(VMInvariantViolation::StorageError) + } + }, + } + } + + pub fn push_write_set(&mut self, write_set: &WriteSet) { + for (ref ap, ref write_op) in write_set.iter() { + match write_op { + WriteOp::Value(blob) => { + self.data_map.insert(ap.clone(), blob.clone()); + } + WriteOp::Deletion => { + self.data_map.remove(ap); + } + } + } + } +} + +/// Trait for the StateVersionView or a mock implementation of the remote cache. +/// Unit and integration tests should use this to mock implementations of "storage" +pub trait RemoteCache { + fn get(&self, access_path: &AccessPath) -> Result>, VMInvariantViolation>; +} + +impl<'block> RemoteCache for BlockDataCache<'block> { + fn get(&self, access_path: &AccessPath) -> Result>, VMInvariantViolation> { + BlockDataCache::get(self, access_path) + } +} + +/// Global cache for a transaction. +/// Materializes Values from the RemoteCache and keeps an Rc to them. +/// It also implements the opcodes that talk to storage and gives the proper guarantees of +/// reference lifetime. +/// Dirty objects are serialized and returned in make_write_set +pub struct TransactionDataCache<'txn> { + // TODO: an AccessPath corresponds to a top level resource but that may not be the + // case moving forward, so we need to review this. + // Also need to relate this to a ResourceKey. + data_map: BTreeMap, + data_cache: &'txn RemoteCache, +} + +impl<'txn> TransactionDataCache<'txn> { + pub fn new(data_cache: &'txn RemoteCache) -> Self { + TransactionDataCache { + data_cache, + data_map: BTreeMap::new(), + } + } + + // Retrieve data from the local cache or loads it from the remote cache into the local cache. + // All operations on the global data are based on this API and they all load the data + // into the cache. + // TODO: this may not be the most efficient model because we always load data into the + // cache even when that would not be strictly needed. Review once we have the whole story + // working + fn load_data(&mut self, ap: &AccessPath, def: StructDef) -> VMResult<&mut GlobalRef> { + if !self.data_map.contains_key(ap) { + match self.data_cache.get(ap)? { + Some(bytes) => { + let res = try_runtime!(Ok(Value::simple_deserialize(&bytes, def))); + let new_root = GlobalRef::make_root(ap.clone(), MutVal::new(res)); + self.data_map.insert(ap.clone(), new_root); + } + None => { + warn!("[VM] Missing data in storage for {:?}", ap); + return Ok(Err(VMRuntimeError { + loc: Location::new(), + err: VMErrorKind::MissingData, + })); + } + }; + } + Ok(Ok(self.data_map.get_mut(ap).expect("data must exist"))) + } + + /// BorrowGlobal opcode cache implementation + pub fn borrow_global(&mut self, ap: &AccessPath, def: StructDef) -> VMResult { + let root_ref = try_runtime!(self.load_data(ap, def)); + // is_loadable() checks ref count and whether the data was deleted + if root_ref.is_loadable() { + // shallow_ref increment ref count + Ok(Ok(root_ref.shallow_clone())) + } else { + Ok(Err(VMRuntimeError { + loc: Location::new(), + err: VMErrorKind::GlobalAlreadyBorrowed, + })) + } + } + + /// Exists opcode cache implementation + pub fn resource_exists( + &mut self, + ap: &AccessPath, + def: StructDef, + ) -> Result<(bool, AbstractMemorySize), VMInvariantViolation> { + Ok(match self.load_data(ap, def)? { + Ok(gref) => { + if gref.is_deleted() { + (false, 0) + } else { + (true, gref.size()) + } + } + Err(_) => (false, 0), + }) + } + + /// MoveFrom opcode cache implementation + pub fn move_resource_from(&mut self, ap: &AccessPath, def: StructDef) -> VMResult { + let root_ref = try_runtime!(self.load_data(ap, def)); + // is_loadable() checks ref count and whether the data was deleted + if root_ref.is_loadable() { + Ok(Ok(Local::Value(root_ref.move_from()))) + } else { + Ok(Err(VMRuntimeError { + loc: Location::new(), + // better name? this is true even for moved from data + err: VMErrorKind::GlobalAlreadyBorrowed, + })) + } + } + + /// MoveToSender opcode cache implementation + pub fn move_resource_to( + &mut self, + ap: &AccessPath, + def: StructDef, + res: MutVal, + ) -> VMResult<()> { + // a resource can be written to an AccessPath if the data does not exists or + // it was deleted (MoveFrom) + let can_write = match self.load_data(ap, def)? { + Ok(data) => data.is_deleted(), + Err(e) => match e.err { + VMErrorKind::MissingData => true, + _ => return Ok(Err(e)), + }, + }; + if can_write { + let new_root = GlobalRef::move_to(ap.clone(), res); + self.data_map.insert(ap.clone(), new_root); + Ok(Ok(())) + } else { + Ok(Err(VMRuntimeError { + loc: Location::new(), + err: VMErrorKind::CannotWriteExistingResource, + })) + } + } + + /// Make a write set from the updated (dirty, deleted) global resources along with + /// to-be-published modules. + /// Consume the TransactionDataCache and must be called at the end of a transaction. + /// This also ends up checking that reference count around global resources is correct + /// at the end of the transactions (all ReleaseRef are properly called) + pub fn make_write_set( + &mut self, + to_be_published_modules: Vec<(CodeKey, Vec)>, + ) -> VMRuntimeResult { + let mut write_set = WriteSetMut::new(Vec::new()); + let data_map = replace(&mut self.data_map, BTreeMap::new()); + for (key, global_ref) in data_map { + if !global_ref.is_clean() { + // if there are pending references get_data() returns None + // this is the check at the end of a transaction to verify all references + // are properly released + let deleted = global_ref.is_deleted(); + if let Some(data) = global_ref.get_data() { + if deleted { + write_set.push((key, WriteOp::Deletion)); + } else if let Some(blob) = data.simple_serialize() { + write_set.push((key, WriteOp::Value(blob))); + } else { + return Err(VMRuntimeError { + loc: Location::new(), + err: VMErrorKind::ValueSerializerError, + }); + } + } else { + return Err(VMRuntimeError { + loc: Location::new(), + err: VMErrorKind::MissingReleaseRef, + }); + } + } + } + + // Insert the code blob to the writeset. + for (key, blob) in to_be_published_modules.into_iter() { + write_set.push(((&key).into(), WriteOp::Value(blob))); + } + + match write_set.freeze() { + Ok(ws) => Ok(ws), + Err(_) => Err(VMRuntimeError { + loc: Location::new(), + err: VMErrorKind::DataFormatError, + }), + } + } + + /// Flush out the cache and restart from a clean state + pub fn clear(&mut self) { + self.data_map.clear() + } +} diff --git a/language/vm/vm_runtime/src/execution_stack.rs b/language/vm/vm_runtime/src/execution_stack.rs new file mode 100644 index 0000000000000..f0ef149df927b --- /dev/null +++ b/language/vm/vm_runtime/src/execution_stack.rs @@ -0,0 +1,170 @@ +// Copyright (c) The Libra Core Contributors +// SPDX-License-Identifier: Apache-2.0 + +use crate::{ + code_cache::module_cache::ModuleCache, + frame::Frame, + loaded_data::function::{FunctionRef, FunctionReference}, + value::{Local, MutVal, Value}, +}; +use move_ir_natives::dispatch::{Result as NativeResult, StackAccessor}; +use std::{fmt, marker::PhantomData}; +use types::byte_array::ByteArray; +use vm::errors::*; + +pub struct ExecutionStack<'alloc, 'txn, P> +where + 'alloc: 'txn, + P: ModuleCache<'alloc>, +{ + stack: Vec, + function_stack: Vec>>, + pub module_cache: P, + + // A execution stack will holds an instance of code cache for the lifetime of alloc. + phantom: PhantomData<&'alloc ()>, +} + +impl<'alloc, 'txn, P> ExecutionStack<'alloc, 'txn, P> +where + 'alloc: 'txn, + P: ModuleCache<'alloc>, +{ + pub fn new(module_cache: P) -> Self { + ExecutionStack { + function_stack: vec![], + stack: vec![], + module_cache, + phantom: PhantomData, + } + } + + pub fn push_call(&mut self, function: FunctionRef<'txn>) -> VMResult<()> { + let callee_arg_size = function.arg_count(); + let args = self.popn(callee_arg_size as u16)?; + self.function_stack.push(Frame::new(function, args)); + Ok(Ok(())) + } + + pub fn pop_call(&mut self) -> VMResult<()> { + self.function_stack + .pop() + .ok_or(VMInvariantViolation::EmptyCallStack)?; + Ok(Ok(())) + } + + pub fn top_frame(&self) -> Result<&Frame<'txn, FunctionRef<'txn>>, VMInvariantViolation> { + Ok(self + .function_stack + .last() + .ok_or(VMInvariantViolation::EmptyCallStack)?) + } + + pub fn top_frame_mut( + &mut self, + ) -> Result<&mut Frame<'txn, FunctionRef<'txn>>, VMInvariantViolation> { + Ok(self + .function_stack + .last_mut() + .ok_or(VMInvariantViolation::EmptyCallStack)?) + } + + pub fn is_call_stack_empty(&self) -> bool { + self.function_stack.is_empty() + } + + pub fn location(&self) -> Result { + Ok(self.top_frame()?.into()) + } + + pub fn push(&mut self, value: Local) { + self.stack.push(value) + } + + pub fn peek(&self) -> Result<&Local, VMInvariantViolation> { + Ok(self + .stack + .last() + .ok_or(VMInvariantViolation::EmptyValueStack)?) + } + + pub fn peek_at(&self, index: usize) -> Result<&Local, VMInvariantViolation> { + let size = self.stack.len(); + Ok(self + .stack + .get(size - index - 1) + .ok_or(VMInvariantViolation::EmptyValueStack)?) + } + + pub fn pop(&mut self) -> Result { + Ok(self + .stack + .pop() + .ok_or(VMInvariantViolation::EmptyValueStack)?) + } + + pub fn pop_as(&mut self) -> VMResult + where + Option: From, + { + let top = self.pop()?.value().and_then(std::convert::Into::into); + Ok(top.ok_or(VMRuntimeError { + loc: self.location()?, + err: VMErrorKind::TypeError, + })) + } + + pub fn popn(&mut self, n: u16) -> Result, VMInvariantViolation> { + let remaining_stack_size = self + .stack + .len() + .checked_sub(n as usize) + .ok_or(VMInvariantViolation::EmptyValueStack)?; + let args = self.stack.split_off(remaining_stack_size); + Ok(args) + } + + pub fn call_stack_height(&self) -> usize { + self.function_stack.len() + } + + pub fn set_stack(&mut self, stack: Vec) { + self.stack = stack; + } + + pub fn get_value_stack(&self) -> &Vec { + &self.stack + } + + pub fn push_frame(&mut self, func: FunctionRef<'txn>) { + self.function_stack.push(Frame::new(func, vec![])); + } +} + +impl<'alloc, 'txn, P> fmt::Debug for ExecutionStack<'alloc, 'txn, P> +where + 'alloc: 'txn, + P: ModuleCache<'alloc>, +{ + fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result { + writeln!(f, "Stack: {:?}", self.stack)?; + writeln!(f, "Current Frames: {:?}", self.function_stack) + } +} + +impl<'alloc, 'txn, P> StackAccessor for &mut ExecutionStack<'alloc, 'txn, P> +where + 'alloc: 'txn, + P: ModuleCache<'alloc>, +{ + fn get_byte_array(&mut self) -> NativeResult { + match self.pop()?.value() { + Some(v) => match MutVal::try_own(v) { + Ok(Value::ByteArray(arr)) => Ok(arr), + Err(err) => Err(err.into()), + _ => Err(VMStaticViolation::TypeMismatch.into()), + }, + None => Err(VMStaticViolation::TypeMismatch.into()), + } + } +} diff --git a/language/vm/vm_runtime/src/frame.rs b/language/vm/vm_runtime/src/frame.rs new file mode 100644 index 0000000000000..22d1989a1cc6d --- /dev/null +++ b/language/vm/vm_runtime/src/frame.rs @@ -0,0 +1,120 @@ +// Copyright (c) The Libra Core Contributors +// SPDX-License-Identifier: Apache-2.0 + +use crate::{ + bounded_fetch, + loaded_data::{function::FunctionReference, loaded_module::LoadedModule}, + value::Local, +}; +use std::{fmt, marker::PhantomData, mem::replace}; +use vm::{ + errors::{Location, VMInvariantViolation, VMResult}, + file_format::{Bytecode, CodeOffset, LocalIndex}, + IndexKind, +}; + +pub struct Frame<'txn, F: 'txn> { + pc: u16, + locals: Vec, + function: F, + phantom: PhantomData<&'txn F>, +} + +impl<'txn, F> Frame<'txn, F> +where + F: FunctionReference<'txn>, +{ + pub fn new(function: F, mut args: Vec) -> Self { + args.resize(function.local_count(), Local::Invalid); + Frame { + pc: 0, + locals: args, + function, + phantom: PhantomData, + } + } + + pub fn code_definition(&self) -> &'txn [Bytecode] { + self.function.code_definition() + } + + pub fn jump(&mut self, offset: CodeOffset) { + self.pc = offset; + } + + pub fn get_pc(&self) -> u16 { + self.pc + } + + pub fn get_local(&self, idx: LocalIndex) -> Result<&Local, VMInvariantViolation> { + bounded_fetch(&self.locals, idx as usize, IndexKind::LocalPool) + } + + pub fn invalidate_local(&mut self, idx: LocalIndex) -> Result { + if let Some(local_ref) = self.locals.get_mut(idx as usize) { + let old_local = replace(local_ref, Local::Invalid); + Ok(old_local) + } else { + Err(VMInvariantViolation::IndexOutOfBounds( + IndexKind::LocalPool, + idx as usize, + self.locals.len(), + )) + } + } + + pub fn store_local(&mut self, idx: LocalIndex, local: Local) -> VMResult<()> { + // We don't need to check if the local matches the local signature + // definition as VM is oblivous to value types. + if let Some(local_ref) = self.locals.get_mut(idx as usize) { + // What should we do if local already has some other values? + *local_ref = local; + Ok(Ok(())) + } else { + Err(VMInvariantViolation::IndexOutOfBounds( + IndexKind::LocalPool, + idx as usize, + self.locals.len(), + )) + } + } + + pub fn module(&self) -> &'txn LoadedModule { + self.function.module() + } +} + +impl<'txn, F> Into for &Frame<'txn, F> { + fn into(self) -> Location { + Location::new() + } +} + +impl<'txn, F> fmt::Debug for Frame<'txn, F> +where + F: FunctionReference<'txn>, +{ + fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result { + write!(f, "\n\tFunction: {}", self.function.name())?; + write!(f, "\n\tLocals: [")?; + for l in self.locals.iter() { + write!(f, "\n\t\t{:?},", l)?; + } + write!(f, "\n\t]") + } +} + +#[cfg(any(test, feature = "instruction_synthesis"))] +impl<'txn, F> Frame<'txn, F> +where + F: FunctionReference<'txn>, +{ + pub fn set_with_states(&mut self, pc: u16, locals: Vec) { + self.pc = pc; + self.locals = locals; + } + + pub fn get_locals(&self) -> &Vec { + &self.locals + } +} diff --git a/language/vm/vm_runtime/src/gas_meter.rs b/language/vm/vm_runtime/src/gas_meter.rs new file mode 100644 index 0000000000000..159295463f944 --- /dev/null +++ b/language/vm/vm_runtime/src/gas_meter.rs @@ -0,0 +1,312 @@ +// Copyright (c) The Libra Core Contributors +// SPDX-License-Identifier: Apache-2.0 + +//! Gas metering logic for the Move VM. +use crate::{ + code_cache::module_cache::ModuleCache, execution_stack::ExecutionStack, + loaded_data::function::FunctionReference, +}; +use types::account_address::ADDRESS_LENGTH; +use vm::{access::ModuleAccess, errors::*, file_format::Bytecode, gas_schedule::*}; + +/// Holds the state of the gas meter. +pub struct GasMeter { + // The current amount of gas that is left ("unburnt gas") in the gas meter. + current_gas_left: GasUnits, + + // We need to disable and enable gas metering for both the prologue and epilogue of the Account + // contract. The VM will then internally unset/set this flag before executing either of them. + meter_on: bool, +} + +// NB: A number of the functions/methods in this struct will return a VMResult +// since we will need to access stack and memory states, and we need to be able +// to report errors properly from these accesses. +impl GasMeter { + /// Create a new gas meter with starting gas amount `gas_amount` + pub fn new(gas_amount: GasUnits) -> Self { + GasMeter { + current_gas_left: gas_amount, + meter_on: true, + } + } + + /// Charges additional gas for the transaction based upon the total size (in bytes) of the + /// submitted transaction. It is important that we charge for the transaction size since a + /// transaction can contain arbitrary amounts of bytes in the `note` field. We also want to + /// disinsentivize large transactions with large notes, so we charge the same amount up to a + /// cutoff, after which we start charging at a greater rate. + pub fn charge_transaction_gas<'alloc, 'txn, P>( + &mut self, + transaction_size: u64, + stk: &ExecutionStack<'alloc, 'txn, P>, + ) -> VMResult<()> + where + 'alloc: 'txn, + P: ModuleCache<'alloc>, + { + let cost = calculate_intrinsic_gas(transaction_size); + self.consume_gas(cost, stk) + } + + /// Queries the internal state of the gas meter to determine if it has at + /// least `needed_gas` amount of gas. + pub fn has_gas(&self, needed_gas: GasUnits) -> bool { + self.current_gas_left >= needed_gas + } + + /// Disables metering of gas. + /// + /// We need to disable and enable gas metering for both the prologue and epilogue of the + /// Account contract. The VM will then internally turn off gas metering before executing either + /// of them using this method. + pub fn disable_metering(&mut self) { + self.meter_on = false; + } + + /// Re-enables metering of gas. + /// + /// After executing the prologue and epilogue in the Account contract gas metering is re-enabled + /// using this method. The VM is responsible for internally calling this method after disabling + /// gas metering. + pub fn enable_metering(&mut self) { + self.meter_on = true; + } + + /// A wrapper that calculates and then consumes the gas unless metering is disabled. + pub fn calculate_and_consume<'alloc, 'txn, P>( + &mut self, + instr: &Bytecode, + stk: &ExecutionStack<'alloc, 'txn, P>, + memory_size: AbstractMemorySize, + ) -> VMResult<()> + where + 'alloc: 'txn, + P: ModuleCache<'alloc>, + { + if self.meter_on { + let instruction_gas = try_runtime!(self.gas_for_instruction(instr, stk, memory_size)); + self.consume_gas(instruction_gas, stk) + } else { + Ok(Ok(())) + } + } + + /// Calculate the gas usage for an instruction taking into account the current stack state, and + /// the size of memory that is being accessed. + pub fn gas_for_instruction<'alloc, 'txn, P>( + &mut self, + instr: &Bytecode, + stk: &ExecutionStack<'alloc, 'txn, P>, + memory_size: AbstractMemorySize, + ) -> VMResult + where + 'alloc: 'txn, + P: ModuleCache<'alloc>, + { + // Get the base cost for the instruction. + let instruction_reqs = match instr { + Bytecode::Add + | Bytecode::Sub + | Bytecode::Mul + | Bytecode::Mod + | Bytecode::Div + | Bytecode::BitOr + | Bytecode::BitAnd + | Bytecode::Xor + | Bytecode::Or + | Bytecode::And + | Bytecode::Not + | Bytecode::Eq + | Bytecode::Neq + | Bytecode::Lt + | Bytecode::Gt + | Bytecode::Le + | Bytecode::LdTrue + | Bytecode::LdFalse + | Bytecode::LdConst(_) + | Bytecode::Branch(_) + | Bytecode::Assert + | Bytecode::Pop + | Bytecode::BrTrue(_) + | Bytecode::BrFalse(_) + | Bytecode::GetTxnGasUnitPrice + | Bytecode::GetTxnMaxGasUnits + | Bytecode::GetGasRemaining + | Bytecode::GetTxnPublicKey + | Bytecode::GetTxnSenderAddress + | Bytecode::GetTxnSequenceNumber + | Bytecode::Ge + | Bytecode::EmitEvent + | Bytecode::FreezeRef => { + let default_gas = static_cost_instr(instr, 1); + Self::gas_of(default_gas) + } + Bytecode::LdAddr(_) => { + let size = ADDRESS_LENGTH as AbstractMemorySize; + let default_gas = static_cost_instr(instr, size); + Self::gas_of(default_gas) + } + Bytecode::LdByteArray(idx) => { + let byte_array_ref = stk.top_frame()?.module().byte_array_at(*idx); + let byte_array_len = byte_array_ref.len() as AbstractMemorySize; + let byte_array_len = words_in(byte_array_len as AbstractMemorySize); + let default_gas = static_cost_instr(instr, byte_array_len); + Self::gas_of(default_gas) + } + // We charge by the length of the string being stored on the stack. + Bytecode::LdStr(idx) => { + let string_ref = stk.top_frame()?.module().string_at(*idx); + let str_len = string_ref.len() as AbstractMemorySize; + let str_len = words_in(str_len as AbstractMemorySize); + let default_gas = static_cost_instr(instr, str_len); + Self::gas_of(default_gas) + } + Bytecode::StLoc(_) => { + // Get the local to store + let local = stk.peek()?; + // Get the size of the local + let size = local.size(); + let default_gas = static_cost_instr(instr, size); + Self::gas_of(default_gas) + } + // Note that a moveLoc incurs a copy overhead + Bytecode::CopyLoc(local_idx) | Bytecode::MoveLoc(local_idx) => { + let local = stk.top_frame()?.get_local(*local_idx)?; + let size = local.size(); + let default_gas = static_cost_instr(instr, size); + Self::gas_of(default_gas) + } + // A return does not affect the value stack at all, and simply pops the call stack + // -- the callee's frame then knows that the return value(s) will be at the top of the + // value stack. Because of this, the cost of the instruction is not dependent upon the + // size of the value being returned. + Bytecode::Ret => { + let default_gas = static_cost_instr(instr, 1); + Self::gas_of(default_gas) + } + Bytecode::Call(call_idx) => { + let self_module = &stk.top_frame()?.module(); + let function_ref = stk + .module_cache + .resolve_function_ref(self_module, *call_idx)? + .ok_or(VMInvariantViolation::LinkerError)?; + if function_ref.is_native() { + 0 // This will be costed at the call site/by the native function + } else { + let call_size = function_ref.arg_count(); + let call_gas = static_cost_instr(instr, call_size as u64); + Self::gas_of(call_gas) + } + } + Bytecode::Unpack(_) => { + let size = stk.peek()?.size(); + Self::gas_of(static_cost_instr(instr, size)) + } + Bytecode::Pack(struct_idx) => { + let struct_def = &stk.top_frame()?.module().module.struct_def_at(*struct_idx); + // Similar logic applies here as in Call, so we probably don't need to take + // into account the size of the values on the value stack that we are placing into + // the struct. + let arg_count = struct_def.field_count; + let total_size = u64::from(arg_count) + STRUCT_SIZE; + let new_gas = static_cost_instr(instr, total_size); + Self::gas_of(new_gas) + } + Bytecode::WriteRef => { + // Get a reference to the value that we are going to write + let write_val = stk.peek_at(1)?; + // Get the size of this value and charge accordingly + let size = write_val.size(); + let default_gas = static_cost_instr(instr, size); + Self::gas_of(default_gas) + } + | Bytecode::ReadRef => { + let size = stk.peek()?.size(); + let default_gas = static_cost_instr(instr, size); + Self::gas_of(default_gas) + } + | Bytecode::BorrowLoc(_) + | Bytecode::BorrowField(_) => { + let default_gas = static_cost_instr(instr, 1); + Self::gas_of(default_gas) + } + Bytecode::CreateAccount => Self::gas_of(static_cost_instr(instr, DEFAULT_ACCOUNT_SIZE)), + // Releasing a reference is not dependent on the size of the underlying data + Bytecode::ReleaseRef => { + Self::gas_of(static_cost_instr(instr, 1)) + } + // Note that we charge twice for these operations; once at the start of + // `execute_single_instruction` we charge once with size 1. This then covers the cost + // of accessing the value and guards (somewhat) against abusive memory accesses. Once + // we have the value/resource in hand we then charge a cost that is dependent on the + // size of the value being moved. + // + // Borrowing a global causes a read of the underlying data. Therefore the cost is + // dependent on the size of the data being borrowed. + Bytecode::BorrowGlobal(_) + // In the process of determining if a resource exists, we need to load/read that + // memory. We therefore need to charge for this query based on the size of the data + // being accessed. + | Bytecode::Exists(_) + // A MoveFrom does not trigger a write to memory. But it does push the value of that + // size onto the stack. So we charge based upon the size of the instruction. + | Bytecode::MoveFrom(_) + // A MoveToSender causes a write of the resource to storage. We therefore charge based + // on the size of the resource being moved. + | Bytecode::MoveToSender(_) => { + let mem_size = if memory_size > 1 { + memory_size - 1 + } else { + 0 // We already charged for size 1 + }; + Self::gas_of(static_cost_instr(instr, mem_size)) + } + }; + Ok(Ok(instruction_reqs)) + } + + /// Get the amount of gas that remains (that has _not_ been consumed) in the gas meter. + /// + /// This method is used by the `GetGasRemaining` bytecode instruction to get the current + /// amount of gas remaining at the point of call. + pub fn remaining_gas(&self) -> GasUnits { + self.current_gas_left + } + + /// Consume the amount of gas given by `gas_amount`. If there is not enough gas + /// left in the internal state, an `OutOfGasError` is returned. + pub fn consume_gas<'alloc, 'txn, P>( + &mut self, + gas_amount: GasUnits, + stk: &ExecutionStack<'alloc, 'txn, P>, + ) -> VMResult<()> + where + 'alloc: 'txn, + P: ModuleCache<'alloc>, + { + if !self.meter_on { + return Ok(Ok(())); + } + if self.current_gas_left >= gas_amount { + self.current_gas_left -= gas_amount; + Ok(Ok(())) + } else { + // Zero out the internal gas state + self.current_gas_left = 0; + let location = stk.location().unwrap_or_default(); + Ok(Err(VMRuntimeError { + loc: location, + err: VMErrorKind::OutOfGasError, + })) + } + } + + /// Take a GasCost from our gas schedule and convert it to a total gas charge in `GasUnits`. + /// + /// This is used internally for converting from a `GasCost` which is a triple of numbers + /// represeing instruction, stack, and memory consumption into a number of `GasUnits`. + fn gas_of(gas_cost: GasCost) -> GasUnits { + gas_cost.instruction_gas + gas_cost.memory_gas + gas_cost.stack_gas + } +} diff --git a/language/vm/vm_runtime/src/identifier.rs b/language/vm/vm_runtime/src/identifier.rs new file mode 100644 index 0000000000000..f53d317e8d94b --- /dev/null +++ b/language/vm/vm_runtime/src/identifier.rs @@ -0,0 +1,39 @@ +// Copyright (c) The Libra Core Contributors +// SPDX-License-Identifier: Apache-2.0 +//! A bunch of helper functions to fetch the storage key for move resources and values. + +use types::{ + access_path::{AccessPath, Accesses}, + account_address::AccountAddress, + language_storage::{ResourceKey, StructTag}, +}; +use vm::{ + access::{BaseAccess, ModuleAccess}, + file_format::{CompiledModule, StructDefinitionIndex}, +}; + +#[cfg(test)] +#[path = "unit_tests/identifier_prop_tests.rs"] +mod identifier_prop_tests; + +/// Get the StructTag for a StructDefinition defined in a published module. +pub fn resource_storage_key(module: &CompiledModule, idx: StructDefinitionIndex) -> StructTag { + let resource = module.struct_def_at(idx); + let res_handle = module.struct_handle_at(resource.struct_handle); + let res_module = module.module_handle_at(res_handle.module); + let res_name = module.string_at(res_handle.name); + let res_mod_addr = module.address_at(res_module.address); + let res_mod_name = module.string_at(res_module.name); + StructTag { + module: res_mod_name.to_string(), + address: *res_mod_addr, + name: res_name.to_string(), + type_params: vec![], + } +} + +/// Get the AccessPath to a resource stored under `address` with type name `tag` +pub fn create_access_path(address: &AccountAddress, tag: StructTag) -> AccessPath { + let resource_tag = ResourceKey::new(*address, tag); + AccessPath::resource_access_path(&resource_tag, &Accesses::empty()) +} diff --git a/language/vm/vm_runtime/src/lib.rs b/language/vm/vm_runtime/src/lib.rs new file mode 100644 index 0000000000000..3f93d13937c67 --- /dev/null +++ b/language/vm/vm_runtime/src/lib.rs @@ -0,0 +1,176 @@ +// Copyright (c) The Libra Core Contributors +// SPDX-License-Identifier: Apache-2.0 + +//! # The VM runtime +//! +//! ## Transaction flow +//! +//! This is the path taken to process a single transaction. +//! +//! ```text +//! SignedTransaction +//! + +//! | +//! +--------------------------|-------------------+ +//! | Validate +--------------+--------------+ | +//! | | | | +//! | | check signature | | +//! | | | | +//! | +--------------+--------------+ | +//! | | | +//! | | | +//! | v | +//! | +--------------+--------------+ | +//! | | | | +//! | | check size and gas | | +//! | | | +---------------------------------+ +//! | +--------------+--------------+ | validation error | +//! | | | | +//! | | | | +//! | v | | +//! | +--------------+--------------+ | | +//! | | | | | +//! | | run prologue | | | +//! | | | | | +//! | +--------------+--------------+ | | +//! | | | | +//! +--------------------------|-------------------+ | +//! | | +//! +--------------------------|-------------------+ | +//! | v | | +//! | Verify +--------------+--------------+ | | +//! | | | | | +//! | | deserialize script, | | | +//! | | verify arguments | | | +//! | | | | | +//! | +--------------+--------------+ | | +//! | | | | +//! | | | v +//! | v | +----------------+------+ +//! | +--------------+--------------+ | | | +//! | | | +------------------->+ discard, no write set | +//! | | deserialize modules | | verification error | | +//! | | | | +----------------+------+ +//! | +--------------+--------------+ | ^ +//! | | | | +//! | | | | +//! | v | | +//! | +--------------+--------------+ | | +//! | | | | | +//! | | verify scripts and modules | | | +//! | | | | | +//! | +--------------+--------------+ | | +//! | | | | +//! +--------------------------|-------------------+ | +//! | | +//! +--------------------------|-------------------+ | +//! | v | | +//! | Execute +--------------+--------------+ | | +//! | | | | | +//! | | execute main | | | +//! | | | | | +//! | +--------------+--------------+ | | +//! | | | | +//! | success or failure | | | +//! | v | | +//! | +--------------+--------------+ | | +//! | | | +---------------------------------+ +//! | | run epilogue | | invariant violation (internal panic) +//! | | | | +//! | +--------------+--------------+ | +//! | | | +//! | | | +//! | v | +//! | +--------------+--------------+ | +-----------------------+ +//! | | | | execution failure | | +//! | | make write set +------------------------>+ keep, only charge gas | +//! | | | | | | +//! | +--------------+--------------+ | +-----------------------+ +//! | | | +//! +--------------------------|-------------------+ +//! | +//! v +//! +--------------+--------------+ +//! | | +//! | keep, transaction executed | +//! | + gas charged | +//! | | +//! +-----------------------------+ +//! ``` + +#[macro_use] +extern crate vm; +#[macro_use] +extern crate lazy_static; +#[macro_use] +extern crate rental; + +mod block_processor; +mod counters; +mod frame; +mod gas_meter; +mod move_vm; +mod process_txn; +mod proptest_types; +mod runtime; +mod value_serializer; + +pub mod code_cache; +pub mod data_cache; +pub mod identifier; +pub mod loaded_data; +pub mod txn_executor; +pub mod value; + +#[cfg(feature = "instruction_synthesis")] +pub mod execution_stack; +#[cfg(not(feature = "instruction_synthesis"))] +mod execution_stack; + +pub use move_vm::MoveVM; +pub use process_txn::verify::static_verify_program; +pub use txn_executor::execute_function; + +use config::config::VMConfig; +use state_view::StateView; +use types::{ + transaction::{SignedTransaction, TransactionOutput}, + vm_error::VMStatus, +}; +use vm::{errors::VMInvariantViolation, IndexKind}; + +pub(crate) fn bounded_fetch( + pool: &[T], + idx: usize, + bound_type: IndexKind, +) -> Result<&T, VMInvariantViolation> { + pool.get(idx) + .ok_or_else(|| VMInvariantViolation::IndexOutOfBounds(bound_type, pool.len(), idx)) +} + +/// This trait describes the VM's verification interfaces. +pub trait VMVerifier { + /// Executes the prologue of the Libra Account and verifies that the transaction is valid. + /// only. Returns `None` if the transaction was validated, or Some(VMStatus) if the transaction + /// was unable to be validated with status `VMStatus`. + fn validate_transaction( + &self, + transaction: SignedTransaction, + state_view: &dyn StateView, + ) -> Option; +} + +/// This trait describes the VM's execution interface. +pub trait VMExecutor { + // NOTE: At the moment there are no persistent caches that live past the end of a block (that's + // why execute_block doesn't take &self.) + // There are some cache invalidation issues around transactions publishing code that need to be + // sorted out before that's possible. + + /// Executes a block of transactions and returns output for each one of them. + fn execute_block( + transactions: Vec, + config: &VMConfig, + state_view: &dyn StateView, + ) -> Vec; +} diff --git a/language/vm/vm_runtime/src/loaded_data/function.rs b/language/vm/vm_runtime/src/loaded_data/function.rs new file mode 100644 index 0000000000000..b45e94eb4a58d --- /dev/null +++ b/language/vm/vm_runtime/src/loaded_data/function.rs @@ -0,0 +1,130 @@ +// Copyright (c) The Libra Core Contributors +// SPDX-License-Identifier: Apache-2.0 +//! Loaded representation for function definitions and handles. + +use crate::loaded_data::loaded_module::LoadedModule; +use vm::{ + access::{BaseAccess, ModuleAccess}, + errors::*, + file_format::{Bytecode, CodeUnit, FunctionDefinitionIndex}, + internals::ModuleIndex, + CompiledModule, +}; + +/// Trait that defines the internal representation of a move function. +pub trait FunctionReference<'txn>: Sized + Clone { + /// Create a new function reference to a module + fn new( + module: &'txn LoadedModule, + idx: FunctionDefinitionIndex, + ) -> Result; + + /// Fetch the reference to the module where the function is defined. + fn module(&self) -> &'txn LoadedModule; + + /// Fetch the code of the function definition. + fn code_definition(&self) -> &'txn [Bytecode]; + + /// Return the signature vector for the function's local value + fn local_count(&self) -> usize; + + /// Return function's argument type + fn arg_count(&self) -> usize; + + /// Return function's return type. + fn return_count(&self) -> usize; + + /// Return whether the function is native or not + fn is_native(&self) -> bool; + + /// Return the name of the function + fn name(&self) -> &'txn str; +} + +/// Resolved form of a function handle +#[derive(Clone, Debug, Eq, PartialEq)] +pub struct FunctionRef<'txn> { + module: &'txn LoadedModule, + def: &'txn FunctionDef, + name: &'txn str, +} + +impl<'txn> FunctionReference<'txn> for FunctionRef<'txn> { + fn new( + module: &'txn LoadedModule, + idx: FunctionDefinitionIndex, + ) -> Result { + let def = &module.function_defs[idx.into_index()]; + let fn_definition = &module.module.function_def_at(idx); + let name_idx = module + .module + .function_handle_at(fn_definition.function) + .name; + Ok(FunctionRef { + module, + def, + name: module.string_at(name_idx), + }) + } + + fn module(&self) -> &'txn LoadedModule { + &self.module + } + + fn code_definition(&self) -> &'txn [Bytecode] { + &self.def.code + } + + fn local_count(&self) -> usize { + self.def.local_count + } + + fn arg_count(&self) -> usize { + self.def.arg_count + } + + fn return_count(&self) -> usize { + self.def.local_count + } + + fn is_native(&self) -> bool { + (self.def.flags & CodeUnit::NATIVE) == CodeUnit::NATIVE + } + + fn name(&self) -> &'txn str { + self.name + } +} + +/// Resolved form of a function definition +#[derive(Clone, Debug, Eq, PartialEq)] +pub struct FunctionDef { + pub local_count: usize, + pub arg_count: usize, + pub return_count: usize, + pub code: Vec, + pub flags: u8, +} + +impl FunctionDef { + pub fn new(module: &CompiledModule, idx: FunctionDefinitionIndex) -> Self { + let definition = module.function_def_at(idx); + let code = definition.code.code.clone(); + let handle = module.function_handle_at(definition.function); + let function_sig = module.function_signature_at(handle.signature); + let flags = definition.flags; + + FunctionDef { + code, + flags, + arg_count: function_sig.arg_types.len(), + return_count: function_sig.return_types.len(), + // Local count for native function is omitted + local_count: if (flags & CodeUnit::NATIVE) == CodeUnit::NATIVE { + 0 + } else { + module.locals_signature_at(definition.code.locals).0.len() + }, + } + } +} diff --git a/language/vm/vm_runtime/src/loaded_data/loaded_module.rs b/language/vm/vm_runtime/src/loaded_data/loaded_module.rs new file mode 100644 index 0000000000000..ed579f66b7024 --- /dev/null +++ b/language/vm/vm_runtime/src/loaded_data/loaded_module.rs @@ -0,0 +1,152 @@ +// Copyright (c) The Libra Core Contributors +// SPDX-License-Identifier: Apache-2.0 +//! Loaded representation for Move modules. + +use crate::loaded_data::{function::FunctionDef, struct_def::StructDef}; +use std::{collections::HashMap, sync::RwLock}; +use types::{account_address::AccountAddress, byte_array::ByteArray}; +use vm::{ + access::{BaseAccess, ModuleAccess}, + errors::VMInvariantViolation, + file_format::{ + AddressPoolIndex, ByteArrayPoolIndex, CompiledModule, FieldDefinitionIndex, + FunctionDefinitionIndex, MemberCount, StringPoolIndex, StructDefinitionIndex, TableIndex, + }, + internals::ModuleIndex, +}; + +/// Defines a loaded module in the memory. Currently we just store module itself with a bunch of +/// reverse mapping that allows querying definition of struct/function by name. +#[derive(Debug, Eq, PartialEq)] +pub struct LoadedModule { + pub module: CompiledModule, + #[allow(dead_code)] + pub struct_defs_table: HashMap, + #[allow(dead_code)] + pub field_defs_table: HashMap, + + pub function_defs_table: HashMap, + + pub function_defs: Vec, + + pub field_offsets: Vec, + + cache: LoadedModuleCache, +} + +#[derive(Debug)] +struct LoadedModuleCache { + // TODO: this can probably be made lock-free by using AtomicPtr or the "atom" crate. Consider + // doing so in the future. + struct_defs: Vec>>, +} + +impl PartialEq for LoadedModuleCache { + fn eq(&self, _other: &Self) -> bool { + // This is a cache so ignore equality checks. + true + } +} + +impl Eq for LoadedModuleCache {} + +impl LoadedModule { + pub fn new(module: CompiledModule) -> Result { + let mut struct_defs_table = HashMap::new(); + let mut field_defs_table = HashMap::new(); + let mut function_defs_table = HashMap::new(); + let mut function_defs = vec![]; + let struct_defs = module.struct_defs().map(|_| RwLock::new(None)).collect(); + let cache = LoadedModuleCache { struct_defs }; + + let mut field_offsets: Vec = module.field_defs().map(|_| 0).collect(); + + for (idx, struct_def) in module.struct_defs().enumerate() { + let name = module + .string_at(module.struct_handle_at(struct_def.struct_handle).name) + .to_string(); + let sd_idx = StructDefinitionIndex::new(idx as TableIndex); + struct_defs_table.insert(name, sd_idx); + + for i in 0..struct_def.field_count { + field_offsets[struct_def.fields.into_index() + i as usize] = i; + } + } + for (idx, field_def) in module.field_defs().enumerate() { + let name = module.string_at(field_def.name).to_string(); + let fd_idx = FieldDefinitionIndex::new(idx as TableIndex); + field_defs_table.insert(name, fd_idx); + } + for (idx, function_def) in module.function_defs().enumerate() { + let name = module + .string_at(module.function_handle_at(function_def.function).name) + .to_string(); + let fd_idx = FunctionDefinitionIndex::new(idx as TableIndex); + function_defs_table.insert(name, fd_idx); + function_defs.push(FunctionDef::new(&module, fd_idx)); + } + Ok(LoadedModule { + module, + struct_defs_table, + field_defs_table, + function_defs_table, + function_defs, + field_offsets, + cache, + }) + } + + pub fn address_at(&self, idx: AddressPoolIndex) -> &AccountAddress { + self.module.address_at(idx) + } + + pub fn string_at(&self, idx: StringPoolIndex) -> &str { + self.module.string_at(idx) + } + + pub fn byte_array_at(&self, idx: ByteArrayPoolIndex) -> ByteArray { + self.module.byte_array_at(idx).clone() + } + + pub fn field_count_at(&self, idx: StructDefinitionIndex) -> MemberCount { + self.module.struct_def_at(idx).field_count + } + + /// Return a cached copy of the struct def at this index, if available. + pub fn cached_struct_def_at(&self, idx: StructDefinitionIndex) -> Option { + let cached = self.cache.struct_defs[idx.into_index()] + .read() + .expect("lock poisoned"); + cached.clone() + } + + /// Cache this struct def at this location. + pub fn cache_struct_def(&self, idx: StructDefinitionIndex, def: StructDef) { + let mut cached = self.cache.struct_defs[idx.into_index()] + .write() + .expect("lock poisoned"); + // XXX If multiple writers call this at the same time, the last write wins. Is this + // desirable? + cached.replace(def); + } + + pub fn get_field_offset( + &self, + idx: FieldDefinitionIndex, + ) -> Result { + self.field_offsets + .get(idx.into_index()) + .cloned() + .ok_or(VMInvariantViolation::LinkerError) + } +} + +// Compile-time test to ensure that this struct stays thread-safe. +#[test] +fn assert_thread_safe() { + fn assert_send() {}; + fn assert_sync() {}; + + assert_send::(); + assert_sync::(); +} diff --git a/language/vm/vm_runtime/src/loaded_data/mod.rs b/language/vm/vm_runtime/src/loaded_data/mod.rs new file mode 100644 index 0000000000000..75e0c2a0f2af5 --- /dev/null +++ b/language/vm/vm_runtime/src/loaded_data/mod.rs @@ -0,0 +1,10 @@ +// Copyright (c) The Libra Core Contributors +// SPDX-License-Identifier: Apache-2.0 +//! Loaded definition of code data used in runtime. +//! +//! This module contains the loaded definition of code data used in runtime. + +pub mod function; +pub mod loaded_module; +pub mod struct_def; +pub mod types; diff --git a/language/vm/vm_runtime/src/loaded_data/struct_def.rs b/language/vm/vm_runtime/src/loaded_data/struct_def.rs new file mode 100644 index 0000000000000..2e0759ff809e3 --- /dev/null +++ b/language/vm/vm_runtime/src/loaded_data/struct_def.rs @@ -0,0 +1,62 @@ +// Copyright (c) The Libra Core Contributors +// SPDX-License-Identifier: Apache-2.0 +//! Loaded representation for Move struct definition. + +use crate::loaded_data::types::Type; +use canonical_serialization::*; +use failure::prelude::*; +use std::sync::Arc; + +// Note that this data structure can represent recursive types but will end up creating reference +// cycles, which is bad. Other parts of the system disallow recursive types for now, but this may +// need to be handled more explicitly in the future. +/// Resolved form of struct definition. +#[derive(Debug, Eq, PartialEq, Clone)] +pub struct StructDef(Arc); + +impl StructDef { + /// Constructs a new [`StructDef`] + pub fn new(field_definitions: Vec) -> Self { + Self(Arc::new(StructDefInner { field_definitions })) + } + + /// Get type declaration for each field in the struct. + #[inline] + pub fn field_definitions(&self) -> &[Type] { + &self.0.field_definitions + } +} + +// Do not implement Clone for this -- the outer StructDef should be Arc'd. +#[derive(Debug, Eq, PartialEq)] +struct StructDefInner { + field_definitions: Vec, +} + +/// This isn't used by any normal code at the moment, but is used by the fuzzer to serialize types +/// alongside values. +impl CanonicalSerialize for StructDef { + fn serialize(&self, serializer: &mut impl CanonicalSerializer) -> Result<()> { + let field_defs = self.field_definitions(); + // Encode the number of field definitions, then the definitions themselves. + field_defs.len().serialize(serializer)?; + for field_def in field_defs { + field_def.serialize(serializer)?; + } + Ok(()) + } +} + +impl CanonicalDeserialize for StructDef { + fn deserialize(deserializer: &mut impl CanonicalDeserializer) -> Result + where + Self: Sized, + { + // Libra only runs on 64-bit machines. + let num_defs = deserializer.decode_u64()? as usize; + let field_defs: Result<_> = (0..num_defs) + .map(|_| Type::deserialize(deserializer)) + .collect(); + Ok(StructDef::new(field_defs?)) + } +} diff --git a/language/vm/vm_runtime/src/loaded_data/types.rs b/language/vm/vm_runtime/src/loaded_data/types.rs new file mode 100644 index 0000000000000..a2127de5de6ab --- /dev/null +++ b/language/vm/vm_runtime/src/loaded_data/types.rs @@ -0,0 +1,82 @@ +// Copyright (c) The Libra Core Contributors +// SPDX-License-Identifier: Apache-2.0 +//! Loaded representation for runtime types. + +use crate::loaded_data::struct_def::StructDef; +use canonical_serialization::*; +use failure::prelude::*; + +#[cfg(test)] +#[path = "../unit_tests/type_prop_tests.rs"] +mod type_prop_tests; + +/// Resolved form of runtime types. +#[derive(Debug, Eq, PartialEq, Clone)] +pub enum Type { + Bool, + U64, + String, + ByteArray, + Address, + Struct(StructDef), + Reference(Box), + MutableReference(Box), +} + +/// This isn't used by any normal code at the moment, but is used by the fuzzer to serialize types +/// alongside values. +impl CanonicalSerialize for Type { + fn serialize(&self, serializer: &mut impl CanonicalSerializer) -> Result<()> { + use Type::*; + + // Add a type for each tag. + let _: &mut _ = match self { + Bool => serializer.encode_u8(0x01)?, + U64 => serializer.encode_u8(0x02)?, + String => serializer.encode_u8(0x03)?, + ByteArray => serializer.encode_u8(0x04)?, + Address => serializer.encode_u8(0x05)?, + Struct(struct_def) => { + serializer.encode_u8(0x06)?; + struct_def.serialize(serializer)?; + serializer + } + Reference(ty) => { + serializer.encode_u8(0x07)?; + ty.serialize(serializer)?; + serializer + } + MutableReference(ty) => { + serializer.encode_u8(0x08)?; + ty.serialize(serializer)?; + serializer + } + }; + Ok(()) + } +} + +impl CanonicalDeserialize for Type { + fn deserialize(deserializer: &mut impl CanonicalDeserializer) -> Result + where + Self: Sized, + { + use Type::*; + + let ty = match deserializer.decode_u8()? { + 0x01 => Bool, + 0x02 => U64, + 0x03 => String, + 0x04 => ByteArray, + 0x05 => Address, + 0x06 => Struct(StructDef::deserialize(deserializer)?), + 0x07 => Reference(Box::new(Type::deserialize(deserializer)?)), + 0x08 => MutableReference(Box::new(Type::deserialize(deserializer)?)), + other => bail!( + "Error while deserializing type: found unexpected tag {:#x}", + other + ), + }; + Ok(ty) + } +} diff --git a/language/vm/vm_runtime/src/move_vm.rs b/language/vm/vm_runtime/src/move_vm.rs new file mode 100644 index 0000000000000..7227919a014de --- /dev/null +++ b/language/vm/vm_runtime/src/move_vm.rs @@ -0,0 +1,81 @@ +// Copyright (c) The Libra Core Contributors +// SPDX-License-Identifier: Apache-2.0 + +use crate::{loaded_data::loaded_module::LoadedModule, runtime::VMRuntime, VMExecutor, VMVerifier}; +use state_view::StateView; +use std::sync::Arc; +use types::{ + transaction::{SignedTransaction, TransactionOutput}, + vm_error::VMStatus, +}; +use vm_cache_map::Arena; + +rental! { + mod move_vm_definition { + use super::*; + + #[rental] + pub struct MoveVMImpl { + alloc: Box>, + runtime: VMRuntime<'alloc>, + } + } +} + +use config::config::VMConfig; +pub use move_vm_definition::MoveVMImpl; + +/// A wrapper to make VMRuntime standalone and thread safe. +#[derive(Clone)] +pub struct MoveVM { + inner: Arc, +} + +impl MoveVM { + pub fn new(config: &VMConfig) -> Self { + let inner = MoveVMImpl::new(Box::new(Arena::new()), |arena| { + VMRuntime::new(&*arena, config) + }); + Self { + inner: Arc::new(inner), + } + } +} + +impl VMVerifier for MoveVM { + fn validate_transaction( + &self, + transaction: SignedTransaction, + state_view: &dyn StateView, + ) -> Option { + // TODO: This should be implemented as an async function. + self.inner + .rent(move |runtime| runtime.verify_transaction(transaction, state_view)) + } +} + +impl VMExecutor for MoveVM { + fn execute_block( + transactions: Vec, + config: &VMConfig, + state_view: &dyn StateView, + ) -> Vec { + let vm = MoveVMImpl::new(Box::new(Arena::new()), |arena| { + // XXX This means that scripts and modules are NOT tested against the whitelist! This + // needs to be fixed. + VMRuntime::new(&*arena, config) + }); + vm.rent(|runtime| runtime.execute_block_transactions(transactions, state_view)) + } +} + +#[test] +fn vm_thread_safe() { + fn assert_send() {} + fn assert_sync() {} + + assert_send::(); + assert_sync::(); + assert_send::(); + assert_sync::(); +} diff --git a/language/vm/vm_runtime/src/process_txn/execute.rs b/language/vm/vm_runtime/src/process_txn/execute.rs new file mode 100644 index 0000000000000..40c662fae7140 --- /dev/null +++ b/language/vm/vm_runtime/src/process_txn/execute.rs @@ -0,0 +1,153 @@ +use crate::{ + code_cache::{module_cache::ModuleCache, script_cache::ScriptCache}, + process_txn::verify::{VerifiedTransaction, VerifiedTransactionState}, +}; +use logger::prelude::*; +use types::{ + transaction::{TransactionOutput, TransactionPayload, TransactionStatus}, + vm_error::{ExecutionStatus, VMStatus}, + write_set::WriteSet, +}; +use vm::errors::{Location, VMErrorKind, VMRuntimeError}; + +/// Represents a transaction that has been executed. +pub struct ExecutedTransaction { + output: TransactionOutput, +} + +impl ExecutedTransaction { + /// Creates a new instance by executing this transaction. + pub fn new<'alloc, 'txn, P>( + verified_txn: VerifiedTransaction<'alloc, 'txn, P>, + script_cache: &'txn ScriptCache<'alloc>, + ) -> Self + where + 'alloc: 'txn, + P: ModuleCache<'alloc>, + { + let output = execute(verified_txn, script_cache); + Self { output } + } + + /// Returns the `TransactionOutput` for this transaction. + pub fn into_output(self) -> TransactionOutput { + self.output + } +} + +fn execute<'alloc, 'txn, P>( + mut verified_txn: VerifiedTransaction<'alloc, 'txn, P>, + script_cache: &'txn ScriptCache<'alloc>, +) -> TransactionOutput +where + 'alloc: 'txn, + P: ModuleCache<'alloc>, +{ + let txn_state = verified_txn.take_state(); + + match verified_txn + .into_inner() + .into_raw_transaction() + .into_payload() + { + TransactionPayload::Program(program) => { + let VerifiedTransactionState { + mut txn_executor, + script, + modules, + } = txn_state.expect("program-based transactions should always have associated state"); + + // Add the script to the cache. + // XXX The cache should probably become a loader and do verification internally. + let (code, args, module_bytes) = program.into_inner(); + let func_ref = match script_cache.cache_script(script, &code) { + Ok(Ok(func)) => func, + Ok(Err(err)) => { + error!("[VM] Error loading script: {:?}", err); + return txn_executor.failed_transaction_cleanup(Ok(Err(err))); + } + Err(err) => { + crit!("[VM] VM error loading script: {:?}", err); + return ExecutedTransaction::discard_error_output(&err); + } + }; + + // Add modules to the cache and prepare for publishing. + let mut publish_modules = vec![]; + for (module, raw_bytes) in modules.into_iter().zip(module_bytes) { + let code_key = module.self_code_key(); + + // Make sure that there is not already a module with this name published + // under the transaction sender's account. + // Note: although this reads from the "module cache", `get_loaded_module` + // will read through the cache to fetch the module from the global storage + // if it is not already cached. + match txn_executor.module_cache().get_loaded_module(&code_key) { + Ok(None) => (), // No module with this name exists. safe to publish one + Ok(Some(_)) => { + // A module with this name already exists. It is not safe to publish + // another one; it would clobber the old module. This would break + // code that links against the module and make published resources + // from the old module inaccessible (or worse, accessible and not + // typesafe). + // We are currently developing a versioning scheme for safe updates + // of modules and resources. + return txn_executor.failed_transaction_cleanup(Ok(Err(VMRuntimeError { + loc: Location::default(), + err: VMErrorKind::DuplicateModuleName, + }))); + } + Err(err) => { + crit!("[VM] VM error loading module: {:?}", err); + return ExecutedTransaction::discard_error_output(&err); + } + } + + match txn_executor.module_cache().cache_module(module) { + Ok(()) => (), + Err(err) => { + error!("[VM] error while caching module: {:?}", err); + return ExecutedTransaction::discard_error_output(&err); + } + }; + publish_modules.push((code_key, raw_bytes)); + } + + // Set up main. + txn_executor.setup_main_args(args); + + // Run main. + match txn_executor.execute_function_impl(func_ref) { + Ok(Ok(_)) => txn_executor.transaction_cleanup(publish_modules), + Ok(Err(err)) => { + error!("[VM] User error running script: {:?}", err); + txn_executor.failed_transaction_cleanup(Ok(Err(err))) + } + Err(err) => { + crit!("[VM] VM error running script: {:?}", err); + ExecutedTransaction::discard_error_output(&err) + } + } + } + // WriteSet transaction. Just proceed and use the writeset as output. + TransactionPayload::WriteSet(write_set) => TransactionOutput::new( + write_set, + vec![], + 0, + VMStatus::Execution(ExecutionStatus::Executed).into(), + ), + } +} + +impl ExecutedTransaction { + #[inline] + pub(crate) fn discard_error_output(err: impl Into) -> TransactionOutput { + // Since this transaction will be discarded, no writeset will be included. + TransactionOutput::new( + WriteSet::default(), + vec![], + 0, + TransactionStatus::Discard(err.into()), + ) + } +} diff --git a/language/vm/vm_runtime/src/process_txn/mod.rs b/language/vm/vm_runtime/src/process_txn/mod.rs new file mode 100644 index 0000000000000..de22c0b1c6ad0 --- /dev/null +++ b/language/vm/vm_runtime/src/process_txn/mod.rs @@ -0,0 +1,61 @@ +use crate::{ + code_cache::module_cache::ModuleCache, data_cache::RemoteCache, + loaded_data::loaded_module::LoadedModule, +}; +use config::config::VMPublishingOption; +use std::marker::PhantomData; +use types::transaction::SignedTransaction; +use vm_cache_map::Arena; + +pub mod execute; +pub mod validate; +pub mod verify; + +use types::vm_error::VMStatus; +use validate::{ValidatedTransaction, ValidationMode}; + +/// The starting point for processing a transaction. All the different states involved are described +/// through the types present in submodules. +pub struct ProcessTransaction<'alloc, 'txn, P> +where + 'alloc: 'txn, + P: ModuleCache<'alloc>, +{ + txn: SignedTransaction, + module_cache: P, + data_cache: &'txn RemoteCache, + allocator: &'txn Arena, + phantom: PhantomData<&'alloc ()>, +} + +impl<'alloc, 'txn, P> ProcessTransaction<'alloc, 'txn, P> +where + 'alloc: 'txn, + P: ModuleCache<'alloc>, +{ + /// Creates a new instance of `ProcessTransaction`. + pub fn new( + txn: SignedTransaction, + module_cache: P, + data_cache: &'txn RemoteCache, + allocator: &'txn Arena, + ) -> Self { + Self { + txn, + module_cache, + data_cache, + allocator, + phantom: PhantomData, + } + } + + /// Validates this transaction. Returns a `ValidatedTransaction` on success or `VMStatus` on + /// failure. + pub fn validate( + self, + mode: ValidationMode, + publishing_option: &VMPublishingOption, + ) -> Result, VMStatus> { + ValidatedTransaction::new(self, mode, publishing_option) + } +} diff --git a/language/vm/vm_runtime/src/process_txn/validate.rs b/language/vm/vm_runtime/src/process_txn/validate.rs new file mode 100644 index 0000000000000..9a5e96ce0f462 --- /dev/null +++ b/language/vm/vm_runtime/src/process_txn/validate.rs @@ -0,0 +1,281 @@ +use crate::{ + code_cache::module_cache::{ModuleCache, TransactionModuleCache}, + data_cache::RemoteCache, + loaded_data::loaded_module::LoadedModule, + process_txn::{verify::VerifiedTransaction, ProcessTransaction}, + txn_executor::TransactionExecutor, +}; +use config::config::VMPublishingOption; +use logger::prelude::*; +use tiny_keccak::Keccak; +use types::{ + transaction::{ + SignedTransaction, TransactionPayload, MAX_TRANSACTION_SIZE_IN_BYTES, SCRIPT_HASH_LENGTH, + }, + vm_error::{VMStatus, VMValidationStatus}, +}; +use vm::{ + errors::convert_prologue_runtime_error, gas_schedule, transaction_metadata::TransactionMetadata, +}; +use vm_cache_map::Arena; + +pub fn is_allowed_script(publishing_option: &VMPublishingOption, program: &[u8]) -> bool { + match publishing_option { + VMPublishingOption::Open | VMPublishingOption::CustomScripts => true, + VMPublishingOption::Locked(whitelist) => { + let mut hash = [0u8; SCRIPT_HASH_LENGTH]; + let mut keccak = Keccak::new_sha3_256(); + keccak.update(program); + keccak.finalize(&mut hash); + whitelist.contains(&hash) + } + } +} + +/// Represents a [`SignedTransaction`] that has been *validated*. This includes all the steps +/// required to ensure that a transaction is valid, other than verifying the submitted program. +pub struct ValidatedTransaction<'alloc, 'txn, P> +where + 'alloc: 'txn, + P: ModuleCache<'alloc>, +{ + txn: SignedTransaction, + txn_state: Option>, +} + +/// The mode to validate transactions in. +#[derive(Clone, Copy, Debug, Eq, PartialEq)] +pub enum ValidationMode { + /// This is the genesis transaction. At the moment it is the only mode that allows for + /// write-set transactions. + Genesis, + /// We're only validating a transaction, not executing it. This tolerates the sequence number + /// being too new. + Validating, + /// We're executing a transaction. This runs the full suite of checks. + #[allow(dead_code)] + Executing, +} + +impl<'alloc, 'txn, P> ValidatedTransaction<'alloc, 'txn, P> +where + 'alloc: 'txn, + P: ModuleCache<'alloc>, +{ + /// Creates a new instance by validating a `SignedTransaction`. + /// + /// This should be called through [`ProcessTransaction::validate`]. + pub(super) fn new( + process_txn: ProcessTransaction<'alloc, 'txn, P>, + mode: ValidationMode, + publishing_option: &VMPublishingOption, + ) -> Result { + let ProcessTransaction { + txn, + module_cache, + data_cache, + allocator, + .. + } = process_txn; + if txn.verify_signature().is_err() { + error!("[VM] Verify signature error"); + return Err(VMStatus::Validation(VMValidationStatus::InvalidSignature)); + } + + let txn_state = match txn.payload() { + TransactionPayload::Program(program) => { + // The transaction is too large. + if txn.raw_txn_bytes_len() > MAX_TRANSACTION_SIZE_IN_BYTES { + let error_str = format!( + "max size: {}, txn size: {}", + MAX_TRANSACTION_SIZE_IN_BYTES, + txn.raw_txn_bytes_len() + ); + return Err(VMStatus::Validation( + VMValidationStatus::ExceededMaxTransactionSize(error_str), + )); + } + + // The submitted max gas units that the transaction can consume is greater than the + // maximum number of gas units bound that we have set for any + // transaction. + if txn.max_gas_amount() > gas_schedule::MAXIMUM_NUMBER_OF_GAS_UNITS { + let error_str = format!( + "max gas units: {}, gas units submitted: {}", + gas_schedule::MAXIMUM_NUMBER_OF_GAS_UNITS, + txn.max_gas_amount() + ); + return Err(VMStatus::Validation( + VMValidationStatus::MaxGasUnitsExceedsMaxGasUnitsBound(error_str), + )); + } + + // The submitted transactions max gas units needs to be at least enough to cover the + // intrinsic cost of the transaction as calculated against the size of the + // underlying `RawTransaction` + let min_txn_fee = + gas_schedule::calculate_intrinsic_gas(txn.raw_txn_bytes_len() as u64); + if txn.max_gas_amount() < min_txn_fee { + let error_str = format!( + "min gas required for txn: {}, gas submitted: {}", + min_txn_fee, + txn.max_gas_amount() + ); + return Err(VMStatus::Validation( + VMValidationStatus::MaxGasUnitsBelowMinTransactionGasUnits(error_str), + )); + } + + // The submitted gas price is less than the minimum gas unit price set by the VM. + // NB: MIN_PRICE_PER_GAS_UNIT may equal zero, but need not in the future. Hence why + // we turn off the clippy warning. + #[allow(clippy::absurd_extreme_comparisons)] + let below_min_bound = txn.gas_unit_price() < gas_schedule::MIN_PRICE_PER_GAS_UNIT; + if below_min_bound { + let error_str = format!( + "gas unit min price: {}, submitted price: {}", + gas_schedule::MIN_PRICE_PER_GAS_UNIT, + txn.gas_unit_price() + ); + return Err(VMStatus::Validation( + VMValidationStatus::GasUnitPriceBelowMinBound(error_str), + )); + } + + // The submitted gas price is greater than the maximum gas unit price set by the VM. + if txn.gas_unit_price() > gas_schedule::MAX_PRICE_PER_GAS_UNIT { + let error_str = format!( + "gas unit max price: {}, submitted price: {}", + gas_schedule::MAX_PRICE_PER_GAS_UNIT, + txn.gas_unit_price() + ); + return Err(VMStatus::Validation( + VMValidationStatus::GasUnitPriceAboveMaxBound(error_str), + )); + } + + // Verify against whitelist if we are locked. Otherwise allow. + if !is_allowed_script(&publishing_option, &program.code()) { + error!("[VM] Custom scripts not allowed: {:?}", &program.code()); + return Err(VMStatus::Validation(VMValidationStatus::UnknownScript)); + } + + if !publishing_option.is_open() { + // Not allowing module publishing for now. + if !program.modules().is_empty() { + error!("[VM] Custom modules not allowed"); + return Err(VMStatus::Validation(VMValidationStatus::UnknownModule)); + } + } + + let metadata = TransactionMetadata::new(&txn); + let mut txn_state = + ValidatedTransactionState::new(metadata, module_cache, data_cache, allocator); + + // Run the prologue to ensure that clients have enough gas and aren't tricking us by + // sending us garbage. + // TODO: write-set transactions (other than genesis??) should also run the prologue. + match txn_state.txn_executor.run_prologue() { + Ok(Ok(_)) => {} + Ok(Err(ref err)) => { + let vm_status = convert_prologue_runtime_error(&err, &txn.sender()); + + // In validating mode, accept transactions with sequence number greater + // or equal to the current sequence number. + match (mode, vm_status) { + ( + ValidationMode::Validating, + VMStatus::Validation(VMValidationStatus::SequenceNumberTooNew), + ) => { + trace!("[VM] Sequence number too new error ignored"); + } + (_, vm_status) => { + warn!("[VM] Error in prologue: {:?}", err); + return Err(vm_status); + } + } + } + Err(ref err) => { + error!("[VM] VM error in prologue: {:?}", err); + return Err(err.into()); + } + }; + + Some(txn_state) + } + TransactionPayload::WriteSet(write_set) => { + // The only acceptable write-set transaction for now is for the genesis + // transaction. + // XXX figure out a story for hard forks. + if mode != ValidationMode::Genesis { + error!("[VM] Attempt to process genesis after initialization"); + return Err(VMStatus::Validation(VMValidationStatus::RejectedWriteSet)); + } + + for (_access_path, write_op) in write_set { + // Genesis transactions only add entries, never delete them. + if write_op.is_deletion() { + error!("[VM] Bad genesis block"); + // TODO: return more detailed error somehow? + return Err(VMStatus::Validation(VMValidationStatus::InvalidWriteSet)); + } + } + + None + } + }; + + Ok(Self { txn, txn_state }) + } + + /// Verifies the bytecode in this transaction. + pub fn verify(self) -> Result, VMStatus> { + VerifiedTransaction::new(self) + } + + /// Returns a reference to the `SignedTransaction` within. + pub fn as_inner(&self) -> &SignedTransaction { + &self.txn + } + + /// Consumes `self` and returns the `SignedTransaction` within. + #[allow(dead_code)] + pub fn into_inner(self) -> SignedTransaction { + self.txn + } + + /// Returns the `ValidatedTransactionState` within. + pub(super) fn take_state(&mut self) -> Option> { + self.txn_state.take() + } +} + +/// State for program-based [`ValidatedTransaction`] instances. +pub(super) struct ValidatedTransactionState<'alloc, 'txn, P> +where + 'alloc: 'txn, + P: ModuleCache<'alloc>, +{ + // <'txn, 'txn> looks weird, but it just means that the module cache passed in (the + // TransactionModuleCache) allocates for that long. + pub(super) txn_executor: + TransactionExecutor<'txn, 'txn, TransactionModuleCache<'alloc, 'txn, P>>, +} + +impl<'alloc, 'txn, P> ValidatedTransactionState<'alloc, 'txn, P> +where + 'alloc: 'txn, + P: ModuleCache<'alloc>, +{ + fn new( + metadata: TransactionMetadata, + module_cache: P, + data_cache: &'txn RemoteCache, + allocator: &'txn Arena, + ) -> Self { + // This temporary cache is used for modules published by a single transaction. + let txn_module_cache = TransactionModuleCache::new(module_cache, allocator); + let txn_executor = TransactionExecutor::new(txn_module_cache, data_cache, metadata); + Self { txn_executor } + } +} diff --git a/language/vm/vm_runtime/src/process_txn/verify.rs b/language/vm/vm_runtime/src/process_txn/verify.rs new file mode 100644 index 0000000000000..63f298e077c96 --- /dev/null +++ b/language/vm/vm_runtime/src/process_txn/verify.rs @@ -0,0 +1,213 @@ +use crate::{ + code_cache::{ + module_cache::{ModuleCache, TransactionModuleCache}, + script_cache::ScriptCache, + }, + process_txn::{execute::ExecutedTransaction, validate::ValidatedTransaction}, + txn_executor::TransactionExecutor, +}; +use bytecode_verifier::{verify_module, verify_script}; +use logger::prelude::*; +use types::{ + account_address::AccountAddress, + transaction::{Program, SignedTransaction, TransactionArgument, TransactionPayload}, + vm_error::{VMStatus, VMVerificationError, VMVerificationStatus}, +}; +use vm::{ + access::BaseAccess, + errors::{VMStaticViolation, VerificationError, VerificationStatus}, + file_format::{CompiledModule, CompiledScript, SignatureToken}, + IndexKind, +}; + +/// Represents a transaction which has been validated and for which the program has been run +/// through the bytecode verifier. +pub struct VerifiedTransaction<'alloc, 'txn, P> +where + 'alloc: 'txn, + P: ModuleCache<'alloc>, +{ + txn: SignedTransaction, + #[allow(dead_code)] + txn_state: Option>, +} + +impl<'alloc, 'txn, P> VerifiedTransaction<'alloc, 'txn, P> +where + 'alloc: 'txn, + P: ModuleCache<'alloc>, +{ + /// Creates a new instance by verifying the bytecode in this validated transaction. + pub(super) fn new( + mut validated_txn: ValidatedTransaction<'alloc, 'txn, P>, + ) -> Result { + let txn_state = validated_txn.take_state(); + let txn = validated_txn.as_inner(); + let txn_state = match txn.payload() { + TransactionPayload::Program(program) => { + let txn_state = txn_state + .expect("program-based transactions should always have associated state"); + + let (script, modules) = Self::verify_program(&txn.sender(), program)?; + + Some(VerifiedTransactionState { + txn_executor: txn_state.txn_executor, + script, + modules, + }) + } + TransactionPayload::WriteSet(_write_set) => { + // All the checks are performed in validation, so there's no need for more checks + // here. + None + } + }; + + Ok(Self { + txn: validated_txn.into_inner(), + txn_state, + }) + } + + fn verify_program( + sender_address: &AccountAddress, + program: &Program, + ) -> Result<(CompiledScript, Vec), VMStatus> { + // Ensure modules and scripts deserialize correctly. + let script = match CompiledScript::deserialize(&program.code()) { + Ok(script) => script, + Err(ref err) => { + error!("[VM] script deserialization failed"); + return Err(err.into()); + } + }; + if !verify_actuals(&script, program.args()) { + error!("[VM] actual type mismatch"); + return Err(VMStatus::Verification(vec![VMVerificationStatus::Script( + VMVerificationError::TypeMismatch("Actual Type Mismatch".to_string()), + )])); + } + + // Make sure all the modules trying to be published in this module are valid. + let modules: Vec = match program + .modules() + .iter() + .map(|module_blob| CompiledModule::deserialize(&module_blob)) + .collect() + { + Ok(modules) => modules, + Err(ref err) => { + error!("[VM] module deserialization failed"); + return Err(err.into()); + } + }; + + // Run the script and module through the bytecode verifier. + let (script, modules, statuses) = static_verify_program(sender_address, script, modules); + if !statuses.is_empty() { + error!("[VM] bytecode verifier returned errors"); + return Err(statuses.iter().collect()); + } + + Ok((script, modules)) + } + + /// Executes this transaction. + pub fn execute(self, script_cache: &'txn ScriptCache<'alloc>) -> ExecutedTransaction { + ExecutedTransaction::new(self, script_cache) + } + + /// Returns the state stored in the transaction, if any. + pub(super) fn take_state(&mut self) -> Option> { + self.txn_state.take() + } + + /// Returns a reference to the `SignedTransaction` within. + #[allow(dead_code)] + pub fn as_inner(&self) -> &SignedTransaction { + &self.txn + } + + /// Consumes `self` and returns the `SignedTransaction` within. + pub fn into_inner(self) -> SignedTransaction { + self.txn + } +} + +/// State for program-based [`VerifiedTransaction`] instances. +#[allow(dead_code)] +pub(super) struct VerifiedTransactionState<'alloc, 'txn, P> +where + 'alloc: 'txn, + P: ModuleCache<'alloc>, +{ + pub(super) txn_executor: + TransactionExecutor<'txn, 'txn, TransactionModuleCache<'alloc, 'txn, P>>, + pub(super) script: CompiledScript, + pub(super) modules: Vec, +} + +/// Run static checks on a program directly. Provided as an alternative API for tests. +pub fn static_verify_program( + sender_address: &AccountAddress, + script: CompiledScript, + modules: Vec, +) -> (CompiledScript, Vec, Vec) { + let mut statuses: Vec>> = vec![]; + let (script, errors) = verify_script(script); + statuses.push(Box::new(errors.into_iter().map(VerificationStatus::Script))); + + let mut modules_out = vec![]; + for (module_idx, module) in modules.into_iter().enumerate() { + let (module, errors) = verify_module(module); + + // Make sure the module's self address matches the transaction sender. The self address is + // where the module will actually be published. If we did not check this, the sender could + // publish a module under anyone's account. + // + // For scripts this isn't a problem because they don't get published to accounts. + let address_mismatch = if module.address() != sender_address { + Some(VerificationError { + kind: IndexKind::AddressPool, + idx: CompiledModule::IMPLEMENTED_MODULE_INDEX as usize, + err: VMStaticViolation::ModuleAddressDoesNotMatchSender, + }) + } else { + None + }; + + statuses.push(Box::new( + errors + .into_iter() + .chain(address_mismatch) + .map(move |err| VerificationStatus::Module(module_idx as u16, err)), + )); + + modules_out.push(module); + } + + // TODO: Cross-module verification. This will need some way of exposing module + // dependencies to the bytecode verifier. + + let statuses = statuses.into_iter().flatten().collect(); + (script, modules_out, statuses) +} + +/// Verify if the transaction arguments match the type signature of the main function. +fn verify_actuals(script: &CompiledScript, args: &[TransactionArgument]) -> bool { + let fh = script.function_handle_at(script.main.function); + let sig = script.function_signature_at(fh.signature); + if sig.arg_types.len() != args.len() { + return false; + } + for (ty, arg) in sig.arg_types.iter().zip(args.iter()) { + match (ty, arg) { + (SignatureToken::U64, TransactionArgument::U64(_)) => (), + (SignatureToken::Address, TransactionArgument::Address(_)) => (), + (SignatureToken::ByteArray, TransactionArgument::ByteArray(_)) => (), + (SignatureToken::String, TransactionArgument::String(_)) => (), + _ => return false, + } + } + true +} diff --git a/language/vm/vm_runtime/src/proptest_types.rs b/language/vm/vm_runtime/src/proptest_types.rs new file mode 100644 index 0000000000000..bfb4b8d35f036 --- /dev/null +++ b/language/vm/vm_runtime/src/proptest_types.rs @@ -0,0 +1,107 @@ +// Copyright (c) The Libra Core Contributors +// SPDX-License-Identifier: Apache-2.0 + +use crate::{ + loaded_data::{struct_def::StructDef, types::Type}, + value::{MutVal, Value}, +}; +use proptest::{collection::vec, prelude::*}; +use types::{account_address::AccountAddress, byte_array::ByteArray}; + +/// Strategies for property-based tests using `Value` instances. +impl Value { + /// Returns a [`Strategy`] that generates random primitive (non-struct) `Value` instances. + pub fn single_value_strategy() -> impl Strategy { + prop_oneof![ + any::().prop_map(Value::Address), + any::().prop_map(Value::U64), + any::().prop_map(Value::Bool), + ".*".prop_map(Value::String), + any::().prop_map(Value::ByteArray), + ] + } + + /// Returns a [`Strategy`] that generates arbitrary values, including `Struct`s. + /// + /// Arguments are used for recursion and define + /// - depth of the nested `Struct` + /// - approximate max number of `Value`s in the whole tree + /// - number of max `Value`s at each level + pub fn nested_strategy( + depth: u32, + desired_size: u32, + expected_branch_size: u32, + ) -> impl Strategy { + let leaf = Self::single_value_strategy(); + leaf.prop_recursive(depth, desired_size, expected_branch_size, |inner| { + Self::struct_strategy_impl(inner) + }) + } + + /// Returns a [`Strategy`] that generates random `Struct` instances. + pub fn struct_strategy() -> impl Strategy { + Self::struct_strategy_impl(Self::nested_strategy(5, 100, 10)) + } + + fn struct_strategy_impl(base: impl Strategy) -> impl Strategy { + vec(base, 0..10).prop_map(|values| { + let mut mut_vals: Vec = Vec::new(); + for value in values { + mut_vals.push(MutVal::new(value)); + } + Value::Struct(mut_vals) + }) + } +} + +impl Arbitrary for Value { + type Parameters = (); + type Strategy = BoxedStrategy; + + fn arbitrary_with(_args: ()) -> Self::Strategy { + Self::nested_strategy(3, 50, 10).boxed() + } +} + +/// Strategies for Type +impl Type { + /// Generate a random primitive Type, no Struct + pub fn single_value_strategy() -> impl Strategy { + use Type::*; + + prop_oneof![ + Just(Bool), + Just(U64), + Just(String), + Just(ByteArray), + Just(Address), + ] + } + + /// Generate either a primitive Value or a Struct. + pub fn nested_strategy( + depth: u32, + desired_size: u32, + expected_branch_size: u32, + ) -> impl Strategy { + use Type::*; + + let leaf = Self::single_value_strategy(); + leaf.prop_recursive(depth, desired_size, expected_branch_size, |inner| { + prop_oneof![ + inner.clone().prop_map(|t| Reference(Box::new(t))), + inner.clone().prop_map(|t| MutableReference(Box::new(t))), + vec(inner, 0..10).prop_map(|defs| Struct(StructDef::new(defs))), + ] + }) + } +} + +impl Arbitrary for Type { + type Parameters = (); + type Strategy = BoxedStrategy; + + fn arbitrary_with(_args: ()) -> Self::Strategy { + Self::nested_strategy(3, 20, 10).boxed() + } +} diff --git a/language/vm/vm_runtime/src/runtime.rs b/language/vm/vm_runtime/src/runtime.rs new file mode 100644 index 0000000000000..6b1af8ebd39f7 --- /dev/null +++ b/language/vm/vm_runtime/src/runtime.rs @@ -0,0 +1,120 @@ +// Copyright (c) The Libra Core Contributors +// SPDX-License-Identifier: Apache-2.0 + +use crate::{ + block_processor::execute_block, + code_cache::{ + module_adapter::ModuleFetcherImpl, + module_cache::{BlockModuleCache, VMModuleCache}, + script_cache::ScriptCache, + }, + counters, + data_cache::BlockDataCache, + loaded_data::loaded_module::LoadedModule, + process_txn::{validate::ValidationMode, ProcessTransaction}, +}; +use config::config::{VMConfig, VMPublishingOption}; +use logger::prelude::*; +use state_view::StateView; +use types::{ + transaction::{SignedTransaction, TransactionOutput}, + vm_error::VMStatus, +}; +use vm_cache_map::Arena; + +/// An instantiation of the MoveVM. +/// `code_cache` is the top level module cache that holds loaded published modules. +/// `script_cache` is the cache that stores all the scripts that have previously been invoked. +/// `publishing_option` is the publishing option that is set. This can be one of either: +/// * Locked, with a whitelist of scripts that the VM is allowed to execute. For scripts that aren't +/// in the whitelist, the VM will just reject it in `verify_transaction`. +/// * Custom scripts, which will allow arbitrary valid scripts, but no module publishing +/// * Open script and module publishing +pub struct VMRuntime<'alloc> { + code_cache: VMModuleCache<'alloc>, + script_cache: ScriptCache<'alloc>, + publishing_option: VMPublishingOption, +} + +impl<'alloc> VMRuntime<'alloc> { + /// Create a new VM instance with an Arena allocator to store the modules and a `config` that + /// contains the whitelist that this VM is allowed to execute. + pub fn new(allocator: &'alloc Arena, config: &VMConfig) -> Self { + VMRuntime { + code_cache: VMModuleCache::new(allocator), + script_cache: ScriptCache::new(allocator), + publishing_option: config.publishing_options.clone(), + } + } + + /// Determine if a transaction is valid. Will return `None` if the transaction is accepted, + /// `Some(Err)` if the VM rejects it, with `Err` as an error code. We verify the following + /// items: + /// 1. The signature on the `SignedTransaction` matches the public key included in the + /// transaction + /// 2. The script to be executed is in the whitelist. + /// 3. Invokes `LibraAccount.prologue`, which checks properties such as the transaction has the + /// right sequence number and the sender has enough balance to pay for the gas. 4. + /// Transaction arguments matches the main function's type signature. 5. Script and modules + /// in the transaction pass the bytecode static verifier. + /// + /// Note: In the future. we may defer these checks to a later pass, as all the scripts we will + /// execute are pre-verified scripts. And bytecode verification is expensive. Thus whether we + /// want to perform this check here remains unknown. + pub fn verify_transaction( + &self, + txn: SignedTransaction, + data_view: &dyn StateView, + ) -> Option { + trace!("[VM] Verify transaction: {:?}", txn); + // Treat a transaction as a single block. + let module_cache = + BlockModuleCache::new(&self.code_cache, ModuleFetcherImpl::new(data_view)); + let data_cache = BlockDataCache::new(data_view); + + let arena = Arena::new(); + let process_txn = ProcessTransaction::new(txn, module_cache, &data_cache, &arena); + let mode = if data_view.is_genesis() { + ValidationMode::Genesis + } else { + ValidationMode::Validating + }; + + let validated_txn = match process_txn.validate(mode, &self.publishing_option) { + Ok(validated_txn) => validated_txn, + Err(vm_status) => { + counters::UNVERIFIED_TRANSACTION.inc(); + return Some(vm_status); + } + }; + match validated_txn.verify() { + Ok(_) => { + counters::VERIFIED_TRANSACTION.inc(); + None + } + Err(vm_status) => { + counters::UNVERIFIED_TRANSACTION.inc(); + Some(vm_status) + } + } + } + + /// Execute a block of transactions. The output vector will have the exact same length as the + /// input vector. The discarded transactions will be marked as `TransactionStatus::Discard` and + /// have an empty writeset. Also the data view is immutable, and also does not have interior + /// mutability. writes to be applied to the data view are encoded in the write set part of a + /// transaction output. + pub fn execute_block_transactions( + &self, + txn_block: Vec, + data_view: &dyn StateView, + ) -> Vec { + execute_block( + txn_block, + &self.code_cache, + &self.script_cache, + data_view, + &self.publishing_option, + ) + } +} diff --git a/language/vm/vm_runtime/src/txn_executor.rs b/language/vm/vm_runtime/src/txn_executor.rs new file mode 100644 index 0000000000000..30342d8c2d9c1 --- /dev/null +++ b/language/vm/vm_runtime/src/txn_executor.rs @@ -0,0 +1,835 @@ +// Copyright (c) The Libra Core Contributors +// SPDX-License-Identifier: Apache-2.0 +//! Processor for a single transaction. + +use crate::{ + code_cache::module_cache::{create_fake_module, ModuleCache, VMModuleCache}, + data_cache::{RemoteCache, TransactionDataCache}, + execution_stack::ExecutionStack, + gas_meter::GasMeter, + identifier::{create_access_path, resource_storage_key}, + loaded_data::{ + function::{FunctionRef, FunctionReference}, + loaded_module::LoadedModule, + }, + value::{Local, MutVal, Reference, Value}, +}; +use move_ir_natives::dispatch::{dispatch_native_call, NativeReturnType}; +use types::{ + access_path::AccessPath, + account_address::AccountAddress, + account_config, + byte_array::ByteArray, + contract_event::ContractEvent, + language_storage::CodeKey, + transaction::{TransactionArgument, TransactionOutput, TransactionStatus}, + vm_error::{ExecutionStatus, VMStatus}, + write_set::WriteSet, +}; +use vm::{ + access::ModuleAccess, + errors::*, + file_format::{Bytecode, CodeOffset, CompiledModule, CompiledScript, StructDefinitionIndex}, + transaction_metadata::TransactionMetadata, +}; +use vm_cache_map::Arena; + +#[cfg(test)] +#[path = "unit_tests/runtime_tests.rs"] +mod runtime_tests; + +// Metadata needed for resolving the account module. +lazy_static! { + /// The CodeKey for where Account module is being stored. + pub static ref ACCOUNT_MODULE: CodeKey = + { CodeKey::new(account_config::core_code_address(), "LibraAccount".to_string()) }; + /// The CodeKey for where LibraCoin module is being stored. + pub static ref COIN_MODULE: CodeKey = + { CodeKey::new(account_config::core_code_address(), "LibraCoin".to_string()) }; +} + +const PROLOGUE_NAME: &str = "prologue"; +const EPILOGUE_NAME: &str = "epilogue"; +const CREATE_ACCOUNT_NAME: &str = "make"; +const ACCOUNT_STRUCT_NAME: &str = "T"; + +fn make_access_path( + module: &CompiledModule, + idx: StructDefinitionIndex, + address: AccountAddress, +) -> AccessPath { + let struct_tag = resource_storage_key(module, idx); + create_access_path(&address, struct_tag) +} + +/// A struct that executes one single transaction. +/// 'alloc is the lifetime for the code cache, which is the argument type P here. Hence the P should +/// live as long as alloc. +/// 'txn is the lifetime of one single transaction. +/// `execution_stack` contains the call stack and value stack of current execution. +/// `txn_data` contains the information of this transaction, such as sender, sequence number, etc. +/// `event_data` is the vector that stores all events emitted during execution. +/// `data_view` is the scratchpad for the local writes emitted by this transaction. +pub struct TransactionExecutor<'alloc, 'txn, P> +where + 'alloc: 'txn, + P: ModuleCache<'alloc>, +{ + #[cfg(feature = "instruction_synthesis")] + pub execution_stack: ExecutionStack<'alloc, 'txn, P>, + + #[cfg(not(feature = "instruction_synthesis"))] + execution_stack: ExecutionStack<'alloc, 'txn, P>, + gas_meter: GasMeter, + txn_data: TransactionMetadata, + event_data: Vec, + data_view: TransactionDataCache<'txn>, +} + +impl<'alloc, 'txn, P> TransactionExecutor<'alloc, 'txn, P> +where + 'alloc: 'txn, + P: ModuleCache<'alloc>, +{ + /// Create a new `TransactionExecutor` to execute a single transaction. `module_cache` is the + /// cache that stores the modules previously read from the blockchain. `data_cache` is the cache + /// that holds read-only connection to the state store as well as the changes made by previous + /// transactions within the same block. + pub fn new( + module_cache: P, + data_cache: &'txn RemoteCache, + txn_data: TransactionMetadata, + ) -> Self { + TransactionExecutor { + execution_stack: ExecutionStack::new(module_cache), + gas_meter: GasMeter::new(txn_data.max_gas_amount()), + txn_data, + event_data: Vec::new(), + data_view: TransactionDataCache::new(data_cache), + } + } + + /// Returns the module cache for this executor. + pub fn module_cache(&self) -> &P { + &self.execution_stack.module_cache + } + + /// Perform a binary operation to two values at the top of the stack. + fn binop(&mut self, f: F) -> VMResult<()> + where + Option: From, + F: FnOnce(T, T) -> Option, + { + let rhs = try_runtime!(self.execution_stack.pop_as::()); + let lhs = try_runtime!(self.execution_stack.pop_as::()); + let result = f(lhs, rhs); + if let Some(v) = result { + self.execution_stack.push(v); + Ok(Ok(())) + } else { + Ok(Err(VMRuntimeError { + loc: self.execution_stack.location()?, + err: VMErrorKind::ArithmeticError, + })) + } + } + + fn binop_int(&mut self, f: F) -> VMResult<()> + where + Option: From, + F: FnOnce(T, T) -> Option, + { + self.binop(|lhs, rhs| f(lhs, rhs).map(Local::u64)) + } + + fn binop_bool(&mut self, f: F) -> VMResult<()> + where + Option: From, + F: FnOnce(T, T) -> bool, + { + self.binop(|lhs, rhs| Some(Local::bool(f(lhs, rhs)))) + } + + /// This function will execute the code sequence starting from the beginning_offset, and return + /// Ok(Ok(offset)) when the instruction sequence hit a branch, either by calling into a new + /// function, branches, function return, etc. The return value will be the pc for the next + /// instruction to be executed. + #[allow(clippy::cognitive_complexity)] + pub fn execute_block( + &mut self, + code: &[Bytecode], + beginning_offset: CodeOffset, + ) -> VMResult { + let mut pc = beginning_offset; + for instruction in &code[beginning_offset as usize..] { + // FIXME: Once we add in memory ops, we will need to pass in the current memory size to + // this function. + try_runtime!(self.gas_meter.calculate_and_consume( + &instruction, + &self.execution_stack, + 1 + )); + + match instruction.clone() { + Bytecode::Pop => { + self.execution_stack.pop()?; + } + Bytecode::Ret => { + try_runtime!(self.execution_stack.pop_call()); + if self.execution_stack.is_call_stack_empty() { + return Ok(Ok(0)); + } else { + return Ok(Ok(self.execution_stack.top_frame()?.get_pc() + 1)); + } + } + Bytecode::BrTrue(offset) => { + if try_runtime!(self.execution_stack.pop_as::()) { + return Ok(Ok(offset)); + } + } + Bytecode::BrFalse(offset) => { + let stack_top = try_runtime!(self.execution_stack.pop_as::()); + if !stack_top { + return Ok(Ok(offset)); + } + } + Bytecode::Branch(offset) => return Ok(Ok(offset)), + Bytecode::LdConst(int_const) => { + self.execution_stack.push(Local::u64(int_const)); + } + Bytecode::LdAddr(idx) => { + let top_frame = self.execution_stack.top_frame()?; + let addr_ref = top_frame.module().address_at(idx); + self.execution_stack.push(Local::address(*addr_ref)); + } + Bytecode::LdStr(idx) => { + let top_frame = self.execution_stack.top_frame()?; + let string_ref = top_frame.module().string_at(idx); + self.execution_stack + .push(Local::string(string_ref.to_string())); + } + Bytecode::LdByteArray(idx) => { + let top_frame = self.execution_stack.top_frame()?; + let byte_array = top_frame.module().byte_array_at(idx); + self.execution_stack.push(Local::bytearray(byte_array)); + } + Bytecode::LdTrue => { + self.execution_stack.push(Local::bool(true)); + } + Bytecode::LdFalse => { + self.execution_stack.push(Local::bool(false)); + } + Bytecode::CopyLoc(idx) => { + let local = self.execution_stack.top_frame()?.get_local(idx)?.clone(); + self.execution_stack.push(local); + } + Bytecode::MoveLoc(idx) => { + let local = self + .execution_stack + .top_frame_mut()? + .invalidate_local(idx)?; + self.execution_stack.push(local); + } + Bytecode::StLoc(idx) => { + let stack_top = self.execution_stack.pop()?; + try_runtime!(self + .execution_stack + .top_frame_mut()? + .store_local(idx, stack_top)); + } + Bytecode::Call(idx) => { + let self_module = &self.execution_stack.top_frame()?.module(); + let callee_function_ref = self + .execution_stack + .module_cache + .resolve_function_ref(self_module, idx)? + .ok_or(VMInvariantViolation::LinkerError)?; + + if callee_function_ref.is_native() { + let module_name: &str = callee_function_ref.module().module.name(); + let function_name: &str = callee_function_ref.name(); + let native_return = dispatch_native_call( + &mut self.execution_stack, + module_name, + function_name, + ) + .map_err(|_| VMInvariantViolation::LinkerError)?; + try_runtime!(self + .gas_meter + .consume_gas(native_return.cost(), &self.execution_stack)); + match native_return.get_return_value() { + NativeReturnType::ByteArray(value) => { + self.execution_stack.push(Local::bytearray(value)); + // Call stack is not reconstructed for a native call, so we just + // proceed on to next instruction. + } + NativeReturnType::Bool(value) => { + self.execution_stack.push(Local::bool(value)); + // Call stack is not reconstructed for a native call, so we just + // proceed on to next instruction. + } + } + } else { + self.execution_stack.top_frame_mut()?.jump(pc); + try_runtime!(self.execution_stack.push_call(callee_function_ref)); + // Call stack is reconstructed, the next instruction to execute will be the + // first instruction of the callee function. Thus we should break here to + // restart the instruction sequence from there. + return Ok(Ok(0)); + } + } + Bytecode::BorrowLoc(idx) => { + match self + .execution_stack + .top_frame()? + .get_local(idx)? + .borrow_local() + { + Some(v) => { + self.execution_stack.push(v); + } + None => { + return Ok(Err(VMRuntimeError { + loc: self.execution_stack.location()?, + err: VMErrorKind::TypeError, + })) + } + } + } + Bytecode::BorrowField(fd_idx) => { + let field_offset = self + .execution_stack + .top_frame()? + .module() + .get_field_offset(fd_idx)?; + match self + .execution_stack + .pop()? + .borrow_field(u32::from(field_offset)) + { + Some(v) => { + self.execution_stack.push(v); + } + None => { + return Ok(Err(VMRuntimeError { + loc: self.execution_stack.location()?, + err: VMErrorKind::TypeError, + })) + } + } + } + Bytecode::Pack(sd_idx) => { + let self_module = self.execution_stack.top_frame()?.module(); + let struct_def = self_module.module.struct_def_at(sd_idx); + let args = self + .execution_stack + .popn(struct_def.field_count)? + .into_iter() + .map(Local::value) + .collect(); + match args { + Some(args) => { + self.execution_stack.push(Local::struct_(args)); + } + None => { + return Ok(Err(VMRuntimeError { + loc: self.execution_stack.location()?, + err: VMErrorKind::TypeError, + })) + } + } + } + Bytecode::Unpack(_sd_idx) => { + let struct_arg = self.execution_stack.pop()?; + match struct_arg.value() { + Some(v) => match &*v.peek() { + Value::Struct(fields) => { + for value in fields { + self.execution_stack.push(Local::Value(value.clone())) + } + } + _ => { + return Ok(Err(VMRuntimeError { + loc: self.execution_stack.location()?, + err: VMErrorKind::TypeError, + })) + } + }, + None => { + return Ok(Err(VMRuntimeError { + loc: self.execution_stack.location()?, + err: VMErrorKind::TypeError, + })) + } + } + } + Bytecode::ReadRef => match self.execution_stack.pop()?.read_reference() { + Some(v) => { + self.execution_stack.push(v); + } + None => { + return Ok(Err(VMRuntimeError { + loc: self.execution_stack.location()?, + err: VMErrorKind::TypeError, + })) + } + }, + Bytecode::WriteRef => { + let mutate_ref = self.execution_stack.pop()?; + let mutate_val = self.execution_stack.pop()?; + match mutate_val.value() { + Some(v) => { + mutate_ref.mutate_reference(v); + } + None => { + return Ok(Err(VMRuntimeError { + loc: self.execution_stack.location()?, + err: VMErrorKind::TypeError, + })) + } + } + } + Bytecode::ReleaseRef => { + let reference = self.execution_stack.pop()?; + match reference.release_reference() { + Ok(_) => (), + Err(e) => return Ok(Err(e)), + } + } + // Arithmetic Operations + Bytecode::Add => try_runtime!(self.binop_int(u64::checked_add)), + Bytecode::Sub => try_runtime!(self.binop_int(u64::checked_sub)), + Bytecode::Mul => try_runtime!(self.binop_int(u64::checked_mul)), + Bytecode::Mod => try_runtime!(self.binop_int(u64::checked_rem)), + Bytecode::Div => try_runtime!(self.binop_int(u64::checked_div)), + Bytecode::BitOr => try_runtime!(self.binop_int(|l: u64, r| Some(l | r))), + Bytecode::BitAnd => try_runtime!(self.binop_int(|l: u64, r| Some(l & r))), + Bytecode::Xor => try_runtime!(self.binop_int(|l: u64, r| Some(l ^ r))), + Bytecode::Or => try_runtime!(self.binop_bool(|l, r| l || r)), + Bytecode::And => try_runtime!(self.binop_bool(|l, r| l && r)), + Bytecode::Lt => try_runtime!(self.binop_bool(|l: u64, r| l < r)), + Bytecode::Gt => try_runtime!(self.binop_bool(|l: u64, r| l > r)), + Bytecode::Le => try_runtime!(self.binop_bool(|l: u64, r| l <= r)), + Bytecode::Ge => try_runtime!(self.binop_bool(|l: u64, r| l >= r)), + Bytecode::Assert => { + let condition = try_runtime!(self.execution_stack.pop_as::()); + let error_code = try_runtime!(self.execution_stack.pop_as::()); + if !condition { + return Ok(Err(VMRuntimeError { + loc: self.execution_stack.location()?, + err: VMErrorKind::AssertionFailure(error_code), + })); + } + } + + // TODO: Should we emit different eq for different primitive type values? + // How should equality between references be defined? Should we just panic + // on reference values? + Bytecode::Eq => { + let lhs = self.execution_stack.pop()?; + let rhs = self.execution_stack.pop()?; + self.execution_stack.push(Local::bool(lhs == rhs)); + } + Bytecode::Neq => { + let lhs = self.execution_stack.pop()?; + let rhs = self.execution_stack.pop()?; + self.execution_stack.push(Local::bool(lhs != rhs)); + } + Bytecode::GetTxnGasUnitPrice => { + self.execution_stack + .push(Local::u64(self.txn_data.gas_unit_price())); + } + Bytecode::GetTxnMaxGasUnits => { + self.execution_stack + .push(Local::u64(self.txn_data.max_gas_amount())); + } + Bytecode::GetTxnSequenceNumber => { + self.execution_stack + .push(Local::u64(self.txn_data.sequence_number())); + } + Bytecode::GetTxnSenderAddress => { + self.execution_stack + .push(Local::address(self.txn_data.sender())); + } + Bytecode::GetTxnPublicKey => { + self.execution_stack.push(Local::bytearray(ByteArray::new( + self.txn_data.public_key().to_slice().to_vec(), + ))); + } + Bytecode::BorrowGlobal(idx) => { + let address = try_runtime!(self.execution_stack.pop_as::()); + let curr_module = &self.execution_stack.top_frame()?.module(); + let ap = make_access_path(&curr_module.module, idx, address); + if let Some(struct_def) = try_runtime!(self + .execution_stack + .module_cache + .resolve_struct_def(curr_module, idx, &self.gas_meter)) + { + let global_ref = + try_runtime!(self.data_view.borrow_global(&ap, struct_def)); + try_runtime!(self.gas_meter.calculate_and_consume( + &instruction, + &self.execution_stack, + global_ref.size() + )); + self.execution_stack.push(Local::GlobalRef(global_ref)); + } else { + return Err(VMInvariantViolation::LinkerError); + } + } + Bytecode::Exists(idx) => { + let address = try_runtime!(self.execution_stack.pop_as::()); + let curr_module = &self.execution_stack.top_frame()?.module(); + let ap = make_access_path(&curr_module.module, idx, address); + if let Some(struct_def) = try_runtime!(self + .execution_stack + .module_cache + .resolve_struct_def(curr_module, idx, &self.gas_meter)) + { + let (exists, mem_size) = self.data_view.resource_exists(&ap, struct_def)?; + try_runtime!(self.gas_meter.calculate_and_consume( + &instruction, + &self.execution_stack, + mem_size + )); + self.execution_stack.push(Local::bool(exists)); + } else { + return Err(VMInvariantViolation::LinkerError); + } + } + Bytecode::MoveFrom(idx) => { + let address = try_runtime!(self.execution_stack.pop_as::()); + let curr_module = &self.execution_stack.top_frame()?.module(); + let ap = make_access_path(&curr_module.module, idx, address); + if let Some(struct_def) = try_runtime!(self + .execution_stack + .module_cache + .resolve_struct_def(curr_module, idx, &self.gas_meter)) + { + let resource = + try_runtime!(self.data_view.move_resource_from(&ap, struct_def)); + try_runtime!(self.gas_meter.calculate_and_consume( + &instruction, + &self.execution_stack, + resource.size() + )); + self.execution_stack.push(resource); + } else { + return Err(VMInvariantViolation::LinkerError); + } + } + Bytecode::MoveToSender(idx) => { + let curr_module = &self.execution_stack.top_frame()?.module(); + let ap = make_access_path(&curr_module.module, idx, self.txn_data.sender()); + if let Some(struct_def) = try_runtime!(self + .execution_stack + .module_cache + .resolve_struct_def(curr_module, idx, &self.gas_meter)) + { + let local = self.execution_stack.pop()?; + + if let Some(resource) = local.value() { + try_runtime!(self.gas_meter.calculate_and_consume( + &instruction, + &self.execution_stack, + resource.size() + )); + try_runtime!(self + .data_view + .move_resource_to(&ap, struct_def, resource)); + } else { + return Ok(Err(VMRuntimeError { + loc: Location::new(), + err: VMErrorKind::TypeError, + })); + } + } else { + return Err(VMInvariantViolation::LinkerError); + } + } + Bytecode::CreateAccount => { + let addr = try_runtime!(self.execution_stack.pop_as::()); + try_runtime!(self.create_account(addr)); + } + Bytecode::FreezeRef => { + // FreezeRef should just be a null op as we don't distinguish between mut and + // immut ref at runtime. + } + Bytecode::Not => { + let top = try_runtime!(self.execution_stack.pop_as::()); + self.execution_stack.push(Local::bool(!top)); + } + Bytecode::EmitEvent => { + let data = match self.execution_stack.pop()?.value() { + Some(value) => value, + None => { + return Ok(Err(VMRuntimeError { + loc: self.execution_stack.location()?, + err: VMErrorKind::TypeError, + })) + } + }; + let byte_array = try_runtime!(self.execution_stack.pop_as::()); + + let reference = self.execution_stack.pop()?; + if let Some(event_data) = reference.emit_event_data(byte_array, data) { + self.event_data.push(event_data); + } + } + Bytecode::GetGasRemaining => { + self.execution_stack + .push(Local::u64(self.gas_meter.remaining_gas())); + } + } + pc += 1; + } + + if cfg!(test) || cfg!(feature = "instruction_synthesis") { + // In order to test the behavior of an instruction stream, hitting end of the code + // should report no error so that we can check the locals. + Ok(Ok(code.len() as CodeOffset)) + } else { + Err(VMInvariantViolation::ProgramCounterOverflow) + } + } + + /// Convert the transaction arguments into move values and push them to the top of the stack. + pub(crate) fn setup_main_args(&mut self, args: Vec) { + for arg in args.into_iter() { + self.execution_stack.push(match arg { + TransactionArgument::U64(i) => Local::u64(i), + TransactionArgument::Address(a) => Local::address(a), + TransactionArgument::ByteArray(b) => Local::bytearray(b), + TransactionArgument::String(s) => Local::string(s), + }); + } + } + + /// Create an account on the blockchain by calling into `CREATE_ACCOUNT_NAME` function stored + /// in the `ACCOUNT_MODULE` on chain. + pub fn create_account(&mut self, addr: AccountAddress) -> VMResult<()> { + let account_module = self + .execution_stack + .module_cache + .get_loaded_module(&ACCOUNT_MODULE)? + .ok_or(VMInvariantViolation::LinkerError)?; + + // Address will be used as the initial authentication key. + try_runtime!(self.execute_function( + &ACCOUNT_MODULE, + CREATE_ACCOUNT_NAME, + vec![Local::bytearray(ByteArray::new(addr.to_vec()))], + )); + + let account_resource = self + .execution_stack + .pop()? + .value() + .ok_or(VMInvariantViolation::LinkerError)?; + let account_struct_id = account_module + .struct_defs_table + .get(ACCOUNT_STRUCT_NAME) + .ok_or(VMInvariantViolation::LinkerError)?; + let account_struct_def = try_runtime!(self + .execution_stack + .module_cache + .resolve_struct_def(account_module, *account_struct_id, &self.gas_meter)) + .ok_or(VMInvariantViolation::LinkerError)?; + + // TODO: Adding the freshly created account's expiration date to the TransactionOutput here. + let account_path = make_access_path(&account_module.module, *account_struct_id, addr); + self.data_view + .move_resource_to(&account_path, account_struct_def, account_resource) + } + + /// Run the prologue of a transaction by calling into `PROLOGUE_NAME` function stored + /// in the `ACCOUNT_MODULE` on chain. + pub(crate) fn run_prologue(&mut self) -> VMResult<()> { + self.gas_meter.disable_metering(); + let result = self.execute_function(&ACCOUNT_MODULE, PROLOGUE_NAME, vec![]); + self.gas_meter.enable_metering(); + result + } + + /// Run the epilogue of a transaction by calling into `EPILOGUE_NAME` function stored + /// in the `ACCOUNT_MODULE` on chain. + fn run_epilogue(&mut self) -> VMResult<()> { + self.gas_meter.disable_metering(); + let result = self.execute_function(&ACCOUNT_MODULE, EPILOGUE_NAME, vec![]); + self.gas_meter.enable_metering(); + result + } + + /// Generate the TransactionOutput on failure. There can be two possibilities: + /// 1. The transaction encounters some runtime error, such as out of gas, arithmetic overflow, + /// etc. In this scenario, we are going to keep this transaction and charge proper gas to the + /// sender. 2. The transaction encounters `VMInvariantError`, which indicates some + /// properties should have been guaranteed failed. Such transaction should be discarded for + /// sanity but this implies a bug in the VM that we should take care of. + pub(crate) fn failed_transaction_cleanup(&mut self, result: VMResult<()>) -> TransactionOutput { + // Discard all the local writes, restart execution from a clean state. + self.clear(); + match self.run_epilogue() { + Ok(Ok(_)) => match self.make_write_set(vec![], result) { + Ok(trans_out) => trans_out, + Err(err) => error_output(&err), + }, + // Running epilogue shouldn't fail here as we've already checked for enough balance in + // the prologue + Ok(Err(err)) => error_output(&err), + Err(err) => error_output(&err), + } + } + + /// Clear all the writes local to this transaction. + fn clear(&mut self) { + self.data_view.clear(); + self.event_data.clear(); + } + + /// Generate the TransactionOutput for a successful transaction + pub(crate) fn transaction_cleanup( + &mut self, + to_be_published_modules: Vec<(CodeKey, Vec)>, + ) -> TransactionOutput { + // First run the epilogue + match self.run_epilogue() { + // If epilogue runs successfully, try to emit the writeset. + Ok(Ok(_)) => match self.make_write_set(to_be_published_modules, Ok(Ok(()))) { + // This step could fail if the program has dangling global reference + Ok(trans_out) => trans_out, + // In case of failure, run the cleanup code. + Err(err) => self.failed_transaction_cleanup(Ok(Err(err))), + }, + // If the sender depleted its balance and can't pay for the gas, run the cleanup code. + Ok(Err(err)) => self.failed_transaction_cleanup(Ok(Err(err))), + Err(err) => error_output(&err), + } + } + + /// Execute a function given a FunctionRef. + pub(crate) fn execute_function_impl(&mut self, func: FunctionRef<'txn>) -> VMResult<()> { + // We charge an intrinsic amount of gas based upon the size of the transaction submitted + // (in raw bytes). + try_runtime!(self + .gas_meter + .charge_transaction_gas(self.txn_data.transaction_size, &self.execution_stack)); + let beginning_height = self.execution_stack.call_stack_height(); + try_runtime!(self.execution_stack.push_call(func)); + // We always start execution from the first instruction. + let mut pc = 0; + + // Execute code until the stack goes back to its original height. At that time we will know + // this function has terminated. + while self.execution_stack.call_stack_height() != beginning_height { + let code = self.execution_stack.top_frame()?.code_definition(); + + // Get the pc for the next instruction to be executed. + pc = try_runtime!(self.execute_block(code, pc)); + + if self.execution_stack.call_stack_height() == beginning_height { + return Ok(Ok(())); + } + } + + Ok(Ok(())) + } + + /// Execute a function. + /// `module` is an identifier for the name the module is stored in. `function_name` is the name + /// of the function. If such function is found, the VM will execute this function with arguments + /// `args`. The return value will be placed on the top of the value stack and abort if an error + /// occurs. + pub fn execute_function( + &mut self, + module: &CodeKey, + function_name: &str, + args: Vec, + ) -> VMResult<()> { + let loaded_module = self + .execution_stack + .module_cache + .get_loaded_module(module)? + .ok_or(VMInvariantViolation::LinkerError)?; + let func_idx = loaded_module + .function_defs_table + .get(function_name) + .ok_or(VMInvariantViolation::LinkerError)?; + let func = FunctionRef::new(loaded_module, *func_idx)?; + + for arg in args.into_iter() { + self.execution_stack.push(arg); + } + + self.execute_function_impl(func) + } + + /// Get the value on the top of the value stack. + pub fn pop_stack(&mut self) -> Result { + self.execution_stack.pop() + } + + /// Produce a write set at the end of a transaction. This will clear all the local states in + /// the TransactionProcessor and turn them into a writeset. + pub fn make_write_set( + &mut self, + to_be_published_modules: Vec<(CodeKey, Vec)>, + result: VMResult<()>, + ) -> VMRuntimeResult { + // This should only be used for bookeeping. The gas is already deducted from the sender's + // account in the account module's epilogue. + let gas: u64 = (self.txn_data.max_gas_amount - self.gas_meter.remaining_gas()) + * self.txn_data.gas_unit_price; + let write_set = self.data_view.make_write_set(to_be_published_modules)?; + + Ok(TransactionOutput::new( + write_set, + self.event_data.clone(), + gas, + match result { + Ok(Ok(())) => { + TransactionStatus::from(VMStatus::Execution(ExecutionStatus::Executed)) + } + Ok(Err(ref err)) => TransactionStatus::from(VMStatus::from(err)), + Err(ref err) => TransactionStatus::from(VMStatus::from(err)), + }, + )) + } +} + +#[inline] +fn error_output(err: impl Into) -> TransactionOutput { + // Since this transaction will be discarded, no writeset will be included. + TransactionOutput::new( + WriteSet::default(), + vec![], + 0, + TransactionStatus::Discard(err.into()), + ) +} + +/// A helper function for executing a single script. Will be deprecated once we have a better +/// testing framework for executing arbitrary script. +pub fn execute_function( + caller_script: CompiledScript, + modules: Vec, + _args: Vec, + data_cache: &RemoteCache, +) -> VMResult<()> { + let allocator = Arena::new(); + let module_cache = VMModuleCache::new(&allocator); + let (main_module, entry_idx) = create_fake_module(caller_script); + let loaded_main = LoadedModule::new(main_module)?; + let entry_func = FunctionRef::new(&loaded_main, entry_idx)?; + for m in modules { + module_cache.cache_module(m)?; + } + let mut vm = TransactionExecutor { + execution_stack: ExecutionStack::new(&module_cache), + gas_meter: GasMeter::new(1_000), + txn_data: TransactionMetadata::default(), + event_data: Vec::new(), + data_view: TransactionDataCache::new(data_cache), + }; + vm.execute_function_impl(entry_func) +} diff --git a/language/vm/vm_runtime/src/unit_tests/identifier_prop_tests.rs b/language/vm/vm_runtime/src/unit_tests/identifier_prop_tests.rs new file mode 100644 index 0000000000000..cee97b790adfe --- /dev/null +++ b/language/vm/vm_runtime/src/unit_tests/identifier_prop_tests.rs @@ -0,0 +1,28 @@ +// Copyright (c) The Libra Core Contributors +// SPDX-License-Identifier: Apache-2.0 + +use crate::identifier::resource_storage_key; +use canonical_serialization::{SimpleDeserializer, SimpleSerializer}; +use proptest::prelude::*; +use vm::file_format::{CompiledModule, StructDefinitionIndex, TableIndex}; + +proptest! { + #[test] + fn identifier_serializer_roundtrip(module in CompiledModule::valid_strategy(20)) { + let code_key = module.self_code_key(); + let deserialized_code_key = { + let serialized_key = SimpleSerializer::>::serialize(&code_key).unwrap(); + SimpleDeserializer::deserialize(&serialized_key).expect("Deserialize should work") + }; + prop_assert_eq!(code_key, deserialized_code_key); + + for i in 0..module.struct_defs.len() { + let struct_key = resource_storage_key(&module, StructDefinitionIndex::new(i as TableIndex)); + let deserialized_struct_key = { + let serialized_key = SimpleSerializer::>::serialize(&struct_key).unwrap(); + SimpleDeserializer::deserialize(&serialized_key).expect("Deserialize should work") + }; + prop_assert_eq!(struct_key, deserialized_struct_key); + } + } +} diff --git a/language/vm/vm_runtime/src/unit_tests/module_cache_tests.rs b/language/vm/vm_runtime/src/unit_tests/module_cache_tests.rs new file mode 100644 index 0000000000000..495e7e700624e --- /dev/null +++ b/language/vm/vm_runtime/src/unit_tests/module_cache_tests.rs @@ -0,0 +1,575 @@ +// Copyright (c) The Libra Core Contributors +// SPDX-License-Identifier: Apache-2.0 + +use super::*; +use crate::{ + code_cache::{ + module_adapter::FakeFetcher, + module_cache::{create_fake_module, ModuleCache}, + }, + loaded_data::function::{FunctionRef, FunctionReference}, +}; +use ::compiler::{compiler, parser::parse_program}; +use hex; +use types::account_address::AccountAddress; +use vm::file_format::*; +use vm_cache_map::Arena; + +fn test_module(name: String) -> CompiledModule { + CompiledModule { + module_handles: vec![ModuleHandle { + name: StringPoolIndex::new(0), + address: AddressPoolIndex::new(0), + }], + struct_handles: vec![], + function_handles: vec![ + FunctionHandle { + module: ModuleHandleIndex::new(0), + name: StringPoolIndex::new(1), + signature: FunctionSignatureIndex::new(0), + }, + FunctionHandle { + module: ModuleHandleIndex::new(0), + name: StringPoolIndex::new(2), + signature: FunctionSignatureIndex::new(1), + }, + ], + + struct_defs: vec![], + field_defs: vec![], + function_defs: vec![ + FunctionDefinition { + function: FunctionHandleIndex::new(0), + flags: CodeUnit::PUBLIC, + code: CodeUnit { + max_stack_size: 10, + locals: LocalsSignatureIndex::new(0), + code: vec![Bytecode::Add], + }, + }, + FunctionDefinition { + function: FunctionHandleIndex::new(1), + flags: CodeUnit::PUBLIC, + code: CodeUnit { + max_stack_size: 10, + locals: LocalsSignatureIndex::new(0), + code: vec![Bytecode::Ret], + }, + }, + ], + type_signatures: vec![], + function_signatures: vec![ + FunctionSignature { + return_types: vec![], + arg_types: vec![], + }, + FunctionSignature { + return_types: vec![], + arg_types: vec![SignatureToken::U64], + }, + ], + locals_signatures: vec![LocalsSignature(vec![])], + string_pool: vec![name, "func1".to_string(), "func2".to_string()], + byte_array_pool: vec![], + address_pool: vec![AccountAddress::default()], + } +} + +fn test_script() -> CompiledScript { + CompiledScript { + main: FunctionDefinition { + function: FunctionHandleIndex::new(0), + flags: CodeUnit::PUBLIC, + code: CodeUnit { + max_stack_size: 10, + locals: LocalsSignatureIndex(0), + code: vec![], + }, + }, + module_handles: vec![ + ModuleHandle { + address: AddressPoolIndex::new(0), + name: StringPoolIndex::new(0), + }, + ModuleHandle { + address: AddressPoolIndex::new(0), + name: StringPoolIndex::new(1), + }, + ], + struct_handles: vec![], + function_handles: vec![ + FunctionHandle { + name: StringPoolIndex::new(2), + signature: FunctionSignatureIndex::new(0), + module: ModuleHandleIndex::new(1), + }, + FunctionHandle { + name: StringPoolIndex::new(3), + signature: FunctionSignatureIndex::new(1), + module: ModuleHandleIndex::new(1), + }, + ], + type_signatures: vec![], + function_signatures: vec![ + FunctionSignature { + return_types: vec![], + arg_types: vec![], + }, + FunctionSignature { + return_types: vec![], + arg_types: vec![SignatureToken::U64], + }, + ], + locals_signatures: vec![LocalsSignature(vec![])], + string_pool: vec![ + "hello".to_string(), + "module".to_string(), + "func1".to_string(), + "func2".to_string(), + ], + byte_array_pool: vec![], + address_pool: vec![AccountAddress::default()], + } +} +#[test] +fn test_loader_one_module() { + // This test tests the linking of function within a single module: We have a module that defines + // two functions, each with different name and signature. This test will make sure that we + // link the function handle with the right function definition within the same module. + let module = test_module("module".to_string()); + let mod_id = module.self_code_key(); + + let allocator = Arena::new(); + let loaded_program = VMModuleCache::new(&allocator); + loaded_program.cache_module(module).unwrap(); + let module_ref = loaded_program.get_loaded_module(&mod_id).unwrap().unwrap(); + + // Get the function reference of the first two function handles. + let func1_ref = loaded_program + .resolve_function_ref(module_ref, FunctionHandleIndex::new(0)) + .unwrap() + .unwrap(); + let func2_ref = loaded_program + .resolve_function_ref(module_ref, FunctionHandleIndex::new(1)) + .unwrap() + .unwrap(); + + // The two references should refer to the same module + assert_eq!( + func2_ref.module() as *const LoadedModule, + func1_ref.module() as *const LoadedModule + ); + + assert_eq!(func1_ref.arg_count(), 0); + assert_eq!(func1_ref.return_count(), 0); + assert_eq!(func1_ref.code_definition(), vec![Bytecode::Add].as_slice()); + + assert_eq!(func2_ref.arg_count(), 1); + assert_eq!(func2_ref.return_count(), 0); + assert_eq!(func2_ref.code_definition(), vec![Bytecode::Ret].as_slice()); +} + +#[test] +fn test_loader_cross_modules() { + let script = test_script(); + let module = test_module("module".to_string()); + + let allocator = Arena::new(); + let loaded_program = VMModuleCache::new(&allocator); + loaded_program.cache_module(module).unwrap(); + + let (owned_entry_module, entry_idx) = create_fake_module(script); + let loaded_main = LoadedModule::new(owned_entry_module).unwrap(); + let entry_func = FunctionRef::new(&loaded_main, entry_idx).unwrap(); + let entry_module = entry_func.module(); + let func1 = loaded_program + .resolve_function_ref(entry_module, FunctionHandleIndex::new(0)) + .unwrap() + .unwrap(); + let func2 = loaded_program + .resolve_function_ref(entry_module, FunctionHandleIndex::new(1)) + .unwrap() + .unwrap(); + + assert_eq!( + func2.module() as *const LoadedModule, + func1.module() as *const LoadedModule + ); + + assert_eq!(func1.arg_count(), 0); + assert_eq!(func1.return_count(), 0); + assert_eq!(func1.code_definition(), vec![Bytecode::Add].as_slice()); + + assert_eq!(func2.arg_count(), 1); + assert_eq!(func2.return_count(), 0); + assert_eq!(func2.code_definition(), vec![Bytecode::Ret].as_slice()); +} + +#[test] +fn test_cache_with_storage() { + let allocator = Arena::new(); + + let (owned_entry_module, entry_idx) = create_fake_module(test_script()); + let loaded_main = LoadedModule::new(owned_entry_module).unwrap(); + let entry_func = FunctionRef::new(&loaded_main, entry_idx).unwrap(); + let entry_module = entry_func.module(); + + let vm_cache = VMModuleCache::new(&allocator); + + // Function is not defined locally. + assert!(vm_cache + .resolve_function_ref(entry_module, FunctionHandleIndex::new(0)) + .unwrap() + .is_none()); + + { + let fetcher = FakeFetcher::new(vec![test_module("module".to_string())]); + let mut block_cache = BlockModuleCache::new(&vm_cache, fetcher); + + // Make sure the block cache fetches the code from the view. + let func1 = block_cache + .resolve_function_ref(entry_module, FunctionHandleIndex::new(0)) + .unwrap() + .unwrap(); + let func2 = block_cache + .resolve_function_ref(entry_module, FunctionHandleIndex::new(1)) + .unwrap() + .unwrap(); + + assert_eq!( + func2.module() as *const LoadedModule, + func1.module() as *const LoadedModule + ); + + assert_eq!(func1.arg_count(), 0); + assert_eq!(func1.return_count(), 0); + assert_eq!(func1.code_definition(), vec![Bytecode::Add].as_slice()); + + assert_eq!(func2.arg_count(), 1); + assert_eq!(func2.return_count(), 0); + assert_eq!(func2.code_definition(), vec![Bytecode::Ret].as_slice()); + + // Clean the fetcher so that there's nothing in the fetcher. + block_cache.storage.clear(); + + let func1 = block_cache + .resolve_function_ref(entry_module, FunctionHandleIndex::new(0)) + .unwrap() + .unwrap(); + let func2 = block_cache + .resolve_function_ref(entry_module, FunctionHandleIndex::new(1)) + .unwrap() + .unwrap(); + + assert_eq!( + func2.module() as *const LoadedModule, + func1.module() as *const LoadedModule + ); + + assert_eq!(func1.arg_count(), 0); + assert_eq!(func1.return_count(), 0); + assert_eq!(func1.code_definition(), vec![Bytecode::Add].as_slice()); + + assert_eq!(func2.arg_count(), 1); + assert_eq!(func2.return_count(), 0); + assert_eq!(func2.code_definition(), vec![Bytecode::Ret].as_slice()); + } + + // Even if the block cache goes out of scope, we should still be able to read the fetched + // definition + let func1 = vm_cache + .resolve_function_ref(entry_module, FunctionHandleIndex::new(0)) + .unwrap() + .unwrap(); + let func2 = vm_cache + .resolve_function_ref(entry_module, FunctionHandleIndex::new(1)) + .unwrap() + .unwrap(); + + assert_eq!( + func2.module() as *const LoadedModule, + func1.module() as *const LoadedModule + ); + + assert_eq!(func1.arg_count(), 0); + assert_eq!(func1.return_count(), 0); + assert_eq!(func1.code_definition(), vec![Bytecode::Add].as_slice()); + + assert_eq!(func2.arg_count(), 1); + assert_eq!(func2.return_count(), 0); + assert_eq!(func2.code_definition(), vec![Bytecode::Ret].as_slice()); +} + +#[test] +fn test_multi_level_cache_write_back() { + let allocator = Arena::new(); + let vm_cache = VMModuleCache::new(&allocator); + + // Put an existing module in the cache. + let module = test_module("existing_module".to_string()); + vm_cache.cache_module(module).unwrap(); + + // Create a new script that refers to both published and unpublished modules. + let script = CompiledScript { + main: FunctionDefinition { + function: FunctionHandleIndex::new(0), + flags: CodeUnit::PUBLIC, + code: CodeUnit { + max_stack_size: 10, + locals: LocalsSignatureIndex(0), + code: vec![], + }, + }, + module_handles: vec![ + // Self + ModuleHandle { + address: AddressPoolIndex::new(0), + name: StringPoolIndex::new(0), + }, + // To-be-published Module + ModuleHandle { + address: AddressPoolIndex::new(0), + name: StringPoolIndex::new(1), + }, + // Existing module on chain + ModuleHandle { + address: AddressPoolIndex::new(0), + name: StringPoolIndex::new(2), + }, + ], + struct_handles: vec![], + function_handles: vec![ + // Func2 defined in the new module + FunctionHandle { + name: StringPoolIndex::new(4), + signature: FunctionSignatureIndex::new(0), + module: ModuleHandleIndex::new(1), + }, + // Func1 defined in the old module + FunctionHandle { + name: StringPoolIndex::new(3), + signature: FunctionSignatureIndex::new(1), + module: ModuleHandleIndex::new(2), + }, + ], + type_signatures: vec![], + function_signatures: vec![ + FunctionSignature { + return_types: vec![], + arg_types: vec![], + }, + FunctionSignature { + return_types: vec![], + arg_types: vec![SignatureToken::U64], + }, + ], + locals_signatures: vec![LocalsSignature(vec![])], + string_pool: vec![ + "hello".to_string(), + "module".to_string(), + "existing_module".to_string(), + "func1".to_string(), + "func2".to_string(), + ], + byte_array_pool: vec![], + address_pool: vec![AccountAddress::default()], + }; + + let (owned_entry_module, entry_idx) = create_fake_module(script); + let loaded_main = LoadedModule::new(owned_entry_module).unwrap(); + let entry_func = FunctionRef::new(&loaded_main, entry_idx).unwrap(); + let entry_module = entry_func.module(); + + { + let txn_allocator = Arena::new(); + { + let txn_cache = TransactionModuleCache::new(&vm_cache, &txn_allocator); + + // We should be able to read existing modules in both cache. + let func1_vm_ref = vm_cache + .resolve_function_ref(entry_module, FunctionHandleIndex::new(1)) + .unwrap() + .unwrap(); + let func1_txn_ref = txn_cache + .resolve_function_ref(entry_module, FunctionHandleIndex::new(1)) + .unwrap() + .unwrap(); + assert_eq!(func1_vm_ref, func1_txn_ref); + + txn_cache + .cache_module(test_module("module".to_string())) + .unwrap(); + + // We should not read the new module in the vm cache, but we should read it from the txn + // cache. + assert!(vm_cache + .resolve_function_ref(entry_module, FunctionHandleIndex::new(0)) + .unwrap() + .is_none()); + let func2_txn_ref = txn_cache + .resolve_function_ref(entry_module, FunctionHandleIndex::new(0)) + .unwrap() + .unwrap(); + assert_eq!(func2_txn_ref.arg_count(), 1); + assert_eq!(func2_txn_ref.return_count(), 0); + assert_eq!( + func2_txn_ref.code_definition(), + vec![Bytecode::Ret].as_slice() + ); + } + + // Drop the transactional arena + vm_cache + .reclaim_cached_module(txn_allocator.into_vec()) + .unwrap(); + } + + // After reclaiming we should see it from the + let func2_ref = vm_cache + .resolve_function_ref(entry_module, FunctionHandleIndex::new(0)) + .unwrap() + .unwrap(); + assert_eq!(func2_ref.arg_count(), 1); + assert_eq!(func2_ref.return_count(), 0); + assert_eq!(func2_ref.code_definition(), vec![Bytecode::Ret].as_slice()); +} + +// TODO: What this function does is way beyond its name suggests. +// Fix it and code depending on it. +fn parse_modules(s: String) -> Vec { + let address = AccountAddress::default(); + let parsed_program = parse_program(&s).unwrap(); + + let compiled_program = compiler::compile_program(&address, &parsed_program, &[]).unwrap(); + compiled_program.modules +} + +#[test] +fn test_same_module_struct_resolution() { + let allocator = Arena::new(); + let vm_cache = VMModuleCache::new(&allocator); + + let code = String::from( + " + modules: + module M1 { + struct X {} + struct T { i: u64, x: V#Self.X } + } + script: + main() { + return; + } + ", + ); + + let module = parse_modules(code); + let fetcher = FakeFetcher::new(module); + let block_cache = BlockModuleCache::new(&vm_cache, fetcher); + { + let code_key = CodeKey::new(AccountAddress::default(), "M1".to_string()); + let module_ref = block_cache.get_loaded_module(&code_key).unwrap().unwrap(); + let gas = GasMeter::new(100_000_000); + let struct_x = block_cache + .resolve_struct_def(module_ref, StructDefinitionIndex::new(0), &gas) + .unwrap() + .unwrap() + .unwrap(); + let struct_t = block_cache + .resolve_struct_def(module_ref, StructDefinitionIndex::new(1), &gas) + .unwrap() + .unwrap() + .unwrap(); + assert_eq!(struct_x, StructDef::new(vec![])); + assert_eq!( + struct_t, + StructDef::new(vec![Type::U64, Type::Struct(StructDef::new(vec![]))]), + ); + } +} + +#[test] +fn test_multi_module_struct_resolution() { + let allocator = Arena::new(); + let vm_cache = VMModuleCache::new(&allocator); + + let code = format!( + " + modules: + module M1 {{ + struct X {{}} + }} + module M2 {{ + import 0x{0}.M1; + struct T {{ i: u64, x: V#M1.X }} + }} + script: + main() {{ + return; + }} + ", + hex::encode(AccountAddress::default()) + ); + + let module = parse_modules(code); + let fetcher = FakeFetcher::new(module); + let block_cache = BlockModuleCache::new(&vm_cache, fetcher); + { + let code_key_2 = CodeKey::new(AccountAddress::default(), "M2".to_string()); + let module2_ref = block_cache.get_loaded_module(&code_key_2).unwrap().unwrap(); + + let gas = GasMeter::new(100_000_000); + let struct_t = block_cache + .resolve_struct_def(module2_ref, StructDefinitionIndex::new(0), &gas) + .unwrap() + .unwrap() + .unwrap(); + assert_eq!( + struct_t, + StructDef::new(vec![Type::U64, Type::Struct(StructDef::new(vec![]))]), + ); + } +} + +#[test] +fn test_field_offset_resolution() { + let allocator = Arena::new(); + let vm_cache = VMModuleCache::new(&allocator); + + let code = String::from( + " + modules: + module M1 { + struct X { f: u64, g: bool} + struct T { i: u64, x: V#Self.X, y: u64 } + } + script: + main() { + return; + } + ", + ); + + let module = parse_modules(code); + let fetcher = FakeFetcher::new(module); + let block_cache = BlockModuleCache::new(&vm_cache, fetcher); + { + let code_key = CodeKey::new(AccountAddress::default(), "M1".to_string()); + let module_ref = block_cache.get_loaded_module(&code_key).unwrap().unwrap(); + + let f_idx = module_ref.field_defs_table.get("f").unwrap(); + assert_eq!(module_ref.get_field_offset(*f_idx).unwrap(), 0); + + let g_idx = module_ref.field_defs_table.get("g").unwrap(); + assert_eq!(module_ref.get_field_offset(*g_idx).unwrap(), 1); + + let i_idx = module_ref.field_defs_table.get("i").unwrap(); + assert_eq!(module_ref.get_field_offset(*i_idx).unwrap(), 0); + + let x_idx = module_ref.field_defs_table.get("x").unwrap(); + assert_eq!(module_ref.get_field_offset(*x_idx).unwrap(), 1); + + let y_idx = module_ref.field_defs_table.get("y").unwrap(); + assert_eq!(module_ref.get_field_offset(*y_idx).unwrap(), 2); + } +} diff --git a/language/vm/vm_runtime/src/unit_tests/runtime_tests.rs b/language/vm/vm_runtime/src/unit_tests/runtime_tests.rs new file mode 100644 index 0000000000000..bc04ae44ceb4e --- /dev/null +++ b/language/vm/vm_runtime/src/unit_tests/runtime_tests.rs @@ -0,0 +1,735 @@ +// Copyright (c) The Libra Core Contributors +// SPDX-License-Identifier: Apache-2.0 + +use super::*; +use crate::{ + code_cache::module_cache::VMModuleCache, txn_executor::TransactionExecutor, value::Local, +}; +use std::collections::HashMap; +use types::{access_path::AccessPath, account_address::AccountAddress, byte_array::ByteArray}; +use vm::{ + file_format::{ + AddressPoolIndex, Bytecode, CodeUnit, CompiledModule, CompiledScript, FunctionDefinition, + FunctionHandle, FunctionHandleIndex, FunctionSignature, FunctionSignatureIndex, + LocalsSignature, LocalsSignatureIndex, ModuleHandle, ModuleHandleIndex, SignatureToken, + StringPoolIndex, + }, + transaction_metadata::TransactionMetadata, +}; +use vm_cache_map::Arena; + +// Trait for the data cache to build a TransactionProcessor +struct FakeDataCache { + #[allow(dead_code)] + data: HashMap>, +} + +impl FakeDataCache { + fn new() -> Self { + FakeDataCache { + data: HashMap::new(), + } + } +} + +impl RemoteCache for FakeDataCache { + fn get(&self, _access_path: &AccessPath) -> Result>, VMInvariantViolation> { + Ok(None) + } +} + +fn fake_script() -> CompiledScript { + CompiledScript { + main: FunctionDefinition { + function: FunctionHandleIndex::new(0), + flags: CodeUnit::PUBLIC, + code: CodeUnit { + max_stack_size: 10, + locals: LocalsSignatureIndex(0), + code: vec![], + }, + }, + module_handles: vec![ModuleHandle { + address: AddressPoolIndex::new(0), + name: StringPoolIndex::new(0), + }], + struct_handles: vec![], + function_handles: vec![FunctionHandle { + name: StringPoolIndex::new(0), + signature: FunctionSignatureIndex::new(0), + module: ModuleHandleIndex::new(0), + }], + type_signatures: vec![], + function_signatures: vec![FunctionSignature { + arg_types: vec![], + return_types: vec![], + }], + locals_signatures: vec![LocalsSignature(vec![])], + string_pool: vec!["hello".to_string()], + byte_array_pool: vec![ByteArray::new(vec![0u8; 32])], + address_pool: vec![AccountAddress::default()], + } +} + +fn test_simple_instruction_impl<'alloc, 'txn>( + vm: &mut TransactionExecutor<'alloc, 'txn, VMModuleCache<'alloc>>, + instr: Bytecode, + value_stack_before: Vec, + value_stack_after: Vec, + local_before: Vec, + local_after: Vec, + expected_offset: u16, +) -> VMResult<()> { + let code = vec![instr]; + vm.execution_stack + .top_frame_mut()? + .set_with_states(0, local_before); + vm.execution_stack.set_stack(value_stack_before); + let offset = try_runtime!(vm.execute_block(code.as_slice(), 0)); + assert_eq!(vm.execution_stack.get_value_stack(), &value_stack_after); + let top_frame = vm.execution_stack.top_frame()?; + assert_eq!(top_frame.get_locals(), &local_after); + assert_eq!(offset, expected_offset); + Ok(Ok(())) +} + +fn test_simple_instruction<'alloc, 'txn>( + vm: &mut TransactionExecutor<'alloc, 'txn, VMModuleCache<'alloc>>, + instr: Bytecode, + value_stack_before: Vec, + value_stack_after: Vec, + local_before: Vec, + local_after: Vec, + expected_offset: u16, +) { + test_simple_instruction_impl( + vm, + instr, + value_stack_before, + value_stack_after, + local_before, + local_after, + expected_offset, + ) + .unwrap() + .unwrap(); +} + +fn test_binop_instruction_impl<'alloc, 'txn>( + vm: &mut TransactionExecutor<'alloc, 'txn, VMModuleCache<'alloc>>, + instr: Bytecode, + stack: Vec, + expected_value: Local, +) -> VMResult<()> { + test_simple_instruction_impl(vm, instr, stack, vec![expected_value], vec![], vec![], 1) +} + +fn test_binop_instruction<'alloc, 'txn>( + vm: &mut TransactionExecutor<'alloc, 'txn, VMModuleCache<'alloc>>, + instr: Bytecode, + stack: Vec, + expected_value: Local, +) { + test_binop_instruction_impl(vm, instr, stack, expected_value) + .unwrap() + .unwrap() +} + +fn test_binop_instruction_overflow<'alloc, 'txn>( + vm: &mut TransactionExecutor<'alloc, 'txn, VMModuleCache<'alloc>>, + instr: Bytecode, + stack: Vec, +) { + assert_eq!( + test_binop_instruction_impl(vm, instr, stack, Local::u64(0)) + .unwrap() + .unwrap_err() + .err, + VMErrorKind::ArithmeticError + ); +} + +#[test] +fn test_simple_instruction_transition() { + let allocator = Arena::new(); + let module_cache = VMModuleCache::new(&allocator); + let (main_module, entry_idx) = create_fake_module(fake_script()); + let loaded_main = LoadedModule::new(main_module).unwrap(); + let entry_func = FunctionRef::new(&loaded_main, entry_idx).unwrap(); + let data_cache = FakeDataCache::new(); + let mut vm = + TransactionExecutor::new(module_cache, &data_cache, TransactionMetadata::default()); + vm.execution_stack.push_frame(entry_func); + + test_simple_instruction( + &mut vm, + Bytecode::Pop, + vec![Local::u64(0)], + vec![], + vec![], + vec![], + 1, + ); + + test_simple_instruction( + &mut vm, + Bytecode::BrTrue(100), + vec![Local::bool(true)], + vec![], + vec![], + vec![], + 100, + ); + + test_simple_instruction( + &mut vm, + Bytecode::BrTrue(100), + vec![Local::bool(false)], + vec![], + vec![], + vec![], + 1, + ); + + test_simple_instruction( + &mut vm, + Bytecode::BrFalse(100), + vec![Local::bool(true)], + vec![], + vec![], + vec![], + 1, + ); + + test_simple_instruction( + &mut vm, + Bytecode::BrFalse(100), + vec![Local::bool(false)], + vec![], + vec![], + vec![], + 100, + ); + + test_simple_instruction( + &mut vm, + Bytecode::Branch(100), + vec![], + vec![], + vec![], + vec![], + 100, + ); + + test_simple_instruction( + &mut vm, + Bytecode::LdConst(100), + vec![], + vec![Local::u64(100)], + vec![], + vec![], + 1, + ); + + let addr = AccountAddress::default(); + test_simple_instruction( + &mut vm, + Bytecode::LdAddr(AddressPoolIndex::new(0)), + vec![], + vec![Local::address(addr)], + vec![], + vec![], + 1, + ); + + test_simple_instruction( + &mut vm, + Bytecode::LdStr(StringPoolIndex::new(0)), + vec![], + vec![Local::string("hello".to_string())], + vec![], + vec![], + 1, + ); + + test_simple_instruction( + &mut vm, + Bytecode::LdTrue, + vec![], + vec![Local::bool(true)], + vec![], + vec![], + 1, + ); + + test_simple_instruction( + &mut vm, + Bytecode::LdFalse, + vec![], + vec![Local::bool(false)], + vec![], + vec![], + 1, + ); + + test_simple_instruction( + &mut vm, + Bytecode::CopyLoc(1), + vec![], + vec![Local::u64(10)], + vec![Local::Invalid, Local::u64(10)], + vec![Local::Invalid, Local::u64(10)], + 1, + ); + + test_simple_instruction( + &mut vm, + Bytecode::MoveLoc(1), + vec![], + vec![Local::u64(10)], + vec![Local::Invalid, Local::u64(10)], + vec![Local::Invalid, Local::Invalid], + 1, + ); + + test_simple_instruction( + &mut vm, + Bytecode::StLoc(0), + vec![Local::bool(true)], + vec![], + vec![Local::Invalid], + vec![Local::bool(true)], + 1, + ); + + test_simple_instruction( + &mut vm, + Bytecode::StLoc(1), + vec![Local::u64(10)], + vec![], + vec![Local::Invalid, Local::Invalid], + vec![Local::Invalid, Local::u64(10)], + 1, + ); + + test_simple_instruction( + &mut vm, + Bytecode::Assert, + vec![Local::u64(42), Local::bool(true)], + vec![], + vec![], + vec![], + 1, + ); + + assert_eq!( + test_simple_instruction_impl( + &mut vm, + Bytecode::Assert, + vec![Local::u64(777), Local::bool(false)], + vec![], + vec![], + vec![], + 1 + ) + .unwrap() + .unwrap_err() + .err, + VMErrorKind::AssertionFailure(777) + ); +} + +#[test] +fn test_arith_instructions() { + let allocator = Arena::new(); + let module_cache = VMModuleCache::new(&allocator); + let (main_module, entry_idx) = create_fake_module(fake_script()); + let loaded_main = LoadedModule::new(main_module).unwrap(); + let entry_func = FunctionRef::new(&loaded_main, entry_idx).unwrap(); + let data_cache = FakeDataCache::new(); + + let mut vm = + TransactionExecutor::new(module_cache, &data_cache, TransactionMetadata::default()); + + vm.execution_stack.push_frame(entry_func); + + test_binop_instruction( + &mut vm, + Bytecode::Add, + vec![Local::u64(1), Local::u64(2)], + Local::u64(3), + ); + test_binop_instruction_overflow( + &mut vm, + Bytecode::Add, + vec![Local::u64(u64::max_value()), Local::u64(1)], + ); + + test_binop_instruction( + &mut vm, + Bytecode::Sub, + vec![Local::u64(10), Local::u64(2)], + Local::u64(8), + ); + test_binop_instruction_overflow(&mut vm, Bytecode::Sub, vec![Local::u64(0), Local::u64(1)]); + + test_binop_instruction( + &mut vm, + Bytecode::Mul, + vec![Local::u64(2), Local::u64(3)], + Local::u64(6), + ); + test_binop_instruction_overflow( + &mut vm, + Bytecode::Mul, + vec![Local::u64(u64::max_value() / 2), Local::u64(3)], + ); + + test_binop_instruction( + &mut vm, + Bytecode::Mod, + vec![Local::u64(10), Local::u64(4)], + Local::u64(2), + ); + test_binop_instruction_overflow(&mut vm, Bytecode::Mod, vec![Local::u64(1), Local::u64(0)]); + + test_binop_instruction( + &mut vm, + Bytecode::Div, + vec![Local::u64(6), Local::u64(2)], + Local::u64(3), + ); + test_binop_instruction_overflow(&mut vm, Bytecode::Div, vec![Local::u64(1), Local::u64(0)]); + + test_binop_instruction( + &mut vm, + Bytecode::BitOr, + vec![Local::u64(5), Local::u64(6)], + Local::u64(7), + ); + + test_binop_instruction( + &mut vm, + Bytecode::BitAnd, + vec![Local::u64(5), Local::u64(6)], + Local::u64(4), + ); + + test_binop_instruction( + &mut vm, + Bytecode::Xor, + vec![Local::u64(5), Local::u64(6)], + Local::u64(3), + ); + + test_binop_instruction( + &mut vm, + Bytecode::Or, + vec![Local::bool(false), Local::bool(true)], + Local::bool(true), + ); + + test_binop_instruction( + &mut vm, + Bytecode::Or, + vec![Local::bool(false), Local::bool(false)], + Local::bool(false), + ); + + test_binop_instruction( + &mut vm, + Bytecode::And, + vec![Local::bool(false), Local::bool(true)], + Local::bool(false), + ); + + test_binop_instruction( + &mut vm, + Bytecode::And, + vec![Local::bool(true), Local::bool(true)], + Local::bool(true), + ); + + test_binop_instruction( + &mut vm, + Bytecode::Eq, + vec![Local::bool(false), Local::bool(true)], + Local::bool(false), + ); + + test_binop_instruction( + &mut vm, + Bytecode::Eq, + vec![Local::u64(5), Local::u64(6)], + Local::bool(false), + ); + + test_binop_instruction( + &mut vm, + Bytecode::Neq, + vec![Local::bool(false), Local::bool(true)], + Local::bool(true), + ); + + test_binop_instruction( + &mut vm, + Bytecode::Neq, + vec![Local::u64(5), Local::u64(6)], + Local::bool(true), + ); + + test_binop_instruction( + &mut vm, + Bytecode::Lt, + vec![Local::u64(5), Local::u64(6)], + Local::bool(true), + ); + + test_binop_instruction( + &mut vm, + Bytecode::Lt, + vec![Local::u64(5), Local::u64(5)], + Local::bool(false), + ); + + test_binop_instruction( + &mut vm, + Bytecode::Gt, + vec![Local::u64(7), Local::u64(6)], + Local::bool(true), + ); + + test_binop_instruction( + &mut vm, + Bytecode::Gt, + vec![Local::u64(5), Local::u64(5)], + Local::bool(false), + ); + + test_binop_instruction( + &mut vm, + Bytecode::Le, + vec![Local::u64(5), Local::u64(6)], + Local::bool(true), + ); + + test_binop_instruction( + &mut vm, + Bytecode::Le, + vec![Local::u64(5), Local::u64(5)], + Local::bool(true), + ); + + test_binop_instruction( + &mut vm, + Bytecode::Ge, + vec![Local::u64(7), Local::u64(6)], + Local::bool(true), + ); + + test_binop_instruction( + &mut vm, + Bytecode::Ge, + vec![Local::u64(5), Local::u64(5)], + Local::bool(true), + ); +} + +fn fake_module_with_calls(sigs: Vec<(Vec, FunctionSignature)>) -> CompiledModule { + let mut names: Vec = sigs + .iter() + .enumerate() + .map(|(i, _)| format!("func{}", i)) + .collect(); + names.insert(0, "module".to_string()); + let function_defs = sigs + .iter() + .enumerate() + .map(|(i, _)| FunctionDefinition { + function: FunctionHandleIndex::new(i as u16), + flags: CodeUnit::PUBLIC, + code: CodeUnit { + max_stack_size: 10, + locals: LocalsSignatureIndex(i as u16), + code: vec![], + }, + }) + .collect(); + let function_handles = sigs + .iter() + .enumerate() + .map(|(i, _)| FunctionHandle { + name: StringPoolIndex::new((i + 1) as u16), + signature: FunctionSignatureIndex::new(i as u16), + module: ModuleHandleIndex::new(0), + }) + .collect(); + let (local_sigs, function_sigs): (Vec<_>, Vec<_>) = sigs.into_iter().unzip(); + CompiledModule { + function_defs, + field_defs: vec![], + struct_defs: vec![], + + module_handles: vec![ModuleHandle { + address: AddressPoolIndex::new(0), + name: StringPoolIndex::new(0), + }], + struct_handles: vec![], + function_handles, + type_signatures: vec![], + function_signatures: function_sigs, + locals_signatures: local_sigs.into_iter().map(LocalsSignature).collect(), + string_pool: names, + byte_array_pool: vec![], + address_pool: vec![AccountAddress::default()], + } +} + +#[test] +fn test_call() { + let module = fake_module_with_calls(vec![ + // () -> (), no local + ( + vec![], + FunctionSignature { + arg_types: vec![], + return_types: vec![], + }, + ), + // () -> (), two locals + ( + vec![SignatureToken::U64, SignatureToken::U64], + FunctionSignature { + arg_types: vec![], + return_types: vec![], + }, + ), + // (Int, Int) -> (), two locals, + ( + vec![SignatureToken::U64, SignatureToken::U64], + FunctionSignature { + arg_types: vec![SignatureToken::U64, SignatureToken::U64], + return_types: vec![], + }, + ), + // (Int, Int) -> (), three locals, + ( + vec![ + SignatureToken::U64, + SignatureToken::U64, + SignatureToken::Bool, + ], + FunctionSignature { + arg_types: vec![SignatureToken::U64, SignatureToken::U64], + return_types: vec![], + }, + ), + ]); + + let mod_id = module.self_code_key(); + let allocator = Arena::new(); + let module_cache = VMModuleCache::new_from_module(module, &allocator).unwrap(); + let fake_func = { + let fake_mod_entry = module_cache.get_loaded_module(&mod_id).unwrap().unwrap(); + module_cache + .resolve_function_ref(fake_mod_entry, FunctionHandleIndex::new(0)) + .unwrap() + .unwrap() + }; + let data_cache = FakeDataCache::new(); + let mut vm = + TransactionExecutor::new(module_cache, &data_cache, TransactionMetadata::default()); + vm.execution_stack.push_frame(fake_func); + + test_simple_instruction( + &mut vm, + Bytecode::Call(FunctionHandleIndex::new(0)), + vec![], + vec![], + vec![], + vec![], + 0, + ); + test_simple_instruction( + &mut vm, + Bytecode::Call(FunctionHandleIndex::new(1)), + vec![], + vec![], + vec![], + vec![Local::Invalid, Local::Invalid], + 0, + ); + test_simple_instruction( + &mut vm, + Bytecode::Call(FunctionHandleIndex::new(2)), + vec![Local::u64(5), Local::u64(4)], + vec![], + vec![], + vec![Local::u64(5), Local::u64(4)], + 0, + ); + test_simple_instruction( + &mut vm, + Bytecode::Call(FunctionHandleIndex::new(3)), + vec![Local::u64(5), Local::u64(4)], + vec![], + vec![], + vec![Local::u64(5), Local::u64(4), Local::Invalid], + 0, + ); +} + +#[test] +fn test_transaction_info() { + let allocator = Arena::new(); + let module_cache = VMModuleCache::new(&allocator); + let (main_module, entry_idx) = create_fake_module(fake_script()); + let loaded_main = LoadedModule::new(main_module).unwrap(); + let entry_func = FunctionRef::new(&loaded_main, entry_idx).unwrap(); + + let txn_info = { + let (_, public_key) = crypto::signing::generate_genesis_keypair(); + TransactionMetadata { + sender: AccountAddress::default(), + public_key, + sequence_number: 10, + max_gas_amount: 100_000_009, + gas_unit_price: 5, + transaction_size: 100, + } + }; + let data_cache = FakeDataCache::new(); + let mut vm = TransactionExecutor::new(module_cache, &data_cache, txn_info); + + vm.execution_stack.push_frame(entry_func); + + test_simple_instruction( + &mut vm, + Bytecode::GetTxnMaxGasUnits, + vec![], + vec![Local::u64(100_000_009)], + vec![], + vec![], + 1, + ); + + test_simple_instruction( + &mut vm, + Bytecode::GetTxnSequenceNumber, + vec![], + vec![Local::u64(10)], + vec![], + vec![], + 1, + ); + + test_simple_instruction( + &mut vm, + Bytecode::GetTxnGasUnitPrice, + vec![], + vec![Local::u64(5)], + vec![], + vec![], + 1, + ); +} diff --git a/language/vm/vm_runtime/src/unit_tests/type_prop_tests.rs b/language/vm/vm_runtime/src/unit_tests/type_prop_tests.rs new file mode 100644 index 0000000000000..9b5330c953eb5 --- /dev/null +++ b/language/vm/vm_runtime/src/unit_tests/type_prop_tests.rs @@ -0,0 +1,19 @@ +// Copyright (c) The Libra Core Contributors +// SPDX-License-Identifier: Apache-2.0 + +use crate::loaded_data::types::Type; +use canonical_serialization::*; +use proptest::prelude::*; + +proptest! { + #[test] + fn roundtrip(ty in any::()) { + let mut serializer = SimpleSerializer::new(); + ty.serialize(&mut serializer).expect("must serialize"); + let blob: Vec = serializer.get_output(); + + let mut deserializer = SimpleDeserializer::new(&blob); + let ty2 = Type::deserialize(&mut deserializer).expect("must deserialize"); + assert_eq!(ty, ty2); + } +} diff --git a/language/vm/vm_runtime/src/unit_tests/value_prop_tests.rs b/language/vm/vm_runtime/src/unit_tests/value_prop_tests.rs new file mode 100644 index 0000000000000..4d4fe537fb189 --- /dev/null +++ b/language/vm/vm_runtime/src/unit_tests/value_prop_tests.rs @@ -0,0 +1,15 @@ +// Copyright (c) The Libra Core Contributors +// SPDX-License-Identifier: Apache-2.0 + +use crate::value::Value; +use proptest::prelude::*; + +proptest! { + #[test] + fn flat_struct_test(value in Value::struct_strategy()) { + let struct_def = value.to_struct_def_FOR_TESTING(); + let blob = value.simple_serialize().expect("must serialize"); + let value1 = Value::simple_deserialize(&blob, struct_def).expect("must deserialize"); + assert_eq!(value, value1); + } +} diff --git a/language/vm/vm_runtime/src/unit_tests/value_tests.rs b/language/vm/vm_runtime/src/unit_tests/value_tests.rs new file mode 100644 index 0000000000000..2676dbda31b06 --- /dev/null +++ b/language/vm/vm_runtime/src/unit_tests/value_tests.rs @@ -0,0 +1,187 @@ +// Copyright (c) The Libra Core Contributors +// SPDX-License-Identifier: Apache-2.0 + +use super::*; +use std::rc::Rc; +use types::access_path::AccessPath; + +#[test] +fn test_simple_mutate() { + let v = Local::u64(1); + let v_ref = v.borrow_local().unwrap(); + let v2 = Local::u64(2); + v_ref.mutate_reference(v2.value().unwrap()); + assert_eq!(v, Local::u64(2)); +} + +#[test] +fn test_cloned_value() { + let v = Local::u64(1); + let v2 = v.clone(); + + let v_ref = v.borrow_local().unwrap(); + let v3 = Local::u64(2); + v_ref.mutate_reference(v3.value().unwrap()); + assert_eq!(v, Local::u64(2)); + assert_eq!(v2, Local::u64(1)); +} + +#[test] +fn test_cloned_references() { + let v = Local::u64(1); + + let v_ref = v.borrow_local().unwrap(); + let v_ref_clone = v_ref.clone(); + + let v3 = Local::u64(2); + v_ref.mutate_reference(v3.value().unwrap()); + assert_eq!(v, Local::u64(2)); + assert_eq!(v_ref_clone.read_reference().unwrap(), Local::u64(2)); +} + +#[test] +fn test_mutate_struct() { + let v_ref = Local::Ref(MutVal::new(Value::Struct( + vec![Local::u64(1), Local::u64(2)] + .into_iter() + .map(|v| v.value().unwrap()) + .collect(), + ))); + + let field_ref = v_ref.borrow_field(1).unwrap(); + + let v2 = Local::u64(3); + field_ref.mutate_reference(v2.value().unwrap()); + + let v_after = Local::struct_( + vec![Local::u64(1), Local::u64(3)] + .into_iter() + .map(|v| v.value().unwrap()) + .collect(), + ); + assert_eq!( + v_ref.read_reference().expect("must be a reference"), + v_after + ); +} + +#[test] +fn test_simple_global_ref() { + // make a global ref to a struct + let v = Value::Struct(vec![ + MutVal::new(Value::U64(1)), + MutVal::new(Value::U64(10)), + MutVal::new(Value::Bool(true)), + ]); + let ap = AccessPath::new(AccountAddress::new([1; 32]), vec![]); + let v_ref = MutVal::new(v); + // make a root + let root = GlobalRef::make_root(ap, v_ref); + assert_eq!(Rc::strong_count(&root.root), 1); + assert_eq!(root.root.borrow().ref_count, 0); + // get a reference to the root (BorrowGlobal) + let global_ref = root.shallow_clone(); + assert_eq!(Rc::strong_count(&root.root), 2); + assert_eq!(root.root.borrow().ref_count, 1); + + // get a copy, drop it and verify ref count in the process + let global_ref1 = global_ref.shallow_clone(); + assert_eq!(Rc::strong_count(&global_ref1.root), 3); + assert_eq!(Rc::strong_count(&global_ref.root), 3); + assert_eq!(root.root.borrow().ref_count, 2); + global_ref1 + .release_reference() + .expect("ref count must not be 0"); + assert_eq!(Rc::strong_count(&global_ref.root), 2); + assert_eq!(root.root.borrow().ref_count, 1); + assert_eq!(global_ref.is_dirty(), false); + + // get references to 2 fields and verify ref count + let field0_ref: GlobalRef; + { + let global_ref1 = global_ref.shallow_clone(); + field0_ref = global_ref1.borrow_field(0).expect("field must exist"); + } + // ref count to 3 because global_ref1 is dropped at the end of the block + assert_eq!(Rc::strong_count(&global_ref.root), 3); + assert_eq!(root.root.borrow().ref_count, 2); + let field1_ref: GlobalRef; + { + let global_ref1 = global_ref.shallow_clone(); + field1_ref = global_ref1.borrow_field(1).expect("field must exist"); + } + // ref count to 4 because global_ref1 is dropped at the end of the block + assert_eq!(Rc::strong_count(&global_ref.root), 4); + assert_eq!(root.root.borrow().ref_count, 3); + + // read reference to first field, verify value and ref count. read_reference() drops reference + let field0_val = field0_ref.read_reference(); + match &*field0_val.peek() { + Value::U64(i) => assert_eq!(*i, 1), + _ => unreachable!("value must be int"), + } + assert_eq!(Rc::strong_count(&global_ref.root), 3); + assert_eq!(root.root.borrow().ref_count, 2); + assert_eq!(global_ref.is_dirty(), false); + + // write reference to second field, verify value and ref count. + // mutate_reference() drops reference + field1_ref.mutate_reference(MutVal::new(Value::U64(100))); + assert_eq!(Rc::strong_count(&global_ref.root), 2); + assert_eq!(root.root.borrow().ref_count, 1); + assert_eq!(global_ref.is_dirty(), true); + let field1_ref: GlobalRef; + { + let global_ref1 = global_ref.shallow_clone(); + field1_ref = global_ref1.borrow_field(1).expect("field must exist"); + } + assert_eq!(Rc::strong_count(&global_ref.root), 3); + assert_eq!(root.root.borrow().ref_count, 2); + let field1_val = field1_ref.read_reference(); + match &*field1_val.peek() { + Value::U64(i) => assert_eq!(*i, 100), + _ => unreachable!("value must be int"), + } + // 1 reference left and dirty flag true + assert_eq!(Rc::strong_count(&global_ref.root), 2); + assert_eq!(root.root.borrow().ref_count, 1); + assert_eq!(global_ref.is_dirty(), true); + + // drop last reference (ReleaseRef) + global_ref + .release_reference() + .expect("ref count must not be 0"); + assert_eq!(Rc::strong_count(&root.root), 1); + assert_eq!(root.root.borrow().ref_count, 0); + assert_eq!(root.is_dirty(), true); +} + +#[test] +fn test_simple_global_ref_err() { + // make a global ref to a struct + let v = Value::Struct(vec![ + MutVal::new(Value::U64(1)), + MutVal::new(Value::U64(10)), + MutVal::new(Value::Bool(true)), + ]); + let ap = AccessPath::new(AccountAddress::new([1; 32]), vec![]); + let v_ref = MutVal::new(v); + // make a root + let root = GlobalRef::make_root(ap, v_ref); + assert_eq!(Rc::strong_count(&root.root), 1); + assert_eq!(root.root.borrow().ref_count, 0); + // get a reference to the root (BorrowGlobal) + let global_ref = root.shallow_clone(); + assert_eq!(Rc::strong_count(&root.root), 2); + assert_eq!(root.root.borrow().ref_count, 1); + + // drop last reference (ReleaseRef) + global_ref + .release_reference() + .expect("ref count must not be 0"); + assert_eq!(Rc::strong_count(&root.root), 1); + assert_eq!(root.root.borrow().ref_count, 0); + + // error on another ReleaseRef + assert!(root.release_reference().is_err()); +} diff --git a/language/vm/vm_runtime/src/unit_tests/vm_types.rs b/language/vm/vm_runtime/src/unit_tests/vm_types.rs new file mode 100644 index 0000000000000..5a7001b5eec65 --- /dev/null +++ b/language/vm/vm_runtime/src/unit_tests/vm_types.rs @@ -0,0 +1,39 @@ +// Copyright (c) The Libra Core Contributors +// SPDX-License-Identifier: Apache-2.0 + +use crate::value::{MutVal, Value}; +use canonical_serialization::SimpleDeserializer; +use types::{account_config::AccountResource, byte_array::ByteArray}; + +#[test] +fn account_type() { + // mimic an Account + let authentication_key = ByteArray::new(vec![5u8; 32]); + let balance = 128u64; + let received_events_count = 8u64; + let sent_events_count = 16u64; + let sequence_number = 32u64; + + let mut account_fields: Vec = Vec::new(); + account_fields.push(MutVal::bytearray(authentication_key.clone())); + let mut coin_fields: Vec = Vec::new(); + coin_fields.push(MutVal::u64(balance)); + account_fields.push(MutVal::struct_(coin_fields.clone())); + account_fields.push(MutVal::u64(received_events_count)); + account_fields.push(MutVal::u64(sent_events_count)); + account_fields.push(MutVal::u64(sequence_number)); + + let account = Value::Struct(account_fields); + let blob = &account.simple_serialize().expect("blob must serialize"); + + let account_resource: AccountResource = + SimpleDeserializer::deserialize(blob).expect("must deserialize"); + assert_eq!(*account_resource.authentication_key(), authentication_key); + assert_eq!(account_resource.balance(), balance); + assert_eq!( + account_resource.received_events_count(), + received_events_count + ); + assert_eq!(account_resource.sent_events_count(), sent_events_count); + assert_eq!(account_resource.sequence_number(), sequence_number); +} diff --git a/language/vm/vm_runtime/src/value.rs b/language/vm/vm_runtime/src/value.rs new file mode 100644 index 0000000000000..cf1b8ae741cf7 --- /dev/null +++ b/language/vm/vm_runtime/src/value.rs @@ -0,0 +1,534 @@ +// Copyright (c) The Libra Core Contributors +// SPDX-License-Identifier: Apache-2.0 + +use crate::loaded_data::{struct_def::StructDef, types::Type}; +use std::{ + cell::{Ref, RefCell}, + rc::Rc, +}; +use types::{ + access_path::AccessPath, + account_address::{AccountAddress, ADDRESS_LENGTH}, + byte_array::ByteArray, + contract_event::ContractEvent, +}; +use vm::{ + errors::*, + gas_schedule::{words_in, AbstractMemorySize, CONST_SIZE, REFERENCE_SIZE, STRUCT_SIZE}, +}; + +#[cfg(test)] +#[path = "unit_tests/value_prop_tests.rs"] +mod value_prop_tests; +#[cfg(test)] +#[path = "unit_tests/value_tests.rs"] +mod value_tests; +#[cfg(test)] +#[path = "unit_tests/vm_types.rs"] +mod vm_types; + +#[derive(PartialEq, Eq, Debug, Clone)] +pub enum Value { + Address(AccountAddress), + U64(u64), + Bool(bool), + String(String), + Struct(Vec), + ByteArray(ByteArray), +} + +impl Value { + fn size(&self) -> AbstractMemorySize { + match self { + Value::U64(_) | Value::Bool(_) => CONST_SIZE, + Value::Address(_) => ADDRESS_LENGTH as AbstractMemorySize, + // Possible debate topic: Should we charge based upon the size of the string. + // At this moment, we take the view that you should be charged as though you are + // copying the string onto the stack here. This doesn't replicate + // the semantics that we utilize currently, but this string may + // need to be copied at some later time, so we need to charge based + // upon the size of the memory that will possibly need to be accessed. + Value::String(s) => words_in(s.len() as AbstractMemorySize), + Value::Struct(vals) => vals.iter().fold(STRUCT_SIZE, |acc, vl| acc + vl.size()), + Value::ByteArray(key) => key.len() as AbstractMemorySize, + } + } + + /// Normal code should always know what type this value has. This is made available only for + /// tests. + #[allow(non_snake_case)] + #[doc(hidden)] + pub fn to_struct_def_FOR_TESTING(&self) -> StructDef { + let values = match self { + Value::Struct(values) => values, + _ => panic!("Value must be a struct {:?}", self), + }; + + let fields = values + .iter() + .map(|mut_val| { + let val = &*mut_val.peek(); + match val { + Value::Bool(_) => Type::Bool, + Value::Address(_) => Type::Address, + Value::U64(_) => Type::U64, + Value::String(_) => Type::String, + Value::ByteArray(_) => Type::ByteArray, + Value::Struct(_) => Type::Struct(val.to_struct_def_FOR_TESTING()), + } + }) + .collect(); + StructDef::new(fields) + } +} + +pub trait Reference +where + Self: std::marker::Sized + Clone + Eq, +{ + fn borrow_field(&self, idx: u32) -> Option; + fn read_reference(self) -> MutVal; + fn mutate_reference(self, v: MutVal); + + fn size(&self) -> AbstractMemorySize; +} + +#[derive(PartialEq, Eq, Debug)] +pub struct MutVal(pub Rc>); + +#[derive(PartialEq, Eq, Debug)] +pub enum Local { + Ref(MutVal), + GlobalRef(GlobalRef), + Value(MutVal), + Invalid, +} + +/// Status for on chain data (published resources): +/// CLEAN - the data was only read +/// DIRTY - the data was changed anywhere in the data tree of the given resource +/// DELETED - MoveFrom was called on the given AccessPath for the given resource +#[rustfmt::skip] +#[allow(non_camel_case_types)] +#[derive(PartialEq, Eq, Debug, Clone)] +enum GlobalDataStatus { + CLEAN = 0, + DIRTY = 1, + DELETED = 2, +} + +/// A root into an instance on chain. +/// Holds flags about the status of the instance and a reference count to balance +/// Borrow* and ReleaseRef +#[derive(PartialEq, Eq, Debug, Clone)] +pub struct RootAccessPath { + status: GlobalDataStatus, + ref_count: u64, + ap: AccessPath, +} + +/// A GlobalRef holds the reference to the data and a shared reference to the root so +/// status flags and reference count can be properly managed +#[derive(PartialEq, Eq, Debug, Clone)] +pub struct GlobalRef { + root: Rc>, + reference: MutVal, +} + +impl Clone for MutVal { + fn clone(&self) -> Self { + MutVal(Rc::new(RefCell::new(self.peek().clone()))) + } +} + +impl Clone for Local { + fn clone(&self) -> Self { + match self { + Local::Ref(v) => Local::Ref(v.shallow_clone()), + Local::GlobalRef(v) => Local::GlobalRef(v.shallow_clone()), + Local::Value(v) => Local::Value(v.clone()), + Local::Invalid => Local::Invalid, + } + } +} + +impl MutVal { + pub fn try_own(mv: Self) -> ::std::result::Result { + match Rc::try_unwrap(mv.0) { + Ok(cell) => Ok(cell.into_inner()), + Err(_) => Err(VMInvariantViolation::LocalReferenceError), + } + } + + pub fn peek(&self) -> Ref { + self.0.borrow() + } + + pub fn new(v: Value) -> Self { + MutVal(Rc::new(RefCell::new(v))) + } + + fn shallow_clone(&self) -> Self { + MutVal(Rc::clone(&self.0)) + } + + fn address(addr: AccountAddress) -> Self { + MutVal::new(Value::Address(addr)) + } + + fn u64(i: u64) -> Self { + MutVal::new(Value::U64(i)) + } + + fn bool(b: bool) -> Self { + MutVal::new(Value::Bool(b)) + } + + fn string(s: String) -> Self { + MutVal::new(Value::String(s)) + } + + fn struct_(v: Vec) -> Self { + MutVal::new(Value::Struct(v)) + } + + fn bytearray(v: ByteArray) -> Self { + MutVal::new(Value::ByteArray(v)) + } + + fn size(&self) -> AbstractMemorySize { + self.peek().size() + } +} + +impl Reference for MutVal { + fn borrow_field(&self, idx: u32) -> Option { + match &*self.peek() { + Value::Struct(ref vec) => vec.get(idx as usize).map(MutVal::shallow_clone), + _ => None, + } + } + + fn read_reference(self) -> MutVal { + self.clone() + } + + fn mutate_reference(self, v: MutVal) { + self.0.replace(v.peek().clone()); + } + + fn size(&self) -> AbstractMemorySize { + words_in(REFERENCE_SIZE as AbstractMemorySize) + } +} + +impl Local { + pub fn address(addr: AccountAddress) -> Self { + Local::Value(MutVal::address(addr)) + } + + pub fn u64(i: u64) -> Self { + Local::Value(MutVal::u64(i)) + } + + pub fn bool(b: bool) -> Self { + Local::Value(MutVal::bool(b)) + } + + pub fn string(s: String) -> Self { + Local::Value(MutVal::string(s)) + } + + pub fn struct_(v: Vec) -> Self { + Local::Value(MutVal::struct_(v)) + } + + pub fn bytearray(v: ByteArray) -> Self { + Local::Value(MutVal::bytearray(v)) + } + + pub fn borrow_local(&self) -> Option { + match self { + Local::Value(v) => Some(Local::Ref(v.shallow_clone())), + _ => None, + } + } + + pub fn borrow_field(&self, idx: u32) -> Option { + match self { + Local::Ref(v) => v.borrow_field(idx).map(Local::Ref), + Local::GlobalRef(v) => v.borrow_field(idx).map(Local::GlobalRef), + _ => None, + } + } + + pub fn read_reference(self) -> Option { + match self { + Local::Ref(r) => Some(Local::Value(r.read_reference())), + Local::GlobalRef(gr) => Some(Local::Value(gr.read_reference())), + _ => None, + } + } + + pub fn mutate_reference(self, v: MutVal) { + match self { + Local::Ref(r) => r.mutate_reference(v), + Local::GlobalRef(r) => r.mutate_reference(v), + _ => (), + } + } + + pub fn release_reference(self) -> Result<(), VMRuntimeError> { + if let Local::GlobalRef(r) = self { + r.release_reference() + } else { + Ok(()) + } + } + + pub fn emit_event_data(self, byte_array: ByteArray, data: MutVal) -> Option { + if let Local::GlobalRef(r) = self { + r.emit_event_data(byte_array, data) + } else { + None + } + } + + pub fn value(self) -> Option { + match self { + Local::Value(v) => Some(v), + _ => None, + } + } + + pub fn size(&self) -> AbstractMemorySize { + match self { + Local::Ref(v) => v.size(), + Local::GlobalRef(v) => v.size(), + Local::Value(v) => v.size(), + Local::Invalid => CONST_SIZE, + } + } +} + +impl RootAccessPath { + pub fn new(ap: AccessPath) -> Self { + RootAccessPath { + status: GlobalDataStatus::CLEAN, + ref_count: 0, + ap, + } + } + + fn mark_dirty(&mut self) { + self.status = GlobalDataStatus::DIRTY; + } + + fn mark_deleted(&mut self) { + self.status = GlobalDataStatus::DELETED; + } + + // REVIEW: check for overflow? + fn inc_ref_count(&mut self) { + self.ref_count += 1; + } + + // the check that the ref_count is already 0 is done in release_ref + fn dec_ref_count(&mut self) { + self.ref_count -= 1; + } + + fn emit_event_data( + &mut self, + byte_array: ByteArray, + counter: u64, + data: MutVal, + ) -> Option { + let blob = match data.peek().simple_serialize() { + Some(data) => data, + None => return None, + }; + let ap = AccessPath::new_for_event(self.ap.address, &self.ap.path, byte_array.as_bytes()); + Some(ContractEvent::new(ap, counter, blob)) + } +} + +impl GlobalRef { + pub fn make_root(ap: AccessPath, reference: MutVal) -> Self { + GlobalRef { + root: Rc::new(RefCell::new(RootAccessPath::new(ap))), + reference, + } + } + + pub fn move_to(ap: AccessPath, reference: MutVal) -> Self { + let mut root = RootAccessPath::new(ap); + root.mark_dirty(); + GlobalRef { + root: Rc::new(RefCell::new(root)), + reference, + } + } + + fn new_ref(root: &GlobalRef, reference: MutVal) -> Self { + // increment the global ref count + root.root.borrow_mut().inc_ref_count(); + GlobalRef { + root: Rc::clone(&root.root), + reference, + } + } + + // Return the resource behind the reference. + // If the reference is not exclusively held by the cache (ref count 0) returns None + pub fn get_data(self) -> Option { + if self.root.borrow().ref_count > 0 { + None + } else { + match Rc::try_unwrap(self.root) { + Ok(_) => match Rc::try_unwrap(self.reference.0) { + Ok(res) => Some(res.into_inner()), + Err(_) => None, + }, + Err(_) => None, + } + } + } + + pub fn is_loadable(&self) -> bool { + self.root.borrow().ref_count == 0 && !self.is_deleted() + } + + pub fn is_dirty(&self) -> bool { + self.root.borrow().status == GlobalDataStatus::DIRTY + } + + pub fn is_deleted(&self) -> bool { + self.root.borrow().status == GlobalDataStatus::DELETED + } + + pub fn is_clean(&self) -> bool { + self.root.borrow().status == GlobalDataStatus::CLEAN + } + + pub fn move_from(&mut self) -> MutVal { + self.root.borrow_mut().mark_deleted(); + self.reference.shallow_clone() + } + + pub fn shallow_clone(&self) -> Self { + // increment the global ref count + self.root.borrow_mut().inc_ref_count(); + GlobalRef { + root: Rc::clone(&self.root), + reference: self.reference.shallow_clone(), + } + } + + fn release_reference(self) -> Result<(), VMRuntimeError> { + if self.root.borrow().ref_count == 0 { + Err(VMRuntimeError { + loc: Location::new(), + err: VMErrorKind::GlobalRefAlreadyReleased, + }) + } else { + self.root.borrow_mut().dec_ref_count(); + Ok(()) + } + } + + fn emit_event_data(self, byte_array: ByteArray, data: MutVal) -> Option { + self.root.borrow_mut().dec_ref_count(); + let counter = match &*self.reference.peek() { + Value::U64(i) => *i, + _ => return None, + }; + self.reference.mutate_reference(MutVal::u64(counter + 1)); + self.root + .borrow_mut() + .emit_event_data(byte_array, counter, data) + } + + fn size(&self) -> AbstractMemorySize { + REFERENCE_SIZE + } +} + +impl Reference for GlobalRef { + fn borrow_field(&self, idx: u32) -> Option { + match &*self.reference.peek() { + Value::Struct(ref vec) => match vec.get(idx as usize) { + Some(field_ref) => { + self.root.borrow_mut().dec_ref_count(); + Some(GlobalRef::new_ref(self, field_ref.shallow_clone())) + } + None => None, + }, + _ => None, + } + } + + fn read_reference(self) -> MutVal { + self.root.borrow_mut().dec_ref_count(); + self.reference.clone() + } + + fn mutate_reference(self, v: MutVal) { + self.root.borrow_mut().dec_ref_count(); + self.root.borrow_mut().mark_dirty(); + self.reference.mutate_reference(v); + } + + fn size(&self) -> AbstractMemorySize { + words_in(REFERENCE_SIZE as AbstractMemorySize) + } +} + +// +// Conversion routines for the interpreter +// + +impl From for Option { + fn from(value: MutVal) -> Option { + match &*value.peek() { + Value::U64(i) => Some(*i), + _ => None, + } + } +} + +impl From for Option { + fn from(value: MutVal) -> Option { + match &*value.peek() { + Value::Bool(b) => Some(*b), + _ => None, + } + } +} + +impl From for Option { + fn from(value: MutVal) -> Option { + match *value.peek() { + Value::Address(addr) => Some(addr), + _ => None, + } + } +} + +impl From for Option { + fn from(value: MutVal) -> Option { + match &*value.peek() { + Value::ByteArray(blob) => Some(blob.clone()), + _ => None, + } + } +} + +impl From for Option { + fn from(value: GlobalRef) -> Option { + match *value.reference.peek() { + Value::Address(addr) => Some(addr), + _ => None, + } + } +} diff --git a/language/vm/vm_runtime/src/value_serializer.rs b/language/vm/vm_runtime/src/value_serializer.rs new file mode 100644 index 0000000000000..bc8c6cf3067f4 --- /dev/null +++ b/language/vm/vm_runtime/src/value_serializer.rs @@ -0,0 +1,145 @@ +// Copyright (c) The Libra Core Contributors +// SPDX-License-Identifier: Apache-2.0 + +use crate::{ + loaded_data::{struct_def::StructDef, types::Type}, + value::{MutVal, Value}, +}; +use canonical_serialization::*; +use failure::prelude::*; +use std::convert::TryFrom; +use types::{account_address::AccountAddress, byte_array::ByteArray}; +use vm::errors::*; + +impl Value { + /// Serialize this value using `SimpleSerializer`. + pub fn simple_serialize(&self) -> Option> { + SimpleSerializer::>::serialize(self).ok() + } + + /// Deserialize this value using `SimpleDeserializer` and a provided struct definition. + pub fn simple_deserialize(blob: &[u8], resource: StructDef) -> VMRuntimeResult { + let mut deserializer = SimpleDeserializer::new(blob); + deserialize_struct(&mut deserializer, &resource) + } +} + +fn deserialize_struct( + deserializer: &mut SimpleDeserializer, + struct_def: &StructDef, +) -> VMRuntimeResult { + let mut s_vals: Vec = Vec::new(); + for field_type in struct_def.field_definitions() { + match field_type { + Type::Bool => { + if let Ok(b) = deserializer.decode_bool() { + s_vals.push(MutVal::new(Value::Bool(b))); + } else { + return Err(VMRuntimeError { + loc: Location::new(), + err: VMErrorKind::DataFormatError, + }); + } + } + Type::U64 => { + if let Ok(val) = deserializer.decode_u64() { + s_vals.push(MutVal::new(Value::U64(val))); + } else { + return Err(VMRuntimeError { + loc: Location::new(), + err: VMErrorKind::DataFormatError, + }); + } + } + Type::String => { + if let Ok(bytes) = deserializer.decode_variable_length_bytes() { + if let Ok(s) = String::from_utf8(bytes) { + s_vals.push(MutVal::new(Value::String(s))); + continue; + } + } + return Err(VMRuntimeError { + loc: Location::new(), + err: VMErrorKind::DataFormatError, + }); + } + Type::ByteArray => { + if let Ok(bytes) = deserializer.decode_variable_length_bytes() { + s_vals.push(MutVal::new(Value::ByteArray(ByteArray::new(bytes)))); + continue; + } + return Err(VMRuntimeError { + loc: Location::new(), + err: VMErrorKind::DataFormatError, + }); + } + Type::Address => { + if let Ok(bytes) = deserializer.decode_variable_length_bytes() { + if let Ok(addr) = AccountAddress::try_from(bytes) { + s_vals.push(MutVal::new(Value::Address(addr))); + continue; + } + } + return Err(VMRuntimeError { + loc: Location::new(), + err: VMErrorKind::DataFormatError, + }); + } + Type::Struct(s_fields) => { + if let Ok(s) = deserialize_struct(deserializer, s_fields) { + s_vals.push(MutVal::new(s)); + } else { + return Err(VMRuntimeError { + loc: Location::new(), + err: VMErrorKind::DataFormatError, + }); + } + } + Type::Reference(_) => { + return Err(VMRuntimeError { + loc: Location::new(), + err: VMErrorKind::InvalidData, + }) + } + Type::MutableReference(_) => { + return Err(VMRuntimeError { + loc: Location::new(), + err: VMErrorKind::InvalidData, + }) + } + } + } + Ok(Value::Struct(s_vals)) +} + +impl CanonicalSerialize for Value { + fn serialize(&self, serializer: &mut impl CanonicalSerializer) -> Result<()> { + match self { + Value::Address(addr) => { + // TODO: this is serializing as a vector but we want just raw bytes + // however the AccountAddress story is a bit difficult to work with right now + serializer.encode_variable_length_bytes(addr.as_ref())?; + } + Value::Bool(b) => { + serializer.encode_bool(*b)?; + } + Value::U64(val) => { + serializer.encode_u64(*val)?; + } + Value::String(s) => { + // TODO: must define an api for canonical serializations of string. + // Right now we are just using Rust to serialize the string + serializer.encode_variable_length_bytes(s.as_bytes())?; + } + Value::Struct(vals) => { + for mut_val in vals { + (*mut_val.peek()).serialize(serializer)?; + } + } + Value::ByteArray(bytearray) => { + serializer.encode_variable_length_bytes(bytearray.as_bytes())?; + } + } + Ok(()) + } +} diff --git a/language/vm/vm_runtime/vm_cache_map/Cargo.toml b/language/vm/vm_runtime/vm_cache_map/Cargo.toml new file mode 100644 index 0000000000000..7035af135759d --- /dev/null +++ b/language/vm/vm_runtime/vm_cache_map/Cargo.toml @@ -0,0 +1,16 @@ +[package] +name = "vm_cache_map" +version = "0.1.0" +authors = ["Libra Association "] +license = "Apache-2.0" +publish = false +edition = "2018" + +[dependencies] +chashmap = "2.2.2" +typed-arena = "1.4.1" + +[dev-dependencies] +crossbeam = "0.7" +proptest = "0.9" +rand = "0.6" diff --git a/language/vm/vm_runtime/vm_cache_map/src/arena.rs b/language/vm/vm_runtime/vm_cache_map/src/arena.rs new file mode 100644 index 0000000000000..cb5234c472e53 --- /dev/null +++ b/language/vm/vm_runtime/vm_cache_map/src/arena.rs @@ -0,0 +1,60 @@ +// Copyright (c) The Libra Core Contributors +// SPDX-License-Identifier: Apache-2.0 + +use std::sync::Mutex; +use typed_arena::Arena as TypedArena; + +/// A thread-safe variant of `typed_arena::Arena`. +/// +/// This implements `Send` and `Sync` if `T` is `Send`. +pub struct Arena { + inner: Mutex>, +} + +impl Default for Arena { + fn default() -> Self { + Self::new() + } +} + +impl Arena { + #[inline] + pub fn new() -> Self { + Self { + inner: Mutex::new(TypedArena::new()), + } + } + + #[inline] + pub fn with_capacity(n: usize) -> Self { + Self { + inner: Mutex::new(TypedArena::with_capacity(n)), + } + } + + // This is safe because it's part of the API design. + #[allow(clippy::mut_from_ref)] + pub fn alloc(&self, value: T) -> &mut T { + let arena = self.inner.lock().expect("lock poisoned"); + let value = arena.alloc(value); + // Extend the lifetime of the value to that of the arena. typed_arena::Arena guarantees + // that the value will never be moved out from underneath, and this wrapper guarantees + // that the arena will not be dropped. + unsafe { ::std::mem::transmute::<&mut T, &mut T>(value) } + } + + #[inline] + pub fn into_vec(self) -> Vec { + let arena = self.inner.into_inner().expect("lock poisoned"); + arena.into_vec() + } +} + +#[test] +fn arena_thread_safe() { + fn assert_send() {} + fn assert_sync() {} + + assert_send::>(); + assert_sync::>(); +} diff --git a/language/vm/vm_runtime/vm_cache_map/src/cache_map.rs b/language/vm/vm_runtime/vm_cache_map/src/cache_map.rs new file mode 100644 index 0000000000000..8eff68042e24d --- /dev/null +++ b/language/vm/vm_runtime/vm_cache_map/src/cache_map.rs @@ -0,0 +1,140 @@ +// Copyright (c) The Libra Core Contributors +// SPDX-License-Identifier: Apache-2.0 + +use crate::Arena; +use chashmap::CHashMap; +use std::{borrow::Borrow, hash::Hash}; + +/// The most common case of `CacheMap`, where references to stored values are handed out. +pub type CacheRefMap<'a, K, V> = CacheMap<'a, K, V, &'a V>; + +/// A map custom designed for the VM runtime caches. Allocations are done in an Arena instead of +/// directly in a hash table, which allows for new entries to be added while existing entries are +/// borrowed out. +/// +/// TODO: Entry-like API? Current one is somewhat awkward to use. +/// TODO: eviction -- how to do it safely? +/// TODO: should the map own the arena? +pub struct CacheMap<'a, K, V, W> { + alloc: &'a Arena, + map: CHashMap, +} + +impl<'a, K, V, W> CacheMap<'a, K, V, W> +where + K: Hash + PartialEq, + W: Clone, +{ + #[inline] + pub fn new(alloc: &'a Arena) -> Self { + Self { + alloc, + map: CHashMap::new(), + } + } + + /// Get the value of the given key in the map. + #[inline] + pub fn get(&self, key: &Q) -> Option + where + K: Borrow, + Q: Hash + PartialEq, + { + self.map.get(key).map(|value| (*value).clone()) + } + + /// Try inserting the value V if missing. The insert function is not called if the value is + /// present. + /// + /// The first value is picked to avoid multiple cached results floating around. This assumes + /// that the cache is immutable (i.e. there's no invalidation). + /// + /// Returns a reference to the inserted value. + pub fn or_insert_with_transform(&self, key: K, insert: F, transform: G) -> W + where + F: FnOnce() -> V, + G: FnOnce(&'a V) -> W, + { + let mut ret: Option = None; + let ret_mut = &mut ret; + self.map.alter(key, move |value| match value { + Some(value) => { + ret_mut.replace(value.clone()); + Some(value) + } + None => { + let alloc_value: &'a V = self.alloc.alloc(insert()); + let value = transform(alloc_value); + ret_mut.replace(value.clone()); + Some(value) + } + }); + ret.expect("return value should always be initialized") + } + + /// A version of insert_with_transform where the transform can fail. If it does then the value + /// is not inserted into the map and is left allocated as garbage in the arena. + pub fn or_insert_with_try_transform( + &self, + key: K, + insert: F, + try_transform: G, + ) -> Result + where + F: FnOnce() -> V, + G: FnOnce(&'a V) -> Result, + { + let mut ret: Option> = None; + let ret_mut = &mut ret; + self.map.alter(key, move |value| match value { + Some(value) => { + ret_mut.replace(Ok(value.clone())); + Some(value) + } + None => { + let alloc_value: &'a V = self.alloc.alloc(insert()); + let res = try_transform(alloc_value); + let (cloned_res, stored_value) = match res { + Ok(value) => (Ok(value.clone()), Some(value)), + Err(err) => (Err(err), None), + }; + ret_mut.replace(cloned_res); + stored_value + } + }); + ret.expect("return value should always be initialized") + } +} + +impl<'a, K, V> CacheRefMap<'a, K, V> +where + K: Hash + PartialEq, +{ + /// Insert the value if not present. Discard the value if present. + /// + /// The first value is picked to avoid multiple cached results floating around. This assumes + /// that the cache is immutable (i.e. there's no invalidation). + /// + /// Returns the address of the inserted value. + #[inline] + pub fn or_insert(&self, key: K, value: V) -> &'a V { + self.or_insert_with_transform(key, move || value, |value_ref| value_ref) + } + + #[inline] + pub fn or_insert_with(&self, key: K, insert: F) -> &'a V + where + F: FnOnce() -> V, + { + self.or_insert_with_transform(key, insert, |value_ref| value_ref) + } +} + +#[test] +fn cache_map_thread_safe() { + fn assert_send() {} + fn assert_sync() {} + + assert_send::>(); + assert_sync::>(); +} diff --git a/language/vm/vm_runtime/vm_cache_map/src/lib.rs b/language/vm/vm_runtime/vm_cache_map/src/lib.rs new file mode 100644 index 0000000000000..57bae7a91856d --- /dev/null +++ b/language/vm/vm_runtime/vm_cache_map/src/lib.rs @@ -0,0 +1,13 @@ +// Copyright (c) The Libra Core Contributors +// SPDX-License-Identifier: Apache-2.0 + +//! A map used for VM runtime caches. The data structures are highly specialized for the +//! VM runtime and are probably not useful outside of it. + +mod arena; +mod cache_map; +#[cfg(test)] +mod unit_tests; + +pub use arena::Arena; +pub use cache_map::{CacheMap, CacheRefMap}; diff --git a/language/vm/vm_runtime/vm_cache_map/src/unit_tests/arena_tests.rs b/language/vm/vm_runtime/vm_cache_map/src/unit_tests/arena_tests.rs new file mode 100644 index 0000000000000..cfd57b1d9915d --- /dev/null +++ b/language/vm/vm_runtime/vm_cache_map/src/unit_tests/arena_tests.rs @@ -0,0 +1,36 @@ +// Copyright (c) The Libra Core Contributors +// SPDX-License-Identifier: Apache-2.0 + +use crate::Arena; +use crossbeam::scope; +use proptest::{collection::vec, prelude::*}; + +proptest! { + #[test] + fn one_thread(strings in vec(".*", 0..50)) { + let arena: Arena = Arena::new(); + for string in strings { + prop_assert_eq!(arena.alloc(string.clone()), &string); + } + } +} + +proptest! { + #![proptest_config(ProptestConfig::with_cases(32))] + #[test] + fn many_threads(string_vecs in vec(vec(".*", 0..50), 0..16)) { + let arena: Arena = Arena::new(); + let arena_ref = &arena; + let res = scope(|s| { + for strings in &string_vecs { + s.spawn(move |_| { + for string in strings { + prop_assert_eq!(arena_ref.alloc(string.clone()), string); + } + Ok(()) + }); + } + }); + res.expect("threads should succeed"); + } +} diff --git a/language/vm/vm_runtime/vm_cache_map/src/unit_tests/cache_map_tests.rs b/language/vm/vm_runtime/vm_cache_map/src/unit_tests/cache_map_tests.rs new file mode 100644 index 0000000000000..a53cc534f0ada --- /dev/null +++ b/language/vm/vm_runtime/vm_cache_map/src/unit_tests/cache_map_tests.rs @@ -0,0 +1,84 @@ +// Copyright (c) The Libra Core Contributors +// SPDX-License-Identifier: Apache-2.0 + +use crate::{Arena, CacheRefMap}; +use crossbeam::scope; +use proptest::{ + collection::{hash_map, vec}, + prelude::*, +}; +use rand::{seq::SliceRandom, thread_rng}; + +const NUM_THREADS: usize = 8; + +proptest! { + #![proptest_config(ProptestConfig::with_cases(32))] + #[test] + fn or_insert(kv_pairs in hash_map(".*", ".*", 0..100)) { + let arena = Arena::new(); + let map = CacheRefMap::new(&arena); + for (key, value) in kv_pairs { + prop_assert_eq!(map.get(&key), None); + prop_assert_eq!(map.or_insert(key.clone(), value.clone()), &value); + prop_assert_eq!(map.get(&key), Some(&value)); + } + } + + #[test] + fn or_insert_duplicates(kv_lists in hash_map(".*", vec(".*", 1..4), 0..100)) { + let arena = Arena::new(); + let map = CacheRefMap::new(&arena); + for (key, values) in kv_lists { + let first = values[0].clone(); + prop_assert_eq!(map.get(&key), None); + prop_assert_eq!(map.or_insert(key.clone(), first.clone()), &first); + prop_assert_eq!(map.get(&key), Some(&first)); + + // Further values for the same key should be ignored. + for value in values.into_iter().skip(1) { + prop_assert_eq!(map.or_insert_with(key.clone(), || value.clone()), &first); + prop_assert_eq!(map.or_insert(key.clone(), value), &first); + } + } + } + + #[test] + fn or_insert_many_threads(kv_lists in hash_map(".*", vec(".*", NUM_THREADS), 0..50)) { + // Try inserting to the list concurrently with NUM_THREADS threads. + let arena: Arena = Arena::new(); + let map: CacheRefMap = CacheRefMap::new(&arena); + let map_ref = ↦ + + let mut kv_thread_lists: Vec> = vec![vec![]; NUM_THREADS]; + for (key, values) in &kv_lists { + for (idx, value) in values.iter().enumerate() { + kv_thread_lists[idx].push((key.clone(), value.to_string())); + } + } + + // Shuffle the lists so each thread gets a chance to insert the first value. + let mut rng = thread_rng(); + for kv_pairs in &mut kv_thread_lists { + kv_pairs.shuffle(&mut rng); + } + + let res = scope(move |s| { + for kv_pairs in kv_thread_lists { + s.spawn(move |_| { + for (key, value) in kv_pairs { + // This is nondeterministic so can't really be compared. + map_ref.or_insert(key, value); + } + Ok::<_, ()>(()) + }); + } + }); + res.expect("threads should succeed"); + + // The final value for each key should be one of the values in kv_lists. + for (key, values) in &kv_lists { + let cached_value = map.get(key).expect("at least one value should have been cached"); + prop_assert!(values.contains(cached_value)); + } + } +} diff --git a/language/vm/vm_runtime/vm_cache_map/src/unit_tests/mod.rs b/language/vm/vm_runtime/vm_cache_map/src/unit_tests/mod.rs new file mode 100644 index 0000000000000..bed89805d83a6 --- /dev/null +++ b/language/vm/vm_runtime/vm_cache_map/src/unit_tests/mod.rs @@ -0,0 +1,5 @@ +// Copyright (c) The Libra Core Contributors +// SPDX-License-Identifier: Apache-2.0 + +mod arena_tests; +mod cache_map_tests; diff --git a/language/vm/vm_runtime/vm_runtime_tests/Cargo.toml b/language/vm/vm_runtime/vm_runtime_tests/Cargo.toml new file mode 100644 index 0000000000000..cdfbc6b742a63 --- /dev/null +++ b/language/vm/vm_runtime/vm_runtime_tests/Cargo.toml @@ -0,0 +1,35 @@ +[package] +name = "vm_runtime_tests" +version = "0.1.0" +edition = "2018" +authors = ["Libra Association "] +license = "Apache-2.0" +publish = false + +[dependencies] +canonical_serialization = { path = "../../../../common/canonical_serialization" } +crypto = { path = "../../../../crypto/legacy_crypto"} +failure = { path = "../../../../common/failure_ext", package = "failure_ext" } +compiler = { path = "../../../compiler"} +lazy_static = "1.3.0" +state_view = { path = "../../../../storage/state_view" } +types = { path = "../../../../types" } +vm = { path = "../../" } +vm_runtime = { path = "../" } +proptest = "0.9.3" +proptest-derive = "0.1.1" +assert_matches = "1.3.0" +proptest_helpers = { path = "../../../../common/proptest_helpers" } +protobuf = "2.6" +proto_conv = { path = "../../../../common/proto_conv", features = ["derive"] } +tiny-keccak = "1.4.2" +vm_genesis = { path = "../../vm_genesis"} +hex = "0.3.2" +getopts = "0.2.18" +config = { path = "../../../../config"} +logger = { path = "../../../../common/logger" } +stdlib = { path = "../../../stdlib" } + +[[bin]] +name = "vm_repl" +path = "src/bin/repl.rs" diff --git a/language/vm/vm_runtime/vm_runtime_tests/src/account.rs b/language/vm/vm_runtime/vm_runtime_tests/src/account.rs new file mode 100644 index 0000000000000..ad92ca6094d94 --- /dev/null +++ b/language/vm/vm_runtime/vm_runtime_tests/src/account.rs @@ -0,0 +1,413 @@ +// Copyright (c) The Libra Core Contributors +// SPDX-License-Identifier: Apache-2.0 + +//! Test infrastructure for modeling Libra accounts. + +use crypto::{PrivateKey, PublicKey}; +use lazy_static::lazy_static; +use std::{convert::TryInto, time::Duration}; +use types::{ + access_path::AccessPath, + account_address::AccountAddress, + account_config, + byte_array::ByteArray, + transaction::{Program, RawTransaction, SignedTransaction, TransactionArgument}, +}; +use vm_genesis::GENESIS_KEYPAIR; +use vm_runtime::{ + identifier::create_access_path, + loaded_data::struct_def::StructDef, + value::{MutVal, Value}, +}; + +// StdLib account, it is where the code is and needed to make access path to Account resources +lazy_static! { + static ref STDLIB_ADDRESS: AccountAddress = { account_config::core_code_address() }; +} + +/// Details about a Libra account. +/// +/// Tests will typically create a set of `Account` instances to run transactions on. This type +/// encodes the logic to operate on and verify operations on any Libra account. +#[derive(Debug, Clone, Eq, PartialEq)] +pub struct Account { + addr: AccountAddress, + /// The current private key for this account. + pub privkey: PrivateKey, + /// The current public key for this account. + pub pubkey: PublicKey, +} + +impl Account { + /// Creates a new account in memory. + /// + /// The account returned by this constructor is a purely logical entity, meaning that it does + /// not automatically get added to the Libra store. To add an account to the store, use + /// [`AccountData`] instances with + /// [`FakeExecutor::add_account_data`][crate::executor::FakeExecutor::add_account_data]. + pub fn new() -> Self { + let (privkey, pubkey) = crypto::signing::generate_keypair(); + Self::with_keypair(privkey, pubkey) + } + + /// Creates a new account with the given keypair. + /// + /// Like with [`Account::new`], the account returned by this constructor is a purely logical + /// entity. + pub fn with_keypair(privkey: PrivateKey, pubkey: PublicKey) -> Self { + let addr = pubkey.into(); + Account { + addr, + privkey, + pubkey, + } + } + + /// Creates a new account representing the association in memory. + /// + /// The address will be [`association_address`][account_config::association_address], and + /// the account will use [`GENESIS_KEYPAIR`][struct@GENESIS_KEYPAIR] as its keypair. + pub fn new_association() -> Self { + Account { + addr: account_config::association_address(), + pubkey: GENESIS_KEYPAIR.1, + privkey: GENESIS_KEYPAIR.0.clone(), + } + } + + /// Returns the address of the account. This is a hash of the public key the account was created + /// with. + /// + /// The address does not change if the account's [keys are rotated][Account::rotate_key]. + pub fn address(&self) -> &AccountAddress { + &self.addr + } + + /// Returns the AccessPath that describes the Account resource instance. + /// + /// Use this to retrieve or publish the Account blob. + // TODO: plug in the account type + pub fn make_access_path(&self) -> AccessPath { + // TODO: we need a way to get the type (StructDef) of the Account in place + create_access_path(&self.addr, account_config::account_struct_tag()) + } + + /// Changes the keys for this account to the provided ones. + pub fn rotate_key(&mut self, privkey: PrivateKey, pubkey: PublicKey) { + self.privkey = privkey; + self.pubkey = pubkey; + } + + /// Computes the authentication key for this account, as stored on the chain. + /// + /// This is the same as the account's address if the keys have never been rotated. + pub fn auth_key(&self) -> AccountAddress { + AccountAddress::from(self.pubkey) + } + + // + // Helpers to read data from an Account resource + // + + // + // Helpers for transaction creation with Account instance as sender + // + + /// Returns a [`SignedTransaction`] with no arguments and this account as the sender. + pub fn create_signed_txn( + &self, + program: Vec, + sequence_number: u64, + max_gas_amount: u64, + gas_unit_price: u64, + ) -> SignedTransaction { + self.create_signed_txn_with_args( + program, + vec![], + sequence_number, + max_gas_amount, + gas_unit_price, + ) + } + + /// Returns a [`SignedTransaction`] with the arguments defined in `args` and this account as + /// the sender. + pub fn create_signed_txn_with_args( + &self, + program: Vec, + args: Vec, + sequence_number: u64, + max_gas_amount: u64, + gas_unit_price: u64, + ) -> SignedTransaction { + self.create_signed_txn_impl( + *self.address(), + Program::new(program, vec![], args), + sequence_number, + max_gas_amount, + gas_unit_price, + ) + } + + /// Returns a [`SignedTransaction`] with the arguments defined in `args` and a custom sender. + /// + /// The transaction is signed with the key corresponding to this account, not the custom sender. + pub fn create_signed_txn_with_args_and_sender( + &self, + sender: AccountAddress, + program: Vec, + args: Vec, + sequence_number: u64, + max_gas_amount: u64, + gas_unit_price: u64, + ) -> SignedTransaction { + self.create_signed_txn_impl( + sender, + Program::new(program, vec![], args), + sequence_number, + max_gas_amount, + gas_unit_price, + ) + } + + /// Returns a [`SignedTransaction`] with the arguments defined in `args` and a custom sender. + /// + /// The transaction is signed with the key corresponding to this account, not the custom sender. + pub fn create_signed_txn_impl( + &self, + sender: AccountAddress, + program: Program, + sequence_number: u64, + max_gas_amount: u64, + gas_unit_price: u64, + ) -> SignedTransaction { + RawTransaction::new( + sender, + sequence_number, + program, + max_gas_amount, + gas_unit_price, + Duration::from_secs(u64::max_value()), + ) + .sign(&self.privkey, self.pubkey) + .unwrap() + } + + /// Given a blob, materializes the VM Value behind it. + pub(crate) fn read_account_resource(blob: &[u8], account_type: StructDef) -> Option { + match Value::simple_deserialize(blob, account_type) { + Ok(account) => Some(account), + Err(_) => None, + } + } +} + +impl Default for Account { + fn default() -> Self { + Self::new() + } +} + +/// Represents an account along with initial state about it. +/// +/// `AccountData` captures the initial state needed to create accounts for tests. +#[derive(Clone, Debug, Eq, PartialEq)] +pub struct AccountData { + account: Account, + balance: u64, + sequence_number: u64, + sent_events_count: u64, + received_events_count: u64, +} + +impl AccountData { + /// Creates a new `AccountData` with a new account. + /// + /// Most tests will want to use this constructor. + pub fn new(balance: u64, sequence_number: u64) -> Self { + Self::with_account(Account::new(), balance, sequence_number) + } + + /// Creates a new `AccountData` with the provided account. + pub fn with_account(account: Account, balance: u64, sequence_number: u64) -> Self { + Self::with_account_and_event_counts(account, balance, sequence_number, 0, 0) + } + + /// Creates a new `AccountData` with custom parameters. + pub fn with_account_and_event_counts( + account: Account, + balance: u64, + sequence_number: u64, + sent_events_count: u64, + received_events_count: u64, + ) -> Self { + Self { + account, + balance, + sequence_number, + sent_events_count, + received_events_count, + } + } + + /// Changes the keys for this account to the provided ones. + pub fn rotate_key(&mut self, privkey: PrivateKey, pubkey: PublicKey) { + self.account.rotate_key(privkey, pubkey) + } + + /// Creates and returns a resource [`Value`] for this data. + pub fn to_resource(&self) -> Value { + // TODO: publish some concept of Account + let coin = Value::Struct(vec![MutVal::new(Value::U64(self.balance))]); + Value::Struct(vec![ + MutVal::new(Value::ByteArray(ByteArray::new( + AccountAddress::from(self.account.pubkey).to_vec(), + ))), + MutVal::new(coin), + MutVal::new(Value::U64(self.received_events_count)), + MutVal::new(Value::U64(self.sent_events_count)), + MutVal::new(Value::U64(self.sequence_number)), + ]) + } + + /// Returns the AccessPath that describes the Account resource instance. + /// + /// Use this to retrieve or publish the Account blob. + // TODO: plug in the account type + pub fn make_access_path(&self) -> AccessPath { + self.account.make_access_path() + } + + /// Returns the address of the account. This is a hash of the public key the account was created + /// with. + /// + /// The address does not change if the account's [keys are rotated][AccountData::rotate_key]. + pub fn address(&self) -> &AccountAddress { + self.account.address() + } + + /// Returns the underlying [`Account`] instance. + pub fn account(&self) -> &Account { + &self.account + } + + /// Converts this data into an `Account` instance. + pub fn into_account(self) -> Account { + self.account + } + + /// Returns the initial balance. + pub fn balance(&self) -> u64 { + self.balance + } + + /// Returns the initial sequence number. + pub fn sequence_number(&self) -> u64 { + self.sequence_number + } + + /// Returns the initial sent events count. + pub fn sent_events_count(&self) -> u64 { + self.sent_events_count + } + + /// Returns the initial received events count. + pub fn received_events_count(&self) -> u64 { + self.received_events_count + } +} + +/// Helper methods for dealing with account resources as seen by the Libra VM. +pub enum AccountResource {} + +impl AccountResource { + /// Returns the authentication key read from a [`Value`] representing the account. + pub fn read_auth_key(account: &Value) -> AccountAddress { + // The return type is slightly confusing -- the auth key stored on the chain is actually + // just a hash of the public key from the account. This may change in the future with + // flexible authentication. + match account { + Value::Struct(fields) => { + let auth_key = fields.get(0).expect("Auth key must be field 0 in Account"); + match &*auth_key.peek() { + Value::ByteArray(bytes) => bytes + .as_bytes() + .try_into() + .expect("Auth key must be parseable as an account address"), + _ => panic!("auth key must be a ByteArray"), + } + } + _ => panic!("Account must be a Value::Struct"), + } + } + + /// Returns the balance read from a [`Value`] representing the account. + pub fn read_balance(account: &Value) -> u64 { + match account { + Value::Struct(fields) => { + let coin = fields + .get(1) + .expect("LibraCoin.T must be the second field in Account"); + match &*coin.peek() { + Value::Struct(balance) => { + let value = balance.get(0).expect("balance field must exist"); + match &*value.peek() { + Value::U64(val) => *val, + _ => panic!("balance field must exist"), + } + } + _ => panic!("account must contain LibraCoin.T as second field"), + } + } + _ => panic!("Account must be a Value::Struct"), + } + } + + /// Returns the received events count read from a [`Value`] representing the account. + pub fn read_received_events_count(account: &Value) -> u64 { + match account { + Value::Struct(fields) => { + let received_events_count = fields + .get(2) + .expect("received_events_count must be field 2 in Account"); + match &*received_events_count.peek() { + Value::U64(val) => *val, + _ => panic!("sequence number field must exist"), + } + } + _ => panic!("Account must be a Value::Struct"), + } + } + + /// Returns the sent events count read from a [`Value`] representing the account. + pub fn read_sent_events_count(account: &Value) -> u64 { + match account { + Value::Struct(fields) => { + let sent_events_count = fields + .get(3) + .expect("sent_events_count must be field 3 in Account"); + match &*sent_events_count.peek() { + Value::U64(val) => *val, + _ => panic!("sequence number field must exist"), + } + } + _ => panic!("Account must be a Value::Struct"), + } + } + + /// Returns the sequence number read from a [`Value`] representing the account. + pub fn read_sequence_number(account: &Value) -> u64 { + match account { + Value::Struct(fields) => { + let sequence_number = fields + .get(4) + .expect("sequence number must be the fifth field in Account"); + match &*sequence_number.peek() { + Value::U64(val) => *val, + _ => panic!("sequence number field must exist"), + } + } + _ => panic!("Account must be a Value::Struct"), + } + } +} diff --git a/language/vm/vm_runtime/vm_runtime_tests/src/account_universe.rs b/language/vm/vm_runtime/vm_runtime_tests/src/account_universe.rs new file mode 100644 index 0000000000000..cfe24034f78c5 --- /dev/null +++ b/language/vm/vm_runtime/vm_runtime_tests/src/account_universe.rs @@ -0,0 +1,443 @@ +// Copyright (c) The Libra Core Contributors +// SPDX-License-Identifier: Apache-2.0 + +//! A model to test properties of common Libra transactions. +//! +//! The structs and functions in this module together form a simplified *model* of how common Libra +//! transactions should behave. This model is then used as an *oracle* for property-based tests -- +//! the results of executing transactions through the VM should match the results computed using +//! this model. +//! +//! For examples of property-based tests written against this model, see the +//! `tests/account_universe` directory. + +// clippy warns on the Arbitrary impl for `AccountPairGen` -- it's how Arbitrary works so ignore it. +#![allow(clippy::unit_arg)] + +mod create_account; +mod peer_to_peer; +mod rotate_key; +pub use create_account::*; +pub use peer_to_peer::*; +pub use rotate_key::*; + +use crate::{ + account::{Account, AccountData}, + executor::FakeExecutor, + gas_costs, +}; +use crypto::{PrivateKey, PublicKey}; +use lazy_static::lazy_static; +use proptest::{ + collection::{vec, SizeRange}, + prelude::*, + strategy::Union, +}; +use proptest_derive::Arbitrary; +use proptest_helpers::{pick_slice_idxs, Index}; +use std::fmt; +use types::{ + transaction::{SignedTransaction, TransactionStatus}, + vm_error::{ExecutionStatus, VMStatus, VMValidationStatus}, +}; + +lazy_static! { + static ref UNIVERSE_SIZE: usize = { + use std::{env, process::abort}; + + match env::var("UNIVERSE_SIZE") { + Ok(s) => match s.parse::() { + Ok(val) => val, + Err(err) => { + println!("Could not parse universe size, aborting: {:?}", err); + // Abort because lazy_static with panics causes poisoning and isn't very + // helpful overall. + abort(); + } + } + Err(env::VarError::NotPresent) => 20, + Err(err) => { + println!("Could not read universe size from the environment, aborting: {:?}", err); + abort(); + } + } + }; +} + +/// The number of accounts to run universe-based proptests with. Set with the `UNIVERSE_SIZE` +/// environment variable. +/// +/// Larger values will provide greater testing but will take longer to run and shrink. Release mode +/// is recommended for values above 100. +#[inline] +pub(crate) fn num_accounts() -> usize { + *UNIVERSE_SIZE +} + +/// The number of transactions to run universe-based proptests with. Set with the `UNIVERSE_SIZE` +/// environment variable (this function will return twice that). +/// +/// Larger values will provide greater testing but will take longer to run and shrink. Release mode +/// is recommended for values above 100. +#[inline] +pub(crate) fn num_transactions() -> usize { + *UNIVERSE_SIZE * 2 +} + +/// A set of accounts which can be used to construct an initial state. +/// +/// For more, see the [`account_universe` module documentation][self]. +#[derive(Clone, Debug)] +pub struct AccountUniverseGen { + accounts: Vec, +} + +/// A set of accounts that has been set up and can now be used to conduct transactions on. +/// +/// For more, see the [`account_universe` module documentation][self]. +#[derive(Clone, Debug)] +pub struct AccountUniverse { + accounts: Vec, + /// Whether to ignore any new accounts that transactions add to the universe. + ignore_new_accounts: bool, +} + +/// Represents any sort of transaction that can be done in an account universe. +pub trait AUTransactionGen: fmt::Debug { + /// Applies this transaction onto the universe, updating balances within the universe as + /// necessary. Returns a signed transaction that can be run on the VM and the expected output. + fn apply(&self, universe: &mut AccountUniverse) -> (SignedTransaction, TransactionStatus); + + /// Creates a boxed version of this transaction, suitable for dynamic dispatch. + fn boxed(self) -> Box + where + Self: 'static + Sized, + { + Box::new(self) + } +} + +impl AUTransactionGen for Box { + fn apply(&self, universe: &mut AccountUniverse) -> (SignedTransaction, TransactionStatus) { + (**self).apply(universe) + } +} + +/// Allows pairs of accounts to be uniformly randomly selected from an account universe. +#[derive(Arbitrary, Clone, Debug)] +pub struct AccountPairGen { + pair: [Index; 2], + // The pick_slice_idx method used by this struct returns values in order, so use this flag + // to determine whether to reverse it. + reverse: bool, +} + +impl AccountUniverseGen { + /// Returns a [`Strategy`] that generates a universe of accounts with pre-populated initial + /// balances. + pub fn strategy( + num_accounts: impl Into, + balance_strategy: impl Strategy, + ) -> impl Strategy { + // Pick a sequence number in a smaller range so that valid transactions can be generated. + // XXX should we also test edge cases around large sequence numbers? + // Note that using a function as a strategy directly means that shrinking will not occur, + // but that should be fine because there's nothing to really shrink within accounts anyway. + vec(AccountData::strategy(balance_strategy), num_accounts) + .prop_map(|accounts| Self { accounts }) + } + + /// Returns a [`Strategy`] that generates a universe of accounts that's guaranteed to succeed, + /// assuming that any transfers out of accounts will be 10_000 or below. + pub fn success_strategy(min_accounts: usize) -> impl Strategy { + // Set the minimum balance to be 5x possible transfers out to handle potential gas cost + // issues. + let min_balance = (10_000 * (num_transactions()) * 5) as u64; + let max_balance = min_balance * 10; + Self::strategy(min_accounts..num_accounts(), min_balance..max_balance) + } + + /// Returns an [`AccountUniverse`] with the initial state generated in this universe. + pub fn setup(self, executor: &mut FakeExecutor) -> AccountUniverse { + for account_data in &self.accounts { + executor.add_account_data(account_data); + } + + AccountUniverse::new(self.accounts, false) + } + + /// Returns an [`AccountUniverse`] with the initial state generated in this universe, and + /// configures the universe to run tests in gas-cost-stability mode. + /// + /// The stability mode causes new accounts to be dropped, since those accounts will usually + /// not be funded enough. + pub fn setup_gas_cost_stability(self, executor: &mut FakeExecutor) -> AccountUniverse { + for account_data in &self.accounts { + executor.add_account_data(account_data); + } + + AccountUniverse::new(self.accounts, true) + } +} + +impl AccountUniverse { + fn new(accounts: Vec, ignore_new_accounts: bool) -> Self { + let accounts = accounts.into_iter().map(AccountCurrent::new).collect(); + Self { + accounts, + ignore_new_accounts, + } + } + + /// Returns the number of accounts currently in this universe. + /// + /// Some transactions might cause new accounts to be created. The return value of this method + /// will include those new accounts. + pub fn num_accounts(&self) -> usize { + self.accounts.len() + } + + /// Returns the accounts currently in this universe. + /// + /// Some transactions might cause new accounts to be created. The return value of this method + /// will include those new accounts. + pub fn accounts(&self) -> &[AccountCurrent] { + &self.accounts + } + + /// Adds an account to the universe so that future transactions can be made out of this account. + /// + /// This is ignored if the universe was configured to be in gas-cost-stability mode. + pub fn add_account(&mut self, account_data: AccountData) { + if !self.ignore_new_accounts { + self.accounts.push(AccountCurrent::new(account_data)); + } + } +} + +impl AccountPairGen { + /// Picks two accounts uniformly randomly from this universe and returns shared references to + /// them. + pub fn pick<'a>(&self, universe: &'a AccountUniverse) -> AccountPair<'a> { + let idxs = pick_slice_idxs(universe.num_accounts(), &self.pair); + assert_eq!(idxs.len(), 2, "universe should have at least two accounts"); + let (low_idx, high_idx) = (idxs[0], idxs[1]); + assert_ne!(low_idx, high_idx, "accounts picked must be distinct"); + + if self.reverse { + AccountPair { + idx_1: high_idx, + idx_2: low_idx, + account_1: &universe.accounts[high_idx], + account_2: &universe.accounts[low_idx], + } + } else { + AccountPair { + idx_1: low_idx, + idx_2: high_idx, + account_1: &universe.accounts[low_idx], + account_2: &universe.accounts[high_idx], + } + } + } + + /// Picks two accounts uniformly randomly from this universe and returns mutable references to + /// them. + pub fn pick_mut<'a>(&self, universe: &'a mut AccountUniverse) -> AccountPairMut<'a> { + let idxs = pick_slice_idxs(universe.num_accounts(), &self.pair); + assert_eq!(idxs.len(), 2, "universe should have at least two accounts"); + let (low_idx, high_idx) = (idxs[0], idxs[1]); + assert_ne!(low_idx, high_idx, "accounts picked must be distinct"); + // Need to use `split_at_mut` because you can't have multiple mutable references to items + // from a single slice at any given time. + let (head, tail) = universe.accounts.split_at_mut(low_idx + 1); + let (low_account, high_account) = (&mut head[low_idx], &mut tail[high_idx - low_idx - 1]); + + if self.reverse { + AccountPairMut { + idx_1: high_idx, + idx_2: low_idx, + account_1: high_account, + account_2: low_account, + } + } else { + AccountPairMut { + idx_1: low_idx, + idx_2: high_idx, + account_1: low_account, + account_2: high_account, + } + } + } +} + +/// Shared references to a pair of distinct accounts picked from a universe. +pub struct AccountPair<'a> { + /// The index of the first account picked. + pub idx_1: usize, + /// The index of the second account picked. + pub idx_2: usize, + /// A reference to the first account picked. + pub account_1: &'a AccountCurrent, + /// A reference to the second account picked. + pub account_2: &'a AccountCurrent, +} + +/// Mutable references to a pair of distinct accounts picked from a universe. +pub struct AccountPairMut<'a> { + /// The index of the first account picked. + pub idx_1: usize, + /// The index of the second account picked. + pub idx_2: usize, + /// A mutable reference to the first account picked. + pub account_1: &'a mut AccountCurrent, + /// A mutable reference to the second account picked. + pub account_2: &'a mut AccountCurrent, +} + +/// Represents the current state of account in a universe, possibly after its state has been updated +/// by running transactions against the universe. +#[derive(Clone, Debug, Eq, PartialEq)] +pub struct AccountCurrent { + initial_data: AccountData, + balance: u64, + sequence_number: u64, + sent_events_count: u64, + received_events_count: u64, +} + +impl AccountCurrent { + fn new(initial_data: AccountData) -> Self { + let balance = initial_data.balance(); + let sequence_number = initial_data.sequence_number(); + let sent_events_count = initial_data.sent_events_count(); + let received_events_count = initial_data.received_events_count(); + Self { + initial_data, + balance, + sequence_number, + sent_events_count, + received_events_count, + } + } + + /// Returns the underlying account. + pub fn account(&self) -> &Account { + &self.initial_data.account() + } + + /// Rotates the key in this account. + pub fn rotate_key(&mut self, privkey: PrivateKey, pubkey: PublicKey) { + self.initial_data.rotate_key(privkey, pubkey); + } + + /// Returns the current balance for this account, assuming all transactions seen so far are + /// applied. + pub fn balance(&self) -> u64 { + self.balance + } + + /// Returns the current sequence number for this account, assuming all transactions seen so far + /// are applied. + pub fn sequence_number(&self) -> u64 { + self.sequence_number + } + + /// Returns the current sent events count for this account, assuming all transactions seen so + /// far are applied. + pub fn sent_events_count(&self) -> u64 { + self.sent_events_count + } + + /// Returns the current received events count for this account, assuming all transactions seen + /// so far are applied. + pub fn received_events_count(&self) -> u64 { + self.received_events_count + } +} + +/// Computes the result for running a transfer out of one account. Also updates the account to +/// reflect this transaction. +/// +/// The return value is a pair of the expected status and whether the transaction was successful. +pub fn txn_one_account_result( + sender: &mut AccountCurrent, + amount: u64, + gas_cost: u64, + low_gas_cost: u64, +) -> (TransactionStatus, bool) { + // The transactions set the gas cost to 1 microlibra. + let enough_max_gas = sender.balance >= gas_costs::TXN_RESERVED; + // This means that we'll get through the main part of the transaction. + let enough_to_transfer = sender.balance >= amount; + let to_deduct = amount + gas_cost; + // This means that we'll get through the entire transaction, including the epilogue + // (where gas costs are deducted). + let enough_to_succeed = sender.balance >= to_deduct; + + match (enough_max_gas, enough_to_transfer, enough_to_succeed) { + (true, true, true) => { + // Success! + sender.sequence_number += 1; + sender.sent_events_count += 1; + sender.balance -= to_deduct; + ( + TransactionStatus::Keep(VMStatus::Execution(ExecutionStatus::Executed)), + true, + ) + } + (true, true, false) => { + // Enough gas to pass validation and to do the transfer, but not enough to succeed + // in the epilogue. The transaction will be run and gas will be deducted from the + // sender, but no other changes will happen. + sender.sequence_number += 1; + sender.balance -= gas_cost; + ( + TransactionStatus::Keep(VMStatus::Execution(ExecutionStatus::AssertionFailure(6))), + false, + ) + } + (true, false, _) => { + // Enough gas to pass validation but not enough to succeed. The transaction will + // be run and gas will be deducted from the sender, but no other changes will + // happen. + sender.sequence_number += 1; + sender.balance -= low_gas_cost; + ( + TransactionStatus::Keep(VMStatus::Execution(ExecutionStatus::AssertionFailure(10))), + false, + ) + } + (false, _, _) => { + // Not enough gas to pass validation. Nothing will happen. + ( + TransactionStatus::Discard(VMStatus::Validation( + VMValidationStatus::InsufficientBalanceForTransactionFee, + )), + false, + ) + } + } +} + +/// Returns a [`Strategy`] that provides a variety of balances (or transfer amounts) over a roughly +/// logarithmic distribution. +pub fn log_balance_strategy(max_balance: u64) -> impl Strategy { + // The logarithmic distribution is modeled by uniformly picking from ranges of powers of 2. + let minimum = gas_costs::TXN_RESERVED.next_power_of_two(); + assert!(max_balance >= minimum, "minimum to make sense"); + let mut strategies = vec![]; + // Balances below and around the minimum are interesting but don't cover *every* power of 2, + // just those starting from the minimum. + let mut lower_bound: u64 = 0; + let mut upper_bound: u64 = minimum; + loop { + strategies.push(lower_bound..upper_bound); + if upper_bound >= max_balance { + break; + } + lower_bound = upper_bound; + upper_bound = (upper_bound * 2).min(max_balance); + } + Union::new(strategies) +} diff --git a/language/vm/vm_runtime/vm_runtime_tests/src/account_universe/create_account.rs b/language/vm/vm_runtime/vm_runtime_tests/src/account_universe/create_account.rs new file mode 100644 index 0000000000000..a58a1a3f963f2 --- /dev/null +++ b/language/vm/vm_runtime/vm_runtime_tests/src/account_universe/create_account.rs @@ -0,0 +1,107 @@ +// Copyright (c) The Libra Core Contributors +// SPDX-License-Identifier: Apache-2.0 + +use crate::{ + account::{Account, AccountData}, + account_universe::{ + txn_one_account_result, AUTransactionGen, AccountPairGen, AccountPairMut, AccountUniverse, + }, + common_transactions::create_account_txn, + gas_costs, +}; +use proptest::prelude::*; +use proptest_derive::Arbitrary; +use proptest_helpers::Index; +use types::{ + transaction::{SignedTransaction, TransactionStatus}, + vm_error::{ExecutionStatus, VMStatus, VMValidationStatus}, +}; + +/// Represents a create-account transaction performed in the account universe. +/// +/// The parameters are the minimum and maximum balances to transfer. +#[derive(Arbitrary, Clone, Debug)] +#[proptest(params = "(u64, u64)")] +pub struct CreateAccountGen { + sender: Index, + new_account: Account, + #[proptest(strategy = "params.0 ..= params.1")] + amount: u64, +} + +impl AUTransactionGen for CreateAccountGen { + fn apply(&self, universe: &mut AccountUniverse) -> (SignedTransaction, TransactionStatus) { + let sender_idx = self.sender.index(universe.num_accounts()); + let sender = &mut universe.accounts[sender_idx]; + + let txn = create_account_txn( + sender.account(), + &self.new_account, + sender.sequence_number, + self.amount, + ); + + let (status, is_success) = txn_one_account_result( + sender, + self.amount, + *gas_costs::CREATE_ACCOUNT, + *gas_costs::CREATE_ACCOUNT_TOO_LOW, + ); + if is_success { + universe.add_account(AccountData::with_account( + self.new_account.clone(), + self.amount, + 0, + )); + } + + (txn, status) + } +} + +/// Represents a create-account transaction in the account universe where the destination already +/// exists. +/// +/// The parameters are the minimum and maximum balances to transfer. +#[derive(Arbitrary, Clone, Debug)] +#[proptest(params = "(u64, u64)")] +pub struct CreateExistingAccountGen { + sender_receiver: AccountPairGen, + #[proptest(strategy = "params.0 ..= params.1")] + amount: u64, +} + +impl AUTransactionGen for CreateExistingAccountGen { + fn apply(&self, universe: &mut AccountUniverse) -> (SignedTransaction, TransactionStatus) { + let AccountPairMut { + account_1: sender, + account_2: receiver, + .. + } = self.sender_receiver.pick_mut(universe); + + let txn = create_account_txn( + sender.account(), + receiver.account(), + sender.sequence_number, + self.amount, + ); + + // This transaction should never work, but it will fail differently if there's not enough + // gas to reserve. + let enough_max_gas = sender.balance >= gas_costs::TXN_RESERVED; + let status = if enough_max_gas { + sender.sequence_number += 1; + sender.balance -= *gas_costs::CREATE_EXISTING_ACCOUNT; + TransactionStatus::Keep(VMStatus::Execution( + ExecutionStatus::CannotWriteExistingResource, + )) + } else { + // Not enough gas to get past the prologue. + TransactionStatus::Discard(VMStatus::Validation( + VMValidationStatus::InsufficientBalanceForTransactionFee, + )) + }; + + (txn, status) + } +} diff --git a/language/vm/vm_runtime/vm_runtime_tests/src/account_universe/peer_to_peer.rs b/language/vm/vm_runtime/vm_runtime_tests/src/account_universe/peer_to_peer.rs new file mode 100644 index 0000000000000..e6b83f55e98b8 --- /dev/null +++ b/language/vm/vm_runtime/vm_runtime_tests/src/account_universe/peer_to_peer.rs @@ -0,0 +1,147 @@ +// Copyright (c) The Libra Core Contributors +// SPDX-License-Identifier: Apache-2.0 + +use crate::{ + account::{Account, AccountData}, + account_universe::{ + txn_one_account_result, AUTransactionGen, AccountPairGen, AccountPairMut, AccountUniverse, + }, + common_transactions::peer_to_peer_txn, + gas_costs, +}; +use proptest::prelude::*; +use proptest_derive::Arbitrary; +use proptest_helpers::Index; +use types::{ + transaction::{SignedTransaction, TransactionStatus}, + vm_error::{ExecutionStatus, VMStatus, VMValidationStatus}, +}; + +/// Represents a peer-to-peer transaction performed in the account universe. +/// +/// The parameters are the minimum and maximum balances to transfer. +#[derive(Arbitrary, Clone, Debug)] +#[proptest(params = "(u64, u64)")] +pub struct P2PTransferGen { + sender_receiver: AccountPairGen, + #[proptest(strategy = "params.0 ..= params.1")] + amount: u64, +} + +/// Represents a peer-to-peer transaction performed in the account universe to a new receiver. +/// +/// The parameters are the minimum and maximum balances to transfer. +#[derive(Arbitrary, Clone, Debug)] +#[proptest(params = "(u64, u64)")] +pub struct P2PNewReceiverGen { + sender: Index, + receiver: Account, + #[proptest(strategy = "params.0 ..= params.1")] + amount: u64, +} + +impl AUTransactionGen for P2PTransferGen { + fn apply(&self, universe: &mut AccountUniverse) -> (SignedTransaction, TransactionStatus) { + let AccountPairMut { + account_1: sender, + account_2: receiver, + .. + } = self.sender_receiver.pick_mut(universe); + + let txn = peer_to_peer_txn( + sender.account(), + receiver.account(), + sender.sequence_number, + self.amount, + ); + + // Now figure out whether the transaction will actually work. (The transactions set the + // gas cost to 1 microlibra.) + let enough_max_gas = sender.balance >= gas_costs::TXN_RESERVED; + // This means that we'll get through the main part of the transaction. + let enough_to_transfer = sender.balance >= self.amount; + let to_deduct = self.amount + *gas_costs::PEER_TO_PEER; + // This means that we'll get through the entire transaction, including the epilogue + // (where gas costs are deducted). + let enough_to_succeed = sender.balance >= to_deduct; + + // Expect a failure if the amount is greater than the current balance. + // XXX return the failure somehow? + let status; + match (enough_max_gas, enough_to_transfer, enough_to_succeed) { + (true, true, true) => { + // Success! + sender.sequence_number += 1; + sender.sent_events_count += 1; + sender.balance -= to_deduct; + + receiver.balance += self.amount; + receiver.received_events_count += 1; + + status = TransactionStatus::Keep(VMStatus::Execution(ExecutionStatus::Executed)); + } + (true, true, false) => { + // Enough gas to pass validation and to do the transfer, but not enough to succeed + // in the epilogue. The transaction will be run and gas will be deducted from the + // sender, but no other changes will happen. + sender.sequence_number += 1; + sender.balance -= *gas_costs::PEER_TO_PEER; + // 6 means the balance was insufficient while trying to deduct gas costs in the + // epilogue. + // TODO: define these values in a central location + status = TransactionStatus::Keep(VMStatus::Execution( + ExecutionStatus::AssertionFailure(6), + )); + } + (true, false, _) => { + // Enough to pass validation but not to do the transfer. The transaction will be run + // and gas will be deducted from the sender, but no other changes will happen. + sender.sequence_number += 1; + sender.balance -= *gas_costs::PEER_TO_PEER_TOO_LOW; + // 10 means the balance was insufficient while trying to transfer. + status = TransactionStatus::Keep(VMStatus::Execution( + ExecutionStatus::AssertionFailure(10), + )); + } + (false, _, _) => { + // Not enough gas to pass validation. Nothing will happen. + status = TransactionStatus::Discard(VMStatus::Validation( + VMValidationStatus::InsufficientBalanceForTransactionFee, + )); + } + } + + (txn, status) + } +} + +impl AUTransactionGen for P2PNewReceiverGen { + fn apply(&self, universe: &mut AccountUniverse) -> (SignedTransaction, TransactionStatus) { + let sender_idx = self.sender.index(universe.num_accounts()); + let sender = &mut universe.accounts[sender_idx]; + + // Create a new, nonexistent account for the receiver. + let txn = peer_to_peer_txn( + sender.account(), + &self.receiver, + sender.sequence_number, + self.amount, + ); + + let (status, is_success) = txn_one_account_result( + sender, + self.amount, + *gas_costs::PEER_TO_PEER_NEW_RECEIVER, + *gas_costs::PEER_TO_PEER_NEW_RECEIVER_TOO_LOW, + ); + if is_success { + universe.add_account(AccountData::with_account( + self.receiver.clone(), + self.amount, + 0, + )); + } + + (txn, status) + } +} diff --git a/language/vm/vm_runtime/vm_runtime_tests/src/account_universe/rotate_key.rs b/language/vm/vm_runtime/vm_runtime_tests/src/account_universe/rotate_key.rs new file mode 100644 index 0000000000000..6b741660ecc83 --- /dev/null +++ b/language/vm/vm_runtime/vm_runtime_tests/src/account_universe/rotate_key.rs @@ -0,0 +1,53 @@ +// Copyright (c) The Libra Core Contributors +// SPDX-License-Identifier: Apache-2.0 + +use crate::{ + account_universe::{AUTransactionGen, AccountUniverse}, + common_transactions::rotate_key_txn, + gas_costs, +}; +use crypto::{utils::keypair_strategy, PrivateKey, PublicKey}; +use proptest::prelude::*; +use proptest_derive::Arbitrary; +use proptest_helpers::Index; +use types::{ + account_address::AccountAddress, + transaction::{SignedTransaction, TransactionStatus}, + vm_error::{ExecutionStatus, VMStatus, VMValidationStatus}, +}; + +/// Represents a rotate-key transaction performed in the account universe. +#[derive(Arbitrary, Clone, Debug)] +#[proptest(no_params)] +pub struct RotateKeyGen { + sender: Index, + #[proptest(strategy = "keypair_strategy()")] + new_key: (PrivateKey, PublicKey), +} + +impl AUTransactionGen for RotateKeyGen { + fn apply(&self, universe: &mut AccountUniverse) -> (SignedTransaction, TransactionStatus) { + let sender_idx = self.sender.index(universe.num_accounts()); + let mut sender = &mut universe.accounts[sender_idx]; + + let new_key_hash = AccountAddress::from(self.new_key.1); + let txn = rotate_key_txn(sender.account(), new_key_hash, sender.sequence_number); + + // This should work all the time except for if the balance is too low for gas. + let enough_max_gas = sender.balance >= gas_costs::TXN_RESERVED; + let status = if enough_max_gas { + sender.sequence_number += 1; + sender.balance -= *gas_costs::ROTATE_KEY; + let (privkey, pubkey) = (self.new_key.0.clone(), self.new_key.1); + sender.rotate_key(privkey, pubkey); + + TransactionStatus::Keep(VMStatus::Execution(ExecutionStatus::Executed)) + } else { + TransactionStatus::Discard(VMStatus::Validation( + VMValidationStatus::InsufficientBalanceForTransactionFee, + )) + }; + + (txn, status) + } +} diff --git a/language/vm/vm_runtime/vm_runtime_tests/src/bin/repl.rs b/language/vm/vm_runtime/vm_runtime_tests/src/bin/repl.rs new file mode 100644 index 0000000000000..a2cbd22c06ec2 --- /dev/null +++ b/language/vm/vm_runtime/vm_runtime_tests/src/bin/repl.rs @@ -0,0 +1,280 @@ +// Copyright (c) The Libra Core Contributors +// SPDX-License-Identifier: Apache-2.0 + +use std::{ + fs::File, + io::{self, Read, Write}, + time::Duration, +}; + +use compiler::{compiler::compile_program, parser::parse_program}; +use failure::Error; +use getopts::{Options, ParsingStyle}; +use hex; +use types::{ + account_address::AccountAddress, + byte_array::ByteArray, + transaction::{Program, RawTransaction, SignedTransaction, TransactionArgument}, +}; +use vm::CompiledModule; +use vm_genesis::STDLIB_MODULES; +use vm_runtime::static_verify_program; +use vm_runtime_tests::{ + account::{Account, AccountResource}, + executor::FakeExecutor, +}; + +struct Repl { + accounts: Vec, + executor: FakeExecutor, + modules: Vec, + source_parser: Options, + publish_parser: Options, + get_account_parser: Options, +} + +pub fn parse_address(s: Option) -> usize { + s.map(|s| { + if s == "" || s.as_bytes()[0] != b'a' { + 0 + } else { + s[1..].parse::().unwrap_or(0) + } + }) + .unwrap_or(0) +} + +const ACCOUNT_SIZE: usize = 10; +const GENESIS_BALANCE: u64 = 100_000_000; +const DEFAULT_GAS_COST: u64 = 1; +const DEFAULT_MAX_GAS: u64 = 10_000; + +impl Repl { + pub fn get_sequence_number(&self, account: &Account) -> u64 { + let sender_resource = self + .executor + .read_account_resource(account) + .expect("sender must exist"); + AccountResource::read_sequence_number(&sender_resource) + } + + pub fn new() -> Self { + let mut executor = FakeExecutor::from_genesis_file(); + let accounts = executor.create_accounts(ACCOUNT_SIZE, GENESIS_BALANCE, 0); + let mut source_parser = Options::new(); + source_parser.parsing_style(ParsingStyle::FloatingFrees); + source_parser.reqopt("f", "file", "Script that you want to execute", "FILE"); + source_parser.optopt("s", "sender", "Sender of this transaction", "ADDRESS"); + source_parser.optopt( + "", + "sign_with", + "PrivKey of account used to sign this transaction", + "PRIVKEY", + ); + source_parser.optflag("v", "verbose", "Display the transaction output"); + + let mut publish_parser = Options::new(); + publish_parser.reqopt("f", "file", "Module that you want to publish", "FILE"); + publish_parser.optopt("s", "sender", "Publisher of the account", "ADDRESS"); + + let mut get_account_parser = Options::new(); + get_account_parser.optopt("s", "sender", "Account you want to query", "ADDRESS"); + + Repl { + executor, + accounts, + modules: STDLIB_MODULES.clone(), + source_parser, + publish_parser, + get_account_parser, + } + } + + pub fn create_signed_txn_with_args( + &mut self, + program_str: String, + args: Vec, + sender_address: AccountAddress, + signer: Account, + sequence_number: u64, + max_gas_amount: u64, + gas_unit_price: u64, + ) -> SignedTransaction { + let parsed_program = parse_program(&program_str).unwrap(); + + let modules = self.modules.clone(); + let compiled_program = compile_program(&sender_address, &parsed_program, &modules).unwrap(); + + let (verified_script, to_be_published_modules, statuses) = static_verify_program( + &sender_address, + compiled_program.script, + compiled_program.modules, + ); + assert_eq!(statuses, vec![]); + + self.modules.extend(to_be_published_modules.clone()); + + let mut script_blob = vec![]; + verified_script.serialize(&mut script_blob).unwrap(); + let mut modules_blob = vec![]; + for m in to_be_published_modules { + let mut module_blob = vec![]; + m.serialize(&mut module_blob).unwrap(); + modules_blob.push(module_blob); + } + println!("program: {}, args: {:?}", program_str, args); + + let program = Program::new(script_blob, modules_blob, args); + RawTransaction::new( + sender_address, + sequence_number, + program, + max_gas_amount, + gas_unit_price, + Duration::from_secs(u64::max_value()), + ) + .sign(&signer.privkey, signer.pubkey) + .unwrap() + } + + pub fn eval_arg(&mut self, input: String) { + let args: Vec<&str> = input.trim().split(' ').collect(); + match args[0] { + "publish" => self.publish(&args[1..]), + "source" => self.source(&args[1..]), + "get_account_info" => self.get_account_info(&args[1..]), + "new_key_pair" => { + let account = Account::new(); + println!("New Account at {}: {:?}", self.accounts.len(), account); + self.accounts.push(account); + Ok(()) + } + _ => { + println!("Try these commands: publish, source, get_account_info, new_key_pair"); + Ok(()) + } + } + .unwrap_or_else(|e| println!("{:?}", e)) + } + + pub fn source(&mut self, args: &[&str]) -> Result<(), Error> { + let matches = self.source_parser.parse(args).map_err(|e| { + println!("{}", self.source_parser.usage("Execute a transaction. Escape parameters will be parsed as arguments to the transaction")); + e + })?; + let sender = parse_address(matches.opt_str("s")); + let txn_code = { + let mut buffer = String::new(); + File::open(matches.opt_str("f").unwrap())? + .read_to_string(&mut buffer) + .unwrap(); + buffer + }; + let signer = parse_address(matches.opt_str("sign_with")); + let txn_args = { + let mut v = vec![]; + for s in matches.free.iter() { + v.push(if s.starts_with('a') { + TransactionArgument::Address( + *self.accounts[parse_address(Some(s.clone()))].address(), + ) + } else if s.starts_with("b0x") { + TransactionArgument::ByteArray(ByteArray::new(hex::decode(&s[3..])?)) + } else { + TransactionArgument::U64(s.parse::()?) + }) + } + v + }; + let txn = self.create_signed_txn_with_args( + txn_code, + txn_args, + *self.accounts[sender].address(), + self.accounts[signer].clone(), + self.get_sequence_number(&self.accounts[sender]), + DEFAULT_MAX_GAS, + DEFAULT_GAS_COST, + ); + for o in self.executor.execute_block(vec![txn]).iter() { + if matches.opt_present("v") { + println!("{:?}", o); + } else { + println!("Gas Consumed: {}", o.gas_used()); + } + self.executor.apply_write_set(o.write_set()); + } + Ok(()) + } + + pub fn publish(&mut self, args: &[&str]) -> Result<(), Error> { + let matches = self.publish_parser.parse(args).map_err(|e| { + println!( + "{}", + self.publish_parser + .usage("Publish a module under a given sender") + ); + e + })?; + let file = { + let mut buffer = String::new(); + File::open(matches.opt_str("f").unwrap())? + .read_to_string(&mut buffer) + .unwrap(); + buffer + }; + let sender = matches + .opt_str("s") + .map(|s| s.parse::().unwrap_or(0)) + .unwrap_or(0); + + let txn = self.create_signed_txn_with_args( + format!("modules: {} script: main() {{ return; }}", file), + vec![], + *self.accounts[sender].address(), + self.accounts[sender].clone(), + self.get_sequence_number(&self.accounts[sender]), + DEFAULT_MAX_GAS, + DEFAULT_GAS_COST, + ); + for o in self.executor.execute_block(vec![txn]).iter() { + if matches.opt_defined("v") { + println!("{:?}", o); + } else { + println!("Gas Consumed: {}", o.gas_used()); + } + } + Ok(()) + } + + pub fn get_account_info(&mut self, args: &[&str]) -> Result<(), Error> { + let matches = self.get_account_parser.parse(args)?; + let sender = parse_address(matches.opt_str("s")); + let account = &self.accounts[sender]; + println!( + "Address: 0x{}", + hex::encode(self.accounts[sender].address()) + ); + if let Some(v) = self.executor.read_account_resource(account) { + println!("balance: {}", AccountResource::read_balance(&v)); + println!( + "sequence_number: {}", + AccountResource::read_sequence_number(&v) + ); + } else { + println!("Account don't exist"); + } + Ok(()) + } +} + +fn main() { + let mut repl = Repl::new(); + loop { + let mut input = String::new(); + print!("> "); + io::stdout().flush().unwrap(); + if io::stdin().read_line(&mut input).is_ok() { + repl.eval_arg(input); + } + } +} diff --git a/language/vm/vm_runtime/vm_runtime_tests/src/common_transactions.rs b/language/vm/vm_runtime/vm_runtime_tests/src/common_transactions.rs new file mode 100644 index 0000000000000..c40a1cef805e2 --- /dev/null +++ b/language/vm/vm_runtime/vm_runtime_tests/src/common_transactions.rs @@ -0,0 +1,127 @@ +// Copyright (c) The Libra Core Contributors +// SPDX-License-Identifier: Apache-2.0 + +//! Support for encoding transactions for common situations. + +use crate::{account::Account, compile::compile_script, gas_costs}; +use lazy_static::lazy_static; +use types::{ + account_address::AccountAddress, + byte_array::ByteArray, + transaction::{SignedTransaction, TransactionArgument}, +}; + +lazy_static! { + /// A serialized transaction to create a new account. + pub static ref CREATE_ACCOUNT: Vec = { create_account() }; + /// A serialized transaction to mint new funds. + pub static ref MINT: Vec = { mint() }; + /// A serialized transaction to transfer coin from one account to another (possibly new) + /// one. + pub static ref PEER_TO_PEER: Vec = { peer_to_peer() }; + /// A serialized transaction to change the keys for an account. + pub static ref ROTATE_KEY: Vec = { rotate_key() }; +} + +/// Returns a transaction to create a new account with the given arguments. +pub fn create_account_txn( + sender: &Account, + new_account: &Account, + seq_num: u64, + initial_amount: u64, +) -> SignedTransaction { + let mut args: Vec = Vec::new(); + args.push(TransactionArgument::Address(*new_account.address())); + args.push(TransactionArgument::U64(initial_amount)); + + sender.create_signed_txn_with_args( + CREATE_ACCOUNT.clone(), + args, + seq_num, + gas_costs::TXN_RESERVED, + 1, + ) +} + +/// Returns a transaction to transfer coin from one account to another (possibly new) one, with the +/// given arguments. +pub fn peer_to_peer_txn( + sender: &Account, + receiver: &Account, + seq_num: u64, + transfer_amount: u64, +) -> SignedTransaction { + let mut args: Vec = Vec::new(); + args.push(TransactionArgument::Address(*receiver.address())); + args.push(TransactionArgument::U64(transfer_amount)); + + // get a SignedTransaction + sender.create_signed_txn_with_args( + PEER_TO_PEER.clone(), + args, + seq_num, + gas_costs::TXN_RESERVED, // this is a default for gas + 1, // this is a default for gas + ) +} + +/// Returns a transaction to change the keys for the given account. +pub fn rotate_key_txn( + sender: &Account, + new_key_hash: AccountAddress, + seq_num: u64, +) -> SignedTransaction { + let args = vec![TransactionArgument::ByteArray(ByteArray::new( + new_key_hash.to_vec(), + ))]; + sender.create_signed_txn_with_args( + ROTATE_KEY.clone(), + args, + seq_num, + gas_costs::TXN_RESERVED, + 1, + ) +} + +/// Returns a transaction to mint new funds with the given arguments. +pub fn mint_txn( + sender: &Account, + receiver: &Account, + seq_num: u64, + transfer_amount: u64, +) -> SignedTransaction { + let mut args: Vec = Vec::new(); + args.push(TransactionArgument::Address(*receiver.address())); + args.push(TransactionArgument::U64(transfer_amount)); + + // get a SignedTransaction + sender.create_signed_txn_with_args( + MINT.clone(), + args, + seq_num, + gas_costs::TXN_RESERVED, // this is a default for gas + 1, // this is a default for gas + ) +} + +// TODO: replace these with helper functions/consts in stdlib +fn create_account() -> Vec { + let code = include_str!("../../../../stdlib/transaction_scripts/create_account.mvir"); + compile_script(code) +} + +fn mint() -> Vec { + let code = include_str!("../../../../stdlib/transaction_scripts/mint.mvir"); + compile_script(code) +} + +fn peer_to_peer() -> Vec { + let code = include_str!("../../../../stdlib/transaction_scripts/peer_to_peer_transfer.mvir"); + compile_script(code) +} + +fn rotate_key() -> Vec { + let code = + include_str!("../../../../stdlib/transaction_scripts/rotate_authentication_key.mvir"); + compile_script(code) +} diff --git a/language/vm/vm_runtime/vm_runtime_tests/src/compile.rs b/language/vm/vm_runtime/vm_runtime_tests/src/compile.rs new file mode 100644 index 0000000000000..66b599116c7f3 --- /dev/null +++ b/language/vm/vm_runtime/vm_runtime_tests/src/compile.rs @@ -0,0 +1,87 @@ +// Copyright (c) The Libra Core Contributors +// SPDX-License-Identifier: Apache-2.0 + +//! Support for compiling scripts and modules in tests. + +use compiler::{ + compiler::{compile_module, compile_program as compiler_compile_program}, + parser::parse_program, +}; +use stdlib::stdlib::*; +use types::{ + account_address::AccountAddress, + transaction::{Program, TransactionArgument}, +}; +use vm::file_format::CompiledModule; + +/// Compile the provided Move code into a blob which can be used as the code for a [`Program`]. +/// +/// The script is compiled with the default account address (`0x0`). +pub fn compile_script(code: &str) -> Vec { + let address = AccountAddress::default(); + let parsed_program = parse_program(code).expect("program must parse"); + let deps = stdlib_deps(&address); + let compiled_program = + compiler_compile_program(&address, &parsed_program, &deps).expect("program must compile"); + + let mut serialized_script = Vec::::new(); + compiled_program + .script + .serialize(&mut serialized_script) + .expect("script must serialize"); + serialized_script +} + +/// Compile the provided Move code and arguments into a `Program` using `address` as the +/// self address for any modules in `code`. +pub fn compile_program_with_address( + address: &AccountAddress, + code: &str, + args: Vec, +) -> Program { + let deps = stdlib_deps(&AccountAddress::default()); + let parsed_program = parse_program(&code).expect("program must parse"); + let compiled_program = + compiler_compile_program(address, &parsed_program, &deps).expect("program must compile"); + + let mut serialized_script = Vec::::new(); + compiled_program + .script + .serialize(&mut serialized_script) + .expect("script must serialize"); + let mut serialized_modules = vec![]; + for m in compiled_program.modules { + let mut module = vec![]; + m.serialize(&mut module).expect("module must serialize"); + serialized_modules.push(module); + } + Program::new(serialized_script, serialized_modules, args) +} + +/// Compile the provided Move code and arguments into a `Program`. +/// +/// This supports both scripts and modules defined in the same Move code. The code is compiled with +/// the default account address (`0x0`). +pub fn compile_program(code: &str, args: Vec) -> Program { + let address = AccountAddress::default(); + compile_program_with_address(&address, code, args) +} + +fn stdlib_deps(address: &AccountAddress) -> Vec { + let coin_module = coin_module(); + let compiled_coin_module = + compile_module(&address, &coin_module, &[]).expect("coin must compile"); + + let hash_module = native_hash_module(); + let compiled_hash_module = + compile_module(&address, &hash_module, &[]).expect("hash must compile"); + + let account_module = account_module(); + + let mut deps = vec![compiled_coin_module, compiled_hash_module]; + let compiled_account_module = + compile_module(&address, &account_module, &deps).expect("account must compile"); + + deps.push(compiled_account_module); + deps +} diff --git a/language/vm/vm_runtime/vm_runtime_tests/src/data_store.rs b/language/vm/vm_runtime/vm_runtime_tests/src/data_store.rs new file mode 100644 index 0000000000000..470ec10fc9439 --- /dev/null +++ b/language/vm/vm_runtime/vm_runtime_tests/src/data_store.rs @@ -0,0 +1,122 @@ +// Copyright (c) The Libra Core Contributors +// SPDX-License-Identifier: Apache-2.0 + +//! Support for mocking the Libra data store. + +use crate::account::AccountData; +use failure::prelude::*; +use lazy_static::lazy_static; +use proto_conv::FromProto; +use protobuf::parse_from_bytes; +use state_view::StateView; +use std::{collections::HashMap, fs::File, io::prelude::*, path::PathBuf}; +use types::{ + access_path::AccessPath, + transaction::{SignedTransaction, TransactionPayload}, + write_set::{WriteOp, WriteSet}, +}; +use vm::errors::*; +use vm_runtime::data_cache::RemoteCache; + +lazy_static! { + /// The write set encoded in the genesis transaction. + pub static ref GENESIS_WRITE_SET: WriteSet = { + let mut path = PathBuf::from(env!("CARGO_MANIFEST_DIR")); + path.pop(); + path.pop(); + path.push("vm_genesis/genesis/genesis.blob"); + + let mut f = File::open(&path).unwrap(); + let mut bytes = vec![]; + f.read_to_end(&mut bytes).unwrap(); + let txn = SignedTransaction::from_proto(parse_from_bytes(&bytes).unwrap()).unwrap(); + match txn.payload() { + TransactionPayload::WriteSet(ws) => ws.clone(), + _ => panic!("Expected writeset txn in genesis txn"), + } + }; +} + +/// An in-memory implementation of [`StateView`] and [`RemoteCache`] for the VM. +/// +/// Tests use this to set up state, and pass in a reference to the cache whenever a `StateView` or +/// `RemoteCache` is needed. +#[derive(Debug, Default)] +pub struct FakeDataStore { + data: HashMap>, +} + +impl FakeDataStore { + /// Creates a new `FakeDataStore` with the provided initial data. + pub fn new(data: HashMap>) -> Self { + FakeDataStore { data } + } + + /// Adds a [`WriteSet`] to this data store. + pub fn add_write_set(&mut self, write_set: &WriteSet) { + for (access_path, write_op) in write_set { + match write_op { + WriteOp::Value(blob) => { + self.set(access_path.clone(), blob.clone()); + } + WriteOp::Deletion => { + self.remove(access_path); + } + } + } + } + + /// Sets a (key, value) pair within this data store. + /// + /// Returns the previous data if the key was occupied. + pub fn set(&mut self, access_path: AccessPath, data_blob: Vec) -> Option> { + self.data.insert(access_path, data_blob) + } + + /// Deletes a key from this data store. + /// + /// Returns the previous data if the key was occupied. + pub fn remove(&mut self, access_path: &AccessPath) -> Option> { + self.data.remove(access_path) + } + + /// Adds an [`AccountData`] to this data store. + pub fn add_account_data(&mut self, account_data: &AccountData) { + match account_data.to_resource().simple_serialize() { + Some(blob) => { + self.set(account_data.make_access_path(), blob); + } + None => panic!("can't create Account data"), + } + } +} + +// This is used by the `execute_block` API. +// TODO: only the "sync" get is implemented +impl StateView for FakeDataStore { + fn get(&self, access_path: &AccessPath) -> Result>> { + // Since the data is in-memory, it can't fail. + match self.data.get(access_path) { + None => Ok(None), + Some(blob) => Ok(Some(blob.clone())), + } + } + + fn multi_get(&self, _access_paths: &[AccessPath]) -> Result>>> { + unimplemented!(); + } + + fn is_genesis(&self) -> bool { + self.data.is_empty() + } +} + +// This is used by the `process_transaction` API. +impl RemoteCache for FakeDataStore { + fn get( + &self, + access_path: &AccessPath, + ) -> ::std::result::Result>, VMInvariantViolation> { + Ok(StateView::get(self, access_path).expect("it should not error")) + } +} diff --git a/language/vm/vm_runtime/vm_runtime_tests/src/executor.rs b/language/vm/vm_runtime/vm_runtime_tests/src/executor.rs new file mode 100644 index 0000000000000..a59785f9901f2 --- /dev/null +++ b/language/vm/vm_runtime/vm_runtime_tests/src/executor.rs @@ -0,0 +1,150 @@ +// Copyright (c) The Libra Core Contributors +// SPDX-License-Identifier: Apache-2.0 + +//! Support for running the VM to execute and verify transactions. + +use crate::{ + account::{Account, AccountData}, + data_store::{FakeDataStore, GENESIS_WRITE_SET}, +}; +use config::config::{NodeConfig, NodeConfigHelpers, VMPublishingOption}; +use state_view::StateView; +use types::{ + access_path::AccessPath, + transaction::{SignedTransaction, TransactionOutput}, + vm_error::VMStatus, + write_set::WriteSet, +}; +use vm_runtime::{ + loaded_data::{struct_def::StructDef, types::Type}, + value::Value, + MoveVM, VMExecutor, VMVerifier, +}; + +/// Provides an environment to run a VM instance. +/// +/// This struct is a mock in-memory implementation of the Libra executor. +#[derive(Debug)] +pub struct FakeExecutor { + config: NodeConfig, + data_store: FakeDataStore, +} + +impl FakeExecutor { + /// Creates an executor from a genesis [`WriteSet`]. + pub fn from_genesis( + write_set: &WriteSet, + publishing_options: Option, + ) -> Self { + let mut executor = FakeExecutor { + config: NodeConfigHelpers::get_single_node_test_config_publish_options( + false, + publishing_options, + ), + data_store: FakeDataStore::default(), + }; + executor.apply_write_set(write_set); + executor + } + + /// Creates an executor from the genesis file GENESIS_FILE_LOCATION + pub fn from_genesis_file() -> Self { + Self::from_genesis(&GENESIS_WRITE_SET, None) + } + + /// Creates an executor from the genesis file GENESIS_FILE_LOCATION with script/module + /// publishing options given by `publishing_options`. These can only be either `Open` or + /// `CustomScript`. + pub fn from_genesis_with_options(publishing_options: VMPublishingOption) -> Self { + if let VMPublishingOption::Locked(_) = publishing_options { + panic!("Whitelisted transactions are not supported as a publishing option") + } + Self::from_genesis(&GENESIS_WRITE_SET, Some(publishing_options)) + } + + /// Creates an executor in which no genesis state has been applied yet. + pub fn no_genesis() -> Self { + FakeExecutor { + config: NodeConfigHelpers::get_single_node_test_config(false), + data_store: FakeDataStore::default(), + } + } + + /// Creates a number of [`Account`] instances all with the same balance and sequence number, + /// and publishes them to this executor's data store. + pub fn create_accounts(&mut self, size: usize, balance: u64, seq_num: u64) -> Vec { + let mut accounts: Vec = Vec::with_capacity(size); + for _i in 0..size { + let account_data = AccountData::new(balance, seq_num); + self.add_account_data(&account_data); + accounts.push(account_data.into_account()); + } + accounts + } + + /// Applies a [`WriteSet`] to this executor's data store. + pub fn apply_write_set(&mut self, write_set: &WriteSet) { + self.data_store.add_write_set(write_set); + } + + /// Adds an account to this executor's data store. + pub fn add_account_data(&mut self, account_data: &AccountData) { + self.data_store.add_account_data(account_data) + } + + /// Reads the resource [`Value`] for an account from this executor's data store. + pub fn read_account_resource(&self, account: &Account) -> Option { + let ap = account.make_access_path(); + let data_blob = StateView::get(&self.data_store, &ap) + .expect("account must exist in data store") + .expect("data must exist in data store"); + let account_type = Self::get_account_struct_def(); + Account::read_account_resource(&data_blob, account_type) + } + + /// Executes the given block of transactions. + /// + /// Typical tests will call this method and check that the output matches what was expected. + /// However, this doesn't apply the results of successful transactions to the data store. + pub fn execute_block(&self, txn_block: Vec) -> Vec { + MoveVM::execute_block(txn_block, &self.config.vm_config, &self.data_store) + } + + pub fn execute_transaction(&self, txn: SignedTransaction) -> TransactionOutput { + let txn_block = vec![txn]; + let mut outputs = self.execute_block(txn_block); + outputs + .pop() + .expect("A block with one transaction should have one output") + } + + /// Get the blob for the associated AccessPath + pub fn read_from_access_path(&self, path: &AccessPath) -> Option> { + StateView::get(&self.data_store, path).unwrap() + } + + /// Verifies the given transaction by running it through the VM verifier. + pub fn verify_transaction(&self, txn: SignedTransaction) -> Option { + let vm = MoveVM::new(&self.config.vm_config); + vm.validate_transaction(txn, &self.data_store) + } + + /// TODO: This is a hack and likely to break soon. THe Account type is replicated here with no + /// checks that is the right now. Fix it! + fn get_account_struct_def() -> StructDef { + // STRUCT DEF StructDef(StructDefInner { field_definitions: [ByteArray, + // Struct(StructDef(StructDefInner { field_definitions: [U64] })), U64, U64, + // U64] }) let coin = StructDef(StructDefInner { field_definitions: + // [Type::U64] }) + let int_type = Type::U64; + let byte_array_type = Type::ByteArray; + let coin = Type::Struct(StructDef::new(vec![int_type.clone()])); + StructDef::new(vec![ + byte_array_type, + coin, + int_type.clone(), + int_type.clone(), + int_type.clone(), + ]) + } +} diff --git a/language/vm/vm_runtime/vm_runtime_tests/src/gas_costs.rs b/language/vm/vm_runtime/vm_runtime_tests/src/gas_costs.rs new file mode 100644 index 0000000000000..13ae57f55a8b4 --- /dev/null +++ b/language/vm/vm_runtime/vm_runtime_tests/src/gas_costs.rs @@ -0,0 +1,144 @@ +// Copyright (c) The Libra Core Contributors +// SPDX-License-Identifier: Apache-2.0 + +//! Gas costs for common transactions. + +use crate::{ + account::{Account, AccountData}, + common_transactions::{create_account_txn, peer_to_peer_txn, rotate_key_txn}, + executor::FakeExecutor, +}; +use lazy_static::lazy_static; +use types::{account_address::AccountAddress, transaction::SignedTransaction}; + +/// The gas each transaction is configured to reserve. If the gas available in the account, +/// converted to microlibra, falls below this threshold, transactions are expected to fail with +/// an insufficient balance. +pub const TXN_RESERVED: u64 = 10_000; + +lazy_static! { + /// The gas cost of a create-account transaction. + /// + /// All such transactions are expected to cost the same gas. + pub static ref CREATE_ACCOUNT: u64 = { + let mut executor = FakeExecutor::from_genesis_file(); + let sender = AccountData::new(1_000_000, 10); + executor.add_account_data(&sender); + let receiver = Account::new(); + + let txn = create_account_txn(sender.account(), &receiver, 10, 20_000); + compute_gas_used(txn, &mut executor) + }; + + /// The gas cost of a create-account transaction where the sender has an insufficient balance. + /// + /// All such transactions are expected to cost the same gas. + pub static ref CREATE_ACCOUNT_TOO_LOW: u64 = { + let mut executor = FakeExecutor::from_genesis_file(); + // The gas amount is the minimum that needs to be reserved, so use a value that's + // clearly higher than that. + let balance = TXN_RESERVED + 10_000; + let sender = AccountData::new(balance, 10); + executor.add_account_data(&sender); + let receiver = Account::new(); + + let txn = create_account_txn(sender.account(), &receiver, 10, balance + 1); + compute_gas_used(txn, &mut executor) + }; + + /// The gas cost of a create-account transaction where the receiver already exists. + /// + /// All such transactions are expected to cost the same gas. + pub static ref CREATE_EXISTING_ACCOUNT: u64 = { + let mut executor = FakeExecutor::from_genesis_file(); + let sender = AccountData::new(1_000_000, 10); + let receiver = AccountData::new(1_000_000, 10); + executor.add_account_data(&sender); + executor.add_account_data(&receiver); + + let txn = create_account_txn(sender.account(), receiver.account(), 10, 20_000); + compute_gas_used(txn, &mut executor) + }; + + /// The gas cost of a peer-to-peer transaction. + /// + /// All such transactions are expected to cost the same gas. + pub static ref PEER_TO_PEER: u64 = { + // Compute gas used by running a placeholder transaction. + let mut executor = FakeExecutor::from_genesis_file(); + let sender = AccountData::new(1_000_000, 10); + let receiver = AccountData::new(1_000_000, 10); + executor.add_account_data(&sender); + executor.add_account_data(&receiver); + + let txn = peer_to_peer_txn(sender.account(), receiver.account(), 10, 20_000); + compute_gas_used(txn, &mut executor) + }; + + /// The gas cost of a peer-to-peer transaction with an insufficient balance. + /// + /// All such transactions are expected to cost the same gas. + pub static ref PEER_TO_PEER_TOO_LOW: u64 = { + let mut executor = FakeExecutor::from_genesis_file(); + // The gas amount is the minimum that needs to be reserved, so use a value that's clearly + // higher than that. + let balance = TXN_RESERVED + 10_000; + let sender = AccountData::new(balance, 10); + let receiver = AccountData::new(1_000_000, 10); + executor.add_account_data(&sender); + executor.add_account_data(&receiver); + + let txn = peer_to_peer_txn(sender.account(), receiver.account(), 10, balance + 1); + compute_gas_used(txn, &mut executor) + }; + + /// The gas cost of a peer-to-peer transaction that creates a new account. + /// + /// All such transactions are expected to cost the same gas. + pub static ref PEER_TO_PEER_NEW_RECEIVER: u64 = { + // Compute gas used by running a placeholder transaction. + let mut executor = FakeExecutor::from_genesis_file(); + let sender = AccountData::new(1_000_000, 10); + executor.add_account_data(&sender); + let receiver = Account::new(); + + let txn = peer_to_peer_txn(sender.account(), &receiver, 10, 20_000); + compute_gas_used(txn, &mut executor) + }; + + /// The gas cost of a peer-to-peer transaction that tries to create a new account, but fails + /// because of an insufficient balance. + /// + /// All such transactions are expected to cost the same gas. + pub static ref PEER_TO_PEER_NEW_RECEIVER_TOO_LOW: u64 = { + let mut executor = FakeExecutor::from_genesis_file(); + // The gas amount is the minimum that needs to be reserved, so use a value that's + // clearly higher than that. + let balance = TXN_RESERVED + 10_000; + let sender = AccountData::new(balance, 10); + executor.add_account_data(&sender); + let receiver = Account::new(); + + let txn = peer_to_peer_txn(sender.account(), &receiver, 10, balance + 1); + compute_gas_used(txn, &mut executor) + }; + + /// The gas cost of a rotate-key transaction. + /// + /// All such transactions are expected to cost the same gas. + pub static ref ROTATE_KEY: u64 = { + let mut executor = FakeExecutor::from_genesis_file(); + let sender = AccountData::new(1_000_000, 10); + executor.add_account_data(&sender); + let (_privkey, pubkey) = crypto::signing::generate_keypair(); + let new_key_hash = AccountAddress::from(pubkey); + + let txn = rotate_key_txn(sender.account(), new_key_hash, 10); + compute_gas_used(txn, &mut executor) + }; +} + +fn compute_gas_used(txn: SignedTransaction, executor: &mut FakeExecutor) -> u64 { + let output = &executor.execute_block(vec![txn])[0]; + output.gas_used() +} diff --git a/language/vm/vm_runtime/vm_runtime_tests/src/lib.rs b/language/vm/vm_runtime/vm_runtime_tests/src/lib.rs new file mode 100644 index 0000000000000..eb6b7481cccbb --- /dev/null +++ b/language/vm/vm_runtime/vm_runtime_tests/src/lib.rs @@ -0,0 +1,82 @@ +// Copyright (c) The Libra Core Contributors +// SPDX-License-Identifier: Apache-2.0 + +//! Test infrastructure for the Libra VM. +//! +//! This crate contains helpers for executing tests against the Libra VM. + +use ::compiler::{compiler, parser::parse_program}; +use data_store::FakeDataStore; +use types::{ + access_path::AccessPath, account_address::AccountAddress, transaction::TransactionArgument, +}; +use vm::{ + errors::*, + file_format::{CompiledModule, CompiledScript}, +}; +use vm_runtime::{execute_function, static_verify_program}; + +#[cfg(test)] +mod tests; + +pub mod account; +pub mod account_universe; +pub mod common_transactions; +pub mod compile; +pub mod data_store; +pub mod executor; +pub mod gas_costs; +mod proptest_types; + +/// Compiles a program with the given arguments and executes it in the VM. +pub fn compile_and_execute(program: &str, args: Vec) -> VMResult<()> { + let address = AccountAddress::default(); + + let parsed_program = parse_program(&program).unwrap(); + let compiled_program = compiler::compile_program(&address, &parsed_program, &[]).unwrap(); + + let (compiled_script, modules) = + verify(&address, compiled_program.script, compiled_program.modules); + execute(compiled_script, args, modules) +} + +pub fn execute( + script: CompiledScript, + args: Vec, + modules: Vec, +) -> VMResult<()> { + // set up the DB + let mut data_view = FakeDataStore::default(); + data_view.set( + AccessPath::new(AccountAddress::random(), vec![]), + vec![0, 0], + ); + execute_function(script, modules, args, &data_view) +} + +fn verify( + sender_address: &AccountAddress, + compiled_script: CompiledScript, + modules: Vec, +) -> (CompiledScript, Vec) { + let (verified_script, verified_modules, statuses) = + static_verify_program(sender_address, compiled_script, modules); + assert_eq!(statuses, vec![]); + (verified_script, verified_modules) +} + +#[macro_export] +macro_rules! assert_prologue_parity { + ($e1:expr, $e2:expr, $e3:pat) => { + assert_matches!($e1, Some($e3)); + assert_matches!($e2, TransactionStatus::Discard($e3)); + }; +} + +#[macro_export] +macro_rules! assert_prologue_disparity { + ($e1:expr => $e2:pat, $e3:expr => $e4:pat) => { + assert_matches!($e1, $e2); + assert_matches!($e3, &$e4); + }; +} diff --git a/language/vm/vm_runtime/vm_runtime_tests/src/proptest_types.rs b/language/vm/vm_runtime/vm_runtime_tests/src/proptest_types.rs new file mode 100644 index 0000000000000..0816ba7e5680e --- /dev/null +++ b/language/vm/vm_runtime/vm_runtime_tests/src/proptest_types.rs @@ -0,0 +1,46 @@ +// Copyright (c) The Libra Core Contributors +// SPDX-License-Identifier: Apache-2.0 + +use crate::account::{Account, AccountData}; +use proptest::prelude::*; + +impl Arbitrary for Account { + type Strategy = fn() -> Account; + type Parameters = (); + + fn arbitrary_with(_params: ()) -> Self::Strategy { + // Provide Account::new as the canonical strategy. This means that no shrinking will happen, + // but that's fine as accounts have nothing to shrink inside them anyway. + Account::new as Self::Strategy + } +} + +impl AccountData { + /// Returns a [`Strategy`] that creates `AccountData` instances. + pub fn strategy(balance_strategy: impl Strategy) -> impl Strategy { + // Pick sequence numbers and event counts in a smaller range so that valid transactions can + // be generated. + // XXX should we also test edge cases around large sequence numbers? + let sequence_strategy = 0u64..(1 << 32); + let event_count_strategy = 0u64..(1 << 32); + + ( + any::(), + balance_strategy, + sequence_strategy, + event_count_strategy.clone(), + event_count_strategy, + ) + .prop_map( + |(account, balance, sequence_number, sent_events_count, received_events_count)| { + AccountData::with_account_and_event_counts( + account, + balance, + sequence_number, + sent_events_count, + received_events_count, + ) + }, + ) + } +} diff --git a/language/vm/vm_runtime/vm_runtime_tests/src/tests.rs b/language/vm/vm_runtime/vm_runtime_tests/src/tests.rs new file mode 100644 index 0000000000000..a860980715811 --- /dev/null +++ b/language/vm/vm_runtime/vm_runtime_tests/src/tests.rs @@ -0,0 +1,23 @@ +// Copyright (c) The Libra Core Contributors +// SPDX-License-Identifier: Apache-2.0 + +//! Test module. +//! +//! Add new test modules to this list. +//! +//! This is not in a top-level tests directory because each file there gets compiled into a +//! separate binary. The linker ends up repeating a lot of work for each binary to not much +//! benefit. + +mod account_universe; +mod arithmetic; +mod create_account; +mod function_call; +mod genesis; +mod mint; +mod module_publishing; +mod pack_unpack; +mod peer_to_peer; +mod rotate_key; +mod validator_set; +mod verify_txn; diff --git a/language/vm/vm_runtime/vm_runtime_tests/src/tests/account_universe.rs b/language/vm/vm_runtime/vm_runtime_tests/src/tests/account_universe.rs new file mode 100644 index 0000000000000..02a0904644eb3 --- /dev/null +++ b/language/vm/vm_runtime/vm_runtime_tests/src/tests/account_universe.rs @@ -0,0 +1,180 @@ +// Copyright (c) The Libra Core Contributors +// SPDX-License-Identifier: Apache-2.0 + +mod create_account; +mod peer_to_peer; +mod rotate_key; + +use crate::{ + account::AccountResource, + account_universe::{ + log_balance_strategy, num_accounts, num_transactions, AUTransactionGen, AccountCurrent, + AccountPairGen, AccountUniverse, AccountUniverseGen, RotateKeyGen, + }, + executor::FakeExecutor, +}; +use proptest::{collection::vec, prelude::*}; + +proptest! { + // These tests are pretty slow but quite comprehensive, so run a smaller number of them. + #![proptest_config(ProptestConfig::with_cases(32))] + + /// Ensure that account pair generators return the correct indexes. + #[test] + fn account_pair_gen( + universe in AccountUniverseGen::strategy(2..num_accounts(), 0u64..10000), + pairs in vec(any::(), 0..num_transactions()), + ) { + let mut executor = FakeExecutor::from_genesis_file(); + let mut universe = universe.setup(&mut executor); + + for pair in pairs { + let (idx_1, idx_2, account_1, account_2) = { + let pick = pair.pick(&universe); + prop_assert_eq!(pick.account_1, &universe.accounts()[pick.idx_1]); + prop_assert_eq!(pick.account_2, &universe.accounts()[pick.idx_2]); + ( + pick.idx_1, + pick.idx_2, + // Need to convert to raw pointers to avoid holding an immutable reference + // (pick_mut below borrows universe mutably, which would conflict.) + // This is safe as all we're doing is comparing pointer equality. + pick.account_1 as *const AccountCurrent, + pick.account_2 as *const AccountCurrent, + ) + }; + + let pick_mut = pair.pick_mut(&mut universe); + prop_assert_eq!(pick_mut.idx_1, idx_1); + prop_assert_eq!(pick_mut.idx_2, idx_2); + prop_assert_eq!(pick_mut.account_1 as *const AccountCurrent, account_1); + prop_assert_eq!(pick_mut.account_2 as *const AccountCurrent, account_2); + } + } + + #[test] + fn all_transactions( + universe in AccountUniverseGen::strategy(2..num_accounts(), log_balance_strategy(10_000_000)), + transactions in vec(all_transactions_strategy(1, 1_000_000), 0..num_transactions()), + ) { + run_and_assert_universe(universe, transactions)?; + } +} + +/// A strategy that returns a random transaction. +fn all_transactions_strategy( + min: u64, + max: u64, +) -> impl Strategy> { + prop_oneof![ + // Most transactions should be p2p payments. + 8 => peer_to_peer::p2p_strategy(min, max), + 1 => create_account::create_account_strategy(min, max), + 1 => any::().prop_map(RotateKeyGen::boxed), + ] +} + +/// Run these transactions and make sure that they all cost the same amount of gas. +pub(crate) fn run_and_assert_gas_cost_stability( + universe: AccountUniverseGen, + transaction_gens: Vec, + gas_cost: u64, +) -> Result<(), TestCaseError> { + let mut executor = FakeExecutor::from_genesis_file(); + let mut universe = universe.setup_gas_cost_stability(&mut executor); + let (transactions, expected_statuses): (Vec<_>, Vec<_>) = transaction_gens + .into_iter() + .map(|transaction_gen| transaction_gen.apply(&mut universe)) + .unzip(); + let outputs = executor.execute_block(transactions); + + for (idx, (output, expected)) in outputs.iter().zip(&expected_statuses).enumerate() { + prop_assert_eq!( + output.status(), + expected, + "unexpected status for transaction {}", + idx + ); + prop_assert_eq!( + output.gas_used(), + gas_cost, + "transaction at idx {} did not have expected gas cost", + idx, + ); + } + + Ok(()) +} + +/// Run these transactions and verify the expected output. +pub(crate) fn run_and_assert_universe( + universe: AccountUniverseGen, + transaction_gens: Vec, +) -> Result<(), TestCaseError> { + let mut executor = FakeExecutor::from_genesis_file(); + let mut universe = universe.setup(&mut executor); + let (transactions, expected_statuses): (Vec<_>, Vec<_>) = transaction_gens + .into_iter() + .map(|transaction_gen| transaction_gen.apply(&mut universe)) + .unzip(); + let outputs = executor.execute_block(transactions); + + prop_assert_eq!(outputs.len(), expected_statuses.len()); + + for (idx, (output, expected)) in outputs.iter().zip(&expected_statuses).enumerate() { + prop_assert_eq!( + output.status(), + expected, + "unexpected status for transaction {}", + idx + ); + executor.apply_write_set(output.write_set()); + } + + assert_accounts_match(&universe, &executor)?; + Ok(()) +} + +/// Verify that the account information in the universe matches the information in the executor. +pub(crate) fn assert_accounts_match( + universe: &AccountUniverse, + executor: &FakeExecutor, +) -> Result<(), TestCaseError> { + for (idx, account) in universe.accounts().iter().enumerate() { + let resource = executor + .read_account_resource(&account.account()) + .expect("resource for this account must exist"); + prop_assert_eq!( + account.account().auth_key(), + AccountResource::read_auth_key(&resource), + "account {} should have correct auth key", + idx + ); + prop_assert_eq!( + account.balance(), + AccountResource::read_balance(&resource), + "account {} should have correct balance", + idx + ); + // XXX These two don't work at the moment because the VM doesn't bump up event counts. + // prop_assert_eq!( + // account.received_events_count(), + // AccountResource::read_received_events_count(&resource), + // "account {} should have correct received_events_count", + // idx + // ); + // prop_assert_eq!( + // account.sent_events_count(), + // AccountResource::read_sent_events_count(&resource), + // "account {} should have correct sent_events_count", + // idx + // ); + prop_assert_eq!( + account.sequence_number(), + AccountResource::read_sequence_number(&resource), + "account {} should have correct sequence number", + idx + ); + } + Ok(()) +} diff --git a/language/vm/vm_runtime/vm_runtime_tests/src/tests/account_universe/create_account.rs b/language/vm/vm_runtime/vm_runtime_tests/src/tests/account_universe/create_account.rs new file mode 100644 index 0000000000000..3c9a21b4e40a6 --- /dev/null +++ b/language/vm/vm_runtime/vm_runtime_tests/src/tests/account_universe/create_account.rs @@ -0,0 +1,84 @@ +// Copyright (c) The Libra Core Contributors +// SPDX-License-Identifier: Apache-2.0 + +use crate::{ + account_universe::{ + log_balance_strategy, num_accounts, num_transactions, AUTransactionGen, AccountUniverseGen, + CreateAccountGen, CreateExistingAccountGen, + }, + gas_costs, + tests::account_universe::{run_and_assert_gas_cost_stability, run_and_assert_universe}, +}; +use proptest::{collection::vec, prelude::*}; + +proptest! { + // These tests are pretty slow but quite comprehensive, so run a smaller number of them. + #![proptest_config(ProptestConfig::with_cases(32))] + + // Need a minimum of one account for create_account. + // Set balances high enough that transactions will always succeed. + #[test] + fn create_account_gas_cost_stability( + universe in AccountUniverseGen::success_strategy(1), + transfers in vec(any_with::((1, 10_000)), 0..num_transactions()), + ) { + run_and_assert_gas_cost_stability(universe, transfers, *gas_costs::CREATE_ACCOUNT)?; + } + + #[test] + fn create_account_high_balance( + universe in AccountUniverseGen::strategy(1..num_accounts(), 1_000_000u64..10_000_000), + transfers in vec(any_with::((1, 10_000)), 0..num_transactions()), + ) { + run_and_assert_universe(universe, transfers)?; + } + + /// Test with balances small enough to possibly trigger failures. + #[test] + fn create_account_low_balance( + universe in AccountUniverseGen::strategy(1..num_accounts(), 0u64..100_000), + transfers in vec(any_with::((1, 50_000)), 0..num_transactions()), + ) { + run_and_assert_universe(universe, transfers)?; + } + + // Need a minimum of two accounts for create account with existing receiver. + // Set balances high enough that transactions will always succeed. + #[test] + fn create_existing_account_gas_cost_stability( + universe in AccountUniverseGen::success_strategy(2), + transfers in vec(any_with::((1, 10_000)), 0..num_transactions()), + ) { + run_and_assert_gas_cost_stability(universe, transfers, *gas_costs::CREATE_EXISTING_ACCOUNT)?; + } + + #[test] + fn create_existing_account( + universe in AccountUniverseGen::strategy(2..num_accounts(), log_balance_strategy(10_000_000)), + transfers in vec(any_with::((1, 1_000_000)), 0..num_transactions()), + ) { + run_and_assert_universe(universe, transfers)?; + } + + /// Mixed tests with the different kinds of create-account transactions and a large variety + /// of balances. + #[test] + fn create_account_mixed( + universe in AccountUniverseGen::strategy(2..num_accounts(), log_balance_strategy(10_000_000)), + transfers in vec(create_account_strategy(1, 1_000_000), 0..num_transactions()), + ) { + run_and_assert_universe(universe, transfers)?; + } +} + +pub(super) fn create_account_strategy( + min: u64, + max: u64, +) -> impl Strategy> { + prop_oneof![ + 3 => any_with::((min, max)).prop_map(CreateAccountGen::boxed), + 1 => any_with::((min, max)).prop_map( + CreateExistingAccountGen::boxed, + ), + ] +} diff --git a/language/vm/vm_runtime/vm_runtime_tests/src/tests/account_universe/peer_to_peer.rs b/language/vm/vm_runtime/vm_runtime_tests/src/tests/account_universe/peer_to_peer.rs new file mode 100644 index 0000000000000..1051fa3b96dbd --- /dev/null +++ b/language/vm/vm_runtime/vm_runtime_tests/src/tests/account_universe/peer_to_peer.rs @@ -0,0 +1,96 @@ +// Copyright (c) The Libra Core Contributors +// SPDX-License-Identifier: Apache-2.0 + +use crate::{ + account_universe::{ + log_balance_strategy, num_accounts, num_transactions, AUTransactionGen, AccountUniverseGen, + P2PNewReceiverGen, P2PTransferGen, + }, + gas_costs, + tests::account_universe::{run_and_assert_gas_cost_stability, run_and_assert_universe}, +}; +use proptest::{collection::vec, prelude::*}; + +proptest! { + // These tests are pretty slow but quite comprehensive, so run a smaller number of them. + #![proptest_config(ProptestConfig::with_cases(32))] + + // Need a minimum of two accounts to send p2p transactions over. + // Set balances high enough that transactions will always succeed. + #[test] + fn p2p_gas_cost_stability( + universe in AccountUniverseGen::success_strategy(2), + transfers in vec(any_with::((1, 10_000)), 0..num_transactions()), + ) { + run_and_assert_gas_cost_stability(universe, transfers, *gas_costs::PEER_TO_PEER)?; + } + + #[test] + fn p2p_high_balance( + universe in AccountUniverseGen::strategy(2..num_accounts(), 1_000_000u64..10_000_000), + transfers in vec(any_with::((1, 10_000)), 0..num_transactions()), + ) { + run_and_assert_universe(universe, transfers)?; + } + + /// Test with balances small enough to possibly trigger failures. + #[test] + fn p2p_low_balance( + universe in AccountUniverseGen::strategy(2..num_accounts(), 0u64..100_000), + transfers in vec(any_with::((1, 50_000)), 0..num_transactions()), + ) { + run_and_assert_universe(universe, transfers)?; + } + + // Need a minimum of one account to send p2p transactions to other accounts. + // Set balances high enough that transactions will always succeed. + #[test] + fn p2p_new_receiver_gas_cost_stability( + universe in AccountUniverseGen::success_strategy(1), + transfers in vec(any_with::((1, 10_000)), 0..num_transactions()), + ) { + run_and_assert_gas_cost_stability( + universe, + transfers, + *gas_costs::PEER_TO_PEER_NEW_RECEIVER, + )?; + } + + /// Test that p2p transfers can be done to new accounts. + #[test] + fn p2p_new_receiver_high_balance( + universe in AccountUniverseGen::strategy(1..num_accounts(), 1_000_000u64..10_000_000), + transfers in vec(any_with::((1, 10_000)), 0..num_transactions()), + ) { + run_and_assert_universe(universe, transfers)?; + } + + /// Test with balances small enough to possibly trigger failures. + #[test] + fn p2p_new_receiver_low_balance( + universe in AccountUniverseGen::strategy(1..num_accounts(), 0u64..100_000), + transfers in vec(any_with::((1, 50_000)), 0..num_transactions()), + ) { + run_and_assert_universe(universe, transfers)?; + } + + /// Mixed tests with all the different kinds of peer to peer transactions and a large + /// variety of balances. + #[test] + fn p2p_mixed( + universe in AccountUniverseGen::strategy(2..num_accounts(), log_balance_strategy(10_000_000)), + transfers in vec(p2p_strategy(1, 1_000_000), 0..num_transactions()), + ) { + run_and_assert_universe(universe, transfers)?; + } +} + +pub(super) fn p2p_strategy( + min: u64, + max: u64, +) -> impl Strategy> { + prop_oneof![ + 3 => any_with::((min, max)).prop_map(P2PTransferGen::boxed), + 1 => any_with::((min, max)).prop_map(P2PNewReceiverGen::boxed), + ] +} diff --git a/language/vm/vm_runtime/vm_runtime_tests/src/tests/account_universe/rotate_key.rs b/language/vm/vm_runtime/vm_runtime_tests/src/tests/account_universe/rotate_key.rs new file mode 100644 index 0000000000000..008b7dfcc94f0 --- /dev/null +++ b/language/vm/vm_runtime/vm_runtime_tests/src/tests/account_universe/rotate_key.rs @@ -0,0 +1,38 @@ +// Copyright (c) The Libra Core Contributors +// SPDX-License-Identifier: Apache-2.0 + +use crate::{ + account_universe::{num_accounts, num_transactions, AccountUniverseGen, RotateKeyGen}, + gas_costs, + tests::account_universe::{run_and_assert_gas_cost_stability, run_and_assert_universe}, +}; +use proptest::{collection::vec, prelude::*}; + +proptest! { + // These tests are pretty slow but quite comprehensive, so run a smaller number of them. + #![proptest_config(ProptestConfig::with_cases(32))] + + #[test] + fn rotate_key_gas_cost_stability( + universe in AccountUniverseGen::success_strategy(1), + key_rotations in vec(any::(), 0..num_transactions()), + ) { + run_and_assert_gas_cost_stability(universe, key_rotations, *gas_costs::ROTATE_KEY)?; + } + + #[test] + fn rotate_key_high_balance( + universe in AccountUniverseGen::strategy(1..num_accounts(), 1_000_000u64..10_000_000), + key_rotations in vec(any::(), 0..num_transactions()), + ) { + run_and_assert_universe(universe, key_rotations)?; + } + + #[test] + fn rotate_key_low_balance( + universe in AccountUniverseGen::strategy(1..num_accounts(), 0u64..100_000), + key_rotations in vec(any::(), 0..num_transactions()), + ) { + run_and_assert_universe(universe, key_rotations)?; + } +} diff --git a/language/vm/vm_runtime/vm_runtime_tests/src/tests/arithmetic.rs b/language/vm/vm_runtime/vm_runtime_tests/src/tests/arithmetic.rs new file mode 100644 index 0000000000000..3eb0d202c1165 --- /dev/null +++ b/language/vm/vm_runtime/vm_runtime_tests/src/tests/arithmetic.rs @@ -0,0 +1,37 @@ +// Copyright (c) The Libra Core Contributors +// SPDX-License-Identifier: Apache-2.0 + +use crate::compile_and_execute; +use vm::assert_ok; + +#[test] +fn simple_main() { + let program = String::from( + " + main() { + return; + } + ", + ); + + assert_ok!(compile_and_execute(&program, vec![])); +} + +#[test] +fn simple_arithmetic() { + let program = String::from( + " + main() { + let a: u64; + let b: u64; + a = 2 + 3; + assert(copy(a) == 5, 42); + b = copy(a) - 1; + assert(copy(b) == 4, 42); + return; + } + ", + ); + + assert_ok!(compile_and_execute(&program, vec![])); +} diff --git a/language/vm/vm_runtime/vm_runtime_tests/src/tests/create_account.rs b/language/vm/vm_runtime/vm_runtime_tests/src/tests/create_account.rs new file mode 100644 index 0000000000000..66efbbbe354de --- /dev/null +++ b/language/vm/vm_runtime/vm_runtime_tests/src/tests/create_account.rs @@ -0,0 +1,100 @@ +// Copyright (c) The Libra Core Contributors +// SPDX-License-Identifier: Apache-2.0 + +use crate::{ + account::{Account, AccountData, AccountResource}, + common_transactions::create_account_txn, + executor::FakeExecutor, +}; +use types::{ + transaction::{SignedTransaction, TransactionStatus}, + vm_error::{ExecutionStatus, VMStatus}, +}; + +#[test] +fn create_account() { + // create a FakeExecutor with a genesis from file + let mut executor = FakeExecutor::from_genesis_file(); + + // create and publish a sender with 1_000_000 coins + let sender = AccountData::new(1_000_000, 10); + executor.add_account_data(&sender); + let new_account = Account::new(); + let initial_amount = 1_000; + let txn = create_account_txn(sender.account(), &new_account, 10, initial_amount); + + // execute transaction + let txns: Vec = vec![txn]; + let output = executor.execute_block(txns); + let txn_output = output.get(0).expect("must have a transaction output"); + assert_eq!( + output[0].status(), + &TransactionStatus::Keep(VMStatus::Execution(ExecutionStatus::Executed)) + ); + println!("write set {:?}", txn_output.write_set()); + executor.apply_write_set(txn_output.write_set()); + + // check that numbers in stored DB are correct + let gas = txn_output.gas_used(); + let sender_balance = 1_000_000 - initial_amount - gas; + let updated_sender = executor + .read_account_resource(sender.account()) + .expect("sender must exist"); + let updated_receiver = executor + .read_account_resource(&new_account) + .expect("receiver must exist"); + assert_eq!( + initial_amount, + AccountResource::read_balance(&updated_receiver) + ); + assert_eq!( + sender_balance, + AccountResource::read_balance(&updated_sender) + ); + assert_eq!(11, AccountResource::read_sequence_number(&updated_sender)); +} + +#[test] +fn create_account_zero_balance() { + // create a FakeExecutor with a genesis from file + let mut executor = FakeExecutor::from_genesis_file(); + + // create and publish a sender with 1_000_000 coins + let sender = AccountData::new(1_000_000, 10); + executor.add_account_data(&sender); + let new_account = Account::new(); + + // define the arguments to the create account transaction + let initial_amount = 0; + let txn = create_account_txn(sender.account(), &new_account, 10, initial_amount); + + // execute transaction + let txns: Vec = vec![txn]; + let output = executor.execute_block(txns); + let txn_output = output.get(0).expect("must have a transaction output"); + assert_eq!( + output[0].status(), + &TransactionStatus::Keep(VMStatus::Execution(ExecutionStatus::Executed)) + ); + println!("write set {:?}", txn_output.write_set()); + executor.apply_write_set(txn_output.write_set()); + + // check that numbers in stored DB are correct + let gas = txn_output.gas_used(); + let sender_balance = 1_000_000 - initial_amount - gas; + let updated_sender = executor + .read_account_resource(sender.account()) + .expect("sender must exist"); + let updated_receiver = executor + .read_account_resource(&new_account) + .expect("receiver must exist"); + assert_eq!( + initial_amount, + AccountResource::read_balance(&updated_receiver) + ); + assert_eq!( + sender_balance, + AccountResource::read_balance(&updated_sender) + ); + assert_eq!(11, AccountResource::read_sequence_number(&updated_sender)); +} diff --git a/language/vm/vm_runtime/vm_runtime_tests/src/tests/function_call.rs b/language/vm/vm_runtime/vm_runtime_tests/src/tests/function_call.rs new file mode 100644 index 0000000000000..54c69eb0f2fba --- /dev/null +++ b/language/vm/vm_runtime/vm_runtime_tests/src/tests/function_call.rs @@ -0,0 +1,111 @@ +// Copyright (c) The Libra Core Contributors +// SPDX-License-Identifier: Apache-2.0 + +use crate::compile_and_execute; +use vm::assert_ok; + +#[test] +fn simple_main() { + let program = String::from( + " + modules: + module M { + public max(a: u64, b: u64): u64 { + if (copy(a) > copy(b)) { + return copy(a); + } else { + return copy(b); + } + return 0; + } + + public sum(a: u64, b: u64): u64 { + let c: u64; + c = copy(a) + copy(b); + return copy(c); + } + } + script: + import 0x0000000000000000000000000000000000000000000000000000000000000000.M; + + main() { + let a: u64; + let b: u64; + let c: u64; + let d: u64; + + a = 10; + b = 2; + c = M.max(copy(a), copy(b)); + d = M.sum(copy(a), copy(b)); + assert(copy(c) == 10, 42); + assert(copy(d) == 12, 42); + return; + } + ", + ); + assert_ok!(compile_and_execute(&program, vec![])); +} + +#[test] +fn diff_type_args() { + let program = String::from( + " + modules: + module M { + public case(a: u64, b: bool): u64 { + if (copy(b)) { + return copy(a); + } else { + return 42; + } + return 112; + } + } + script: + import 0x0.M; + + main() { + let a: u64; + a = 10; + a = M.case(move(a), false); + assert(copy(a) == 42, 41); + return; + } + ", + ); + assert_ok!(compile_and_execute(&program, vec![])); +} + +#[test] +fn multiple_return_values() { + let program = String::from( + " + modules: + module M { + public id3(a: u64, b: bool, c: address): u64 * bool * address { + return move(a), move(b), move(c); + } + } + script: + import 0x0.M; + + main() { + let a: u64; + let b: bool; + let c: address; + + a = 10; + b = false; + c = 0x0; + + a, b, c = M.id3(move(a), move(b), move(c)); + assert(move(a) == 10, 42); + assert(move(b) == false, 43); + assert(move(c) == 0x0, 44); + return; + } + ", + ); + assert_ok!(compile_and_execute(&program, vec![])); +} diff --git a/language/vm/vm_runtime/vm_runtime_tests/src/tests/genesis.rs b/language/vm/vm_runtime/vm_runtime_tests/src/tests/genesis.rs new file mode 100644 index 0000000000000..435a8469f8132 --- /dev/null +++ b/language/vm/vm_runtime/vm_runtime_tests/src/tests/genesis.rs @@ -0,0 +1,36 @@ +// Copyright (c) The Libra Core Contributors +// SPDX-License-Identifier: Apache-2.0 + +use crate::{assert_prologue_parity, executor::FakeExecutor}; +use assert_matches::assert_matches; +use crypto::signing::KeyPair; +use types::{ + access_path::AccessPath, + account_config, + test_helpers::transaction_test_helpers, + transaction::TransactionStatus, + vm_error::{VMStatus, VMValidationStatus}, + write_set::{WriteOp, WriteSetMut}, +}; + +#[test] +fn invalid_genesis_write_set() { + let executor = FakeExecutor::no_genesis(); + // Genesis write sets are not allowed to contain deletions. + let write_op = (AccessPath::default(), WriteOp::Deletion); + let write_set = WriteSetMut::new(vec![write_op]).freeze().unwrap(); + let address = account_config::association_address(); + let keypair = KeyPair::new(::crypto::signing::generate_keypair().0); + let signed_txn = transaction_test_helpers::get_write_set_txn( + address, + 0, + keypair.private_key().clone(), + keypair.public_key(), + Some(write_set), + ); + assert_prologue_parity!( + executor.verify_transaction(signed_txn.clone()), + executor.execute_transaction(signed_txn).status(), + VMStatus::Validation(VMValidationStatus::InvalidWriteSet) + ); +} diff --git a/language/vm/vm_runtime/vm_runtime_tests/src/tests/mint.rs b/language/vm/vm_runtime/vm_runtime_tests/src/tests/mint.rs new file mode 100644 index 0000000000000..c8d08546781d5 --- /dev/null +++ b/language/vm/vm_runtime/vm_runtime_tests/src/tests/mint.rs @@ -0,0 +1,114 @@ +// Copyright (c) The Libra Core Contributors +// SPDX-License-Identifier: Apache-2.0 + +use crate::{ + account::{Account, AccountData, AccountResource}, + common_transactions::mint_txn, + executor::FakeExecutor, +}; +use types::{ + transaction::{SignedTransaction, TransactionStatus}, + vm_error::{ExecutionStatus, VMStatus}, +}; + +#[test] +fn mint_to_existing() { + // create a FakeExecutor with a genesis from file + let mut executor = FakeExecutor::from_genesis_file(); + let genesis_account = Account::new_association(); + + // create and publish a sender with 1_000_000 coins + let receiver = AccountData::new(1_000_000, 10); + executor.add_account_data(&receiver); + + let mint_amount = 1_000; + let txn = mint_txn(&genesis_account, receiver.account(), 0, mint_amount); + + // execute transaction + let txns: Vec = vec![txn]; + let output = executor.execute_block(txns); + let txn_output = output.get(0).expect("must have a transaction output"); + assert_eq!( + output[0].status(), + &TransactionStatus::Keep(VMStatus::Execution(ExecutionStatus::Executed)) + ); + println!("write set {:?}", txn_output.write_set()); + executor.apply_write_set(txn_output.write_set()); + + // check that numbers in stored DB are correct + let gas = txn_output.gas_used(); + let sender_balance = 1_000_000_000 - gas; + let receiver_balance = 1_000_000 + mint_amount; + + let updated_sender = executor + .read_account_resource(&genesis_account) + .expect("sender must exist"); + let updated_receiver = executor + .read_account_resource(receiver.account()) + .expect("receiver must exist"); + assert_eq!( + sender_balance, + AccountResource::read_balance(&updated_sender) + ); + assert_eq!( + receiver_balance, + AccountResource::read_balance(&updated_receiver) + ); + assert_eq!(1, AccountResource::read_sequence_number(&updated_sender)); + assert_eq!(10, AccountResource::read_sequence_number(&updated_receiver)); +} + +#[test] +fn mint_to_new_account() { + // create a FakeExecutor with a genesis from file + let mut executor = FakeExecutor::from_genesis_file(); + let genesis_account = Account::new_association(); + + // create and publish a sender with 1_000_000 coins + let new_account = Account::new(); + + let mint_amount = 10_000; + let txn = mint_txn(&genesis_account, &new_account, 0, mint_amount); + + // execute transaction + let txns: Vec = vec![txn]; + let output = executor.execute_block(txns); + let txn_output = output.get(0).expect("must have a transaction output"); + assert_eq!( + output[0].status(), + &TransactionStatus::Keep(VMStatus::Execution(ExecutionStatus::Executed)) + ); + executor.apply_write_set(txn_output.write_set()); + + // check that numbers in stored DB are correct + let gas = txn_output.gas_used(); + let sender_balance = 1_000_000_000 - gas; + let receiver_balance = mint_amount; + + let updated_sender = executor + .read_account_resource(&genesis_account) + .expect("sender must exist"); + let updated_receiver = executor + .read_account_resource(&new_account) + .expect("receiver must exist"); + assert_eq!( + sender_balance, + AccountResource::read_balance(&updated_sender) + ); + assert_eq!( + receiver_balance, + AccountResource::read_balance(&updated_receiver) + ); + assert_eq!(1, AccountResource::read_sequence_number(&updated_sender)); + assert_eq!(0, AccountResource::read_sequence_number(&updated_receiver)); + + // Mint can only be called from genesis address; + let txn = mint_txn(&new_account, &new_account, 0, mint_amount); + let txns: Vec = vec![txn]; + let output = executor.execute_block(txns); + + assert_eq!( + output[0].status(), + &TransactionStatus::Keep(VMStatus::Execution(ExecutionStatus::MissingData)) + ); +} diff --git a/language/vm/vm_runtime/vm_runtime_tests/src/tests/module_publishing.rs b/language/vm/vm_runtime/vm_runtime_tests/src/tests/module_publishing.rs new file mode 100644 index 0000000000000..53993f7847e68 --- /dev/null +++ b/language/vm/vm_runtime/vm_runtime_tests/src/tests/module_publishing.rs @@ -0,0 +1,261 @@ +// Copyright (c) The Libra Core Contributors +// SPDX-License-Identifier: Apache-2.0 + +use crate::{ + account::AccountData, assert_prologue_parity, compile::compile_program_with_address, + executor::FakeExecutor, +}; +use assert_matches::assert_matches; +use config::config::VMPublishingOption; +use types::{ + transaction::TransactionStatus, + vm_error::{ + ExecutionStatus, VMStatus, VMValidationStatus, VMVerificationError, VMVerificationStatus, + }, +}; + +// A module with an address different from the sender's address should be rejected +#[test] +fn bad_module_address() { + let mut executor = FakeExecutor::from_genesis_with_options(VMPublishingOption::Open); + + // create a transaction trying to publish a new module. + let account1 = AccountData::new(1_000_000, 10); + let account2 = AccountData::new(1_000_000, 10); + + executor.add_account_data(&account1); + executor.add_account_data(&account2); + + let program = String::from( + " + modules: + module M { + + } + + script: + main() { + return; + } + ", + ); + + // compile with account 1's address + let compiled_script = compile_program_with_address(account1.address(), &program, vec![]); + // send with account 2's address + let txn = account2.account().create_signed_txn_impl( + *account2.address(), + compiled_script, + 10, + 10_000, + 1, + ); + + // verify and fail because the addresses don't match + let vm_status = executor.verify_transaction(txn.clone()); + let status = match vm_status { + Some(VMStatus::Verification(status)) => status, + vm_status => panic!("Unexpected verification status: {:?}", vm_status), + }; + match status.as_slice() { + &[VMVerificationStatus::Module( + 0, + VMVerificationError::ModuleAddressDoesNotMatchSender(_), + )] => {} + err => panic!("Unexpected verification error: {:?}", err), + }; + + // execute and fail for the same reason + let output = executor.execute_transaction(txn); + let status = match output.status() { + TransactionStatus::Discard(VMStatus::Verification(status)) => status, + vm_status => panic!("Unexpected verification status: {:?}", vm_status), + }; + match status.as_slice() { + &[VMVerificationStatus::Module( + 0, + VMVerificationError::ModuleAddressDoesNotMatchSender(_), + )] => {} + err => panic!("Unexpected verification error: {:?}", err), + }; +} + +// Publishing a module named M under the same address twice should be rejected +#[test] +fn duplicate_module() { + let mut executor = FakeExecutor::from_genesis_with_options(VMPublishingOption::Open); + + let sequence_number = 2; + let account = AccountData::new(1_000_000, sequence_number); + executor.add_account_data(&account); + + let program = String::from( + " + modules: + module M { + + } + + script: + main() { + return; + } + ", + ); + let compiled_script = compile_program_with_address(account.address(), &program, vec![]); + + let txn1 = account.account().create_signed_txn_impl( + *account.address(), + compiled_script.clone(), + sequence_number, + 10_000, + 1, + ); + + let txn2 = account.account().create_signed_txn_impl( + *account.address(), + compiled_script, + sequence_number + 1, + 10_000, + 1, + ); + + let output1 = executor.execute_transaction(txn1); + executor.apply_write_set(output1.write_set()); + // first tx should succeed + assert_eq!( + output1.status(), + &TransactionStatus::Keep(VMStatus::Execution(ExecutionStatus::Executed)), + ); + + // second one should fail because it tries to re-publish a module named M + let output2 = executor.execute_transaction(txn2); + assert_eq!( + output2.status(), + &TransactionStatus::Keep(VMStatus::Execution(ExecutionStatus::DuplicateModuleName)), + ); +} + +#[test] +pub fn test_publishing_no_modules_non_whitelist_script() { + // create a FakeExecutor with a genesis from file + let mut executor = FakeExecutor::from_genesis_with_options(VMPublishingOption::CustomScripts); + + // create a transaction trying to publish a new module. + let sender = AccountData::new(1_000_000, 10); + executor.add_account_data(&sender); + + let program = String::from( + " + modules: + module M { + } + script: + main () { + return; + } + ", + ); + + let random_script = compile_program_with_address(sender.address(), &program, vec![]); + let txn = + sender + .account() + .create_signed_txn_impl(*sender.address(), random_script, 10, 10_000, 1); + + assert_prologue_parity!( + executor.verify_transaction(txn.clone()), + executor.execute_transaction(txn).status(), + VMStatus::Validation(VMValidationStatus::UnknownModule) + ); +} + +#[test] +pub fn test_publishing_allow_modules() { + // create a FakeExecutor with a genesis from file + let mut executor = FakeExecutor::from_genesis_with_options(VMPublishingOption::Open); + + // create a transaction trying to publish a new module. + let sender = AccountData::new(1_000_000, 10); + executor.add_account_data(&sender); + + let program = String::from( + " + modules: + module M { + } + script: + main () { + return; + }", + ); + + let random_script = compile_program_with_address(sender.address(), &program, vec![]);; + let txn = + sender + .account() + .create_signed_txn_impl(*sender.address(), random_script, 10, 10_000, 1); + assert_eq!(executor.verify_transaction(txn.clone()), None); + assert_eq!( + executor.execute_transaction(txn).status(), + &TransactionStatus::Keep(VMStatus::Execution(ExecutionStatus::Executed)) + ); +} + +#[test] +pub fn test_publishing_with_error() { + // create a FakeExecutor with a genesis from file + let mut executor = FakeExecutor::from_genesis_with_options(VMPublishingOption::Open); + + // create a transaction trying to publish a new module. + let sender = AccountData::new(1_000_000, 10); + executor.add_account_data(&sender); + + let program = String::from( + " + modules: + module M { + } + script: + main () { + assert(false, 42); + return; + }", + ); + + let random_script = compile_program_with_address(sender.address(), &program, vec![]);; + let txn1 = + sender + .account() + .create_signed_txn_impl(*sender.address(), random_script, 10, 10_000, 1); + let program = String::from( + " + modules: + module M { + } + script: + main () { + return; + }", + ); + + let random_script = compile_program_with_address(sender.address(), &program, vec![]);; + let txn2 = + sender + .account() + .create_signed_txn_impl(*sender.address(), random_script, 11, 10_000, 1); + + assert_eq!(executor.verify_transaction(txn1.clone()), None); + assert_eq!(executor.verify_transaction(txn2.clone()), None); + + let result = executor.execute_block(vec![txn1, txn2]); + assert_eq!( + result[0].status(), + &TransactionStatus::Keep(VMStatus::Execution(ExecutionStatus::AssertionFailure(42))) + ); + + assert_eq!( + result[1].status(), + &TransactionStatus::Keep(VMStatus::Execution(ExecutionStatus::Executed)) + ); +} diff --git a/language/vm/vm_runtime/vm_runtime_tests/src/tests/pack_unpack.rs b/language/vm/vm_runtime/vm_runtime_tests/src/tests/pack_unpack.rs new file mode 100644 index 0000000000000..2e1c608578fe8 --- /dev/null +++ b/language/vm/vm_runtime/vm_runtime_tests/src/tests/pack_unpack.rs @@ -0,0 +1,39 @@ +// Copyright (c) The Libra Core Contributors +// SPDX-License-Identifier: Apache-2.0 + +use crate::compile_and_execute; +use vm::assert_ok; + +#[test] +fn simple_unpack() { + let program = String::from( + " +modules: +module Test { + resource T { i: u64, b: bool } + + public new_t(): R#Self.T { + return T { i: 0, b: false }; + } + + public unpack_t(t: R#Self.T) { + let i: u64; + let flag: bool; + T { i, b: flag } = move(t); + return; + } + +} +script: +import 0x0.Test; +main() { + let t: R#Test.T; + + t = Test.new_t(); + Test.unpack_t(move(t)); + + return; +}", + ); + assert_ok!(compile_and_execute(&program, vec![])); +} diff --git a/language/vm/vm_runtime/vm_runtime_tests/src/tests/peer_to_peer.rs b/language/vm/vm_runtime/vm_runtime_tests/src/tests/peer_to_peer.rs new file mode 100644 index 0000000000000..6049d70baf7fe --- /dev/null +++ b/language/vm/vm_runtime/vm_runtime_tests/src/tests/peer_to_peer.rs @@ -0,0 +1,576 @@ +// Copyright (c) The Libra Core Contributors +// SPDX-License-Identifier: Apache-2.0 + +use crate::{ + account::{Account, AccountData, AccountResource}, + common_transactions::peer_to_peer_txn, + executor::FakeExecutor, +}; +use canonical_serialization::SimpleDeserializer; +use std::time::Instant; +use types::{ + account_config::{account_received_event_path, account_sent_event_path, AccountEvent}, + transaction::{SignedTransaction, TransactionOutput, TransactionStatus}, + vm_error::{ExecutionStatus, VMStatus}, +}; + +#[test] +fn single_peer_to_peer_with_event() { + ::logger::try_init_for_testing(); + // create a FakeExecutor with a genesis from file + let mut executor = FakeExecutor::from_genesis_file(); + + // create and publish a sender with 1_000_000 coins and a receiver with 100_000 coins + let sender = AccountData::new(1_000_000, 10); + let receiver = AccountData::new(100_000, 10); + executor.add_account_data(&sender); + executor.add_account_data(&receiver); + + let transfer_amount = 1_000; + let txn = peer_to_peer_txn(sender.account(), receiver.account(), 10, transfer_amount); + + // execute transaction + let txns: Vec = vec![txn]; + let output = executor.execute_block(txns); + let txn_output = output.get(0).expect("must have a transaction output"); + assert_eq!( + output[0].status(), + &TransactionStatus::Keep(VMStatus::Execution(ExecutionStatus::Executed)) + ); + let rec_ev_path = account_sent_event_path(); + let sent_ev_path = account_received_event_path(); + for event in txn_output.events() { + assert!( + rec_ev_path == event.access_path().path || sent_ev_path == event.access_path().path + ); + } + executor.apply_write_set(txn_output.write_set()); + + // check that numbers in stored DB are correct + let gas = txn_output.gas_used(); + let sender_balance = 1_000_000 - transfer_amount - gas; + let receiver_balance = 100_000 + transfer_amount; + let updated_sender = executor + .read_account_resource(sender.account()) + .expect("sender must exist"); + let updated_receiver = executor + .read_account_resource(receiver.account()) + .expect("receiver must exist"); + assert_eq!( + receiver_balance, + AccountResource::read_balance(&updated_receiver) + ); + assert_eq!( + sender_balance, + AccountResource::read_balance(&updated_sender) + ); + assert_eq!(11, AccountResource::read_sequence_number(&updated_sender)); + assert_eq!( + 0, + AccountResource::read_received_events_count(&updated_sender) + ); + assert_eq!(1, AccountResource::read_sent_events_count(&updated_sender)); + assert_eq!( + 1, + AccountResource::read_received_events_count(&updated_receiver) + ); + assert_eq!( + 0, + AccountResource::read_sent_events_count(&updated_receiver) + ); +} + +#[test] +fn few_peer_to_peer_with_event() { + // create a FakeExecutor with a genesis from file + let mut executor = FakeExecutor::from_genesis_file(); + + // create and publish a sender with 1_000_000 coins and a receiver with 100_000 coins + let sender = AccountData::new(1_000_000, 10); + let receiver = AccountData::new(100_000, 10); + executor.add_account_data(&sender); + executor.add_account_data(&receiver); + + let transfer_amount = 1_000; + + // execute transaction + let txns: Vec = vec![ + peer_to_peer_txn(sender.account(), receiver.account(), 10, transfer_amount), + peer_to_peer_txn(sender.account(), receiver.account(), 11, transfer_amount), + peer_to_peer_txn(sender.account(), receiver.account(), 12, transfer_amount), + peer_to_peer_txn(sender.account(), receiver.account(), 13, transfer_amount), + ]; + let output = executor.execute_block(txns); + for (idx, txn_output) in output.iter().enumerate() { + assert_eq!( + txn_output.status(), + &TransactionStatus::Keep(VMStatus::Execution(ExecutionStatus::Executed)) + ); + + // check events + for event in txn_output.events() { + let account_event: AccountEvent = + SimpleDeserializer::deserialize(event.event_data()).expect("event data must parse"); + assert_eq!(transfer_amount, account_event.amount()); + assert!( + &account_event.account() == sender.address() + || &account_event.account() == receiver.address() + ); + } + + let original_sender = executor + .read_account_resource(sender.account()) + .expect("sender must exist"); + let original_receiver = executor + .read_account_resource(receiver.account()) + .expect("receiver must exist"); + executor.apply_write_set(txn_output.write_set()); + + // check that numbers in stored DB are correct + let gas = txn_output.gas_used(); + let sender_balance = + AccountResource::read_balance(&original_sender) - transfer_amount - gas; + let receiver_balance = AccountResource::read_balance(&original_receiver) + transfer_amount; + let updated_sender = executor + .read_account_resource(sender.account()) + .expect("sender must exist"); + let updated_receiver = executor + .read_account_resource(receiver.account()) + .expect("receiver must exist"); + assert_eq!( + receiver_balance, + AccountResource::read_balance(&updated_receiver) + ); + assert_eq!( + sender_balance, + AccountResource::read_balance(&updated_sender) + ); + assert_eq!( + 11 + idx as u64, + AccountResource::read_sequence_number(&updated_sender) + ); + assert_eq!( + 0, + AccountResource::read_received_events_count(&updated_sender) + ); + assert_eq!( + idx as u64 + 1, + AccountResource::read_sent_events_count(&updated_sender) + ); + assert_eq!( + idx as u64 + 1, + AccountResource::read_received_events_count(&updated_receiver) + ); + assert_eq!( + 0, + AccountResource::read_sent_events_count(&updated_receiver) + ); + } +} + +/// Test that a zero-amount transaction fails, per policy. +#[test] +fn zero_amount_peer_to_peer() { + let mut executor = FakeExecutor::from_genesis_file(); + let sequence_number = 10; + let sender = AccountData::new(1_000_000, sequence_number); + let receiver = AccountData::new(100_000, sequence_number); + executor.add_account_data(&sender); + executor.add_account_data(&receiver); + + let transfer_amount = 0; + let txn = peer_to_peer_txn( + sender.account(), + receiver.account(), + sequence_number, + transfer_amount, + ); + + let output = &executor.execute_block(vec![txn])[0]; + // Error code 7 means that the transaction was a zero-amount one. + assert_eq!( + output.status(), + &TransactionStatus::Keep(VMStatus::Execution(ExecutionStatus::AssertionFailure(7))), + ); +} + +#[test] +fn peer_to_peer_create_account() { + // create a FakeExecutor with a genesis from file + let mut executor = FakeExecutor::from_genesis_file(); + + // create and publish a sender with 1_000_000 coins + let sender = AccountData::new(1_000_000, 10); + executor.add_account_data(&sender); + let new_account = Account::new(); + + // define the arguments to the peer to peer transaction + let transfer_amount = 1_000; + let txn = peer_to_peer_txn(sender.account(), &new_account, 10, transfer_amount); + + // execute transaction + let txns: Vec = vec![txn]; + let output = executor.execute_block(txns); + let txn_output = output.get(0).expect("must have a transaction output"); + assert_eq!( + output[0].status(), + &TransactionStatus::Keep(VMStatus::Execution(ExecutionStatus::Executed)) + ); + executor.apply_write_set(txn_output.write_set()); + + // check that numbers in stored DB are correct + let gas = txn_output.gas_used(); + let sender_balance = 1_000_000 - transfer_amount - gas; + let receiver_balance = transfer_amount; + let updated_sender = executor + .read_account_resource(sender.account()) + .expect("sender must exist"); + let updated_receiver = executor + .read_account_resource(&new_account) + .expect("receiver must exist"); + assert_eq!( + receiver_balance, + AccountResource::read_balance(&updated_receiver) + ); + assert_eq!( + sender_balance, + AccountResource::read_balance(&updated_sender) + ); + assert_eq!(11, AccountResource::read_sequence_number(&updated_sender)); +} + +// Holder for transaction data; arguments to transactions. +struct TxnInfo { + pub sender: Account, + pub receiver: Account, + pub transfer_amount: u64, +} + +impl TxnInfo { + fn new(sender: &Account, receiver: &Account, transfer_amount: u64) -> Self { + TxnInfo { + sender: sender.clone(), + receiver: receiver.clone(), + transfer_amount, + } + } +} + +// Create a cyclic transfer around a slice of Accounts. +// Each Account makes a transfer for the same amount to the next LibraAccount. +fn create_cyclic_transfers( + executor: &FakeExecutor, + accounts: &[Account], + transfer_amount: u64, +) -> (Vec, Vec) { + let mut txns: Vec = Vec::new(); + let mut txns_info: Vec = Vec::new(); + // loop through all transactions and let each transfer the same amount to the next one + let count = accounts.len(); + for i in 0..count { + let sender = &accounts[i]; + let sender_resource = executor + .read_account_resource(&sender) + .expect("sender must exist"); + let seq_num = AccountResource::read_sequence_number(&sender_resource); + let receiver = &accounts[(i + 1) % count]; + + let txn = peer_to_peer_txn(sender, receiver, seq_num, transfer_amount); + txns.push(txn); + txns_info.push(TxnInfo::new(sender, receiver, transfer_amount)); + } + (txns_info, txns) +} + +// Create a one to many transfer around a slice of Accounts. +// The first account is the payer and all others are receivers. +fn create_one_to_many_transfers( + executor: &FakeExecutor, + accounts: &[Account], + transfer_amount: u64, +) -> (Vec, Vec) { + let mut txns: Vec = Vec::new(); + let mut txns_info: Vec = Vec::new(); + // grab account 0 as a sender + let sender = &accounts[0]; + let sender_resource = executor + .read_account_resource(&sender) + .expect("sender must exist"); + let seq_num = AccountResource::read_sequence_number(&sender_resource); + // loop through all transactions and let each transfer the same amount to the next one + let count = accounts.len(); + for (i, receiver) in accounts.iter().enumerate().take(count).skip(1) { + // let receiver = &accounts[i]; + + let txn = peer_to_peer_txn(sender, receiver, seq_num + i as u64 - 1, transfer_amount); + txns.push(txn); + txns_info.push(TxnInfo::new(sender, receiver, transfer_amount)); + } + (txns_info, txns) +} + +// Create a many to one transfer around a slice of Accounts. +// The first account is the receiver and all others are payers. +fn create_many_to_one_transfers( + executor: &FakeExecutor, + accounts: &[Account], + transfer_amount: u64, +) -> (Vec, Vec) { + let mut txns: Vec = Vec::new(); + let mut txns_info: Vec = Vec::new(); + // grab account 0 as a sender + let receiver = &accounts[0]; + // loop through all transactions and let each transfer the same amount to the next one + let count = accounts.len(); + for sender in accounts.iter().take(count).skip(1) { + //let sender = &accounts[i]; + let sender_resource = executor + .read_account_resource(sender) + .expect("sender must exist"); + let seq_num = AccountResource::read_sequence_number(&sender_resource); + + let txn = peer_to_peer_txn(sender, receiver, seq_num, transfer_amount); + txns.push(txn); + txns_info.push(TxnInfo::new(sender, receiver, transfer_amount)); + } + (txns_info, txns) +} + +// Verify a transfer output. +// Checks that sender and receiver in a peer to peer transaction are in proper +// state after a successful transfer. +// The transaction arguments are provided in txn_args. +// Apply the WriteSet to the data store. +fn check_and_apply_transfer_output( + executor: &mut FakeExecutor, + txn_args: &[TxnInfo], + output: &[TransactionOutput], +) { + let count = output.len(); + for i in 0..count { + let txn_info = &txn_args[i]; + let sender = &txn_info.sender; + let receiver = &txn_info.receiver; + let transfer_amount = txn_info.transfer_amount; + let sender_resource = executor + .read_account_resource(&sender) + .expect("sender must exist"); + let sender_initial_balance = AccountResource::read_balance(&sender_resource); + let sender_seq_num = AccountResource::read_sequence_number(&sender_resource); + let receiver_resource = executor + .read_account_resource(&receiver) + .expect("receiver must exist"); + let receiver_initial_balance = AccountResource::read_balance(&receiver_resource); + + // apply single transaction to DB + let txn_output = &output[i]; + executor.apply_write_set(txn_output.write_set()); + + // check that numbers stored in DB are correct + let gas = txn_output.gas_used(); + let sender_balance = sender_initial_balance - transfer_amount - gas; + let receiver_balance = receiver_initial_balance + transfer_amount; + let updated_sender = executor + .read_account_resource(&sender) + .expect("sender must exist"); + let updated_receiver = executor + .read_account_resource(&receiver) + .expect("receiver must exist"); + assert_eq!( + receiver_balance, + AccountResource::read_balance(&updated_receiver) + ); + assert_eq!( + sender_balance, + AccountResource::read_balance(&updated_sender) + ); + assert_eq!( + sender_seq_num + 1, + AccountResource::read_sequence_number(&updated_sender) + ); + } +} + +// simple utility to print all account to visually inspect account data +fn print_accounts(executor: &FakeExecutor, accounts: &[Account]) { + for account in accounts { + let account_resource = executor + .read_account_resource(&account) + .expect("sender must exist"); + println!("{:?}", account_resource); + } +} + +#[test] +fn cycle_peer_to_peer() { + // create a FakeExecutor with a genesis from file + let mut executor = FakeExecutor::from_genesis_file(); + + // create and publish accounts with 1_000_000 coins + let account_size = 100usize; + let initial_balance = 1_000_000u64; + let initial_seq_num = 10u64; + let accounts = executor.create_accounts(account_size, initial_balance, initial_seq_num); + + // set up the transactions + let transfer_amount = 1_000; + let (txns_info, txns) = create_cyclic_transfers(&executor, &accounts, transfer_amount); + + // execute transaction + let mut execution_time = 0u128; + let now = Instant::now(); + let output = executor.execute_block(txns); + execution_time += now.elapsed().as_nanos(); + println!("EXECUTION TIME: {}", execution_time); + for txn_output in &output { + assert_eq!( + txn_output.status(), + &TransactionStatus::Keep(VMStatus::Execution(ExecutionStatus::Executed)) + ); + } + assert_eq!(accounts.len(), output.len()); + + check_and_apply_transfer_output(&mut executor, &txns_info, &output); + print_accounts(&executor, &accounts); +} + +#[test] +fn cycle_peer_to_peer_multi_block() { + // create a FakeExecutor with a genesis from file + let mut executor = FakeExecutor::from_genesis_file(); + + // create and publish accounts with 1_000_000 coins + let account_size = 100usize; + let initial_balance = 1_000_000u64; + let initial_seq_num = 10u64; + let accounts = executor.create_accounts(account_size, initial_balance, initial_seq_num); + + // set up the transactions + let transfer_amount = 1_000; + let block_count = 5u64; + let cycle = account_size / (block_count as usize); + let mut range_left = 0usize; + let mut execution_time = 0u128; + for _i in 0..block_count { + range_left = if range_left + cycle >= account_size { + account_size - cycle + } else { + range_left + }; + let (txns_info, txns) = create_cyclic_transfers( + &executor, + &accounts[range_left..range_left + cycle], + transfer_amount, + ); + + // execute transaction + let now = Instant::now(); + let output = executor.execute_block(txns); + execution_time += now.elapsed().as_nanos(); + for txn_output in &output { + assert_eq!( + txn_output.status(), + &TransactionStatus::Keep(VMStatus::Execution(ExecutionStatus::Executed)) + ); + } + assert_eq!(cycle, output.len()); + check_and_apply_transfer_output(&mut executor, &txns_info, &output); + range_left = (range_left + cycle) % account_size; + } + println!("EXECUTION TIME: {}", execution_time); + print_accounts(&executor, &accounts); +} + +#[test] +fn one_to_many_peer_to_peer() { + // create a FakeExecutor with a genesis from file + let mut executor = FakeExecutor::from_genesis_file(); + + // create and publish accounts with 1_000_000 coins + let account_size = 100usize; + let initial_balance = 1_000_000u64; + let initial_seq_num = 10u64; + let accounts = executor.create_accounts(account_size, initial_balance, initial_seq_num); + + // set up the transactions + let transfer_amount = 1_000; + let block_count = 2u64; + let cycle = account_size / (block_count as usize); + let mut range_left = 0usize; + let mut execution_time = 0u128; + for _i in 0..block_count { + range_left = if range_left + cycle >= account_size { + account_size - cycle + } else { + range_left + }; + let (txns_info, txns) = create_one_to_many_transfers( + &executor, + &accounts[range_left..range_left + cycle], + transfer_amount, + ); + + // execute transaction + let now = Instant::now(); + let output = executor.execute_block(txns); + execution_time += now.elapsed().as_nanos(); + for txn_output in &output { + assert_eq!( + txn_output.status(), + &TransactionStatus::Keep(VMStatus::Execution(ExecutionStatus::Executed)) + ); + } + assert_eq!(cycle - 1, output.len()); + check_and_apply_transfer_output(&mut executor, &txns_info, &output); + range_left = (range_left + cycle) % account_size; + } + println!("EXECUTION TIME: {}", execution_time); + print_accounts(&executor, &accounts); +} + +#[test] +fn many_to_one_peer_to_peer() { + // create a FakeExecutor with a genesis from file + let mut executor = FakeExecutor::from_genesis_file(); + + // create and publish accounts with 1_000_000 coins + let account_size = 100usize; + let initial_balance = 1_000_000u64; + let initial_seq_num = 10u64; + let accounts = executor.create_accounts(account_size, initial_balance, initial_seq_num); + + // set up the transactions + let transfer_amount = 1_000; + let block_count = 2u64; + let cycle = account_size / (block_count as usize); + let mut range_left = 0usize; + let mut execution_time = 0u128; + for _i in 0..block_count { + range_left = if range_left + cycle >= account_size { + account_size - cycle + } else { + range_left + }; + let (txns_info, txns) = create_many_to_one_transfers( + &executor, + &accounts[range_left..range_left + cycle], + transfer_amount, + ); + + // execute transaction + let now = Instant::now(); + let output = executor.execute_block(txns); + execution_time += now.elapsed().as_nanos(); + for txn_output in &output { + assert_eq!( + txn_output.status(), + &TransactionStatus::Keep(VMStatus::Execution(ExecutionStatus::Executed)) + ); + } + assert_eq!(cycle - 1, output.len()); + check_and_apply_transfer_output(&mut executor, &txns_info, &output); + range_left = (range_left + cycle) % account_size; + } + println!("EXECUTION TIME: {}", execution_time); + print_accounts(&executor, &accounts); +} diff --git a/language/vm/vm_runtime/vm_runtime_tests/src/tests/rotate_key.rs b/language/vm/vm_runtime/vm_runtime_tests/src/tests/rotate_key.rs new file mode 100644 index 0000000000000..3b43d2ebfb899 --- /dev/null +++ b/language/vm/vm_runtime/vm_runtime_tests/src/tests/rotate_key.rs @@ -0,0 +1,65 @@ +// Copyright (c) The Libra Core Contributors +// SPDX-License-Identifier: Apache-2.0 + +use crate::{ + account::{Account, AccountData, AccountResource}, + common_transactions::{create_account_txn, rotate_key_txn}, + executor::FakeExecutor, +}; +use types::{ + account_address::AccountAddress, + transaction::TransactionStatus, + vm_error::{ExecutionStatus, VMStatus, VMValidationStatus}, +}; + +#[test] +fn rotate_key() { + let mut executor = FakeExecutor::from_genesis_file(); + + // create and publish sender + let mut sender = AccountData::new(1_000_000, 10); + executor.add_account_data(&sender); + + let (privkey, pubkey) = crypto::signing::generate_keypair(); + let new_key_hash = AccountAddress::from(pubkey); + let txn = rotate_key_txn(sender.account(), new_key_hash, 10); + + // execute transaction + let output = &executor.execute_block(vec![txn])[0]; + assert_eq!( + output.status(), + &TransactionStatus::Keep(VMStatus::Execution(ExecutionStatus::Executed)), + ); + executor.apply_write_set(output.write_set()); + + // Check that numbers in store are correct. + let gas = output.gas_used(); + let balance = 1_000_000 - gas; + let updated_sender = executor + .read_account_resource(sender.account()) + .expect("sender must exist"); + assert_eq!( + new_key_hash, + AccountResource::read_auth_key(&updated_sender) + ); + assert_eq!(balance, AccountResource::read_balance(&updated_sender)); + assert_eq!(11, AccountResource::read_sequence_number(&updated_sender)); + + // Check that transactions cannot be sent with the old key any more. + let new_account = Account::new(); + let old_key_txn = create_account_txn(sender.account(), &new_account, 11, 100_000); + let old_key_output = &executor.execute_block(vec![old_key_txn])[0]; + assert_eq!( + old_key_output.status(), + &TransactionStatus::Discard(VMStatus::Validation(VMValidationStatus::InvalidAuthKey)), + ); + + // Check that transactions can be sent with the new key. + sender.rotate_key(privkey, pubkey); + let new_key_txn = create_account_txn(sender.account(), &new_account, 11, 100_000); + let new_key_output = &executor.execute_block(vec![new_key_txn])[0]; + assert_eq!( + new_key_output.status(), + &TransactionStatus::Keep(VMStatus::Execution(ExecutionStatus::Executed)), + ); +} diff --git a/language/vm/vm_runtime/vm_runtime_tests/src/tests/validator_set.rs b/language/vm/vm_runtime/vm_runtime_tests/src/tests/validator_set.rs new file mode 100644 index 0000000000000..f5ccccab4751a --- /dev/null +++ b/language/vm/vm_runtime/vm_runtime_tests/src/tests/validator_set.rs @@ -0,0 +1,21 @@ +// Copyright (c) The Libra Core Contributors +// SPDX-License-Identifier: Apache-2.0 + +use crate::executor::FakeExecutor; +use canonical_serialization::SimpleDeserializer; +use types::{ + access_path::VALIDATOR_SET_ACCESS_PATH, validator_public_keys::ValidatorPublicKeys, + validator_set::ValidatorSet, +}; + +#[test] +fn load_genesis_validator_set() { + let executor = FakeExecutor::from_genesis_file(); + let validator_set_bytes = executor + .read_from_access_path(&VALIDATOR_SET_ACCESS_PATH) + .unwrap(); + let validator_set: ValidatorSet = + SimpleDeserializer::deserialize(&validator_set_bytes).unwrap(); + let expected_payload: Vec = vec![]; + assert_eq!(validator_set.payload(), expected_payload.as_slice()); +} diff --git a/language/vm/vm_runtime/vm_runtime_tests/src/tests/verify_txn.rs b/language/vm/vm_runtime/vm_runtime_tests/src/tests/verify_txn.rs new file mode 100644 index 0000000000000..6c56887b615fe --- /dev/null +++ b/language/vm/vm_runtime/vm_runtime_tests/src/tests/verify_txn.rs @@ -0,0 +1,552 @@ +// Copyright (c) The Libra Core Contributors +// SPDX-License-Identifier: Apache-2.0 + +use crate::{ + account::AccountData, + assert_prologue_disparity, assert_prologue_parity, + common_transactions::*, + compile::{compile_program_with_address, compile_script}, + executor::FakeExecutor, +}; +use assert_matches::assert_matches; +use config::config::{NodeConfigHelpers, VMPublishingOption}; +use crypto::signing::KeyPair; +use std::collections::HashSet; +use tiny_keccak::Keccak; +use types::{ + test_helpers::transaction_test_helpers, + transaction::{ + TransactionArgument, TransactionStatus, MAX_TRANSACTION_SIZE_IN_BYTES, SCRIPT_HASH_LENGTH, + }, + vm_error::{ + ExecutionStatus, VMStatus, VMValidationStatus, VMVerificationError, VMVerificationStatus, + }, +}; +use vm::gas_schedule; +use vm_genesis::encode_transfer_program; + +#[test] +fn verify_signature() { + let mut executor = FakeExecutor::from_genesis_file(); + let sender = AccountData::new(900_000, 10); + executor.add_account_data(&sender); + // Generate a new key pair to try and sign things with. + let other_keypair = KeyPair::new(::crypto::signing::generate_keypair().0); + let program = encode_transfer_program(sender.address(), 100); + let signed_txn = transaction_test_helpers::get_unverified_test_signed_txn( + *sender.address(), + 0, + other_keypair.private_key().clone(), + sender.account().pubkey, + Some(program), + ); + + assert_prologue_parity!( + executor.verify_transaction(signed_txn.clone()), + executor.execute_transaction(signed_txn).status(), + VMStatus::Validation(VMValidationStatus::InvalidSignature) + ); +} + +#[test] +fn verify_rejected_write_set() { + let mut executor = FakeExecutor::from_genesis_file(); + let sender = AccountData::new(900_000, 10); + executor.add_account_data(&sender); + let signed_txn = transaction_test_helpers::get_write_set_txn( + *sender.address(), + 0, + sender.account().privkey.clone(), + sender.account().pubkey, + None, + ); + + assert_prologue_parity!( + executor.verify_transaction(signed_txn.clone()), + executor.execute_transaction(signed_txn).status(), + VMStatus::Validation(VMValidationStatus::RejectedWriteSet) + ); +} + +#[test] +fn verify_whitelist() { + // Making sure the whitelist's hash matches the current compiled script. If this fails, please + // try run `cargo run` under vm_genesis and update the vm_config in node.config.toml and in + // config.rs in libra/config crate. + let programs: HashSet<_> = vec![ + PEER_TO_PEER.clone(), + MINT.clone(), + ROTATE_KEY.clone(), + CREATE_ACCOUNT.clone(), + ] + .into_iter() + .map(|s| { + let mut hash = [0u8; SCRIPT_HASH_LENGTH]; + let mut keccak = Keccak::new_sha3_256(); + + keccak.update(&s); + keccak.finalize(&mut hash); + hash + }) + .collect(); + + let config = NodeConfigHelpers::get_single_node_test_config(false); + + assert_eq!( + Some(&programs), + config.vm_config.publishing_options.get_whitelist_set() + ) +} + +#[test] +fn verify_simple_payment() { + // create a FakeExecutor with a genesis from file + let mut executor = FakeExecutor::from_genesis_file(); + + // create and publish a sender with 1_000_000 coins and a receiver with 100_000 coins + let sender = AccountData::new(900_000, 10); + let receiver = AccountData::new(100_000, 10); + executor.add_account_data(&sender); + executor.add_account_data(&receiver); + + // define the arguments to the peer to peer transaction + let transfer_amount = 1_000; + let mut args: Vec = Vec::new(); + args.push(TransactionArgument::Address(*receiver.address())); + args.push(TransactionArgument::U64(transfer_amount)); + + // Create a new transaction that has the exact right sequence number. + let txn = sender.account().create_signed_txn_with_args( + PEER_TO_PEER.clone(), + args.clone(), + 10, // this should be programmable but for now is 1 more than the setup + 10_000, + 1, + ); + assert_eq!(executor.verify_transaction(txn), None); + + // Create a new transaction that has the bad auth key. + let txn = sender.account().create_signed_txn_with_args_and_sender( + *receiver.address(), + PEER_TO_PEER.clone(), + args.clone(), + 10, // this should be programmable but for now is 1 more than the setup + 10_000, + 1, + ); + assert_prologue_parity!( + executor.verify_transaction(txn.clone()), + executor.execute_transaction(txn).status(), + VMStatus::Validation(VMValidationStatus::InvalidAuthKey) + ); + + // Create a new transaction that has a old sequence number. + let txn = sender.account().create_signed_txn_with_args( + PEER_TO_PEER.clone(), + args.clone(), + 1, + 10_000, + 1, + ); + assert_prologue_parity!( + executor.verify_transaction(txn.clone()), + executor.execute_transaction(txn).status(), + VMStatus::Validation(VMValidationStatus::SequenceNumberTooOld) + ); + + // Create a new transaction that has a too new sequence number. + let txn = sender.account().create_signed_txn_with_args( + PEER_TO_PEER.clone(), + args.clone(), + 11, + 10_000, + 1, + ); + assert_prologue_disparity!( + executor.verify_transaction(txn.clone()) => None, + executor.execute_transaction(txn).status() => + TransactionStatus::Discard(VMStatus::Validation( + VMValidationStatus::SequenceNumberTooNew + )) + ); + + // Create a new transaction that doesn't have enough balance to pay for gas. + let txn = sender.account().create_signed_txn_with_args( + PEER_TO_PEER.clone(), + args.clone(), + 10, + 1_000_000, + 1, + ); + assert_prologue_parity!( + executor.verify_transaction(txn.clone()), + executor.execute_transaction(txn).status(), + VMStatus::Validation(VMValidationStatus::InsufficientBalanceForTransactionFee) + ); + + // XXX TZ: TransactionExpired + + // RejectedWriteSet is tested in `verify_rejected_write_set` + // InvalidWriteSet is tested in genesis.rs + + // Create a new transaction from a bogus account that doesn't exist + let bogus_account = AccountData::new(100_000, 10); + let txn = bogus_account.account().create_signed_txn_with_args( + PEER_TO_PEER.clone(), + args.clone(), + 10, + 10_000, + 1, + ); + assert_prologue_parity!( + executor.verify_transaction(txn.clone()), + executor.execute_transaction(txn).status(), + VMStatus::Validation(VMValidationStatus::SendingAccountDoesNotExist(_)) + ); + + // RejectedWriteSet is tested in `verify_rejected_write_set` + // InvalidWriteSet is tested in genesis.rs + + // The next couple tests test transaction size, and bounds on gas price and the number of gas + // units that can be submitted with a transaction. + // + // We test these in the reverse order that they appear in verify_transaction, and build up the + // errors one-by-one to make sure that we are both catching all of them, and that we are doing + // so in the specified order. + let txn = sender.account().create_signed_txn_with_args( + PEER_TO_PEER.clone(), + args.clone(), + 10, + 1_000_000, + gas_schedule::MAX_PRICE_PER_GAS_UNIT + 1, + ); + assert_prologue_parity!( + executor.verify_transaction(txn.clone()), + executor.execute_transaction(txn).status(), + VMStatus::Validation(VMValidationStatus::GasUnitPriceAboveMaxBound(_)) + ); + + // Note: We can't test this at the moment since MIN_PRICE_PER_GAS_UNIT is set to 0 for testnet. + // Uncomment this test once we have a non-zero MIN_PRICE_PER_GAS_UNIT. + // let txn = sender.account().create_signed_txn_with_args( + // PEER_TO_PEER.clone(), + // args.clone(), + // 10, + // 1_000_000, + // gas_schedule::MIN_PRICE_PER_GAS_UNIT - 1, + // ); + // assert_eq!( + // executor.verify_transaction(txn), + // Some(VMStatus::Validation( + // VMValidationStatus::GasUnitPriceBelowMinBound + // )) + // ); + + let txn = sender.account().create_signed_txn_with_args( + PEER_TO_PEER.clone(), + args.clone(), + 10, + 1, + gas_schedule::MAX_PRICE_PER_GAS_UNIT, + ); + assert_prologue_parity!( + executor.verify_transaction(txn.clone()), + executor.execute_transaction(txn).status(), + VMStatus::Validation(VMValidationStatus::MaxGasUnitsBelowMinTransactionGasUnits(_)) + ); + + let txn = sender.account().create_signed_txn_with_args( + PEER_TO_PEER.clone(), + args.clone(), + 10, + gas_schedule::MIN_TRANSACTION_GAS_UNITS - 1, + gas_schedule::MAX_PRICE_PER_GAS_UNIT, + ); + assert_prologue_parity!( + executor.verify_transaction(txn.clone()), + executor.execute_transaction(txn).status(), + VMStatus::Validation(VMValidationStatus::MaxGasUnitsBelowMinTransactionGasUnits(_)) + ); + + let txn = sender.account().create_signed_txn_with_args( + PEER_TO_PEER.clone(), + args.clone(), + 10, + gas_schedule::MAXIMUM_NUMBER_OF_GAS_UNITS + 1, + gas_schedule::MAX_PRICE_PER_GAS_UNIT, + ); + assert_prologue_parity!( + executor.verify_transaction(txn.clone()), + executor.execute_transaction(txn).status(), + VMStatus::Validation(VMValidationStatus::MaxGasUnitsExceedsMaxGasUnitsBound(_)) + ); + + let txn = sender.account().create_signed_txn_with_args( + PEER_TO_PEER.clone(), + vec![TransactionArgument::U64(42); MAX_TRANSACTION_SIZE_IN_BYTES], + 10, + gas_schedule::MAXIMUM_NUMBER_OF_GAS_UNITS + 1, + gas_schedule::MAX_PRICE_PER_GAS_UNIT, + ); + assert_prologue_parity!( + executor.verify_transaction(txn.clone()), + executor.execute_transaction(txn).status(), + VMStatus::Validation(VMValidationStatus::ExceededMaxTransactionSize(_)) + ); + + // Create a new transaction that swaps the two arguments. + let mut args: Vec = Vec::new(); + args.push(TransactionArgument::U64(transfer_amount)); + args.push(TransactionArgument::Address(*receiver.address())); + + let txn = + sender + .account() + .create_signed_txn_with_args(PEER_TO_PEER.clone(), args, 10, 10_000, 1); + assert_eq!( + executor.verify_transaction(txn), + Some(VMStatus::Verification(vec![VMVerificationStatus::Script( + VMVerificationError::TypeMismatch("Actual Type Mismatch".to_string()) + )])) + ); + + // Create a new transaction that has no argument. + let txn = + sender + .account() + .create_signed_txn_with_args(PEER_TO_PEER.clone(), vec![], 10, 10_000, 1); + assert_eq!( + executor.verify_transaction(txn), + Some(VMStatus::Verification(vec![VMVerificationStatus::Script( + VMVerificationError::TypeMismatch("Actual Type Mismatch".to_string()) + )])) + ); +} + +#[test] +pub fn test_whitelist() { + // create a FakeExecutor with a genesis from file + let mut executor = FakeExecutor::from_genesis_file(); + + // create an empty transaction + let sender = AccountData::new(1_000_000, 10); + executor.add_account_data(&sender); + + let random_script = compile_script("main() {return;}"); + let txn = sender + .account() + .create_signed_txn_with_args(random_script, vec![], 10, 10_000, 1); + assert_prologue_parity!( + executor.verify_transaction(txn.clone()), + executor.execute_transaction(txn).status(), + VMStatus::Validation(VMValidationStatus::UnknownScript) + ); +} + +#[test] +pub fn test_arbitrary_script_execution() { + // create a FakeExecutor with a genesis from file + let mut executor = FakeExecutor::from_genesis_with_options(VMPublishingOption::CustomScripts); + + // create an empty transaction + let sender = AccountData::new(1_000_000, 10); + executor.add_account_data(&sender); + + let random_script = compile_script("main() {return;}"); + let txn = sender + .account() + .create_signed_txn_with_args(random_script, vec![], 10, 10_000, 1); + assert_eq!(executor.verify_transaction(txn.clone()), None); + assert_eq!( + executor.execute_transaction(txn).status(), + &TransactionStatus::Keep(VMStatus::Execution(ExecutionStatus::Executed)) + ); +} + +#[test] +pub fn test_no_publishing() { + // create a FakeExecutor with a genesis from file + let mut executor = FakeExecutor::from_genesis_with_options(VMPublishingOption::CustomScripts); + + // create a transaction trying to publish a new module. + let sender = AccountData::new(1_000_000, 10); + let receiver = AccountData::new(100_000, 10); + executor.add_account_data(&sender); + executor.add_account_data(&receiver); + + let program = String::from( + " + modules: + module M { + public max(a: u64, b: u64): u64 { + if (copy(a) > copy(b)) { + return copy(a); + } else { + return copy(b); + } + return 0; + } + + public sum(a: u64, b: u64): u64 { + let c: u64; + c = copy(a) + copy(b); + return copy(c); + } + } + script: + import 0x0.LibraAccount; + main (payee: address, amount: u64) { + LibraAccount.pay_from_sender(move(payee), move(amount)); + return; + } + ", + ); + + let mut args: Vec = Vec::new(); + args.push(TransactionArgument::Address(*receiver.address())); + args.push(TransactionArgument::U64(100)); + + let random_script = compile_program_with_address(sender.address(), &program, args); + let txn = + sender + .account() + .create_signed_txn_impl(*sender.address(), random_script, 10, 10_000, 1); + assert_prologue_parity!( + executor.verify_transaction(txn.clone()), + executor.execute_transaction(txn).status(), + VMStatus::Validation(VMValidationStatus::UnknownModule) + ); +} + +#[test] +pub fn test_open_publishing_invalid_address() { + // create a FakeExecutor with a genesis from file + let mut executor = FakeExecutor::from_genesis_with_options(VMPublishingOption::Open); + + // create a transaction trying to publish a new module. + let sender = AccountData::new(1_000_000, 10); + let receiver = AccountData::new(100_000, 10); + executor.add_account_data(&sender); + executor.add_account_data(&receiver); + + let program = String::from( + " + modules: + module M { + public max(a: u64, b: u64): u64 { + if (copy(a) > copy(b)) { + return copy(a); + } else { + return copy(b); + } + return 0; + } + + public sum(a: u64, b: u64): u64 { + let c: u64; + c = copy(a) + copy(b); + return copy(c); + } + } + script: + import 0x0.LibraAccount; + main (payee: address, amount: u64) { + LibraAccount.pay_from_sender(move(payee), move(amount)); + return; + } + ", + ); + + let mut args: Vec = Vec::new(); + args.push(TransactionArgument::Address(*receiver.address())); + args.push(TransactionArgument::U64(100)); + + let random_script = compile_program_with_address(receiver.address(), &program, args); + let txn = + sender + .account() + .create_signed_txn_impl(*sender.address(), random_script, 10, 10_000, 1); + + // verify and fail because the addresses don't match + let vm_status = executor.verify_transaction(txn.clone()); + let status = match vm_status { + Some(VMStatus::Verification(status)) => status, + vm_status => panic!("Unexpected verification status: {:?}", vm_status), + }; + match status.as_slice() { + &[VMVerificationStatus::Module( + 0, + VMVerificationError::ModuleAddressDoesNotMatchSender(_), + )] => {} + err => panic!("Unexpected verification error: {:?}", err), + }; + + // execute and fail for the same reason + let output = executor.execute_transaction(txn); + let status = match output.status() { + TransactionStatus::Discard(VMStatus::Verification(status)) => status, + vm_status => panic!("Unexpected verification status: {:?}", vm_status), + }; + match status.as_slice() { + &[VMVerificationStatus::Module( + 0, + VMVerificationError::ModuleAddressDoesNotMatchSender(_), + )] => {} + err => panic!("Unexpected verification error: {:?}", err), + }; +} + +#[test] +pub fn test_open_publishing() { + // create a FakeExecutor with a genesis from file + let mut executor = FakeExecutor::from_genesis_with_options(VMPublishingOption::Open); + + // create a transaction trying to publish a new module. + let sender = AccountData::new(1_000_000, 10); + let receiver = AccountData::new(100_000, 10); + executor.add_account_data(&sender); + executor.add_account_data(&receiver); + + let program = String::from( + " + modules: + module M { + public max(a: u64, b: u64): u64 { + if (copy(a) > copy(b)) { + return copy(a); + } else { + return copy(b); + } + return 0; + } + + public sum(a: u64, b: u64): u64 { + let c: u64; + c = copy(a) + copy(b); + return copy(c); + } + } + script: + import 0x0.LibraAccount; + main (payee: address, amount: u64) { + LibraAccount.pay_from_sender(move(payee), move(amount)); + return; + } + ", + ); + + let mut args: Vec = Vec::new(); + args.push(TransactionArgument::Address(*receiver.address())); + args.push(TransactionArgument::U64(100)); + + let random_script = compile_program_with_address(sender.address(), &program, args); + let txn = + sender + .account() + .create_signed_txn_impl(*sender.address(), random_script, 10, 10_000, 1); + assert_eq!(executor.verify_transaction(txn.clone()), None); + assert_eq!( + executor.execute_transaction(txn).status(), + &TransactionStatus::Keep(VMStatus::Execution(ExecutionStatus::Executed)) + ); +} diff --git a/libra.png b/libra.png new file mode 100644 index 0000000000000..a159ebf19494f Binary files /dev/null and b/libra.png differ diff --git a/libra_node/Cargo.toml b/libra_node/Cargo.toml new file mode 100644 index 0000000000000..572fec88d530d --- /dev/null +++ b/libra_node/Cargo.toml @@ -0,0 +1,39 @@ +[package] +name = "libra_node" +version = "0.1.0" +authors = ["Libra Association "] +license = "Apache-2.0" +publish = false +edition = "2018" + +[dependencies] +grpcio = "0.4.3" +grpcio-sys = "0.4.4" +signal-hook = "0.1.9" +tokio = "0.1.16" + +admission_control_proto = { path = "../admission_control/admission_control_proto" } +admission_control_service = { path = "../admission_control/admission_control_service" } +config = { path = "../config" } +consensus = { path = "../consensus" } +crash_handler = { path = "../common/crash_handler" } +crypto = { path = "../crypto/legacy_crypto" } +debug_interface = { path = "../common/debug_interface" } +logger = { path = "../common/logger" } +executable_helpers = { path = "../common/executable_helpers"} +execution_proto = { path = "../execution/execution_proto" } +grpc_helpers = { path = "../common/grpc_helpers" } +mempool = { path = "../mempool" } +metrics = { path = "../common/metrics" } +execution_service = { path = "../execution/execution_service" } +failure = { path = "../common/failure_ext", package = "failure_ext" } +network = { path = "../network" } +proto_conv = { path = "../common/proto_conv" } +storage_client = { path = "../storage/storage_client" } +storage_service = { path = "../storage/storage_service" } +types = { path = "../types" } +vm_genesis = { path = "../language/vm/vm_genesis" } +vm_validator = { path = "../vm_validator" } + +[dev-dependencies] +config_builder = { path = "../config/config_builder" } diff --git a/libra_node/src/lib.rs b/libra_node/src/lib.rs new file mode 100644 index 0000000000000..366267689a600 --- /dev/null +++ b/libra_node/src/lib.rs @@ -0,0 +1,4 @@ +// Copyright (c) The Libra Core Contributors +// SPDX-License-Identifier: Apache-2.0 + +pub mod main_node; diff --git a/libra_node/src/main.rs b/libra_node/src/main.rs new file mode 100644 index 0000000000000..f38f977231833 --- /dev/null +++ b/libra_node/src/main.rs @@ -0,0 +1,44 @@ +// Copyright (c) The Libra Core Contributors +// SPDX-License-Identifier: Apache-2.0 + +use executable_helpers::helpers::{ + setup_executable, ARG_CONFIG_PATH, ARG_DISABLE_LOGGING, ARG_PEER_ID, +}; +use signal_hook; +use std::sync::{ + atomic::{AtomicBool, Ordering}, + Arc, +}; + +fn register_signals(term: Arc) { + for signal in &[ + signal_hook::SIGTERM, + signal_hook::SIGINT, + signal_hook::SIGHUP, + ] { + let term_clone = Arc::clone(&term); + let thread = std::thread::current(); + unsafe { + signal_hook::register(*signal, move || { + term_clone.store(true, Ordering::Relaxed); + thread.unpark(); + }) + .expect("failed to register signal handler"); + } + } +} + +fn main() { + let (config, _logger, _args) = setup_executable( + "Libra single node".to_string(), + vec![ARG_PEER_ID, ARG_CONFIG_PATH, ARG_DISABLE_LOGGING], + ); + let (_ac_handle, _node_handle) = libra_node::main_node::setup_environment(&config); + + let term = Arc::new(AtomicBool::new(false)); + register_signals(Arc::clone(&term)); + + while !term.load(Ordering::Relaxed) { + std::thread::park(); + } +} diff --git a/libra_node/src/main_node.rs b/libra_node/src/main_node.rs new file mode 100644 index 0000000000000..0bbb26f49bf6d --- /dev/null +++ b/libra_node/src/main_node.rs @@ -0,0 +1,276 @@ +// Copyright (c) The Libra Core Contributors +// SPDX-License-Identifier: Apache-2.0 + +use admission_control_proto::proto::admission_control_grpc::{ + create_admission_control, AdmissionControlClient, +}; +use admission_control_service::admission_control_service::AdmissionControlService; +use config::config::NodeConfig; +use consensus::consensus_provider::{make_consensus_provider, ConsensusProvider}; +use debug_interface::{node_debug_service::NodeDebugService, proto::node_debug_interface_grpc}; +use execution_proto::proto::execution_grpc; +use execution_service::ExecutionService; +use grpc_helpers::ServerHandle; +use grpcio::{ChannelBuilder, EnvBuilder, ServerBuilder}; +use grpcio_sys; +use logger::prelude::*; +use mempool::{proto::mempool_grpc::MempoolClient, MempoolRuntime}; +use metrics::metric_server; +use network::{ + validator_network::{ + network_builder::{NetworkBuilder, TransportType}, + ConsensusNetworkEvents, ConsensusNetworkSender, MempoolNetworkEvents, MempoolNetworkSender, + CONSENSUS_DIRECT_SEND_PROTOCOL, CONSENSUS_RPC_PROTOCOL, MEMPOOL_DIRECT_SEND_PROTOCOL, + }, + NetworkPublicKeys, ProtocolId, +}; +use std::{ + cmp::max, + convert::{TryFrom, TryInto}, + sync::Arc, + thread, + time::Instant, +}; +use storage_client::{StorageRead, StorageReadServiceClient, StorageWriteServiceClient}; +use storage_service::start_storage_service; +use tokio::runtime::{Builder, Runtime}; +use types::account_address::AccountAddress as PeerId; +use vm_validator::vm_validator::VMValidator; + +pub struct LibraHandle { + _ac: ServerHandle, + _mempool: MempoolRuntime, + _network_runtime: Runtime, + consensus: Box, + _execution: ServerHandle, + _storage: ServerHandle, + _debug: ServerHandle, +} + +impl Drop for LibraHandle { + fn drop(&mut self) { + self.consensus.stop(); + } +} + +fn setup_ac(config: &NodeConfig) -> (::grpcio::Server, AdmissionControlClient) { + let env = Arc::new( + EnvBuilder::new() + .name_prefix("grpc-ac-") + .cq_count(unsafe { max(grpcio_sys::gpr_cpu_num_cores() as usize / 2, 2) }) + .build(), + ); + let port = config.admission_control.admission_control_service_port; + + // Create mempool client + let connection_str = format!("localhost:{}", config.mempool.mempool_service_port); + let env2 = Arc::new(EnvBuilder::new().name_prefix("grpc-ac-mem-").build()); + let mempool_client = Arc::new(MempoolClient::new( + ChannelBuilder::new(env2).connect(&connection_str), + )); + + // Create storage read client + let storage_client: Arc = Arc::new(StorageReadServiceClient::new( + Arc::new(EnvBuilder::new().name_prefix("grpc-ac-sto-").build()), + "localhost", + config.storage.port, + )); + + let vm_validator = Arc::new(VMValidator::new(&config, Arc::clone(&storage_client))); + + let handle = AdmissionControlService::new( + mempool_client, + storage_client, + vm_validator, + config + .admission_control + .need_to_check_mempool_before_validation, + ); + let service = create_admission_control(handle); + let server = ServerBuilder::new(Arc::clone(&env)) + .register_service(service) + .bind(config.admission_control.address.clone(), port) + .build() + .expect("Unable to create grpc server"); + + let connection_str = format!("localhost:{}", port); + let client = AdmissionControlClient::new(ChannelBuilder::new(env).connect(&connection_str)); + (server, client) +} + +fn setup_executor(config: &NodeConfig) -> ::grpcio::Server { + let client_env = Arc::new(EnvBuilder::new().name_prefix("grpc-exe-sto-").build()); + let storage_read_client = Arc::new(StorageReadServiceClient::new( + Arc::clone(&client_env), + &config.storage.address, + config.storage.port, + )); + let storage_write_client = Arc::new(StorageWriteServiceClient::new( + Arc::clone(&client_env), + &config.storage.address, + config.storage.port, + )); + + let handle = ExecutionService::new(storage_read_client, storage_write_client, config); + let service = execution_grpc::create_execution(handle); + ::grpcio::ServerBuilder::new(Arc::new(EnvBuilder::new().name_prefix("grpc-exe-").build())) + .register_service(service) + .bind(config.execution.address.clone(), config.execution.port) + .build() + .expect("Unable to create grpc server") +} + +fn setup_debug_interface(config: &NodeConfig) -> ::grpcio::Server { + let env = Arc::new(EnvBuilder::new().name_prefix("grpc-debug-").build()); + // Start Debug interface + let debug_service = + node_debug_interface_grpc::create_node_debug_interface(NodeDebugService::new()); + ::grpcio::ServerBuilder::new(env) + .register_service(debug_service) + .bind( + config.debug_interface.address.clone(), + config.debug_interface.admission_control_node_debug_port, + ) + .build() + .expect("Unable to create grpc server") +} + +pub fn setup_network( + config: &NodeConfig, +) -> ( + (MempoolNetworkSender, MempoolNetworkEvents), + (ConsensusNetworkSender, ConsensusNetworkEvents), + Runtime, +) { + let runtime = Builder::new() + .name_prefix("network-") + .build() + .expect("Failed to start runtime. Won't be able to start networking."); + let peer_id = PeerId::try_from(config.base.peer_id.clone()).expect("Invalid PeerId"); + let listen_addr = config.network.listen_address.clone(); + let advertised_addr = config.network.advertised_address.clone(); + let trusted_peers = config + .base + .trusted_peers + .get_trusted_network_peers() + .clone() + .into_iter() + .map(|(peer_id, (signing_public_key, identity_public_key))| { + ( + peer_id, + NetworkPublicKeys { + signing_public_key, + identity_public_key, + }, + ) + }) + .collect(); + let seed_peers = config + .network + .seed_peers + .seed_peers + .clone() + .into_iter() + .map(|(peer_id, addrs)| (peer_id.try_into().expect("Invalid PeerId"), addrs)) + .collect(); + let network_signing_keypair = config.base.peer_keypairs.get_network_signing_keypair(); + let network_identity_keypair = config.base.peer_keypairs.get_network_identity_keypair(); + let ( + (mempool_network_sender, mempool_network_events), + (consensus_network_sender, consensus_network_events), + _listen_addr, + ) = NetworkBuilder::new(runtime.executor(), peer_id, listen_addr) + .transport(if config.network.enable_encryption_and_authentication { + TransportType::TcpNoise + } else { + TransportType::Tcp + }) + .advertised_address(advertised_addr) + .seed_peers(seed_peers) + .signing_keys(network_signing_keypair) + .identity_keys(network_identity_keypair) + .trusted_peers(trusted_peers) + .discovery_interval_ms(config.network.discovery_interval_ms) + .connectivity_check_interval_ms(config.network.connectivity_check_interval_ms) + .consensus_protocols(vec![ + ProtocolId::from_static(CONSENSUS_RPC_PROTOCOL), + ProtocolId::from_static(CONSENSUS_DIRECT_SEND_PROTOCOL), + ]) + .mempool_protocols(vec![ProtocolId::from_static(MEMPOOL_DIRECT_SEND_PROTOCOL)]) + .direct_send_protocols(vec![ + ProtocolId::from_static(CONSENSUS_DIRECT_SEND_PROTOCOL), + ProtocolId::from_static(MEMPOOL_DIRECT_SEND_PROTOCOL), + ]) + .rpc_protocols(vec![ProtocolId::from_static(CONSENSUS_RPC_PROTOCOL)]) + .build(); + + ( + (mempool_network_sender, mempool_network_events), + (consensus_network_sender, consensus_network_events), + runtime, + ) +} + +pub fn setup_environment(node_config: &NodeConfig) -> (AdmissionControlClient, LibraHandle) { + crash_handler::setup_panic_handler(); + + let mut instant = Instant::now(); + let storage = start_storage_service(&node_config); + debug!( + "Storage service started in {} ms", + instant.elapsed().as_millis() + ); + + instant = Instant::now(); + let execution = ServerHandle::setup(setup_executor(&node_config)); + debug!( + "Execution service started in {} ms", + instant.elapsed().as_millis() + ); + + instant = Instant::now(); + let ( + (mempool_network_sender, mempool_network_events), + (consensus_network_sender, consensus_network_events), + network_runtime, + ) = setup_network(&node_config); + debug!("Network started in {} ms", instant.elapsed().as_millis()); + + instant = Instant::now(); + let (ac_server, ac_client) = setup_ac(&node_config); + let ac = ServerHandle::setup(ac_server); + debug!("AC started in {} ms", instant.elapsed().as_millis()); + + instant = Instant::now(); + let mempool = + MempoolRuntime::boostrap(&node_config, mempool_network_sender, mempool_network_events); + debug!("Mempool started in {} ms", instant.elapsed().as_millis()); + + let debug_if = ServerHandle::setup(setup_debug_interface(&node_config)); + + let metrics_port = node_config.debug_interface.metrics_server_port; + let metric_host = node_config.debug_interface.address.clone(); + thread::spawn(move || metric_server::start_server(metric_host, metrics_port)); + + instant = Instant::now(); + let mut consensus_provider = make_consensus_provider( + &node_config, + consensus_network_sender, + consensus_network_events, + ); + consensus_provider + .start() + .expect("Failed to start consensus. Can't proceed."); + debug!("Consensus started in {} ms", instant.elapsed().as_millis()); + + let libra_handle = LibraHandle { + _ac: ac, + _mempool: mempool, + _network_runtime: network_runtime, + consensus: consensus_provider, + _execution: execution, + _storage: storage, + _debug: debug_if, + }; + (ac_client, libra_handle) +} diff --git a/libra_swarm/Cargo.toml b/libra_swarm/Cargo.toml new file mode 100644 index 0000000000000..2546c05c6db5d --- /dev/null +++ b/libra_swarm/Cargo.toml @@ -0,0 +1,24 @@ +[package] +name = "libra_swarm" +version = "0.1.0" +authors = ["Libra Association "] +license = "Apache-2.0" +publish = false +edition = "2018" + +[dependencies] +bincode = "1.1.1" +grpcio = "0.4.3" +lazy_static = "1.2.0" +structopt = "0.2.15" +tempfile = "3.0.6" + +client_lib = { package = "client", path = "../client" } +config = { path = "../config" } +config_builder = { path = "../config/config_builder" } +crypto = { path = "../crypto/legacy_crypto" } +debug_interface = { path = "../common/debug_interface" } +failure = { path = "../common/failure_ext", package = "failure_ext" } +generate_keypair = { path = "../config/generate_keypair" } +logger = { path = "../common/logger" } +tools = { path = "../common/tools" } diff --git a/libra_swarm/src/client.rs b/libra_swarm/src/client.rs new file mode 100644 index 0000000000000..4c006d84c010a --- /dev/null +++ b/libra_swarm/src/client.rs @@ -0,0 +1,192 @@ +// Copyright (c) The Libra Core Contributors +// SPDX-License-Identifier: Apache-2.0 + +use crate::utils; +use client_lib::{client_proxy::ClientProxy, commands}; +use std::{ + collections::HashMap, + io::{self, Write}, + path::Path, + process::{Child, Command, Output, Stdio}, + sync::Arc, +}; + +pub struct InteractiveClient { + client: Option, +} + +impl Drop for InteractiveClient { + fn drop(&mut self) { + if self.client.is_none() { + return; + } + // Kill client process if still running. + let mut client = self.client.take().unwrap(); + match client.try_wait().unwrap() { + Some(status) => { + if !status.success() { + panic!( + "Client terminated with status: {}", + status.code().unwrap_or(-1) + ); + } + } + None => { + client.kill().unwrap(); + } + } + } +} + +impl InteractiveClient { + pub fn new_with_inherit_io( + port: u16, + faucet_key_file_path: &Path, + mnemonic_file_path: &Path, + validator_set_file: String, + ) -> Self { + // We need to call canonicalize on the path because we are running client from + // workspace root and the function calling new_with_inherit_io isn't necessarily + // running from that location, so if a relative path is passed, it wouldn't work + // unless we convert it to an absolute path + Self { + client: Some( + Command::new(utils::get_bin("client")) + .current_dir(utils::workspace_root()) + .arg("-p") + .arg(port.to_string()) + .arg("-m") + .arg( + faucet_key_file_path + .canonicalize() + .expect("Unable to get canonical path of faucet key file") + .to_str() + .unwrap(), + ) + .arg("-n") + .arg( + mnemonic_file_path + .canonicalize() + .expect("Unable to get canonical path of mnemonic file") + .to_str() + .unwrap(), + ) + .arg("-a") + .arg("localhost") + .arg("-s") + .arg(validator_set_file) + .stdin(Stdio::inherit()) + .stdout(Stdio::inherit()) + .stderr(Stdio::inherit()) + .spawn() + .expect("Failed to spawn client process"), + ), + } + } + + pub fn new_with_piped_io( + port: u16, + faucet_key_file_path: &Path, + mnemonic_file_path: &Path, + validator_set_file: String, + ) -> Self { + Self { + /// Note: For easier debugging it's convenient to see the output + /// from the client CLI. Comment the stdout/stderr lines below + /// and enjoy pretty Matrix-style output. + client: Some( + Command::new(utils::get_bin("client")) + .current_dir(utils::workspace_root()) + .arg("-p") + .arg(port.to_string()) + .arg("-m") + .arg( + faucet_key_file_path + .canonicalize() + .expect("Unable to get canonical path of faucet key file") + .to_str() + .unwrap(), + ) + .arg("-n") + .arg( + mnemonic_file_path + .canonicalize() + .expect("Unable to get canonical path of mnemonic file") + .to_str() + .unwrap(), + ) + .arg("-a") + .arg("localhost") + .arg("-s") + .arg(validator_set_file) + .stdin(Stdio::piped()) + .stdout(Stdio::piped()) + .stderr(Stdio::piped()) + .spawn() + .expect("Failed to spawn client process"), + ), + } + } + + pub fn output(mut self) -> io::Result { + self.client.take().unwrap().wait_with_output() + } + + pub fn send_instructions(&mut self, instructions: &[&str]) -> io::Result<()> { + let input = self.client.as_mut().unwrap().stdin.as_mut().unwrap(); + for i in instructions { + input.write_all((i.to_string() + "\n").as_bytes())?; + input.flush()?; + } + Ok(()) + } +} + +pub struct InProcessTestClient { + client: ClientProxy, + alias_to_cmd: HashMap<&'static str, Arc>, +} + +impl InProcessTestClient { + pub fn new( + port: u16, + faucet_key_file_path: &Path, + mnemonic_file_path: &str, + validator_set_file: String, + ) -> Self { + let (_, alias_to_cmd) = commands::get_commands(); + Self { + client: ClientProxy::new( + "localhost", + port.to_string().as_str(), + &validator_set_file, + faucet_key_file_path + .canonicalize() + .expect("Unable to get canonical path of faucet key file") + .to_str() + .unwrap(), + /* faucet server */ None, + Some(mnemonic_file_path.to_string()), + ) + .unwrap(), + alias_to_cmd, + } + } + + pub fn execute_instructions(&mut self, instructions: &[&str]) { + for instr in instructions { + let to_parse = &instr.to_string(); + let params = commands::parse_cmd(to_parse); + // filter out empty lines + if params.is_empty() || params[0].is_empty() { + continue; + } + let cmd = self.alias_to_cmd.get(params[0]).expect("Cmd not found"); + cmd.execute(&mut self.client, ¶ms); + } + } + + pub fn client(&mut self) -> &mut ClientProxy { + &mut self.client + } +} diff --git a/libra_swarm/src/lib.rs b/libra_swarm/src/lib.rs new file mode 100644 index 0000000000000..3ea925e2464d8 --- /dev/null +++ b/libra_swarm/src/lib.rs @@ -0,0 +1,7 @@ +// Copyright (c) The Libra Core Contributors +// SPDX-License-Identifier: Apache-2.0 + +pub mod client; +mod output_tee; +pub mod swarm; +pub mod utils; diff --git a/libra_swarm/src/main.rs b/libra_swarm/src/main.rs new file mode 100644 index 0000000000000..5f39f6e9102f8 --- /dev/null +++ b/libra_swarm/src/main.rs @@ -0,0 +1,63 @@ +// Copyright (c) The Libra Core Contributors +// SPDX-License-Identifier: Apache-2.0 + +use libra_swarm::{client, swarm::LibraSwarm}; +use std::path::Path; +use structopt::StructOpt; + +#[derive(Debug, StructOpt)] +#[structopt( + name = "libra_swarm", + author = "Libra", + about = "Libra swarm to start local nodes" +)] +struct Args { + /// Number of nodes to start (1 by default) + #[structopt(short = "n", long = "num_nodes")] + pub num_nodes: Option, + /// Disable logging (for performance testing)" + #[structopt(short = "d", long = "disable_logging")] + pub disable_logging: bool, + /// Start client + #[structopt(short = "s", long = "start_client")] + pub start_client: bool, + /// Path to the generated keypair for faucet account. Generated by generate_keypair. + /// If not passed, a new keypair will be generated for you and placed in a temp directory + #[structopt(short = "m", long = "faucet_key_file_path")] + pub faucet_key_file_path: Option, +} + +fn main() { + let args = Args::from_args(); + let num_nodes = args.num_nodes.unwrap_or(1); + + let (faucet_account_keypair, faucet_key_file_path, _temp_dir) = + generate_keypair::load_faucet_key_or_create_default(args.faucet_key_file_path); + + println!("Faucet account created in file {:?}", faucet_key_file_path); + + let swarm = LibraSwarm::launch_swarm( + num_nodes, + args.disable_logging, + faucet_account_keypair, + false, /* tee_logs */ + ); + + let tmp_mnemonic_file = tempfile::NamedTempFile::new().unwrap();; + if args.start_client { + let client = client::InteractiveClient::new_with_inherit_io( + *swarm.get_validators_public_ports().get(0).unwrap(), + Path::new(&faucet_key_file_path), + &tmp_mnemonic_file.into_temp_path(), + swarm.get_trusted_peers_config_path(), + ); + println!("Loading client..."); + let _output = client.output().expect("Failed to wait on child"); + println!("Exit client."); + } else { + println!("CTRL-C to exit."); + loop { + std::thread::park(); + } + } +} diff --git a/libra_swarm/src/output_tee.rs b/libra_swarm/src/output_tee.rs new file mode 100644 index 0000000000000..ec1b18e1f4b0c --- /dev/null +++ b/libra_swarm/src/output_tee.rs @@ -0,0 +1,112 @@ +// Copyright (c) The Libra Core Contributors +// SPDX-License-Identifier: Apache-2.0 + +use std::{ + io::{BufRead, BufReader, Read, Write}, + sync::{Arc, Mutex}, + thread::JoinHandle, +}; +use tools::output_capture::OutputCapture; + +/// This tee takes stdout/stderr stream of spawned process and tees it into two destinations: +/// log file and stdout of main thread with given prefix +pub struct OutputTee { + capture: OutputCapture, + log_file: Arc>, + child_stdout: Box, + child_stderr: Box, + prefix: String, +} + +impl OutputTee { + pub fn new( + capture: OutputCapture, + log_file: Box, + child_stdout: Box, + child_stderr: Box, + prefix: String, + ) -> OutputTee { + OutputTee { + capture, + log_file: Arc::new(Mutex::new(log_file)), + child_stdout, + child_stderr, + prefix, + } + } + + /// Start new threads for teeing output of stdout/err streams + /// Threads will terminate when streams are closed + pub fn start(self) -> OutputTeeGuard { + let capture = self.capture; + let log_file = self.log_file; + let prefix = self.prefix; + + let stdout_handle = TeeThread { + capture: capture.clone(), + log_file: log_file.clone(), + prefix: prefix.clone(), + stream: self.child_stdout, + } + .start(); + + let stderr_handle = TeeThread { + capture: capture.clone(), + log_file: log_file.clone(), + prefix: prefix.clone(), + stream: self.child_stderr, + } + .start(); + + OutputTeeGuard { + stdout_handle, + stderr_handle, + } + } +} + +pub struct OutputTeeGuard { + stdout_handle: JoinHandle<()>, + stderr_handle: JoinHandle<()>, +} + +impl OutputTeeGuard { + // JoinHandle::join returns Result, which signals whether thread panicked or not + // If it panicked, it will print panic anyway, there is no reason to process it here + #[allow(unused_must_use)] + pub fn join(self) { + self.stdout_handle.join(); + self.stderr_handle.join(); + } +} + +struct TeeThread { + capture: OutputCapture, + prefix: String, + stream: Box, + log_file: Arc>, +} + +impl TeeThread { + pub fn start(self) -> JoinHandle<()> { + std::thread::spawn(move || self.run()) + } + + pub fn run(self) { + self.capture.apply(); + let buf_reader = BufReader::new(self.stream); + for line in buf_reader.lines() { + let line = match line { + Err(e) => { + println!("Failed to read line for tee: {}", e); + return; + } + Ok(line) => line, + }; + println!("{}{}", self.prefix, line); + if let Err(e) = writeln!(self.log_file.lock().unwrap(), "{}", line) { + println!("Error teeing to file: {}", e); + } + } + } +} diff --git a/libra_swarm/src/swarm.rs b/libra_swarm/src/swarm.rs new file mode 100644 index 0000000000000..dbdea4892c861 --- /dev/null +++ b/libra_swarm/src/swarm.rs @@ -0,0 +1,453 @@ +// Copyright (c) The Libra Core Contributors +// SPDX-License-Identifier: Apache-2.0 + +use crate::{ + output_tee::{OutputTee, OutputTeeGuard}, + utils, +}; +use config::config::NodeConfig; +use config_builder::swarm_config::{SwarmConfig, SwarmConfigBuilder}; +use crypto::signing::KeyPair; +use debug_interface::NodeDebugClient; +use failure::prelude::*; +use logger::prelude::*; +use std::{ + collections::HashMap, + env, + fs::File, + io::Read, + path::{Path, PathBuf}, + process::{Child, Command, Stdio}, +}; +use tempfile::TempDir; +use tools::output_capture::OutputCapture; + +const LIBRA_NODE_BIN: &str = "libra_node"; + +pub struct LibraNode { + node: Child, + debug_client: NodeDebugClient, + peer_id: String, + log: PathBuf, + output_tee_guard: Option, +} + +impl Drop for LibraNode { + // When the LibraNode struct goes out of scope we need to kill the child process + fn drop(&mut self) { + // check if the process has already been terminated + match self.node.try_wait() { + // The child process has already terminated, perhaps due to a crash + Ok(Some(_)) => {} + + // The node is still running so we need to attempt to kill it + _ => { + if let Err(e) = self.node.kill() { + panic!("LibraNode process could not be killed: '{}'", e); + } + } + } + if let Some(output_tee_guard) = self.output_tee_guard.take() { + output_tee_guard.join(); + } + } +} + +impl LibraNode { + pub fn launch( + config: &NodeConfig, + config_path: &Path, + logdir: &Path, + disable_logging: bool, + tee_logs: bool, + ) -> Result { + let peer_id = config.base.peer_id.clone(); + let log = logdir.join(format!("{}.log", peer_id)); + let log_file = File::create(&log)?; + let mut node_command = Command::new(utils::get_bin(LIBRA_NODE_BIN)); + node_command + .current_dir(utils::workspace_root()) + .arg("-f") + .arg(config_path) + .args(&["-p", &peer_id]); + if env::var("RUST_LOG").is_err() { + // Only set our RUST_LOG if its not present in environment + node_command.env("RUST_LOG", "debug"); + } + if disable_logging { + node_command.arg("-d"); + } + + if tee_logs { + node_command.stdout(Stdio::piped()).stderr(Stdio::piped()); + } else { + node_command + .stdout(log_file.try_clone()?) + .stderr(log_file.try_clone()?); + }; + + let mut node = node_command + .spawn() + .context("Error launching node process")?; + + let output_tee_guard = if tee_logs { + let prefix = format!("[{}] ", &peer_id[..8]); + let capture = OutputCapture::grab(); + let tee = OutputTee::new( + capture, + Box::new(log_file), + Box::new(node.stdout.take().expect("Can't get child stdout")), + Box::new(node.stderr.take().expect("Can't get child stderr")), + prefix, + ); + Some(tee.start()) + } else { + None + }; + + let debug_client = NodeDebugClient::new( + "localhost", + config.debug_interface.admission_control_node_debug_port, + ); + Ok(Self { + node, + debug_client, + peer_id, + log, + output_tee_guard, + }) + } + + pub fn get_log_contents(&self) -> Result { + let mut log = File::open(&self.log)?; + let mut contents = String::new(); + log.read_to_string(&mut contents)?; + + Ok(contents) + } + + fn get_metric(&self, metric_name: &str) -> Option { + match self.debug_client.get_node_metric(metric_name) { + Err(e) => { + debug!( + "error getting {} for node: {}; error: {}", + metric_name, self.peer_id, e + ); + None + } + Ok(maybeval) => { + if maybeval.is_none() { + debug!("Node: {} did not report {}", self.peer_id, metric_name); + } + maybeval + } + } + } + + pub fn check_connectivity(&self, expected_peers: i64) -> bool { + if let Some(num_connected_peers) = self.get_metric("network_gauge{op=connected_peers}") { + if num_connected_peers != expected_peers { + debug!( + "Node '{}' Expected peers: {}, found peers: {}", + self.peer_id, expected_peers, num_connected_peers + ); + return false; + } else { + return true; + } + } + false + } + + pub fn health_check(&mut self) -> HealthStatus { + debug!("Health check on node '{}'", self.peer_id); + + // check if the process has terminated + match self.node.try_wait() { + // This would mean the child process has crashed + Ok(Some(status)) => { + debug!("Node '{}' crashed with: {}", self.peer_id, status); + return HealthStatus::Crashed(status); + } + + // This is the case where the node is still running + Ok(None) => {} + + // Some other unknown error + Err(e) => { + panic!("error attempting to query Node: {}", e); + } + } + + match self.debug_client.get_node_metrics() { + Ok(_) => { + debug!("Node '{}' is healthy", self.peer_id); + HealthStatus::Healthy + } + Err(e) => { + debug!("Rpc check error: {:?}", e); + HealthStatus::RpcFailure(e) + } + } + } +} + +pub enum HealthStatus { + Healthy, + Crashed(::std::process::ExitStatus), + RpcFailure(failure::Error), +} + +/// Struct holding instances and information of Libra Swarm +pub struct LibraSwarm { + dir: Option, + nodes: HashMap, + config: SwarmConfig, + tee_logs: bool, +} + +impl LibraSwarm { + pub fn launch_swarm( + num_nodes: usize, + disable_logging: bool, + faucet_account_keypair: KeyPair, + tee_logs: bool, + ) -> Self { + let dir = tempfile::tempdir().unwrap(); + let logs_dir_path = &dir.path().join("logs"); + println!("Logs directory: {:?}", logs_dir_path); + std::fs::create_dir(&logs_dir_path).unwrap(); + let base = utils::workspace_root().join("config/data/configs/node.config.toml"); + let mut config_builder = SwarmConfigBuilder::new(); + config_builder + .with_ipv4() + .with_nodes(num_nodes) + .with_base(base) + .with_output_dir(&dir) + .with_faucet_keypair(faucet_account_keypair) + .randomize_ports(); + let config = config_builder.build().unwrap(); + + let mut swarm = Self { + dir: Some(dir), + nodes: HashMap::new(), + config, + tee_logs, + }; + // For each config launch a node + for (path, node_config) in swarm.config.get_configs() { + let node = LibraNode::launch( + &node_config, + &path, + &logs_dir_path, + disable_logging, + tee_logs, + ) + .unwrap(); + swarm.nodes.insert( + node_config.admission_control.admission_control_service_port, + node, + ); + } + + if !swarm.wait_for_startup() { + panic!("Error launching swarm"); + } + if !swarm.wait_for_connectivity() { + // Verify connectivity + panic!("Some nodes not connected to each other"); + } + info!("Successfully launched Swarm"); + + swarm + } + + fn wait_for_connectivity(&self) -> bool { + // Early return if we're only launching a single node + if self.nodes.len() == 1 { + return true; + } + + let num_attempts = 60; + + for i in 0..num_attempts { + debug!("Wait for connectivity attempt: {}", i); + + if self + .nodes + .values() + .all(|node| node.check_connectivity(self.nodes.len() as i64 - 1)) + { + return true; + } + + ::std::thread::sleep(::std::time::Duration::from_millis(1000)); + } + + false + } + + fn wait_for_startup(&mut self) -> bool { + let num_attempts = 120; + let mut done = vec![false; self.nodes.len()]; + + for i in 0..num_attempts { + debug!("Wait for startup attempt: {} of {}", i, num_attempts); + for (node, done) in self.nodes.values_mut().zip(done.iter_mut()) { + if *done { + continue; + } + + match node.health_check() { + HealthStatus::Healthy => *done = true, + HealthStatus::RpcFailure(_) => continue, + HealthStatus::Crashed(status) => { + panic!( + "Libra node '{}' has crashed with status '{}'. Log output: '''{}'''", + node.peer_id, + status, + node.get_log_contents().unwrap() + ); + } + } + } + + // Check if all the nodes have been successfully lauched + if done.iter().all(|status| *status) { + return true; + } + + ::std::thread::sleep(::std::time::Duration::from_millis(1000)); + } + + false + } + + /// This function first checks the last committed round of all the nodes, picks the max + /// value and then waits for all the nodes to catch up to that round. + /// Once done, we can guarantee that all the txns committed before the invocation of this + /// function are now available at all the nodes. + pub fn wait_for_all_nodes_to_catchup(&mut self) -> bool { + let num_attempts = 60; + let last_committed_round_str = "consensus{op=committed_blocks_count}"; + let mut done = vec![false; self.nodes.len()]; + + let mut last_committed_round = 0; + // First, try to retrieve the max value across all the committed rounds + debug!("Calculating max committed round across the validators."); + for node in self.nodes.values() { + match node.get_metric(last_committed_round_str) { + Some(val) => { + debug!("\tNode {} last committed round = {}", node.peer_id, val); + last_committed_round = last_committed_round.max(val); + } + None => { + debug!( + "\tNode {} last committed round unknown, assuming 0.", + node.peer_id + ); + } + } + } + + // Now wait for all the nodes to catch up to the max. + for i in 0..num_attempts { + debug!( + "Wait for catchup, target_commit_round = {}, attempt: {} of {}", + last_committed_round, i, num_attempts + ); + for (node, done) in self.nodes.values_mut().zip(done.iter_mut()) { + if *done { + continue; + } + + match node.get_metric(last_committed_round_str) { + Some(val) => { + if val >= last_committed_round { + debug!( + "\tNode {} is caught up with last committed round {}", + node.peer_id, val + ); + *done = true; + } else { + debug!( + "\tNode {} is not caught up yet with last committed round {}", + node.peer_id, val + ); + } + } + None => { + debug!( + "\tNode {} last committed round unknown, assuming 0.", + node.peer_id + ); + } + } + } + + // Check if all the nodes have been successfully caught up + if done.iter().all(|status| *status) { + return true; + } + + ::std::thread::sleep(::std::time::Duration::from_millis(1000)); + } + + false + } + + /// Vector with the public ports of all the validators in the swarm. + pub fn get_validators_public_ports(&self) -> Vec { + self.nodes.keys().cloned().collect() + } + + pub fn kill_node(&mut self, port: u16) { + self.nodes.remove(&port); + } + + pub fn add_node(&mut self, port: u16, disable_logging: bool) -> bool { + let logs_dir_path = self.dir.as_ref().map(|x| x.path().join("logs")).unwrap(); + + for (path, config) in self.config.get_configs() { + if config.admission_control.admission_control_service_port == port { + let mut node = LibraNode::launch( + &config, + &path, + &logs_dir_path, + disable_logging, + self.tee_logs, + ) + .unwrap(); + for _ in 0..60 { + if let HealthStatus::Healthy = node.health_check() { + self.nodes.insert(port, node); + return self.wait_for_connectivity(); + } + ::std::thread::sleep(::std::time::Duration::from_millis(1000)); + } + } + } + false + } + + pub fn get_trusted_peers_config_path(&self) -> String { + let (path, _) = self.config.get_trusted_peers_config(); + path.canonicalize() + .expect("Unable to get canonical path of trusted peers config file") + .to_str() + .unwrap() + .to_string() + } +} + +impl Drop for LibraSwarm { + fn drop(&mut self) { + // If panicking, we don't want to gc the swarm directory. + if std::thread::panicking() { + if let Some(dir) = self.dir.take() { + let logs = dir.into_path(); + println!("logs located: {:?}", logs); + } + } + } +} diff --git a/libra_swarm/src/utils.rs b/libra_swarm/src/utils.rs new file mode 100644 index 0000000000000..9285aa7f1d629 --- /dev/null +++ b/libra_swarm/src/utils.rs @@ -0,0 +1,86 @@ +// Copyright (c) The Libra Core Contributors +// SPDX-License-Identifier: Apache-2.0 + +use lazy_static::lazy_static; +use logger::prelude::*; +use std::{env, path::PathBuf, process::Command}; + +const WORKSPACE_BUILD_ERROR_MSG: &str = r#" + Unable to build all workspace binaries. Cannot continue running tests. + + Try running 'cargo build --all --bins' yourself. +"#; + +lazy_static! { + // Global flag indicating if all binaries in the workspace have been built. + static ref WORKSPACE_BUILT: bool = { + info!("Building project binaries"); + let args = if cfg!(debug_assertions) { + vec!["build", "--all", "--bins"] + } else { + vec!["build", "--all", "--bins", "--release"] + }; + + + let cargo_build = Command::new("cargo") + .current_dir(workspace_root()) + .args(&args) + .output() + .expect(WORKSPACE_BUILD_ERROR_MSG); + if !cargo_build.status.success() { + panic!(WORKSPACE_BUILD_ERROR_MSG); + } + + info!("Finished building project binaries"); + + true + }; +} + +// Path to top level workspace +pub fn workspace_root() -> PathBuf { + let mut path = build_dir(); + while !path.ends_with("target") { + path.pop(); + } + path.pop(); + path +} + +// Path to the directory where build artifacts live. +//TODO maybe add an Environment Variable which points to built binaries +pub fn build_dir() -> PathBuf { + env::current_exe() + .ok() + .map(|mut path| { + path.pop(); + if path.ends_with("deps") { + path.pop(); + } + path + }) + .expect("Can't find the build directory. Cannot continue running tests") +} + +// Path to a specified binary +pub fn get_bin>(bin_name: S) -> PathBuf { + // We have to check to see if the workspace is built first to ensure that the binaries we're + // testing are up to date. + if !*WORKSPACE_BUILT { + panic!(WORKSPACE_BUILD_ERROR_MSG); + } + + let bin_name = bin_name.as_ref(); + let bin_path = build_dir().join(format!("{}{}", bin_name, env::consts::EXE_SUFFIX)); + + // If the binary doesn't exist then either building them failed somehow or the supplied binary + // name doesn't match any binaries this workspace can produce. + if !bin_path.exists() { + panic!(format!( + "Can't find binary '{}' in expected path {:?}", + bin_name, bin_path + )); + } + + bin_path +} diff --git a/mempool/Cargo.toml b/mempool/Cargo.toml new file mode 100644 index 0000000000000..39d051075b94c --- /dev/null +++ b/mempool/Cargo.toml @@ -0,0 +1,41 @@ +[package] +name = "mempool" +version = "0.1.0" +authors = ["Libra Association "] +license = "Apache-2.0" +publish = false +edition = "2018" + +[dependencies] +chrono = "0.4.6" +futures = "0.1.25" +futures-preview = { version = "=0.3.0-alpha.16", package = "futures-preview", features = ["compat"] } +grpcio = "0.4.3" +grpcio-sys = "0.4.4" +lazy_static = "1.3.0" +lru-cache = "0.1.1" +protobuf = "2.6" +tokio = "0.1.16" +ttl_cache = "0.4.2" + +config = { path = "../config" } +crypto = { path = "../crypto/legacy_crypto" } +failure = { path = "../common/failure_ext", package = "failure_ext" } +grpc_helpers = { path = "../common/grpc_helpers" } +logger = { path = "../common/logger" } +metrics = { path = "../common/metrics" } +network = {path = "../network"} +proto_conv = { path = "../common/proto_conv" } +storage_client = { path = "../storage/storage_client" } +types = { path = "../types" } +vm_validator = {path = "../vm_validator"} + +[dev-dependencies] +rand = "0.6.5" +tempfile = "3.0.6" +channel = { path = "../common/channel" } + +storage_service = { path = "../storage/storage_service" } + +[build-dependencies] +build_helpers = { path = "../common/build_helpers" } diff --git a/mempool/README.md b/mempool/README.md new file mode 100644 index 0000000000000..db7f6e393652f --- /dev/null +++ b/mempool/README.md @@ -0,0 +1,53 @@ +--- +id: mempool +title: Mempool +custom_edit_url: https://github.com/libra/libra/edit/master/mempool/README.md +--- +# Mempool + +Mempool is a memory-buffer that holds the transactions that are waiting to be executed. + +## Overview + +Admission control (AC) module sends transactions to mempool. Mempool holds the transactions for a period of time, before consensus commits them. When a new transaction is added, mempool shares this transaction with other validators (validator nodes) in the system. Mempool is a β€œshared mempool,” as transactions between mempools are shared with other validators. This helps maintain a pseudoglobal ordering. + +When a validator receives a transaction from another mempool, the transaction is ordered when it’s added to the ordered queue of the recipient validator. To reduce network consumption in the shared mempool, each validator is responsible for the delivery of its own transactions. We don't rebroadcast transactions originating from a peer validator. + +We only broadcast transactions that have some probability of being included in the next block. This means that either the sequence number of the transaction is the next sequence number of the sender account, or it is sequential to it. For example, if the current sequence number for an account is 2 and local mempool contains transactions with sequence numbers 2, 3, 4, 7, 8, then only transactions 2, 3, and 4 will be broadcast. + +The consensus module pulls transactions from mempool, mempool does not push transactions into consensus. This is to ensure that while consensus is not ready for transactions: + +* Mempool can continue ordering transactions based on gas; and +* Consensus can allow transactions to build up in the mempool. + +This allows transactions to be grouped into a single consensus block, and prioritized by gas price. + +Mempool doesn't keep track of transactions sent to consensus. On each get_block request (to pull a block of transaction from mempool), consensus sends a set of transactions that were pulled from mempool, but not committed. This allows the mempool to stay agnostic about different consensus proposal branches. + +When a transaction is fully executed and written to storage, consensus notifies mempool. Mempool then drops this transaction from its internal state. + +## Implementation Details + +Internally, mempool is modeled as `HashMap` with various indexes built on top of it. + +The main index - PriorityIndex is an ordered queue of transactions that are β€œready” to be included in the next block (i.e., they have a sequence number which is sequential to the current sequence number for the account). This queue is ordered by gas price so that if a client is willing to pay more (than other clients) per unit of execution, then they can enter consensus earlier. + +Note that, even though global ordering is maintained by gas price, for a single account, transactions are ordered by sequence number. All transactions that are not ready to be included in the next block are part of a separate ParkingLotIndex. They are moved to the ordered queue once some event unblocks them. + +Here is an example: mempool has a transaction with sequence number 4, while the current sequence number for that account is 3. This transaction is considered β€œnon-ready.” Callback from consensus notifies that transaction was committed (i.e., transaction 3 was submitted to a different node and has hence been committed on chain). This event β€œunblocks” the local transaction, and transaction #4 is moved to the OrderedQueue. + +Mempool only holds a limited number of transactions to avoid overwhelming the system and to prevent abuse and attack. Transactions in Mempool have two types of expirations: systemTTL and client-specified expiration. When either of these is reached, the transaction is removed from Mempool. + +SystemTTL is checked periodically in the background, while the expiration specified by the client is checked on every Consensus commit request. We use a separate system TTL to ensure that a transaction doesn’t remain stuck in the Mempool forever, even if Consensus doesn't make progress. + +## How is this module organized? +``` + mempool/src + β”œβ”€β”€ core_mempool # main in memory data structure + β”œβ”€β”€ proto # protobuf definitions for interactions with mempool + β”œβ”€β”€ lib.rs + β”œβ”€β”€ mempool_service.rs # gRPC service + β”œβ”€β”€ runtime.rs # bundle of shared mempool and gRPC service + └── shared_mempool.rs # shared mempool +``` + diff --git a/mempool/build.rs b/mempool/build.rs new file mode 100644 index 0000000000000..1e81dacbd42c8 --- /dev/null +++ b/mempool/build.rs @@ -0,0 +1,20 @@ +// Copyright (c) The Libra Core Contributors +// SPDX-License-Identifier: Apache-2.0 + +//! This compiles all the `.proto` files under `src/` directory. +//! +//! For example, if there is a file `src/a/b/c.proto`, it will generate `src/a/b/c.rs` and +//! `src/a/b/c_grpc.rs`. + +fn main() { + let proto_root = "src/proto"; + let proto_shared_root = "src/proto/shared"; + let dependent_root = "../types/src/proto"; + // Build shared directory without further dependencies. + build_helpers::build_helpers::compile_proto(proto_shared_root, vec![], false); + build_helpers::build_helpers::compile_proto( + proto_root, + vec![dependent_root, proto_shared_root], + true, + ); +} diff --git a/mempool/src/core_mempool/index.rs b/mempool/src/core_mempool/index.rs new file mode 100644 index 0000000000000..1d5afb881aed1 --- /dev/null +++ b/mempool/src/core_mempool/index.rs @@ -0,0 +1,276 @@ +// Copyright (c) The Libra Core Contributors +// SPDX-License-Identifier: Apache-2.0 + +/// This module provides various indexes used by Mempool +use crate::core_mempool::transaction::{MempoolTransaction, TimelineState}; +use std::{ + cmp::Ordering, + collections::{btree_set::Iter, BTreeMap, BTreeSet}, + iter::Rev, + ops::Bound, + time::Duration, +}; +use types::account_address::AccountAddress; + +pub type AccountTransactions = BTreeMap; + +/// PriorityIndex represents main Priority Queue in Mempool +/// It's used to form transaction block for Consensus +/// Transactions are ordered by gas price. Second level ordering is done by expiration time +/// +/// We don't store full content of transaction in index +/// Instead we use `OrderedQueueKey` - logical reference to transaction in main store +pub struct PriorityIndex { + data: BTreeSet, +} + +pub type PriorityQueueIter<'a> = Rev>; + +impl PriorityIndex { + pub(crate) fn new() -> Self { + Self { + data: BTreeSet::new(), + } + } + + /// add transaction to index + pub(crate) fn insert(&mut self, txn: &MempoolTransaction) { + self.data.insert(self.make_key(&txn)); + } + + /// remove transaction from index + pub(crate) fn remove(&mut self, txn: &MempoolTransaction) { + self.data.remove(&self.make_key(&txn)); + } + + pub(crate) fn contains(&self, txn: &MempoolTransaction) -> bool { + self.data.contains(&self.make_key(txn)) + } + + fn make_key(&self, txn: &MempoolTransaction) -> OrderedQueueKey { + OrderedQueueKey { + gas_price: txn.get_gas_price(), + expiration_time: txn.expiration_time, + address: txn.get_sender(), + sequence_number: txn.get_sequence_number(), + } + } + + /// returns iterator over priority queue + pub(crate) fn iter(&self) -> PriorityQueueIter { + self.data.iter().rev() + } +} + +#[derive(Eq, PartialEq, Clone, Debug, Hash)] +pub struct OrderedQueueKey { + pub gas_price: u64, + pub expiration_time: Duration, + pub address: AccountAddress, + pub sequence_number: u64, +} + +impl PartialOrd for OrderedQueueKey { + fn partial_cmp(&self, other: &OrderedQueueKey) -> Option { + Some(self.cmp(other)) + } +} + +impl Ord for OrderedQueueKey { + fn cmp(&self, other: &OrderedQueueKey) -> Ordering { + match self.gas_price.cmp(&other.gas_price) { + Ordering::Equal => {} + ordering => return ordering, + } + match self.expiration_time.cmp(&other.expiration_time).reverse() { + Ordering::Equal => {} + ordering => return ordering, + } + match self.address.cmp(&other.address) { + Ordering::Equal => {} + ordering => return ordering, + } + self.sequence_number.cmp(&other.sequence_number).reverse() + } +} + +/// TTLIndex is used to perform garbage collection of old transactions in Mempool +/// Periodically separate GC-like job queries this index to find out transactions that have to be +/// removed Index is represented as `BTreeSet` +/// where `TTLOrderingKey` is logical reference to TxnInfo +/// Index is ordered by `TTLOrderingKey::expiration_time` +pub struct TTLIndex { + data: BTreeSet, + get_expiration_time: Box Duration + Send + Sync>, +} + +impl TTLIndex { + pub(crate) fn new(get_expiration_time: Box) -> Self + where + F: Fn(&MempoolTransaction) -> Duration + 'static + Send + Sync, + { + Self { + data: BTreeSet::new(), + get_expiration_time, + } + } + + /// add transaction to index + pub(crate) fn insert(&mut self, txn: &MempoolTransaction) { + self.data.insert(self.make_key(&txn)); + } + + /// remove transaction from index + pub(crate) fn remove(&mut self, txn: &MempoolTransaction) { + self.data.remove(&self.make_key(&txn)); + } + + /// GC all old transactions + pub(crate) fn gc(&mut self, now: Duration) -> Vec { + let ttl_key = TTLOrderingKey { + expiration_time: now, + address: AccountAddress::default(), + sequence_number: 0, + }; + + let mut active = self.data.split_off(&ttl_key); + let ttl_transactions = self.data.iter().cloned().collect(); + self.data.clear(); + self.data.append(&mut active); + ttl_transactions + } + + fn make_key(&self, txn: &MempoolTransaction) -> TTLOrderingKey { + TTLOrderingKey { + expiration_time: (self.get_expiration_time)(txn), + address: txn.get_sender(), + sequence_number: txn.get_sequence_number(), + } + } + + pub(crate) fn size(&self) -> usize { + self.data.len() + } +} + +#[derive(Eq, PartialEq, PartialOrd, Clone, Debug)] +pub struct TTLOrderingKey { + pub expiration_time: Duration, + pub address: AccountAddress, + pub sequence_number: u64, +} + +impl Ord for TTLOrderingKey { + fn cmp(&self, other: &TTLOrderingKey) -> Ordering { + match self.expiration_time.cmp(&other.expiration_time) { + Ordering::Equal => { + (&self.address, self.sequence_number).cmp(&(&other.address, other.sequence_number)) + } + ordering => ordering, + } + } +} + +/// TimelineIndex is ordered log of all transactions that are "ready" for broadcast +/// we only add transaction to index if it has a chance to be included in next consensus block +/// it means it's status != NotReady or it's sequential to other "ready" transaction +/// +/// It's represented as Map +/// where timeline_id is auto increment unique id of "ready" transaction in local Mempool +/// (Address, sequence_number) is a logical reference to transaction content in main storage +pub struct TimelineIndex { + timeline_id: u64, + timeline: BTreeMap, +} + +impl TimelineIndex { + pub(crate) fn new() -> Self { + Self { + timeline_id: 1, + timeline: BTreeMap::new(), + } + } + + /// read all transactions from timeline since + pub(crate) fn read_timeline( + &mut self, + timeline_id: u64, + count: usize, + ) -> Vec<(AccountAddress, u64)> { + let mut batch = vec![]; + for (_, &(address, sequence_number)) in self + .timeline + .range((Bound::Excluded(timeline_id), Bound::Unbounded)) + { + batch.push((address, sequence_number)); + if batch.len() == count { + break; + } + } + batch + } + + /// add transaction to index + pub(crate) fn insert(&mut self, txn: &mut MempoolTransaction) { + self.timeline.insert( + self.timeline_id, + (txn.get_sender(), txn.get_sequence_number()), + ); + txn.timeline_state = TimelineState::Ready(self.timeline_id); + self.timeline_id += 1; + } + + /// remove transaction from index + pub(crate) fn remove(&mut self, txn: &MempoolTransaction) { + if let TimelineState::Ready(timeline_id) = txn.timeline_state { + self.timeline.remove(&timeline_id); + } + } +} + +/// ParkingLotIndex keeps track of "not_ready" transactions +/// e.g. transactions that can't be included in next block +/// (because their sequence number is too high) +/// we keep separate index to be able to efficiently evict them when Mempool is full +pub struct ParkingLotIndex { + data: BTreeSet, +} + +impl ParkingLotIndex { + pub(crate) fn new() -> Self { + Self { + data: BTreeSet::new(), + } + } + + /// add transaction to index + pub(crate) fn insert(&mut self, txn: &MempoolTransaction) { + self.data.insert(TxnPointer::from(txn)); + } + + /// remove transaction from index + pub(crate) fn remove(&mut self, txn: &MempoolTransaction) { + self.data.remove(&TxnPointer::from(txn)); + } + + /// returns random "non-ready" transaction (with highest sequence number for that account) + pub(crate) fn pop(&mut self) -> Option { + self.data.iter().rev().next().cloned() + } +} + +/// Logical pointer to `MempoolTransaction` +/// Includes Account's address and transaction sequence number +pub type TxnPointer = (AccountAddress, u64); + +impl From<&MempoolTransaction> for TxnPointer { + fn from(transaction: &MempoolTransaction) -> Self { + (transaction.get_sender(), transaction.get_sequence_number()) + } +} + +impl From<&OrderedQueueKey> for TxnPointer { + fn from(key: &OrderedQueueKey) -> Self { + (key.address, key.sequence_number) + } +} diff --git a/mempool/src/core_mempool/mempool.rs b/mempool/src/core_mempool/mempool.rs new file mode 100644 index 0000000000000..1cb52bb311e60 --- /dev/null +++ b/mempool/src/core_mempool/mempool.rs @@ -0,0 +1,238 @@ +// Copyright (c) The Libra Core Contributors +// SPDX-License-Identifier: Apache-2.0 + +//! mempool is used to track transactions which have been submitted but not yet +//! agreed upon. +use std::time::{Duration, SystemTime, UNIX_EPOCH}; + +use crate::{ + core_mempool::{ + index::TxnPointer, + transaction::{MempoolAddTransactionStatus, MempoolTransaction, TimelineState}, + transaction_store::TransactionStore, + }, + OP_COUNTERS, +}; +use chrono::Utc; +use config::config::NodeConfig; +use logger::prelude::*; +use lru_cache::LruCache; +use std::{ + cmp::{max, min}, + collections::HashSet, +}; +use ttl_cache::TtlCache; +use types::{account_address::AccountAddress, transaction::SignedTransaction}; + +pub struct Mempool { + // stores metadata of all transactions in mempool (of all states) + transactions: TransactionStore, + + sequence_number_cache: LruCache, + // temporary DS. TODO: eventually retire it + // for each transaction, entry with timestamp is added when transaction enters mempool + // used to measure e2e latency of transaction in system, as well as time it takes to pick it up + // by consensus + metrics_cache: TtlCache<(AccountAddress, u64), i64>, + pub system_transaction_timeout: Duration, +} + +impl Mempool { + pub(crate) fn new(config: &NodeConfig) -> Self { + Mempool { + transactions: TransactionStore::new(&config.mempool), + sequence_number_cache: LruCache::new(config.mempool.sequence_cache_capacity), + metrics_cache: TtlCache::new(config.mempool.capacity), + system_transaction_timeout: Duration::from_secs( + config.mempool.system_transaction_timeout_secs, + ), + } + } + + /// This function will be called once the transaction has been stored + pub(crate) fn remove_transaction( + &mut self, + sender: &AccountAddress, + sequence_number: u64, + is_rejected: bool, + ) { + debug!( + "[Mempool] Removing transaction from mempool: {}:{}", + sender, sequence_number + ); + self.log_latency(sender.clone(), sequence_number, "e2e.latency"); + self.metrics_cache.remove(&(*sender, sequence_number)); + + // update current cached sequence number for account + let cached_value = self + .sequence_number_cache + .remove(sender) + .unwrap_or_default(); + + let new_sequence_number = if is_rejected { + min(sequence_number, cached_value) + } else { + max(cached_value, sequence_number + 1) + }; + self.sequence_number_cache + .insert(sender.clone(), new_sequence_number); + + self.transactions + .commit_transaction(&sender, sequence_number); + } + + fn log_latency(&mut self, account: AccountAddress, sequence_number: u64, metric: &str) { + if let Some(&creation_time) = self.metrics_cache.get(&(account, sequence_number)) { + OP_COUNTERS.observe( + metric, + (Utc::now().timestamp_millis() - creation_time) as f64, + ); + } + } + + fn check_balance(&mut self, txn: &SignedTransaction, balance: u64, gas_amount: u64) -> bool { + let required_balance = txn.gas_unit_price() * gas_amount + + self.transactions.get_required_balance(&txn.sender()); + balance >= required_balance + } + + /// Used to add a transaction to the Mempool + /// Performs basic validation: checks account's balance and sequence number + pub(crate) fn add_txn( + &mut self, + txn: SignedTransaction, + gas_amount: u64, + db_sequence_number: u64, + balance: u64, + timeline_state: TimelineState, + ) -> MempoolAddTransactionStatus { + debug!( + "[Mempool] Adding transaction to mempool: {}:{}", + &txn.sender(), + db_sequence_number + ); + if !self.check_balance(&txn, balance, gas_amount) { + return MempoolAddTransactionStatus::InsufficientBalance; + } + + let cached_value = self.sequence_number_cache.get_mut(&txn.sender()); + let sequence_number = match cached_value { + Some(value) => max(*value, db_sequence_number), + None => db_sequence_number, + }; + self.sequence_number_cache + .insert(txn.sender(), sequence_number); + + // don't accept old transactions (e.g. seq is less than account's current seq_number) + if txn.sequence_number() < sequence_number { + return MempoolAddTransactionStatus::InvalidSeqNumber; + } + + let expiration_time = SystemTime::now() + .duration_since(UNIX_EPOCH) + .expect("init timestamp failure") + + self.system_transaction_timeout; + self.metrics_cache.insert( + (txn.sender(), txn.sequence_number()), + Utc::now().timestamp_millis(), + Duration::from_secs(100), + ); + + let txn_info = MempoolTransaction::new(txn, expiration_time, gas_amount, timeline_state); + + let status = self.transactions.insert(txn_info, sequence_number); + OP_COUNTERS.inc(&format!("insert.{:?}", status)); + status + } + + /// Fetches next block of transactions for consensus + /// `batch_size` - size of requested block + /// `seen_txns` - transactions that were sent to Consensus but were not committed yet + /// Mempool should filter out such transactions + pub(crate) fn get_block( + &mut self, + batch_size: u64, + mut seen: HashSet, + ) -> Vec { + let mut result = vec![]; + // Helper DS. Helps to mitigate scenarios where account submits several transactions + // with increasing gas price (e.g. user submits transactions with sequence number 1, 2 + // and gas_price 1, 10 respectively) + // Later txn has higher gas price and will be observed first in priority index iterator, + // but can't be executed before first txn. Once observed, such txn will be saved in + // `skipped` DS and rechecked once it's ancestor becomes available + let mut skipped = HashSet::new(); + + // iterate over the queue of transactions based on gas price + 'main: for txn in self.transactions.iter_queue() { + if seen.contains(&TxnPointer::from(txn)) { + continue; + } + let mut seq = txn.sequence_number; + let account_sequence_number = self.sequence_number_cache.get_mut(&txn.address); + let seen_previous = seq > 0 && seen.contains(&(txn.address, seq - 1)); + // include transaction if it's "next" for given account or + // we've already sent its ancestor to Consensus + if seen_previous || account_sequence_number == Some(&mut seq) { + let ptr = TxnPointer::from(txn); + seen.insert(ptr); + result.push(ptr); + if (result.len() as u64) == batch_size { + break; + } + + // check if we can now include some transactions + // that were skipped before for given account + let mut skipped_txn = (txn.address, seq + 1); + while skipped.contains(&skipped_txn) { + seen.insert(skipped_txn); + result.push(skipped_txn); + if (result.len() as u64) == batch_size { + break 'main; + } + skipped_txn = (txn.address, skipped_txn.1 + 1); + } + } else { + skipped.insert(TxnPointer::from(txn)); + } + } + // convert transaction pointers to real values + let block: Vec<_> = result + .into_iter() + .filter_map(|(address, seq)| self.transactions.get(&address, seq)) + .collect(); + for transaction in &block { + self.log_latency( + transaction.sender(), + transaction.sequence_number(), + "txn_pre_consensus_ms", + ); + } + block + } + + /// TTL based garbage collection. Remove all transactions that got expired + pub(crate) fn gc_by_system_ttl(&mut self) { + self.transactions.gc_by_system_ttl(); + } + + /// Garbage collection based on client-specified expiration time + pub(crate) fn gc_by_expiration_time(&mut self, block_time: Duration) { + self.transactions.gc_by_expiration_time(block_time); + } + + /// Read `count` transactions from timeline since `timeline_id` + /// Returns block of transactions and new last_timeline_id + pub(crate) fn read_timeline( + &mut self, + timeline_id: u64, + count: usize, + ) -> (Vec, u64) { + self.transactions.read_timeline(timeline_id, count) + } + + /// Check the health of core mempool. + pub(crate) fn health_check(&self) -> bool { + self.transactions.health_check() + } +} diff --git a/mempool/src/core_mempool/mod.rs b/mempool/src/core_mempool/mod.rs new file mode 100644 index 0000000000000..9ac882d8ca95c --- /dev/null +++ b/mempool/src/core_mempool/mod.rs @@ -0,0 +1,16 @@ +// Copyright (c) The Libra Core Contributors +// SPDX-License-Identifier: Apache-2.0 + +mod index; +mod mempool; +mod transaction; +mod transaction_store; + +pub use self::{ + index::TxnPointer, + mempool::Mempool as CoreMempool, + transaction::{MempoolAddTransactionStatus, TimelineState}, +}; + +#[cfg(test)] +mod unit_tests; diff --git a/mempool/src/core_mempool/transaction.rs b/mempool/src/core_mempool/transaction.rs new file mode 100644 index 0000000000000..01ddda862cd16 --- /dev/null +++ b/mempool/src/core_mempool/transaction.rs @@ -0,0 +1,123 @@ +// Copyright (c) The Libra Core Contributors +// SPDX-License-Identifier: Apache-2.0 + +use crate::proto::shared::mempool_status::MempoolAddTransactionStatus as ProtoMempoolAddTransactionStatus; +use failure::prelude::*; +use proto_conv::{FromProto, IntoProto}; +use std::time::Duration; +use types::{account_address::AccountAddress, transaction::SignedTransaction}; + +#[derive(Clone)] +pub struct MempoolTransaction { + pub txn: SignedTransaction, + // system expiration time of transaction. It should be removed from mempool by that time + pub expiration_time: Duration, + pub gas_amount: u64, + pub timeline_state: TimelineState, +} + +impl MempoolTransaction { + pub(crate) fn new( + txn: SignedTransaction, + expiration_time: Duration, + gas_amount: u64, + timeline_state: TimelineState, + ) -> Self { + Self { + txn, + gas_amount, + expiration_time, + timeline_state, + } + } + pub(crate) fn get_sequence_number(&self) -> u64 { + self.txn.sequence_number() + } + pub(crate) fn get_sender(&self) -> AccountAddress { + self.txn.sender() + } + pub(crate) fn get_gas_price(&self) -> u64 { + self.txn.gas_unit_price() + } +} + +#[derive(Clone, Copy, PartialEq, Eq, Debug, Hash)] +pub enum TimelineState { + // transaction is ready for broadcast + // Associated integer represents it's position in log of such transactions + Ready(u64), + // transaction is not yet ready for broadcast + // but it might change in a future + NotReady, + // transaction will never be qualified for broadcasting + // currently we don't broadcast transactions originated on other peers + NonQualified, +} + +/// Status of transaction insertion operation +#[derive(Debug, PartialEq, Eq, Clone)] +pub enum MempoolAddTransactionStatus { + /// Transaction was successfully sent to Mempool + Valid, + /// The sender does not have enough balance for the transaction + InsufficientBalance, + /// Transaction sequence number is invalid(e.g. too old) + InvalidSeqNumber, + /// Mempool is full (reached max global capacity) + MempoolIsFull, + /// Account reached max capacity per account + TooManyTransactions, + /// Invalid update. Only gas price increase is allowed + InvalidUpdate, +} + +impl IntoProto for MempoolAddTransactionStatus { + type ProtoType = crate::proto::shared::mempool_status::MempoolAddTransactionStatus; + + fn into_proto(self) -> Self::ProtoType { + match self { + MempoolAddTransactionStatus::Valid => ProtoMempoolAddTransactionStatus::Valid, + MempoolAddTransactionStatus::InsufficientBalance => { + ProtoMempoolAddTransactionStatus::InsufficientBalance + } + MempoolAddTransactionStatus::InvalidSeqNumber => { + ProtoMempoolAddTransactionStatus::InvalidSeqNumber + } + MempoolAddTransactionStatus::InvalidUpdate => { + ProtoMempoolAddTransactionStatus::InvalidUpdate + } + MempoolAddTransactionStatus::MempoolIsFull => { + ProtoMempoolAddTransactionStatus::MempoolIsFull + } + MempoolAddTransactionStatus::TooManyTransactions => { + ProtoMempoolAddTransactionStatus::TooManyTransactions + } + } + } +} + +impl FromProto for MempoolAddTransactionStatus { + type ProtoType = crate::proto::shared::mempool_status::MempoolAddTransactionStatus; + + fn from_proto(object: Self::ProtoType) -> Result { + let ret = match object { + ProtoMempoolAddTransactionStatus::Valid => MempoolAddTransactionStatus::Valid, + ProtoMempoolAddTransactionStatus::InsufficientBalance => { + MempoolAddTransactionStatus::InsufficientBalance + } + ProtoMempoolAddTransactionStatus::InvalidSeqNumber => { + MempoolAddTransactionStatus::InvalidSeqNumber + } + ProtoMempoolAddTransactionStatus::InvalidUpdate => { + MempoolAddTransactionStatus::InvalidUpdate + } + ProtoMempoolAddTransactionStatus::MempoolIsFull => { + MempoolAddTransactionStatus::MempoolIsFull + } + ProtoMempoolAddTransactionStatus::TooManyTransactions => { + MempoolAddTransactionStatus::TooManyTransactions + } + }; + Ok(ret) + } +} diff --git a/mempool/src/core_mempool/transaction_store.rs b/mempool/src/core_mempool/transaction_store.rs new file mode 100644 index 0000000000000..9e155d7237ccd --- /dev/null +++ b/mempool/src/core_mempool/transaction_store.rs @@ -0,0 +1,302 @@ +// Copyright (c) The Libra Core Contributors +// SPDX-License-Identifier: Apache-2.0 + +use crate::{ + core_mempool::{ + index::{ + AccountTransactions, ParkingLotIndex, PriorityIndex, PriorityQueueIter, TTLIndex, + TimelineIndex, + }, + transaction::{MempoolAddTransactionStatus, MempoolTransaction, TimelineState}, + }, + OP_COUNTERS, +}; +use config::config::MempoolConfig; +use std::{ + collections::HashMap, + ops::Bound, + time::{Duration, SystemTime, UNIX_EPOCH}, +}; +use types::{account_address::AccountAddress, transaction::SignedTransaction}; + +/// TransactionStore is in-memory storage for all transactions in mempool +pub struct TransactionStore { + // main DS + transactions: HashMap, + + // indexes + priority_index: PriorityIndex, + // TTLIndex based on client-specified expiration time + expiration_time_index: TTLIndex, + // TTLIndex based on system expiration time + // we keep it separate from `expiration_time_index` so Mempool can't be clogged + // by old transactions even if it hasn't received commit callbacks for a while + system_ttl_index: TTLIndex, + timeline_index: TimelineIndex, + // keeps track of "non-ready" txns (transactions that can't be included in next block) + parking_lot_index: ParkingLotIndex, + + // configuration + capacity: usize, + capacity_per_user: usize, +} + +impl TransactionStore { + pub(crate) fn new(config: &MempoolConfig) -> Self { + Self { + // main DS + transactions: HashMap::new(), + + // various indexes + system_ttl_index: TTLIndex::new(Box::new(|t: &MempoolTransaction| t.expiration_time)), + expiration_time_index: TTLIndex::new(Box::new(|t: &MempoolTransaction| { + t.txn.expiration_time() + })), + priority_index: PriorityIndex::new(), + timeline_index: TimelineIndex::new(), + parking_lot_index: ParkingLotIndex::new(), + + // configuration + capacity: config.capacity, + capacity_per_user: config.capacity_per_user, + } + } + + /// fetch transaction by account address + sequence_number + pub(crate) fn get( + &self, + address: &AccountAddress, + sequence_number: u64, + ) -> Option { + if let Some(txns) = self.transactions.get(&address) { + if let Some(txn) = txns.get(&sequence_number) { + return Some(txn.txn.clone()); + } + } + None + } + + /// insert transaction into TransactionStore + /// performs validation checks and updates indexes + pub(crate) fn insert( + &mut self, + txn: MempoolTransaction, + current_sequence_number: u64, + ) -> MempoolAddTransactionStatus { + let (is_update, status) = self.check_for_update(&txn); + if is_update { + return status; + } + if self.check_if_full() { + return MempoolAddTransactionStatus::MempoolIsFull; + } + + let address = txn.get_sender(); + let sequence_number = txn.get_sequence_number(); + + self.transactions + .entry(address) + .or_insert_with(AccountTransactions::new); + + if let Some(txns) = self.transactions.get_mut(&address) { + // capacity check + if txns.len() >= self.capacity_per_user { + return MempoolAddTransactionStatus::TooManyTransactions; + } + + // insert into storage and other indexes + self.system_ttl_index.insert(&txn); + self.expiration_time_index.insert(&txn); + txns.insert(sequence_number, txn); + OP_COUNTERS.set("txn.system_ttl_index", self.system_ttl_index.size()); + } + self.process_ready_transactions(&address, current_sequence_number); + MempoolAddTransactionStatus::Valid + } + + /// Check whether the queue size >= threshold in config. + pub(crate) fn health_check(&self) -> bool { + self.system_ttl_index.size() <= self.capacity + } + + /// checks if Mempool is full + /// If it's full, tries to free some space by evicting transactions from ParkingLot + fn check_if_full(&mut self) -> bool { + if self.system_ttl_index.size() >= self.capacity { + // try to free some space in Mempool from ParkingLot + if let Some((address, sequence_number)) = self.parking_lot_index.pop() { + if let Some(txns) = self.transactions.get_mut(&address) { + if let Some(txn) = txns.remove(&sequence_number) { + self.index_remove(&txn); + } + } + } + } + self.system_ttl_index.size() >= self.capacity + } + + /// check if transaction is already present in Mempool + /// e.g. given request is update + /// we allow increase in gas price to speed up process + fn check_for_update( + &mut self, + txn: &MempoolTransaction, + ) -> (bool, MempoolAddTransactionStatus) { + let mut is_update = false; + let mut status = MempoolAddTransactionStatus::Valid; + + if let Some(txns) = self.transactions.get_mut(&txn.get_sender()) { + if let Some(current_version) = txns.get_mut(&txn.get_sequence_number()) { + is_update = true; + // TODO: do we need to ensure the rest of content hasn't changed + if txn.get_gas_price() <= current_version.get_gas_price() { + status = MempoolAddTransactionStatus::InvalidUpdate; + } else { + self.priority_index.remove(¤t_version); + current_version.txn = txn.txn.clone(); + self.priority_index.insert(¤t_version); + } + } + } + (is_update, status) + } + + /// fixes following invariants: + /// all transactions of given account that are sequential to current sequence number + /// supposed to be included in both PriorityIndex (ordering for Consensus) and + /// TimelineIndex (txns for SharedMempool) + /// Other txns are considered to be "non-ready" and should be added to ParkingLotIndex + fn process_ready_transactions( + &mut self, + address: &AccountAddress, + current_sequence_number: u64, + ) { + if let Some(txns) = self.transactions.get_mut(&address) { + let mut sequence_number = current_sequence_number; + while let Some(txn) = txns.get_mut(&sequence_number) { + self.priority_index.insert(txn); + + if txn.timeline_state == TimelineState::NotReady { + self.timeline_index.insert(txn); + } + sequence_number += 1; + } + for (_, txn) in txns.range_mut((Bound::Excluded(sequence_number), Bound::Unbounded)) { + match txn.timeline_state { + TimelineState::Ready(_) => {} + _ => { + self.parking_lot_index.insert(&txn); + } + } + } + } + } + + /// handles transaction commit + /// it includes deletion of all transactions with sequence number <= `sequence_number` + /// and potential promotion of sequential txns to PriorityIndex/TimelineIndex + pub(crate) fn commit_transaction(&mut self, account: &AccountAddress, sequence_number: u64) { + if let Some(txns) = self.transactions.get_mut(&account) { + // remove all previous seq number transactions for this account + // This can happen if transactions are sent to multiple nodes and one of + // nodes has sent the transaction to consensus but this node still has the + // transaction sitting in mempool + let mut active = txns.split_off(&(sequence_number + 1)); + let txns_for_removal = txns.clone(); + txns.clear(); + txns.append(&mut active); + + for transaction in txns_for_removal.values() { + self.index_remove(transaction); + } + } + self.process_ready_transactions(account, sequence_number + 1); + } + + /// removes transaction from all indexes + fn index_remove(&mut self, txn: &MempoolTransaction) { + self.system_ttl_index.remove(&txn); + self.expiration_time_index.remove(&txn); + self.priority_index.remove(&txn); + self.timeline_index.remove(&txn); + self.parking_lot_index.remove(&txn); + OP_COUNTERS.set("txn.system_ttl_index", self.system_ttl_index.size()); + } + + /// returns gas amount required to process all transactions for given account + pub(crate) fn get_required_balance(&mut self, address: &AccountAddress) -> u64 { + match self.transactions.get_mut(&address) { + Some(txns) => txns.iter().fold(0, |acc, (_, txn)| { + acc + txn.txn.gas_unit_price() * txn.gas_amount + }), + None => 0, + } + } + + /// Read `count` transactions from timeline since `timeline_id` + /// Returns block of transactions and new last_timeline_id + pub(crate) fn read_timeline( + &mut self, + timeline_id: u64, + count: usize, + ) -> (Vec, u64) { + let mut batch = vec![]; + let mut last_timeline_id = timeline_id; + for (address, sequence_number) in self.timeline_index.read_timeline(timeline_id, count) { + if let Some(txns) = self.transactions.get_mut(&address) { + if let Some(txn) = txns.get(&sequence_number) { + batch.push(txn.txn.clone()); + if let TimelineState::Ready(timeline_id) = txn.timeline_state { + last_timeline_id = timeline_id; + } + } + } + } + (batch, last_timeline_id) + } + + /// GC old transactions + pub(crate) fn gc_by_system_ttl(&mut self) { + let now = SystemTime::now() + .duration_since(UNIX_EPOCH) + .expect("init timestamp failure"); + + self.gc(now, true); + } + + /// GC old transactions based on client-specified expiration time + pub(crate) fn gc_by_expiration_time(&mut self, block_time: Duration) { + self.gc(block_time, false); + } + + fn gc(&mut self, now: Duration, by_system_ttl: bool) { + let (index_name, index) = if by_system_ttl { + ("gc.system_ttl_index", &mut self.system_ttl_index) + } else { + ("gc.expiration_time_index", &mut self.expiration_time_index) + }; + OP_COUNTERS.inc(index_name); + + for key in index.gc(now) { + if let Some(txns) = self.transactions.get_mut(&key.address) { + // mark all following transactions as non-ready + for (_, t) in txns.range((Bound::Excluded(key.sequence_number), Bound::Unbounded)) { + self.parking_lot_index.insert(&t); + self.priority_index.remove(&t); + self.timeline_index.remove(&t); + } + if let Some(txn) = txns.remove(&key.sequence_number) { + let is_active = self.priority_index.contains(&txn); + let status = if is_active { "active" } else { "parked" }; + OP_COUNTERS.inc(&format!("{}.{}", index_name, status)); + self.index_remove(&txn); + } + } + } + OP_COUNTERS.set("txn.system_ttl_index", self.system_ttl_index.size()); + } + + pub(crate) fn iter_queue(&self) -> PriorityQueueIter { + self.priority_index.iter() + } +} diff --git a/mempool/src/core_mempool/unit_tests/common.rs b/mempool/src/core_mempool/unit_tests/common.rs new file mode 100644 index 0000000000000..264ed497780ba --- /dev/null +++ b/mempool/src/core_mempool/unit_tests/common.rs @@ -0,0 +1,137 @@ +// Copyright (c) The Libra Core Contributors +// SPDX-License-Identifier: Apache-2.0 + +use crate::core_mempool::{CoreMempool, MempoolAddTransactionStatus, TimelineState, TxnPointer}; +use config::config::NodeConfigHelpers; +use crypto::signing::generate_keypair_for_testing; +use failure::prelude::*; +use lazy_static::lazy_static; +use rand::{rngs::StdRng, SeedableRng}; +use std::{collections::HashSet, iter::FromIterator}; +use types::{ + account_address::AccountAddress, + transaction::{Program, RawTransaction, SignedTransaction}, +}; + +pub(crate) fn setup_mempool() -> (CoreMempool, ConsensusMock) { + ( + CoreMempool::new(&NodeConfigHelpers::get_single_node_test_config(true)), + ConsensusMock::new(), + ) +} + +lazy_static! { + static ref ACCOUNTS: Vec = + vec![AccountAddress::random(), AccountAddress::random()]; +} + +#[derive(Clone)] +pub struct TestTransaction { + address: usize, + sequence_number: u64, + gas_price: u64, +} + +impl TestTransaction { + pub(crate) fn new(address: usize, sequence_number: u64, gas_price: u64) -> Self { + Self { + address, + sequence_number, + gas_price, + } + } + + pub(crate) fn make_signed_transaction_with_expiration_time( + &self, + exp_time: std::time::Duration, + ) -> SignedTransaction { + self.make_signed_transaction_impl(100, exp_time) + } + + pub(crate) fn make_signed_transaction_with_max_gas_amount( + &self, + max_gas_amount: u64, + ) -> SignedTransaction { + self.make_signed_transaction_impl( + max_gas_amount, + std::time::Duration::from_secs(u64::max_value()), + ) + } + + pub(crate) fn make_signed_transaction(&self) -> SignedTransaction { + self.make_signed_transaction_impl(100, std::time::Duration::from_secs(u64::max_value())) + } + + fn make_signed_transaction_impl( + &self, + max_gas_amount: u64, + exp_time: std::time::Duration, + ) -> SignedTransaction { + let raw_txn = RawTransaction::new( + TestTransaction::get_address(self.address), + self.sequence_number, + Program::new(vec![], vec![], vec![]), + max_gas_amount, + self.gas_price, + exp_time, + ); + let mut seed: [u8; 32] = [0u8; 32]; + seed[..4].copy_from_slice(&[1, 2, 3, 4]); + let mut rng: StdRng = StdRng::from_seed(seed); + let (privkey, pubkey) = generate_keypair_for_testing(&mut rng); + raw_txn + .sign(&privkey, pubkey) + .expect("Failed to sign raw transaction.") + } + + pub(crate) fn get_address(address: usize) -> AccountAddress { + ACCOUNTS[address] + } +} + +// adds transactions to mempool +pub(crate) fn add_txns_to_mempool( + pool: &mut CoreMempool, + txns: Vec, +) -> Vec { + let mut transactions = vec![]; + for transaction in txns { + let txn = transaction.make_signed_transaction(); + pool.add_txn(txn.clone(), 0, 0, 1000, TimelineState::NotReady); + transactions.push(txn); + } + transactions +} + +pub(crate) fn add_txn(pool: &mut CoreMempool, transaction: TestTransaction) -> Result<()> { + let txn = transaction.make_signed_transaction(); + match pool.add_txn(txn.clone(), 0, 0, 1000, TimelineState::NotReady) { + MempoolAddTransactionStatus::Valid => Ok(()), + _ => Err(format_err!("insertion failure")), + } +} + +// helper struct that keeps state between `.get_block` calls. Imitates work of Consensus +pub struct ConsensusMock(HashSet); + +impl ConsensusMock { + pub(crate) fn new() -> Self { + Self(HashSet::new()) + } + + pub(crate) fn get_block( + &mut self, + mempool: &mut CoreMempool, + block_size: u64, + ) -> Vec { + let block = mempool.get_block(block_size, self.0.clone()); + self.0 = self + .0 + .union(&HashSet::from_iter( + block.iter().map(|t| (t.sender(), t.sequence_number())), + )) + .cloned() + .collect(); + block + } +} diff --git a/mempool/src/core_mempool/unit_tests/core_mempool_test.rs b/mempool/src/core_mempool/unit_tests/core_mempool_test.rs new file mode 100644 index 0000000000000..064cbce440f96 --- /dev/null +++ b/mempool/src/core_mempool/unit_tests/core_mempool_test.rs @@ -0,0 +1,321 @@ +// Copyright (c) The Libra Core Contributors +// SPDX-License-Identifier: Apache-2.0 + +use crate::core_mempool::{ + unit_tests::common::{add_txn, add_txns_to_mempool, setup_mempool, TestTransaction}, + CoreMempool, MempoolAddTransactionStatus, TimelineState, +}; +use config::config::NodeConfigHelpers; +use std::{collections::HashSet, time::Duration}; +use types::transaction::SignedTransaction; + +#[test] +fn test_transaction_ordering() { + let (mut mempool, mut consensus) = setup_mempool(); + + // default ordering: gas price + let mut transactions = add_txns_to_mempool( + &mut mempool, + vec![TestTransaction::new(0, 0, 3), TestTransaction::new(1, 0, 5)], + ); + assert_eq!( + consensus.get_block(&mut mempool, 1), + vec!(transactions[1].clone()) + ); + assert_eq!( + consensus.get_block(&mut mempool, 1), + vec!(transactions[0].clone()) + ); + + // second level ordering: expiration time + let (mut mempool, mut consensus) = setup_mempool(); + transactions = add_txns_to_mempool( + &mut mempool, + vec![TestTransaction::new(0, 0, 1), TestTransaction::new(1, 0, 1)], + ); + for transaction in &transactions { + assert_eq!( + consensus.get_block(&mut mempool, 1), + vec![transaction.clone()] + ); + } + + // last level: for same account it should be by sequence number + let (mut mempool, mut consensus) = setup_mempool(); + transactions = add_txns_to_mempool( + &mut mempool, + vec![ + TestTransaction::new(1, 0, 7), + TestTransaction::new(1, 1, 5), + TestTransaction::new(1, 2, 1), + TestTransaction::new(1, 3, 6), + ], + ); + for transaction in &transactions { + assert_eq!( + consensus.get_block(&mut mempool, 1), + vec![transaction.clone()] + ); + } +} + +#[test] +fn test_update_transaction_in_mempool() { + let (mut mempool, mut consensus) = setup_mempool(); + let txns = add_txns_to_mempool( + &mut mempool, + vec![TestTransaction::new(0, 0, 1), TestTransaction::new(1, 0, 2)], + ); + let fixed_txns = add_txns_to_mempool(&mut mempool, vec![TestTransaction::new(0, 0, 5)]); + + // check that first transactions pops up first + assert_eq!( + consensus.get_block(&mut mempool, 1), + vec![fixed_txns[0].clone()] + ); + assert_eq!(consensus.get_block(&mut mempool, 1), vec![txns[1].clone()]); +} + +#[test] +fn test_remove_transaction() { + let (mut pool, mut consensus) = setup_mempool(); + + // test normal flow + let txns = add_txns_to_mempool( + &mut pool, + vec![TestTransaction::new(0, 0, 1), TestTransaction::new(0, 1, 2)], + ); + for txn in txns { + pool.remove_transaction(&txn.sender(), txn.sequence_number(), false); + } + let new_txns = add_txns_to_mempool( + &mut pool, + vec![TestTransaction::new(1, 0, 3), TestTransaction::new(1, 1, 4)], + ); + // should return only txns from new_txns + assert_eq!(consensus.get_block(&mut pool, 1), vec!(new_txns[0].clone())); + assert_eq!(consensus.get_block(&mut pool, 1), vec!(new_txns[1].clone())); +} + +#[test] +fn test_balance_check() { + let mut pool = setup_mempool().0; + let address = 0; + + let transaction1 = TestTransaction::new(address, 0, 1); + assert_eq!( + pool.add_txn( + transaction1.make_signed_transaction(), + 1, + 0, + 2, + TimelineState::NotReady + ), + MempoolAddTransactionStatus::Valid + ); + + assert_eq!( + pool.add_txn( + TestTransaction::new(address, 1, 1).make_signed_transaction(), + 10, + 1, + 5, + TimelineState::NotReady + ), + MempoolAddTransactionStatus::InsufficientBalance + ); + + // check that gas unit price is taking into account for balance check + assert_eq!( + pool.add_txn( + TestTransaction::new(address, 1, /* gas price */ 2).make_signed_transaction(), + /* gas amount */ 3, + 1, + 5, + TimelineState::NotReady + ), + MempoolAddTransactionStatus::InsufficientBalance + ); +} + +#[test] +fn test_system_ttl() { + // created mempool with system_transaction_timeout = 0 + // All transactions are supposed to be evicted on next gc run + let mut config = NodeConfigHelpers::get_single_node_test_config(true); + config.mempool.system_transaction_timeout_secs = 0; + let mut mempool = CoreMempool::new(&config); + + add_txn(&mut mempool, TestTransaction::new(0, 0, 10)).unwrap(); + + // reset system ttl timeout + mempool.system_transaction_timeout = Duration::from_secs(10); + // add new transaction. Should be valid for 10 seconds + let transaction = TestTransaction::new(1, 0, 1); + add_txn(&mut mempool, transaction.clone()).unwrap(); + + // gc routine should clear transaction from first insert but keep last one + mempool.gc_by_system_ttl(); + let batch = mempool.get_block(1, HashSet::new()); + assert_eq!(vec![transaction.make_signed_transaction()], batch); +} + +#[test] +fn test_commit_callback() { + // consensus commit callback should unlock txns in parking lot + let mut pool = setup_mempool().0; + // insert transaction with sequence number 6 to pool(while last known executed transaction is 0) + let txns = add_txns_to_mempool(&mut pool, vec![TestTransaction::new(1, 6, 1)]); + + // check that pool is empty + assert!(pool.get_block(1, HashSet::new()).is_empty()); + // transaction 5 got back from consensus + pool.remove_transaction(&TestTransaction::get_address(1), 5, false); + // verify that we can execute transaction 6 + assert_eq!(pool.get_block(1, HashSet::new())[0], txns[0]); +} + +#[test] +fn test_sequence_number_cache() { + // checks potential race where StateDB is lagging + let mut pool = setup_mempool().0; + // callback from consensus should set current sequence number for account + pool.remove_transaction(&TestTransaction::get_address(1), 5, false); + + // try to add transaction with sequence number 6 to pool(while last known executed transaction + // for AC is 0) + add_txns_to_mempool(&mut pool, vec![TestTransaction::new(1, 6, 1)]); + // verify that we can execute transaction 6 + assert_eq!(pool.get_block(1, HashSet::new()).len(), 1); +} + +#[test] +fn test_reset_sequence_number_on_failure() { + let mut pool = setup_mempool().0; + // add two transactions for account + add_txns_to_mempool( + &mut pool, + vec![TestTransaction::new(1, 0, 1), TestTransaction::new(1, 1, 1)], + ); + + // notify mempool about failure in arbitrary order + pool.remove_transaction(&TestTransaction::get_address(1), 0, true); + pool.remove_transaction(&TestTransaction::get_address(1), 1, true); + + // verify that new transaction for this account can be added + assert!(add_txn(&mut pool, TestTransaction::new(1, 0, 1)).is_ok()); +} + +#[test] +fn test_timeline() { + let mut pool = setup_mempool().0; + add_txns_to_mempool( + &mut pool, + vec![ + TestTransaction::new(1, 0, 1), + TestTransaction::new(1, 1, 1), + TestTransaction::new(1, 3, 1), + TestTransaction::new(1, 5, 1), + ], + ); + let view = |txns: Vec| -> Vec { + txns.iter() + .map(SignedTransaction::sequence_number) + .collect() + }; + let (timeline, _) = pool.read_timeline(0, 10); + assert_eq!(view(timeline), vec![0, 1]); + + // add txn 2 to unblock txn3 + add_txns_to_mempool(&mut pool, vec![TestTransaction::new(1, 2, 1)]); + let (timeline, _) = pool.read_timeline(0, 10); + assert_eq!(view(timeline), vec![0, 1, 2, 3]); + + // try different start read position + let (timeline, _) = pool.read_timeline(2, 10); + assert_eq!(view(timeline), vec![2, 3]); + + // simulate callback from consensus to unblock txn 5 + pool.remove_transaction(&TestTransaction::get_address(1), 4, false); + let (timeline, _) = pool.read_timeline(0, 10); + assert_eq!(view(timeline), vec![5]); +} + +#[test] +fn test_capacity() { + let mut config = NodeConfigHelpers::get_single_node_test_config(true); + config.mempool.capacity = 1; + config.mempool.system_transaction_timeout_secs = 0; + let mut pool = CoreMempool::new(&config); + + // error on exceeding limit + add_txn(&mut pool, TestTransaction::new(1, 0, 1)).unwrap(); + assert!(add_txn(&mut pool, TestTransaction::new(1, 1, 1)).is_err()); + + // commit transaction and free space + pool.remove_transaction(&TestTransaction::get_address(1), 0, false); + assert!(add_txn(&mut pool, TestTransaction::new(1, 1, 1)).is_ok()); + + // fill it up and check that GC routine will clear space + assert!(add_txn(&mut pool, TestTransaction::new(1, 2, 1)).is_err()); + pool.gc_by_system_ttl(); + assert!(add_txn(&mut pool, TestTransaction::new(1, 2, 1)).is_ok()); +} + +#[test] +fn test_parking_lot_eviction() { + let mut config = NodeConfigHelpers::get_single_node_test_config(true); + config.mempool.capacity = 5; + let mut pool = CoreMempool::new(&config); + // add transactions with following sequence numbers to Mempool + for seq in &[0, 1, 2, 9, 10] { + add_txn(&mut pool, TestTransaction::new(1, *seq, 1)).unwrap(); + } + // Mempool is full. Insert few txns for other account + for seq in &[0, 1] { + add_txn(&mut pool, TestTransaction::new(0, *seq, 1)).unwrap(); + } + // Make sure that we have correct txns in Mempool + let mut txns: Vec<_> = pool + .get_block(5, HashSet::new()) + .iter() + .map(SignedTransaction::sequence_number) + .collect(); + txns.sort(); + assert_eq!(txns, vec![0, 0, 1, 1, 2]); + + // Make sure we can't insert any new transactions, cause parking lot supposed to be empty by now + assert!(add_txn(&mut pool, TestTransaction::new(0, 2, 1)).is_err()); +} + +#[test] +fn test_gc_ready_transaction() { + let mut pool = setup_mempool().0; + add_txn(&mut pool, TestTransaction::new(1, 0, 1)).unwrap(); + + // insert in the middle transaction that's going to be expired + let txn = TestTransaction::new(1, 1, 1) + .make_signed_transaction_with_expiration_time(Duration::from_secs(0)); + pool.add_txn(txn, 0, 0, 100, TimelineState::NotReady); + + // insert few transactions after it + // They supposed to be ready because there's sequential path from 0 to them + add_txn(&mut pool, TestTransaction::new(1, 2, 1)).unwrap(); + add_txn(&mut pool, TestTransaction::new(1, 3, 1)).unwrap(); + + // chack that all txns are ready + let (timeline, _) = pool.read_timeline(0, 10); + assert_eq!(timeline.len(), 4); + + // gc expired transaction + pool.gc_by_expiration_time(Duration::from_secs(1)); + + // make sure txns 2 and 3 became not ready and we can't read them from any API + let block = pool.get_block(10, HashSet::new()); + assert_eq!(block.len(), 1); + assert_eq!(block[0].sequence_number(), 0); + + let (timeline, _) = pool.read_timeline(0, 10); + assert_eq!(timeline.len(), 1); + assert_eq!(timeline[0].sequence_number(), 0); +} diff --git a/mempool/src/core_mempool/unit_tests/mod.rs b/mempool/src/core_mempool/unit_tests/mod.rs new file mode 100644 index 0000000000000..fb24e7b22bb96 --- /dev/null +++ b/mempool/src/core_mempool/unit_tests/mod.rs @@ -0,0 +1,6 @@ +// Copyright (c) The Libra Core Contributors +// SPDX-License-Identifier: Apache-2.0 + +mod common; +mod core_mempool_test; +mod shared_mempool_test; diff --git a/mempool/src/core_mempool/unit_tests/shared_mempool_test.rs b/mempool/src/core_mempool/unit_tests/shared_mempool_test.rs new file mode 100644 index 0000000000000..8fe058aac7849 --- /dev/null +++ b/mempool/src/core_mempool/unit_tests/shared_mempool_test.rs @@ -0,0 +1,282 @@ +// Copyright (c) The Libra Core Contributors +// SPDX-License-Identifier: Apache-2.0 + +use crate::{ + core_mempool::{unit_tests::common::TestTransaction, CoreMempool, TimelineState}, + shared_mempool::{start_shared_mempool, SharedMempoolNotification, SyncEvent}, +}; +use channel; +use config::config::{NodeConfig, NodeConfigHelpers}; +use failure::prelude::*; +use futures::{ + sync::mpsc::{unbounded, UnboundedReceiver, UnboundedSender}, + Stream, +}; +use futures_preview::{ + compat::Stream01CompatExt, executor::block_on, SinkExt, StreamExt, TryStreamExt, +}; +use network::{ + interface::{NetworkNotification, NetworkRequest}, + proto::MempoolSyncMsg, + validator_network::{MempoolNetworkEvents, MempoolNetworkSender}, +}; +use proto_conv::FromProto; +use std::{ + collections::{HashMap, HashSet}, + sync::{Arc, Mutex}, +}; +use storage_service::mocks::mock_storage_client::MockStorageReadClient; +use tokio::runtime::Runtime; +use types::{transaction::SignedTransaction, PeerId}; +use vm_validator::mocks::mock_vm_validator::MockVMValidator; + +#[derive(Default)] +struct SharedMempoolNetwork { + mempools: HashMap>>, + network_reqs_rxs: HashMap>, + network_notifs_txs: HashMap>, + runtimes: HashMap, + subscribers: HashMap>, + timers: HashMap>, +} + +impl SharedMempoolNetwork { + fn bootstrap_with_config(peers: Vec, mut config: NodeConfig) -> Self { + let mut smp = Self::default(); + config.mempool.shared_mempool_batch_size = 1; + + for peer in peers { + let mempool = Arc::new(Mutex::new(CoreMempool::new(&config))); + let (network_reqs_tx, network_reqs_rx) = channel::new_test(8); + let (network_notifs_tx, network_notifs_rx) = channel::new_test(8); + let network_sender = MempoolNetworkSender::new(network_reqs_tx); + let network_events = MempoolNetworkEvents::new(network_notifs_rx); + let (sender, subscriber) = unbounded(); + let (timer_sender, timer_receiver) = unbounded(); + + let runtime = start_shared_mempool( + &config, + Arc::clone(&mempool), + network_sender, + network_events, + Arc::new(MockStorageReadClient), + Arc::new(MockVMValidator), + vec![sender], + Some( + timer_receiver + .compat() + .map_err(|_| format_err!("test")) + .boxed(), + ), + ); + + smp.mempools.insert(peer, mempool); + smp.network_reqs_rxs.insert(peer, network_reqs_rx); + smp.network_notifs_txs.insert(peer, network_notifs_tx); + smp.subscribers.insert(peer, subscriber); + smp.timers.insert(peer, timer_sender); + smp.runtimes.insert(peer, runtime); + } + smp + } + + fn bootstrap(peers: Vec) -> Self { + Self::bootstrap_with_config(peers, NodeConfigHelpers::get_single_node_test_config(true)) + } + + fn add_txns(&mut self, peer_id: &PeerId, txns: Vec) { + let mut mempool = self.mempools.get(peer_id).unwrap().lock().unwrap(); + for txn in txns { + let transaction = txn.make_signed_transaction_with_max_gas_amount(5); + mempool.add_txn(transaction, 0, 0, 10, TimelineState::NotReady); + } + } + + fn send_event(&mut self, peer: &PeerId, notif: NetworkNotification) { + let network_notifs_tx = self.network_notifs_txs.get_mut(peer).unwrap(); + block_on(network_notifs_tx.send(notif)).unwrap(); + self.wait_for_event(peer, SharedMempoolNotification::PeerStateChange); + } + + fn wait_for_event(&mut self, peer_id: &PeerId, event: SharedMempoolNotification) { + let subscriber = self.subscribers.get_mut(peer_id).unwrap(); + while subscriber.wait().next().unwrap().unwrap() != event { + continue; + } + } + + /// deliveres next message from given node to it's peer + fn deliver_message(&mut self, peer: &PeerId) -> (SignedTransaction, PeerId) { + // emulate timer tick + self.timers + .get(peer) + .unwrap() + .unbounded_send(SyncEvent) + .unwrap(); + + // await next message from node + let network_reqs_rx = self.network_reqs_rxs.get_mut(peer).unwrap(); + let network_req = block_on(network_reqs_rx.next()).unwrap(); + + match network_req { + NetworkRequest::SendMessage(peer_id, msg) => { + let mut sync_msg: MempoolSyncMsg = + ::protobuf::parse_from_bytes(msg.mdata.as_ref()).unwrap(); + let transaction = + SignedTransaction::from_proto(sync_msg.take_transactions().pop().unwrap()) + .unwrap(); + // send it to peer + let receiver_network_notif_tx = self.network_notifs_txs.get_mut(&peer_id).unwrap(); + block_on( + receiver_network_notif_tx.send(NetworkNotification::RecvMessage(*peer, msg)), + ) + .unwrap(); + + // await message delivery + self.wait_for_event(&peer_id, SharedMempoolNotification::NewTransactions); + + // verify transaction was inserted into Mempool + let mempool = self.mempools.get(&peer).unwrap(); + let block = mempool.lock().unwrap().get_block(100, HashSet::new()); + assert!(block.iter().any(|t| t == &transaction)); + (transaction, peer_id) + } + _ => panic!("peer {:?} didn't broadcast transaction", peer), + } + } +} + +#[test] +fn test_basic_flow() { + let (peer_a, peer_b) = (PeerId::random(), PeerId::random()); + + let mut smp = SharedMempoolNetwork::bootstrap(vec![peer_a, peer_b]); + smp.add_txns( + &peer_a, + vec![ + TestTransaction::new(1, 0, 1), + TestTransaction::new(1, 1, 1), + TestTransaction::new(1, 2, 1), + ], + ); + + // A discovers new peer B + smp.send_event(&peer_a, NetworkNotification::NewPeer(peer_b)); + + for seq in 0..3 { + // A attempts to send message + let transaction = smp.deliver_message(&peer_a).0; + assert_eq!(transaction.sequence_number(), seq); + } +} + +#[test] +fn test_interruption_in_sync() { + let (peer_a, peer_b, peer_c) = (PeerId::random(), PeerId::random(), PeerId::random()); + let mut smp = SharedMempoolNetwork::bootstrap(vec![peer_a, peer_b, peer_c]); + smp.add_txns(&peer_a, vec![TestTransaction::new(1, 0, 1)]); + + // A discovers 2 peers + smp.send_event(&peer_a, NetworkNotification::NewPeer(peer_b)); + smp.send_event(&peer_a, NetworkNotification::NewPeer(peer_c)); + + // make sure it delivered first transaction to both nodes + let mut peers = vec![ + smp.deliver_message(&peer_a).1, + smp.deliver_message(&peer_a).1, + ]; + peers.sort(); + let mut expected_peers = vec![peer_b, peer_c]; + expected_peers.sort(); + assert_eq!(peers, expected_peers); + + // A loses connection to B + smp.send_event(&peer_a, NetworkNotification::LostPeer(peer_b)); + + // only C receives following transactions + smp.add_txns(&peer_a, vec![TestTransaction::new(1, 1, 1)]); + let (txn, peer_id) = smp.deliver_message(&peer_a); + assert_eq!(peer_id, peer_c); + assert_eq!(txn.sequence_number(), 1); + + smp.add_txns(&peer_a, vec![TestTransaction::new(1, 2, 1)]); + let (txn, peer_id) = smp.deliver_message(&peer_a); + assert_eq!(peer_id, peer_c); + assert_eq!(txn.sequence_number(), 2); + + // A reconnects to B + smp.send_event(&peer_a, NetworkNotification::NewPeer(peer_b)); + + // B should receive transaction 2 + let (txn, peer_id) = smp.deliver_message(&peer_a); + assert_eq!(peer_id, peer_b); + assert_eq!(txn.sequence_number(), 1); +} + +#[test] +fn test_ready_transactions() { + let (peer_a, peer_b) = (PeerId::random(), PeerId::random()); + let mut smp = SharedMempoolNetwork::bootstrap(vec![peer_a, peer_b]); + smp.add_txns( + &peer_a, + vec![TestTransaction::new(1, 0, 1), TestTransaction::new(1, 2, 1)], + ); + // first message delivery + smp.send_event(&peer_a, NetworkNotification::NewPeer(peer_b)); + smp.deliver_message(&peer_a); + + // add txn1 to Mempool + smp.add_txns(&peer_a, vec![TestTransaction::new(1, 1, 1)]); + // txn1 unlocked txn2. Now all transactions can go through in correct order + let txn = &smp.deliver_message(&peer_a).0; + assert_eq!(txn.sequence_number(), 1); + let txn = &smp.deliver_message(&peer_a).0; + assert_eq!(txn.sequence_number(), 2); +} + +#[test] +fn test_broadcast_self_transactions() { + let (peer_a, peer_b) = (PeerId::random(), PeerId::random()); + let mut smp = SharedMempoolNetwork::bootstrap(vec![peer_a, peer_b]); + smp.add_txns(&peer_a, vec![TestTransaction::new(0, 0, 1)]); + + // A and B discover each other + smp.send_event(&peer_a, NetworkNotification::NewPeer(peer_b)); + smp.send_event(&peer_b, NetworkNotification::NewPeer(peer_a)); + + // A sends txn to B + smp.deliver_message(&peer_a); + + // add new txn to B + smp.add_txns(&peer_b, vec![TestTransaction::new(1, 0, 1)]); + + // verify that A will receive only second transaction from B + let (txn, _) = smp.deliver_message(&peer_b); + assert_eq!(txn.sender(), TestTransaction::get_address(1)); +} + +#[test] +fn test_broadcast_dependencies() { + let (peer_a, peer_b) = (PeerId::random(), PeerId::random()); + let mut smp = SharedMempoolNetwork::bootstrap(vec![peer_a, peer_b]); + // Peer A has transactions with sequence numbers 0 and 2 + smp.add_txns( + &peer_a, + vec![TestTransaction::new(0, 0, 1), TestTransaction::new(0, 2, 1)], + ); + // Peer B has txn1 + smp.add_txns(&peer_b, vec![TestTransaction::new(0, 1, 1)]); + + // A and B discover each other + smp.send_event(&peer_a, NetworkNotification::NewPeer(peer_b)); + smp.send_event(&peer_b, NetworkNotification::NewPeer(peer_a)); + + // B receives 0 + smp.deliver_message(&peer_a); + // now B can broadcast 1 + let txn = smp.deliver_message(&peer_b).0; + assert_eq!(txn.sequence_number(), 1); + // now A can broadcast 2 + let txn = smp.deliver_message(&peer_a).0; + assert_eq!(txn.sequence_number(), 2); +} diff --git a/mempool/src/lib.rs b/mempool/src/lib.rs new file mode 100644 index 0000000000000..b2830fbcec727 --- /dev/null +++ b/mempool/src/lib.rs @@ -0,0 +1,72 @@ +// Copyright (c) The Libra Core Contributors +// SPDX-License-Identifier: Apache-2.0 + +#![feature(async_await)] +#![deny(missing_docs)] +//! Mempool is used to hold transactions that have been submitted but not yet agreed upon and +//! executed. +//! +//! **Flow**: AC sends transactions into mempool which holds them for a period of time before +//! sending them into consensus. When a new transaction is added, Mempool shares this transaction +//! with other nodes in the system. This is a form of β€œshared mempool” in that transactions between +//! mempools are shared with other validators. This helps maintain a pseudo global ordering since +//! when a validator receives a transaction from another mempool, it will be ordered when added in +//! the ordered queue of the recipient validator. To reduce network consumption, in β€œshared mempool” +//! each validator is responsible for delivery of its own transactions (we don't rebroadcast +//! transactions originated on a different peer). Also we only broadcast transactions that have some +//! chance to be included in next block: their sequence number equals to the next sequence number of +//! account or sequential to it. For example, if the current sequence number for an account is 2 and +//! local mempool contains transactions with sequence numbers 2,3,4,7,8, then only transactions 2, 3 +//! and 4 will be broadcast. +//! +//! Consensus pulls transactions from mempool rather than mempool pushing into consensus. This is +//! done so that while consensus is not yet ready for transactions, we keep ordering based on gas +//! and consensus can let transactions build up. This allows for batching of transactions into a +//! single consensus block as well as prioritizing by gas price. Mempool doesn't keep track of +//! transactions that were sent to Consensus. On each get_block request, Consensus additionally +//! sends a set of transactions that were pulled from Mempool so far but were not committed yet. +//! This is done so Mempool can be agnostic about different Consensus proposal branches. Once a +//! transaction is fully executed and written to storage, Consensus notifies Mempool about it which +//! later drops it from its internal state. +//! +//! **Internals**: Internally Mempool is modeled as `HashMap` +//! with various indexes built on top of it. The main index `PriorityIndex` is an ordered queue of +//! transactions that are β€œready” to be included in next block(i.e. have sequence number sequential +//! to current for account). This queue is ordered by gas price so that if a client is willing to +//! pay more (than other clients) per unit of execution, then they can enter consensus earlier. Note +//! that although global ordering is maintained by gas price, for a single account, transactions are +//! ordered by sequence number. +//! +//! All transactions that are not ready to be included in the next block are part of separate +//! `ParkingLotIndex`. They will be moved to the ordered queue once some event unblocks them. For +//! example, Mempool has transaction with sequence number 4, while current sequence number for that +//! account is 3. Such transaction is considered to be β€œnon-ready”. Then callback from Consensus +//! notifies that transaction was committed(i.e. transaction 3 was submitted to different node). +//! Such event β€œunblocks” local transaction and txn4 will be moved to OrderedQueue. +//! +//! Mempool only holds a limited number of transactions to prevent OOMing the system. Additionally +//! there's a limit of number of transactions per account to prevent different abuses/attacks +//! +//! Transactions in Mempool have two types of expirations: systemTTL and client-specified +//! expiration. Once we hit either of those, the transaction is removed from Mempool. SystemTTL is +//! checked periodically in the background, while the client-specified expiration is checked on +//! every Consensus commit request. We use a separate system TTL to ensure that a transaction won't +//! remain stuck in Mempool forever, even if Consensus doesn't make progress +pub mod proto; +pub use runtime::MempoolRuntime; + +mod core_mempool; +mod mempool_service; +mod runtime; +mod shared_mempool; + +// module op counters +use lazy_static::lazy_static; +use metrics::OpMetrics; +lazy_static! { + static ref OP_COUNTERS: OpMetrics = OpMetrics::new_and_registered("mempool"); +} +pub use crate::core_mempool::MempoolAddTransactionStatus; + +#[cfg(test)] +mod unit_tests; diff --git a/mempool/src/mempool_service.rs b/mempool/src/mempool_service.rs new file mode 100644 index 0000000000000..fcfaf0fabef6a --- /dev/null +++ b/mempool/src/mempool_service.rs @@ -0,0 +1,157 @@ +// Copyright (c) The Libra Core Contributors +// SPDX-License-Identifier: Apache-2.0 + +use crate::{ + core_mempool::{CoreMempool, TimelineState, TxnPointer}, + proto::mempool_grpc::Mempool, + OP_COUNTERS, +}; +use futures::Future; +use grpc_helpers::{create_grpc_invalid_arg_status, default_reply_error_logger}; +use logger::prelude::*; +use metrics::counters::SVC_COUNTERS; +use proto_conv::{FromProto, IntoProto}; +use std::{ + cmp, + collections::HashSet, + convert::TryFrom, + sync::{Arc, Mutex}, + time::Duration, +}; +use types::{ + account_address::AccountAddress, proto::transaction::SignedTransactionsBlock, + transaction::SignedTransaction, +}; + +#[derive(Clone)] +pub(crate) struct MempoolService { + pub(crate) core_mempool: Arc>, +} + +impl Mempool for MempoolService { + fn add_transaction_with_validation( + &mut self, + ctx: ::grpcio::RpcContext<'_>, + mut req: crate::proto::mempool::AddTransactionWithValidationRequest, + sink: ::grpcio::UnarySink, + ) { + trace!("[GRPC] Mempool::add_transaction_with_validation"); + let _timer = SVC_COUNTERS.req(&ctx); + let mut success = true; + let proto_transaction = req.take_signed_txn(); + match SignedTransaction::from_proto(proto_transaction) { + Err(e) => { + success = false; + ctx.spawn( + sink.fail(create_grpc_invalid_arg_status( + "add_transaction_with_validation", + e, + )) + .map_err(default_reply_error_logger), + ); + } + Ok(transaction) => { + let insertion_result = self + .core_mempool + .lock() + .expect("[add txn] acquire mempool lock") + .add_txn( + transaction, + req.max_gas_cost, + req.latest_sequence_number, + req.account_balance, + TimelineState::NotReady, + ); + + let mut response = + crate::proto::mempool::AddTransactionWithValidationResponse::new(); + response.set_status(insertion_result.into_proto()); + ctx.spawn(sink.success(response).map_err(default_reply_error_logger)) + } + } + SVC_COUNTERS.resp(&ctx, success); + } + + fn get_block( + &mut self, + ctx: ::grpcio::RpcContext<'_>, + req: super::proto::mempool::GetBlockRequest, + sink: ::grpcio::UnarySink, + ) { + trace!("[GRPC] Mempool::get_block"); + let _timer = SVC_COUNTERS.req(&ctx); + + let block_size = cmp::max(req.get_max_block_size(), 1); + OP_COUNTERS.inc_by("get_block.requested", block_size as usize); + let exclude_transactions: HashSet = req + .get_transactions() + .iter() + .map(|t| (AccountAddress::try_from(t.get_sender()), t.sequence_number)) + .filter(|(address, _)| address.is_ok()) + .map(|(address, seq)| (address.unwrap(), seq)) + .collect(); + + let mut txns = self + .core_mempool + .lock() + .expect("[get_block] acquire mempool lock") + .get_block(block_size, exclude_transactions); + + let transactions = txns.drain(..).map(SignedTransaction::into_proto).collect(); + + let mut block = SignedTransactionsBlock::new(); + block.set_transactions(::protobuf::RepeatedField::from_vec(transactions)); + OP_COUNTERS.inc_by("get_block.returned", block.get_transactions().len()); + let mut response = crate::proto::mempool::GetBlockResponse::new(); + response.set_block(block); + ctx.spawn(sink.success(response).map_err(default_reply_error_logger)); + SVC_COUNTERS.resp(&ctx, true); + } + + fn commit_transactions( + &mut self, + ctx: ::grpcio::RpcContext<'_>, + req: crate::proto::mempool::CommitTransactionsRequest, + sink: ::grpcio::UnarySink, + ) { + trace!("[GRPC] Mempool::commit_transaction"); + let _timer = SVC_COUNTERS.req(&ctx); + OP_COUNTERS.inc_by( + "commit_transactions.requested", + req.get_transactions().len(), + ); + let mut pool = self + .core_mempool + .lock() + .expect("[update status] acquire mempool lock"); + for transaction in req.get_transactions() { + if let Ok(address) = AccountAddress::try_from(transaction.get_sender()) { + let sequence_number = transaction.get_sequence_number(); + pool.remove_transaction(&address, sequence_number, transaction.get_is_rejected()); + } + } + let block_timestamp_usecs = req.get_block_timestamp_usecs(); + if block_timestamp_usecs > 0 { + pool.gc_by_expiration_time(Duration::from_micros(block_timestamp_usecs)); + } + let response = crate::proto::mempool::CommitTransactionsResponse::new(); + ctx.spawn(sink.success(response).map_err(default_reply_error_logger)); + SVC_COUNTERS.resp(&ctx, true); + } + + fn health_check( + &mut self, + ctx: ::grpcio::RpcContext<'_>, + _req: crate::proto::mempool::HealthCheckRequest, + sink: ::grpcio::UnarySink, + ) { + trace!("[GRPC] Mempool::health_check"); + let pool = self + .core_mempool + .lock() + .expect("[health_check] acquire mempool lock"); + let mut response = crate::proto::mempool::HealthCheckResponse::new(); + response.set_is_healthy(pool.health_check()); + ctx.spawn(sink.success(response).map_err(default_reply_error_logger)); + } +} diff --git a/mempool/src/proto/mempool.proto b/mempool/src/proto/mempool.proto new file mode 100644 index 0000000000000..3ef2440d73496 --- /dev/null +++ b/mempool/src/proto/mempool.proto @@ -0,0 +1,105 @@ +// Copyright (c) The Libra Core Contributors +// SPDX-License-Identifier: Apache-2.0 + +syntax = "proto3"; + +package mempool; + +import "transaction.proto"; +import "shared/mempool_status.proto"; + +// ----------------------------------------------------------------------------- +// ---------------- Mempool Service Definition +// ----------------------------------------------------------------------------- +service Mempool { + // Adds a new transaction to the mempool with validation against existing + // transactions in the mempool. Note that this function performs additional + // validation that AC is unable to perform. (because AC knows only about a + // single transaction, but mempool potentially knows about multiple pending + // transactions) + rpc AddTransactionWithValidation(AddTransactionWithValidationRequest) + returns (AddTransactionWithValidationResponse) {} + + // Fetch ordered block of transactions + rpc GetBlock(GetBlockRequest) returns (GetBlockResponse) {} + + // Remove committed transactions from Mempool + rpc CommitTransactions(CommitTransactionsRequest) + returns (CommitTransactionsResponse) {} + + // Check the health of mempool + rpc HealthCheck(HealthCheckRequest) + returns (HealthCheckResponse) {} +} + +// ----------------------------------------------------------------------------- +// ---------------- AddTransactionWithValidation +// ----------------------------------------------------------------------------- + +message AddTransactionWithValidationRequest { + // Transaction from a wallet + types.SignedTransaction signed_txn = 1; + // Max amount of gas required to execute the transaction + // Without running the program, it is very difficult to determine this number, + // so we use the max gas specified by the signed transaction. + // This field is still included separately from the signed transaction so that + // if we have a better methodology in the future, we can more accurately + // specify the max gas. + uint64 max_gas_cost = 2; + // Latest sequence number of the involved account from state db. + uint64 latest_sequence_number = 3; + // Latest account balance of the involved account from state db. + uint64 account_balance = 4; +} + +message AddTransactionWithValidationResponse { + // The ledger version at the time of the transaction submitted. The submitted + // transaction will have version bigger than this 'current_version' + uint64 current_version = 1; + // The result of the transaction submission + MempoolAddTransactionStatus status = 2; +} + +// ----------------------------------------------------------------------------- +// ---------------- GetBlock +// ----------------------------------------------------------------------------- +message GetBlockRequest { + uint64 max_block_size = 1; + repeated TransactionExclusion transactions = 2; +} + +message GetBlockResponse { types.SignedTransactionsBlock block = 1; } + +message TransactionExclusion { + bytes sender = 1; + uint64 sequence_number = 2; +} + +// ----------------------------------------------------------------------------- +// ---------------- CommitTransactions +// ----------------------------------------------------------------------------- +message CommitTransactionsRequest { + repeated CommittedTransaction transactions = 1; + // agreed monotonic timestamp microseconds since the epoch for a committed block + // used by Mempool to GC expired transactions + uint64 block_timestamp_usecs = 2; +} + +message CommitTransactionsResponse {} + +message CommittedTransaction { + bytes sender = 1; + uint64 sequence_number = 2; + bool is_rejected = 3; +} + +// ----------------------------------------------------------------------------- +// ---------------- HealthCheck +// ----------------------------------------------------------------------------- +message HealthCheckRequest { +} + +message HealthCheckResponse { + // Indicate whether Mempool is in healthy condition. + bool is_healthy = 1; +} diff --git a/mempool/src/proto/mod.rs b/mempool/src/proto/mod.rs new file mode 100644 index 0000000000000..12eabee9803e8 --- /dev/null +++ b/mempool/src/proto/mod.rs @@ -0,0 +1,11 @@ +// Copyright (c) The Libra Core Contributors +// SPDX-License-Identifier: Apache-2.0 + +#![allow(missing_docs)] +use crate::proto::shared::*; +use types::proto::*; + +pub mod mempool; +pub mod mempool_client; +pub mod mempool_grpc; +pub mod shared; diff --git a/mempool/src/proto/shared/mempool_status.proto b/mempool/src/proto/shared/mempool_status.proto new file mode 100644 index 0000000000000..040503980c901 --- /dev/null +++ b/mempool/src/proto/shared/mempool_status.proto @@ -0,0 +1,21 @@ +// Copyright (c) The Libra Core Contributors +// SPDX-License-Identifier: Apache-2.0 + +syntax = "proto3"; + +package mempool; + +enum MempoolAddTransactionStatus { + // Transaction was sent to Mempool + Valid = 0; + // The sender does not have enough balance for the transaction. + InsufficientBalance = 1; + // Sequence number is old, etc. + InvalidSeqNumber = 2; + // Mempool is full (reached max global capacity) + MempoolIsFull = 3; + // Account reached max capacity per account + TooManyTransactions = 4; + // Invalid update. Only gas price increase is allowed + InvalidUpdate = 5; +} diff --git a/mempool/src/proto/shared/mod.rs b/mempool/src/proto/shared/mod.rs new file mode 100644 index 0000000000000..e10755b81bd41 --- /dev/null +++ b/mempool/src/proto/shared/mod.rs @@ -0,0 +1,4 @@ +// Copyright (c) The Libra Core Contributors +// SPDX-License-Identifier: Apache-2.0 + +pub mod mempool_status; diff --git a/mempool/src/runtime.rs b/mempool/src/runtime.rs new file mode 100644 index 0000000000000..1ea6eb89ae31e --- /dev/null +++ b/mempool/src/runtime.rs @@ -0,0 +1,80 @@ +// Copyright (c) The Libra Core Contributors +// SPDX-License-Identifier: Apache-2.0 + +use crate::{ + core_mempool::CoreMempool, mempool_service::MempoolService, proto::mempool_grpc, + shared_mempool::start_shared_mempool, +}; +use config::config::NodeConfig; +use grpc_helpers::ServerHandle; +use grpcio::EnvBuilder; +use grpcio_sys; +use network::validator_network::{MempoolNetworkEvents, MempoolNetworkSender}; +use std::{ + cmp::max, + sync::{Arc, Mutex}, +}; +use storage_client::{StorageRead, StorageReadServiceClient}; +use tokio::runtime::Runtime; +use vm_validator::vm_validator::VMValidator; + +/// Handle for Mempool Runtime +pub struct MempoolRuntime { + /// gRPC server to serve request from AC and Consensus + pub grpc_server: ServerHandle, + /// separate shared mempool runtime + pub shared_mempool: Runtime, +} + +impl MempoolRuntime { + /// setup Mempool runtime + pub fn boostrap( + config: &NodeConfig, + network_sender: MempoolNetworkSender, + network_events: MempoolNetworkEvents, + ) -> Self { + let mempool = Arc::new(Mutex::new(CoreMempool::new(&config))); + + // setup grpc server + let env = Arc::new( + EnvBuilder::new() + .name_prefix("grpc-mempool-") + .cq_count(unsafe { max(grpcio_sys::gpr_cpu_num_cores() as usize / 2, 2) }) + .build(), + ); + let handle = MempoolService { + core_mempool: Arc::clone(&mempool), + }; + let service = mempool_grpc::create_mempool(handle); + let grpc_server = ::grpcio::ServerBuilder::new(env) + .register_service(service) + .bind( + config.mempool.address.clone(), + config.mempool.mempool_service_port, + ) + .build() + .expect("[mempool] unable to create grpc server"); + + // setup shared mempool + let storage_client: Arc = Arc::new(StorageReadServiceClient::new( + Arc::new(EnvBuilder::new().name_prefix("grpc-mem-sto-").build()), + "localhost", + config.storage.port, + )); + let vm_validator = Arc::new(VMValidator::new(&config, Arc::clone(&storage_client))); + let shared_mempool = start_shared_mempool( + config, + mempool, + network_sender, + network_events, + storage_client, + vm_validator, + vec![], + None, + ); + Self { + grpc_server: ServerHandle::setup(grpc_server), + shared_mempool, + } + } +} diff --git a/mempool/src/shared_mempool.rs b/mempool/src/shared_mempool.rs new file mode 100644 index 0000000000000..540498396bfd4 --- /dev/null +++ b/mempool/src/shared_mempool.rs @@ -0,0 +1,450 @@ +// Copyright (c) The Libra Core Contributors +// SPDX-License-Identifier: Apache-2.0 + +use crate::{ + core_mempool::{CoreMempool, MempoolAddTransactionStatus, TimelineState}, + OP_COUNTERS, +}; +use config::config::{MempoolConfig, NodeConfig}; +use failure::prelude::*; +use futures::sync::mpsc::UnboundedSender; +use futures_preview::{ + compat::{Future01CompatExt, Stream01CompatExt}, + future::{self, join_all}, + FutureExt, Stream, StreamExt, TryFutureExt, TryStreamExt, +}; +use logger::prelude::*; +use network::{ + proto::MempoolSyncMsg, + validator_network::{Event, MempoolNetworkEvents, MempoolNetworkSender}, +}; +use proto_conv::{FromProto, IntoProto}; +use std::{ + collections::HashMap, + ops::Deref, + pin::Pin, + sync::{Arc, Mutex}, + time::Duration, +}; +use storage_client::StorageRead; +use tokio::{ + runtime::{Builder, Runtime}, + timer::Interval, +}; +use types::{transaction::SignedTransaction, PeerId}; +use vm_validator::vm_validator::{get_account_state, TransactionValidation}; + +/// state of last sync with peer +/// `timeline_id` is position in log of ready transactions +/// `is_alive` - is connection healthy +#[derive(Clone)] +struct PeerSyncState { + timeline_id: u64, + is_alive: bool, +} + +type PeerInfo = HashMap; + +/// Outbound peer syncing event emitted by [`IntervalStream`]. +#[derive(Debug)] +pub(crate) struct SyncEvent; + +type IntervalStream = Pin> + Send + 'static>>; + +#[derive(Copy, Clone, Debug, PartialEq)] +pub enum SharedMempoolNotification { + Sync, + PeerStateChange, + NewTransactions, +} + +/// Struct that owns all dependencies required by shared mempool routines +struct SharedMempool +where + V: TransactionValidation + 'static, +{ + mempool: Arc>, + network_sender: MempoolNetworkSender, + config: MempoolConfig, + storage_read_client: Arc, + validator: Arc, + peer_info: Arc>, + subscribers: Vec>, +} + +// TODO(gzh): Cannot derive `Clone`. +// Issue: https://github.com/rust-lang/rust/issues/26925 +impl Clone for SharedMempool +where + V: TransactionValidation + 'static, +{ + fn clone(&self) -> Self { + Self { + mempool: Arc::clone(&self.mempool), + network_sender: self.network_sender.clone(), + config: self.config.clone(), + storage_read_client: Arc::clone(&self.storage_read_client), + validator: Arc::clone(&self.validator), + peer_info: self.peer_info.clone(), + subscribers: self.subscribers.clone(), + } + } +} + +fn notify_subscribers( + event: SharedMempoolNotification, + subscribers: &[UnboundedSender], +) { + for subscriber in subscribers { + let _ = subscriber.unbounded_send(event); + } +} + +fn default_timer(tick_ms: u64) -> IntervalStream { + Interval::new_interval(Duration::from_millis(tick_ms)) + .compat() + .map_ok(|_| SyncEvent) + .map_err(|_| format_err!("[shared mempool] timer tick error")) + .boxed() +} + +/// new peer discovery handler +/// adds new entry to `peer_info` +fn new_peer(peer_info: &Mutex, peer_id: PeerId) { + peer_info + .lock() + .expect("[shared mempool] failed to acquire peer_info lock") + .entry(peer_id) + .or_insert(PeerSyncState { + timeline_id: 0, + is_alive: true, + }) + .is_alive = true; +} + +/// lost peer handler. Marks connection as dead +fn lost_peer(peer_info: &Mutex, peer_id: PeerId) { + if let Some(state) = peer_info + .lock() + .expect("[shared mempool] failed to acquire peer_info lock") + .get_mut(&peer_id) + { + state.is_alive = false; + } +} + +/// sync routine +/// used to periodically broadcast ready to go transactions to peers +async fn sync_with_peers<'a>( + peer_info: &'a Mutex, + mempool: &'a Mutex, + network_sender: &'a mut MempoolNetworkSender, + batch_size: usize, +) { + // Clone the underlying peer_info map and use this to sync and collect + // state updates. We do this instead of holding the lock for the whole + // function since that would hold the lock across await points which is bad. + let peer_info_copy = peer_info + .lock() + .expect("[shared mempool] failed to acquire peer_info lock") + .deref() + .clone(); + + let mut state_updates = vec![]; + + for (peer_id, peer_state) in peer_info_copy.into_iter() { + if peer_state.is_alive { + let timeline_id = peer_state.timeline_id; + + let (transactions, new_timeline_id) = mempool + .lock() + .expect("[shared mempool] failed to acquire mempool lock") + .read_timeline(timeline_id, batch_size); + + if !transactions.is_empty() { + OP_COUNTERS.inc_by("smp.sync_with_peers", transactions.len()); + let mut msg = MempoolSyncMsg::new(); + msg.set_peer_id(peer_id.into()); + msg.set_transactions( + transactions + .into_iter() + .map(IntoProto::into_proto) + .collect(), + ); + + debug!( + "MempoolNetworkSender.send_to peer {} msg {:?}", + peer_id, msg + ); + // Since this is a direct-send, this will only error if the network + // module has unexpectedly crashed or shutdown. + network_sender + .send_to(peer_id, msg) + .await + .expect("[shared mempool] failed to direct-send mempool sync message"); + } + + state_updates.push((peer_id, new_timeline_id)); + } + } + + // Lock the shared peer_info and apply state updates. + let mut peer_info = peer_info + .lock() + .expect("[shared mempool] failed to acquire peer_info lock"); + for (peer_id, new_timeline_id) in state_updates { + peer_info + .entry(peer_id) + .and_modify(|t| t.timeline_id = new_timeline_id); + } +} + +/// used to validate incoming transactions and add them to local Mempool +async fn process_incoming_transactions( + smp: SharedMempool, + peer_id: PeerId, + transactions: Vec, +) where + V: TransactionValidation, +{ + let validations = join_all( + transactions + .iter() + .map(|t| smp.validator.validate_transaction(t.clone()).compat()), + ) + .await; + + let account_states = join_all( + transactions + .iter() + .map(|t| get_account_state(smp.storage_read_client.clone(), t.sender())), + ) + .await; + + let mut mempool = smp + .mempool + .lock() + .expect("[shared mempool] failed to acquire mempool lock"); + + for (idx, transaction) in transactions.into_iter().enumerate() { + if let Ok(None) = validations[idx] { + if let Ok((sequence_number, balance)) = account_states[idx] { + let gas_cost = transaction.max_gas_amount(); + let insertion_result = mempool.add_txn( + transaction, + gas_cost, + sequence_number, + balance, + TimelineState::NonQualified, + ); + if insertion_result == MempoolAddTransactionStatus::Valid { + OP_COUNTERS.inc(&format!("smp.transactions.success.{:?}", peer_id)); + } + } + } + } + notify_subscribers(SharedMempoolNotification::NewTransactions, &smp.subscribers); +} + +/// This task handles [`SyncEvent`], which is periodically emitted for us to +/// broadcast ready to go transactions to peers. +async fn outbound_sync_task(smp: SharedMempool, mut interval: IntervalStream) +where + V: TransactionValidation, +{ + let peer_info = smp.peer_info; + let mempool = smp.mempool; + let mut network_sender = smp.network_sender; + let batch_size = smp.config.shared_mempool_batch_size; + let subscribers = smp.subscribers; + + while let Some(sync_event) = interval.next().await { + trace!("SyncEvent: {:?}", sync_event); + match sync_event { + Ok(_) => { + sync_with_peers(&peer_info, &mempool, &mut network_sender, batch_size).await; + notify_subscribers(SharedMempoolNotification::Sync, &subscribers); + } + Err(e) => { + error!("Error in outbound_sync_task timer interval: {:?}", e); + break; + } + } + } + + crit!("SharedMempool outbound_sync_task terminated"); +} + +/// This task handles inbound network events. +async fn inbound_network_task(smp: SharedMempool, network_events: MempoolNetworkEvents) +where + V: TransactionValidation, +{ + let peer_info = smp.peer_info.clone(); + let subscribers = smp.subscribers.clone(); + let max_inbound_syncs = smp.config.shared_mempool_max_concurrent_inbound_syncs; + + // Handle the NewPeer/LostPeer events immediatedly, since they are not async + // and we don't want to buffer them or let them get reordered. The inbound + // direct-send messages are placed in a bounded FuturesUnordered queue and + // allowed to execute concurrently. The .buffer_unordered() also correctly + // handles back-pressure, so if mempool is slow the back-pressure will + // propagate down to network. + let f_inbound_network_task = network_events + .filter_map(move |network_event| { + trace!("SharedMempoolEvent::NetworkEvent::{:?}", network_event); + match network_event { + Ok(network_event) => match network_event { + Event::NewPeer(peer_id) => { + OP_COUNTERS.inc("smp.event.new_peer"); + new_peer(&peer_info, peer_id); + notify_subscribers( + SharedMempoolNotification::PeerStateChange, + &subscribers, + ); + future::ready(None) + } + Event::LostPeer(peer_id) => { + OP_COUNTERS.inc("smp.event.lost_peer"); + lost_peer(&peer_info, peer_id); + notify_subscribers( + SharedMempoolNotification::PeerStateChange, + &subscribers, + ); + future::ready(None) + } + // Pass through messages to next combinator + Event::Message((peer_id, msg)) => future::ready(Some((peer_id, msg))), + _ => { + security_log(SecurityEvent::InvalidNetworkEventMP) + .error("UnexpectedNetworkEvent") + .data(&network_event) + .log(); + unreachable!("Unexpected network event") + } + }, + Err(e) => { + security_log(SecurityEvent::InvalidNetworkEventMP) + .error(&e) + .log(); + future::ready(None) + } + } + }) + // Run max_inbound_syncs number of `process_incoming_transactions` concurrently + .for_each_concurrent( + max_inbound_syncs, /* limit */ + move |(peer_id, mut msg)| { + OP_COUNTERS.inc("smp.event.message"); + let transactions: Vec<_> = msg + .take_transactions() + .into_iter() + .filter_map(|txn| match SignedTransaction::from_proto(txn) { + Ok(t) => Some(t), + Err(e) => { + security_log(SecurityEvent::InvalidTransactionMP) + .error(&e) + .data(&msg) + .log(); + None + } + }) + .collect(); + OP_COUNTERS.inc_by( + &format!("smp.transactions.received.{:?}", peer_id), + transactions.len(), + ); + + process_incoming_transactions(smp.clone(), peer_id, transactions) + }, + ); + + // drive the inbound futures to completion + f_inbound_network_task.await; + + crit!("SharedMempool inbound_network_task terminated"); +} + +/// GC all expired transactions by SystemTTL +async fn gc_task(mempool: Arc>, gc_interval_ms: u64) { + let mut interval = Interval::new_interval(Duration::from_millis(gc_interval_ms)).compat(); + while let Some(res) = interval.next().await { + match res { + Ok(_) => { + mempool + .lock() + .expect("[shared mempool] failed to acquire mempool lock") + .gc_by_system_ttl(); + } + Err(e) => { + error!("Error in gc_task timer interval: {:?}", e); + break; + } + } + } + + crit!("SharedMempool gc_task terminated"); +} + +/// bootstrap of SharedMempool +/// creates separate Tokio Runtime that runs following routines: +/// - outbound_sync_task (task that periodically broadcasts transactions to peers) +/// - inbound_network_task (task that handles inbound mempool messages and network events) +/// - gc_task (task that performs GC of all expired transactions by SystemTTL) +pub(crate) fn start_shared_mempool( + config: &NodeConfig, + mempool: Arc>, + network_sender: MempoolNetworkSender, + network_events: MempoolNetworkEvents, + storage_read_client: Arc, + validator: Arc, + subscribers: Vec>, + timer: Option, +) -> Runtime +where + V: TransactionValidation + 'static, +{ + let runtime = Builder::new() + .name_prefix("shared-mem-") + .build() + .expect("[shared mempool] failed to create runtime"); + let executor = runtime.executor(); + + let peer_info = Arc::new(Mutex::new(PeerInfo::new())); + + let smp = SharedMempool { + mempool: mempool.clone(), + config: config.mempool.clone(), + network_sender, + storage_read_client, + validator, + peer_info, + subscribers, + }; + + let interval = + timer.unwrap_or_else(|| default_timer(config.mempool.shared_mempool_tick_interval_ms)); + + executor.spawn( + outbound_sync_task(smp.clone(), interval) + .boxed() + .unit_error() + .compat(), + ); + + executor.spawn( + inbound_network_task(smp, network_events) + .boxed() + .unit_error() + .compat(), + ); + + executor.spawn( + gc_task(mempool, config.mempool.system_transaction_gc_interval_ms) + .boxed() + .unit_error() + .compat(), + ); + + runtime +} diff --git a/mempool/src/unit_tests/mod.rs b/mempool/src/unit_tests/mod.rs new file mode 100644 index 0000000000000..f46e1711f7b01 --- /dev/null +++ b/mempool/src/unit_tests/mod.rs @@ -0,0 +1,4 @@ +// Copyright (c) The Libra Core Contributors +// SPDX-License-Identifier: Apache-2.0 + +mod service_test; diff --git a/mempool/src/unit_tests/service_test.rs b/mempool/src/unit_tests/service_test.rs new file mode 100644 index 0000000000000..451a5be0e13d6 --- /dev/null +++ b/mempool/src/unit_tests/service_test.rs @@ -0,0 +1,155 @@ +// Copyright (c) The Libra Core Contributors +// SPDX-License-Identifier: Apache-2.0 + +use crate::{ + core_mempool::CoreMempool, + mempool_service::MempoolService, + proto::{ + mempool::*, + mempool_grpc::{self, *}, + shared::mempool_status::*, + }, +}; +use config::config::NodeConfigHelpers; +use crypto::signing::generate_keypair; +use grpc_helpers::ServerHandle; +use grpcio::{ChannelBuilder, EnvBuilder}; +use proto_conv::FromProto; +use std::{ + sync::{Arc, Mutex}, + time::Duration, +}; +use types::{ + account_address::AccountAddress, + test_helpers::transaction_test_helpers::get_test_signed_transaction, + transaction::SignedTransaction, +}; + +fn setup_mempool() -> (::grpcio::Server, MempoolClient) { + let node_config = NodeConfigHelpers::get_single_node_test_config(true); + + let env = Arc::new(EnvBuilder::new().build()); + let core_mempool = Arc::new(Mutex::new(CoreMempool::new(&node_config))); + let handle = MempoolService { core_mempool }; + let service = mempool_grpc::create_mempool(handle); + + let server = ::grpcio::ServerBuilder::new(env.clone()) + .register_service(service) + .bind("localhost", 0) + .build() + .expect("Unable to create grpc server"); + let (_, port) = server.bind_addrs()[0]; + let connection_str = format!("localhost:{}", port); + let client = MempoolClient::new(ChannelBuilder::new(env).connect(&connection_str)); + (server, client) +} + +fn create_add_transaction_request(expiration_time: u64) -> AddTransactionWithValidationRequest { + let mut req = AddTransactionWithValidationRequest::new(); + let sender = AccountAddress::random(); + let (private_key, public_key) = generate_keypair(); + + let transaction = get_test_signed_transaction( + sender, + 0, + private_key, + public_key, + None, + expiration_time, + 1, + None, + ); + req.set_signed_txn(transaction.clone()); + req.set_max_gas_cost(10); + req.set_account_balance(1000); + req +} + +#[test] +fn test_add_transaction() { + let (server, client) = setup_mempool(); + let _handle = ServerHandle::setup(server); + // create request + let mut req = create_add_transaction_request(0); + req.set_account_balance(100); + let response = client.add_transaction_with_validation(&req).unwrap(); + // check status + assert_eq!(response.get_status(), MempoolAddTransactionStatus::Valid); +} + +#[test] +fn test_get_block() { + let (server, client) = setup_mempool(); + let _handle = ServerHandle::setup(server); + + // add transaction to mempool + let mut req = create_add_transaction_request(0); + req.set_account_balance(100); + client.add_transaction_with_validation(&req).unwrap(); + + // get next block + let response = client.get_block(&GetBlockRequest::new()).unwrap(); + let block = response.get_block(); + assert_eq!(block.get_transactions().len(), 1); + assert_eq!( + block.get_transactions()[0].raw_txn_bytes, + req.get_signed_txn().raw_txn_bytes + ); +} + +#[test] +fn test_consensus_callbacks() { + let (server, client) = setup_mempool(); + let _handle = ServerHandle::setup(server); + + // add transaction + let add_req = create_add_transaction_request(0); + client.add_transaction_with_validation(&add_req).unwrap(); + + let mut response = client.get_block(&GetBlockRequest::new()).unwrap(); + assert_eq!(response.get_block().get_transactions().len(), 1); + + // remove: transaction is commited + let mut transaction = CommittedTransaction::new(); + let signed_txn = SignedTransaction::from_proto(add_req.get_signed_txn().clone()).unwrap(); + let sender = signed_txn.sender().as_ref().to_vec(); + transaction.set_sender(sender); + transaction.set_sequence_number(0); + + let mut req = CommitTransactionsRequest::new(); + req.set_transactions(::protobuf::RepeatedField::from_vec(vec![transaction])); + client.commit_transactions(&req).unwrap(); + response = client.get_block(&GetBlockRequest::new()).unwrap(); + assert!(response.get_block().get_transactions().is_empty()); +} + +#[test] +fn test_gc_by_expiration_time() { + let (server, client) = setup_mempool(); + let _handle = ServerHandle::setup(server); + + // add transaction with expiration time 1 + let add_req = create_add_transaction_request(1); + client.add_transaction_with_validation(&add_req).unwrap(); + + // commit empty block with block_time 2 + let mut req = CommitTransactionsRequest::new(); + req.set_block_timestamp_usecs(Duration::from_secs(2).as_micros() as u64); + client.commit_transactions(&req).unwrap(); + + // verify that transaction is evicted from Mempool + let response = client.get_block(&GetBlockRequest::new()).unwrap(); + assert!(response.get_block().get_transactions().is_empty()); + + // add transaction with expiration time 3 + let add_req = create_add_transaction_request(3); + client.add_transaction_with_validation(&add_req).unwrap(); + // commit empty block with block_time 3 + let mut req = CommitTransactionsRequest::new(); + req.set_block_timestamp_usecs(Duration::from_secs(3).as_micros() as u64); + client.commit_transactions(&req).unwrap(); + + // verify that transaction is still in Mempool + let response = client.get_block(&GetBlockRequest::new()).unwrap(); + assert_eq!(response.get_block().get_transactions().len(), 1); +} diff --git a/network/Cargo.toml b/network/Cargo.toml new file mode 100644 index 0000000000000..258a0c3c7310c --- /dev/null +++ b/network/Cargo.toml @@ -0,0 +1,44 @@ +[package] +name = "network" +version = "0.1.0" +authors = ["Libra Association "] +license = "Apache-2.0" +build = "build.rs" +publish = false +edition = "2018" + +[dependencies] +bytes = "0.4.12" +futures = { version = "=0.3.0-alpha.16", package = "futures-preview", features = ["async-await", "nightly", "io-compat", "compat"] } +lazy_static = "1.3.0" +parity-multiaddr = "0.4.0" +pin-utils = "=0.1.0-alpha.4" +protobuf = { version = "2.6", features = ["with-bytes"] } +rand = "0.6.5" +tokio = "0.1.16" +tokio-timer = "0.2.10" +unsigned-varint = { version = "0.2.2", features = ["codec"] } + +channel = { path = "../common/channel" } +crypto = { path = "../crypto/legacy_crypto" } +failure = { package = "failure_ext", path = "../common/failure_ext" } +logger = { path = "../common/logger" } +memsocket = { path = "memsocket" } +metrics = { path = "../common/metrics" } +netcore = { path = "netcore" } +noise = { path = "noise" } +types = { path = "../types" } + +[dev-dependencies] +criterion = { version = "0.2.11", features = ["real_blackbox"] } + +[build-dependencies] +protoc-rust = "2.5.0" + +[[bench]] +name = "socket_muxer_bench" +harness = false + +[[bench]] +name = "network_bench" +harness = false diff --git a/network/README.md b/network/README.md new file mode 100644 index 0000000000000..23ad271ff2a42 --- /dev/null +++ b/network/README.md @@ -0,0 +1,129 @@ +--- +id: network +title: Network +custom_edit_url: https://github.com/libra/libra/edit/master/network/README.md +--- + +# Network + +The network component provides peer-to-peer communication primitives to other +components of a validator. + +## Overview + +The network component is specifically designed to facilitate the consensus and +shared mempool protocols. Currently, it provides these consumers with two +primary interfaces: +* RPC, for Remote Procedure Calls; and +* DirectSend, for fire-and-forget style message delivery to a single receiver. + +The network component uses: +* [Multiaddr](https://multiformats.io/multiaddr/) scheme for peer addressing. +* TCP for reliable transport. +* [Noise](https://noiseprotocol.org/noise.html) for authentication and full + end-to-end encryption. +* [Yamux](https://github.com/hashicorp/yamux/blob/master/spec.md) for +multiplexing substreams over a single connection; and +* Push-style [gossip](https://en.wikipedia.org/wiki/Gossip_protocol) for peer +discovery. + +Each new substream is assigned a *protocol* supported by both the sender and +the receiver. Each RPC and DirectSend type corresponds to one such protocol. + +Only eligible members are allowed to join the inter-validator network. Their +identity and public key information is provided by the consensus +component at initialization and on updates to system membership. A new +validator also needs the network addresses of a few *seed* peers to help it +bootstrap connectivity to the network. The seed peers first authenticate the +joining validator as an eligible member and then share their network state +with it. + +Each member of the network maintains a full membership view and connects +directly to any validator it needs to communicate with. A validator that cannot +be connected to directly is assumed to fall in the quota of Byzantine faults +tolerated by the system. + +Validator health information, determined using periodic liveness probes, is not +shared between validators; instead, each validator directly monitors its peers +for liveness. + +This approach should scale up to a few hundred validators before requiring +partial membership views, sophisticated failure detectors, or network overlays. + +## Implementation Details + +### System Architecture + + +---------------------+---------------------+ + | Consensus | Mempool | + +---------------------+---------------------+ + | Validator Network | + +---------------------+---------------------+ + | NetworkProvider | + +------------------------------------------------+-----------------+ | + | Discovery, health, etc | RPC | DirectSend | | + +--------------+---------------------------------------------------------+ + | Peer Manager | + +------------------------------------------------------------------+-----+ + +The network component is implemented in the +[Actor](https://en.wikipedia.org/wiki/Actor_model) programming model — +i.e., it uses message-passing to communicate between different subcomponents +running as independent "tasks." The [tokio](https://tokio.rs/) framework is +used as the task runtime. The different subcomponents in the network component +are: + +* **NetworkProvider** — Exposes network API to clients. It forwards +requests from upstream clients to appropriate downstream components and sends +incoming RPC and DirectSend requests to appropriate upstream handlers. +* **Peer Manager** — Listens for incoming connections and dials other +peers on the network. It also notifies other components about new/lost +connection events and demultiplexes incoming substreams to appropriate protocol +handlers. +* **Connectivity Manager** — Ensures that we remain connected to a node +if and only if it is an eligible member of the network. Connectivity Manager +receives addresses of peers from the Discovery component and issues +dial/disconnect requests to the Peer Manager. +* **Discovery** — Uses push-style gossip for discovering new peers and +updates to addresses of existing peers. On every *tick*, it opens a new +substream with a randomly selected peer and sends its view of the network to +this peer. It informs the connectivity manager of any changes to the network +detected from inbound discovery messages. +* **Health Checker** — Performs periodic liveness probes to ensure the +health of a peer/connection. It resets the connection with the peer if a +configurable number of probes fail in succession. Probes currently fail on a +configurable static timeout. +* **Direct Send** — Allows sending/receiving messages to/from remote +peers. It notifies upstream handlers of inbound messages. +* **RPC** — Allows sending/receiving RPCs to/from other peers. It notifies +upstream handlers about inbound RPCs. The upstream handler is passed a channel +through which can send a serialized response to the caller. + +In addition to the subcomponents described above, the network component +consists of utilities to perform encryption, transport multiplexing, protocol +negotiation, etc. + +## How is this module organized? + + network + β”œβ”€β”€ benches # network benchmarks + β”œβ”€β”€ memsocket # In-memory transport for tests + β”œβ”€β”€ netcore + β”‚Β Β  └── src + β”‚Β Β  β”œβ”€β”€ multiplexing # substream multiplexing over a transport + β”‚Β Β  β”œβ”€β”€ negotiate # protocol negotiation + β”‚Β Β  └── transport # composable transport API + β”œβ”€β”€ noise # noise framework for authentication and encryption + └── src + β”œβ”€β”€ channel # mpsc channel wrapped in IntGauge + β”œβ”€β”€ connectivity_manager # component to ensure connectivity to peers + β”œβ”€β”€ interface # generic network API + β”œβ”€β”€ peer_manager # component to dial/listen for connections + β”œβ”€β”€ proto # protobuf definitions for network messages + β”œβ”€β”€ protocols # message protocols + β”‚Β Β  β”œβ”€β”€ direct_send # protocol for fire-and-forget style message delivery + β”‚Β Β  β”œβ”€β”€ discovery # protocol for peer discovery and gossip + β”‚Β Β  β”œβ”€β”€ health_checker # protocol for health probing + β”‚Β Β  └── rpc # protocol for remote procedure calls + β”œβ”€β”€ sink # utilities over message sinks + └── validator_network # network API for consensus and mempool diff --git a/network/benches/network_bench.rs b/network/benches/network_bench.rs new file mode 100644 index 0000000000000..26bca23c3eac2 --- /dev/null +++ b/network/benches/network_bench.rs @@ -0,0 +1,327 @@ +// Copyright (c) The Libra Core Contributors +// SPDX-License-Identifier: Apache-2.0 + +#![feature(async_await)] +// Allow KiB, MiB consts +#![allow(non_upper_case_globals, non_snake_case)] +// Allow fns to take &usize, since criterion only passes parameters by ref +#![allow(clippy::trivially_copy_pass_by_ref)] +// Allow writing 1 * KiB or 1 * MiB +#![allow(clippy::identity_op)] + +use bytes::Bytes; +use core::str::FromStr; +use criterion::{ + criterion_group, criterion_main, AxisScale, Bencher, Criterion, ParameterizedBenchmark, + PlotConfiguration, Throughput, +}; +use crypto::{signing, x25519}; +use futures::{ + channel::mpsc, + compat::Future01CompatExt, + executor::block_on, + future::{FutureExt, TryFutureExt}, + sink::SinkExt, + stream::{FuturesUnordered, StreamExt}, +}; +use network::{ + proto::{Block, ConsensusMsg, RequestBlock, RespondBlock}, + protocols::rpc::error::RpcError, + validator_network::{ + network_builder::{NetworkBuilder, TransportType}, + ConsensusNetworkSender, Event, CONSENSUS_DIRECT_SEND_PROTOCOL, CONSENSUS_RPC_PROTOCOL, + }, + NetworkPublicKeys, ProtocolId, +}; +use parity_multiaddr::Multiaddr; +use protobuf::Message; +use std::{collections::HashMap, time::Duration}; +use tokio::runtime::Runtime; +use types::PeerId; + +const KiB: usize = 1 << 10; +const MiB: usize = 1 << 20; +const NUM_MSGS: u32 = 100; +const TOLERANCE: u32 = 5; +const HOUR_IN_MS: u64 = 60 * 60 * 1000; + +fn direct_send_bench(b: &mut Bencher, msg_len: &usize) { + let mut runtime = Runtime::new().unwrap(); + let (dialer_peer_id, dialer_addr) = ( + PeerId::random(), + Multiaddr::from_str("/ip4/127.0.0.1/tcp/0").unwrap(), + ); + let (listener_peer_id, listener_addr) = ( + PeerId::random(), + Multiaddr::from_str("/ip4/127.0.0.1/tcp/0").unwrap(), + ); + + // Setup keys for dialer. + let (dialer_signing_private_key, dialer_signing_public_key) = signing::generate_keypair(); + let (dialer_identity_private_key, dialer_identity_public_key) = x25519::generate_keypair(); + + // Setup keys for listener. + let (listener_signing_private_key, listener_signing_public_key) = signing::generate_keypair(); + let (listener_identity_private_key, listener_identity_public_key) = x25519::generate_keypair(); + + // Setup trusted peers. + let trusted_peers: HashMap<_, _> = vec![ + ( + dialer_peer_id, + NetworkPublicKeys { + signing_public_key: dialer_signing_public_key, + identity_public_key: dialer_identity_public_key, + }, + ), + ( + listener_peer_id, + NetworkPublicKeys { + signing_public_key: listener_signing_public_key, + identity_public_key: listener_identity_public_key, + }, + ), + ] + .into_iter() + .collect(); + + // Set up the listener network + let ((_, _), (_listener_sender, mut listener_events), listen_addr) = + NetworkBuilder::new(runtime.executor(), listener_peer_id, listener_addr) + .transport(TransportType::TcpNoise) + .trusted_peers(trusted_peers.clone()) + .identity_keys((listener_identity_private_key, listener_identity_public_key)) + .signing_keys((listener_signing_private_key, listener_signing_public_key)) + .discovery_interval_ms(HOUR_IN_MS) + .consensus_protocols(vec![ProtocolId::from_static( + CONSENSUS_DIRECT_SEND_PROTOCOL, + )]) + .direct_send_protocols(vec![ProtocolId::from_static( + CONSENSUS_DIRECT_SEND_PROTOCOL, + )]) + .build(); + + // Set up the dialer network + let ((_, _), (mut dialer_sender, mut dialer_events), _) = + NetworkBuilder::new(runtime.executor(), dialer_peer_id, dialer_addr) + .transport(TransportType::TcpNoise) + .trusted_peers(trusted_peers.clone()) + .identity_keys((dialer_identity_private_key, dialer_identity_public_key)) + .signing_keys((dialer_signing_private_key, dialer_signing_public_key)) + .seed_peers( + [(listener_peer_id, vec![listen_addr])] + .iter() + .cloned() + .collect(), + ) + .discovery_interval_ms(HOUR_IN_MS) + .consensus_protocols(vec![ProtocolId::from_static( + CONSENSUS_DIRECT_SEND_PROTOCOL, + )]) + .direct_send_protocols(vec![ProtocolId::from_static( + CONSENSUS_DIRECT_SEND_PROTOCOL, + )]) + .build(); + + // Wait for establishing connection + let first_dialer_event = block_on(dialer_events.next()).unwrap().unwrap(); + assert_eq!(first_dialer_event, Event::NewPeer(listener_peer_id)); + let first_listener_event = block_on(listener_events.next()).unwrap().unwrap(); + assert_eq!(first_listener_event, Event::NewPeer(dialer_peer_id)); + + // Compose Proposal message with `msg_len` bytes payload + let msg = compose_proposal(*msg_len); + + let (mut tx, mut rx) = mpsc::channel(0); + // The listener side keeps receiving messages and send signal back to the bencher to finishh + // the iteration once NUM_MSGS messages are received. + let f_listener = async move { + let mut counter = 0u32; + while let Some(_) = listener_events.next().await { + counter += 1; + // By the nature of DirectSend protocol, some messages may be lost when a connection is + // broken temporarily. + if counter == NUM_MSGS - TOLERANCE { + tx.send(()).await.unwrap(); + counter = 0; + } + } + }; + runtime.spawn(f_listener.boxed().unit_error().compat()); + + // The dialer side keeps sending messages. In each iteration of the benchmark, it sends + // NUM_MSGS messages and wait until the listener side sends signal back. + b.iter(|| { + for _ in 0..NUM_MSGS { + block_on(dialer_sender.send_to(listener_peer_id, msg.clone())).unwrap(); + } + block_on(rx.next()).unwrap(); + }); + block_on(runtime.shutdown_now().compat()).unwrap(); +} + +fn compose_proposal(msg_len: usize) -> ConsensusMsg { + let mut msg = ConsensusMsg::new(); + let proposal = msg.mut_proposal(); + proposal.set_proposer(PeerId::random().into()); + let block = proposal.mut_proposed_block(); + block.set_payload(vec![0u8; msg_len].into()); + msg +} + +fn rpc_bench(b: &mut Bencher, msg_len: &usize) { + let mut runtime = Runtime::new().unwrap(); + let (dialer_peer_id, dialer_addr) = ( + PeerId::random(), + Multiaddr::from_str("/ip4/127.0.0.1/tcp/0").unwrap(), + ); + let (listener_peer_id, listener_addr) = ( + PeerId::random(), + Multiaddr::from_str("/ip4/127.0.0.1/tcp/0").unwrap(), + ); + + // Setup keys for dialer. + let (dialer_signing_private_key, dialer_signing_public_key) = signing::generate_keypair(); + let (dialer_identity_private_key, dialer_identity_public_key) = x25519::generate_keypair(); + + // Setup keys for listener. + let (listener_signing_private_key, listener_signing_public_key) = signing::generate_keypair(); + let (listener_identity_private_key, listener_identity_public_key) = x25519::generate_keypair(); + + // Setup trusted peers. + let trusted_peers: HashMap<_, _> = vec![ + ( + dialer_peer_id, + NetworkPublicKeys { + signing_public_key: dialer_signing_public_key, + identity_public_key: dialer_identity_public_key, + }, + ), + ( + listener_peer_id, + NetworkPublicKeys { + signing_public_key: listener_signing_public_key, + identity_public_key: listener_identity_public_key, + }, + ), + ] + .into_iter() + .collect(); + + // Set up the listener network + let ((_, _), (_listener_sender, mut listener_events), listen_addr) = + NetworkBuilder::new(runtime.executor(), listener_peer_id, listener_addr) + .transport(TransportType::TcpNoise) + .trusted_peers(trusted_peers.clone()) + .identity_keys((listener_identity_private_key, listener_identity_public_key)) + .signing_keys((listener_signing_private_key, listener_signing_public_key)) + .discovery_interval_ms(HOUR_IN_MS) + .consensus_protocols(vec![ProtocolId::from_static(CONSENSUS_RPC_PROTOCOL)]) + .rpc_protocols(vec![ProtocolId::from_static(CONSENSUS_RPC_PROTOCOL)]) + .build(); + + // Set up the dialer network + let ((_, _), (dialer_sender, mut dialer_events), _) = + NetworkBuilder::new(runtime.executor(), dialer_peer_id, dialer_addr) + .transport(TransportType::TcpNoise) + .trusted_peers(trusted_peers.clone()) + .identity_keys((dialer_identity_private_key, dialer_identity_public_key)) + .signing_keys((dialer_signing_private_key, dialer_signing_public_key)) + .seed_peers( + [(listener_peer_id, vec![listen_addr])] + .iter() + .cloned() + .collect(), + ) + .discovery_interval_ms(HOUR_IN_MS) + .consensus_protocols(vec![ProtocolId::from_static(CONSENSUS_RPC_PROTOCOL)]) + .rpc_protocols(vec![ProtocolId::from_static(CONSENSUS_RPC_PROTOCOL)]) + .build(); + + // Wait for establishing connection + let first_dialer_event = block_on(dialer_events.next()).unwrap().unwrap(); + assert_eq!(first_dialer_event, Event::NewPeer(listener_peer_id)); + let first_listener_event = block_on(listener_events.next()).unwrap().unwrap(); + assert_eq!(first_listener_event, Event::NewPeer(dialer_peer_id)); + + // Compose RequestBlock message and RespondBlock message with `msg_len` bytes payload + let req = compose_request_block(); + let res = compose_respond_block(*msg_len); + + // The listener side keeps receiving RPC requests and sending responses back + let f_listener = async move { + while let Some(Ok(event)) = listener_events.next().await { + match event { + Event::RpcRequest((_, _, res_tx)) => res_tx + .send(Ok(Bytes::from( + res.clone() + .write_to_bytes() + .expect("fail to serialize proto"), + ))) + .expect("fail to send rpc response to network"), + event => panic!("Unexpected event: {:?}", event), + } + } + }; + runtime.spawn(f_listener.boxed().unit_error().compat()); + + // The dialer side keeps sending RPC requests. In each iteration of the benchmark, it sends + // NUM_MSGS requests and blocks on getting the responses. + b.iter(|| { + let mut requests = FuturesUnordered::new(); + for _ in 0..NUM_MSGS { + requests.push(request_block( + dialer_sender.clone(), + listener_peer_id, + req.clone(), + )); + } + while let Some(res) = block_on(requests.next()) { + let _ = res.unwrap(); + } + }); + block_on(runtime.shutdown_now().compat()).unwrap(); +} + +async fn request_block( + mut sender: ConsensusNetworkSender, + recipient: PeerId, + req_msg: RequestBlock, +) -> Result { + sender + .request_block(recipient, req_msg, Duration::from_secs(15)) + .await +} + +fn compose_request_block() -> RequestBlock { + let mut req = RequestBlock::new(); + req.set_block_id(vec![0u8; 32].into()); + req +} + +fn compose_respond_block(msg_len: usize) -> ConsensusMsg { + let mut msg = ConsensusMsg::new(); + let res = msg.mut_respond_block(); + let mut block = Block::new(); + block.set_payload(vec![0u8; msg_len].into()); + res.mut_blocks().push(block); + msg +} + +fn network_crate_benchmark(c: &mut Criterion) { + ::logger::try_init_for_testing(); + + // Parameterize benchmarks over the message length. + let msg_lens = vec![32usize, 256, 1 * KiB, 4 * KiB, 64 * KiB, 256 * KiB, 1 * MiB]; + + c.bench( + "network_crate_benchmark", + ParameterizedBenchmark::new("direct_send", direct_send_bench, msg_lens) + .with_function("rpc", rpc_bench) + .sample_size(10) + .plot_config(PlotConfiguration::default().summary_scale(AxisScale::Logarithmic)) + .throughput(|msg_len| Throughput::Bytes((*msg_len as u32) * NUM_MSGS)), + ); +} + +criterion_group!(benches, network_crate_benchmark); +criterion_main!(benches); diff --git a/network/benches/socket_muxer_bench.rs b/network/benches/socket_muxer_bench.rs new file mode 100644 index 0000000000000..3dc5d0f8b573d --- /dev/null +++ b/network/benches/socket_muxer_bench.rs @@ -0,0 +1,506 @@ +// Copyright (c) The Libra Core Contributors +// SPDX-License-Identifier: Apache-2.0 + +#![feature(async_await)] +// Allow KiB, MiB consts +#![allow(non_upper_case_globals, non_snake_case)] +// Allow fns to take &usize, since criterion only passes parameters by ref +#![allow(clippy::trivially_copy_pass_by_ref)] +// Allow writing 1 * KiB or 1 * MiB +#![allow(clippy::identity_op)] + +//! Network Benchmarks +//! ================== +//! +//! The `socket_muxer_bench` benchmarks measures the throughput of sending +//! messages over a single stream. +//! +//! # Run the benchmarks +//! +//! `cargo bench -p network` +//! +//! # View the report +//! +//! `open network/target/criterion/report/index.html` +//! +//! Note: gnuplot must be installed to generate benchmark plots. + +use bytes::Bytes; +use criterion::{ + criterion_group, criterion_main, AxisScale, Bencher, Criterion, ParameterizedBenchmark, + PlotConfiguration, Throughput, +}; +use futures::{ + channel::oneshot, + compat::Sink01CompatExt, + executor::block_on, + future::{Future, FutureExt, TryFutureExt}, + io::{AsyncRead, AsyncReadExt, AsyncWrite}, + sink::{Sink, SinkExt}, + stream::{self, Stream, StreamExt}, +}; +use memsocket::MemorySocket; +use netcore::{ + multiplexing::{yamux::Yamux, StreamMultiplexer}, + transport::{ + memory::MemoryTransport, + tcp::{TcpSocket, TcpTransport}, + Transport, TransportExt, + }, +}; +use noise::{NoiseConfig, NoiseSocket}; +use parity_multiaddr::Multiaddr; +use std::{fmt::Debug, sync::Arc, time::Duration}; +use tokio::{ + codec::Framed, + runtime::{Runtime, TaskExecutor}, +}; +use unsigned_varint::codec::UviBytes; + +const KiB: usize = 1 << 10; +const MiB: usize = 1 << 20; + +// The number of messages to send per `Bencher::iter`. We also flush to ensure +// we measure all the message being sent. +const SENDS_PER_ITER: usize = 100; + +/// Build a MemorySocket + Noise transport +fn build_memsocket_noise_transport() -> impl Transport> { + MemoryTransport::default().and_then(move |socket, origin| { + async move { + let noise_config = Arc::new(NoiseConfig::new_random()); + let (_remote_static_key, socket) = + noise_config.upgrade_connection(socket, origin).await?; + Ok(socket) + } + }) +} + +/// Build a MemorySocket + Muxer transport +fn build_memsocket_muxer_transport() -> impl Transport { + MemoryTransport::default().and_then(Yamux::upgrade_connection) +} + +/// Buikd a MemorySocket + Noise + Muxer transport +fn build_memsocket_noise_muxer_transport() -> impl Transport { + MemoryTransport::default() + .and_then(move |socket, origin| { + async move { + let noise_config = Arc::new(NoiseConfig::new_random()); + let (_remote_static_key, socket) = + noise_config.upgrade_connection(socket, origin).await?; + Ok(socket) + } + }) + .and_then(Yamux::upgrade_connection) +} + +/// Build a Tcp + Noise transport +fn build_tcp_noise_transport() -> impl Transport> { + TcpTransport::default().and_then(move |socket, origin| { + async move { + let noise_config = Arc::new(NoiseConfig::new_random()); + let (_remote_static_key, socket) = + noise_config.upgrade_connection(socket, origin).await?; + Ok(socket) + } + }) +} + +/// Build a Tcp + Muxer transport +fn build_tcp_muxer_transport() -> impl Transport { + TcpTransport::default().and_then(Yamux::upgrade_connection) +} + +/// Build a Tcp + Noise + Muxer transport +fn build_tcp_noise_muxer_transport() -> impl Transport { + TcpTransport::default() + .and_then(move |socket, origin| { + async move { + let noise_config = Arc::new(NoiseConfig::new_random()); + let (_remote_static_key, socket) = + noise_config.upgrade_connection(socket, origin).await?; + Ok(socket) + } + }) + .and_then(Yamux::upgrade_connection) +} + +/// Spawn a Future on an executor, but send the output over oneshot channel. +fn spawn_with_handle(executor: &TaskExecutor, f: F) -> oneshot::Receiver +where + F: Future + Send + 'static, + F::Output: Send, +{ + let (tx, rx) = oneshot::channel(); + let f_send = async move { + let out = f.await; + let _ = tx.send(out); + }; + executor.spawn(f_send.boxed().unit_error().compat()); + rx +} + +/// Server side handler for send throughput benchmark when the messages are sent +/// over a simple stream (tcp or in-memory). +async fn server_stream_handler(mut server_listener: L) -> impl Stream +where + L: Stream> + Unpin, + I: Future>, + S: AsyncRead + AsyncWrite + Unpin, + E: ::std::error::Error, +{ + // Wait for next inbound connection + let (f_stream, _) = server_listener.next().await.unwrap().unwrap(); + let stream = f_stream.await.unwrap(); + let mut stream = Framed::new(stream.compat(), UviBytes::::default()).sink_compat(); + + // Drain all messages from the client. + while let Some(_) = stream.next().await {} + stream.close().await.unwrap(); + + // Return stream so we drop after runtime shuts down to avoid race + stream +} + +/// Server side handler for send throughput benchmark when the messages are sent +/// over a muxer substream. +async fn server_muxer_handler(mut server_listener: L) -> (M, impl Stream) +where + L: Stream> + Unpin, + I: Future>, + M: StreamMultiplexer, + E: ::std::error::Error, +{ + // Wait for next inbound connection + let (f_muxer, _) = server_listener.next().await.unwrap().unwrap(); + let muxer = f_muxer.await.unwrap(); + + // Wait for inbound client substream + let mut muxer_inbounds = muxer.listen_for_inbound(); + let substream = muxer_inbounds.next().await.unwrap().unwrap(); + let mut stream = Framed::new(substream.compat(), UviBytes::::default()).sink_compat(); + + // Drain all messages from the client. + while let Some(_) = stream.next().await {} + stream.close().await.unwrap(); + + // Return muxer and stream so we drop after runtime shuts down to avoid race + (muxer, stream) +} + +/// The tight inner loop we're actually benchmarking. In this benchmark, we simply +/// measure the throughput of sending many messages of size `msg_len` over +/// `client_stream`. +fn bench_client_send(b: &mut Bencher, msg_len: usize, client_stream: &mut S) +where + S: Sink + Unpin, + S::SinkError: Debug, +{ + // Benchmark sending over the in-memory stream. + let data = Bytes::from(vec![0u8; msg_len]); + b.iter(|| { + // Create a stream of messages to send + let mut data_stream = stream::repeat(data.clone()).take(SENDS_PER_ITER as u64); + // Send the batch of messages. Note that `Sink::send_all` will flush the + // sink after exhausting the `data_stream`, which is necessary to ensure + // we measure sending all of the messages. + block_on(client_stream.send_all(&mut data_stream)).unwrap(); + }); + + // Shutdown client stream. + block_on(client_stream.close()).unwrap(); +} + +/// Setup and benchmark the client side for the simple stream case +/// (tcp or in-memory). +fn bench_client_stream_send( + b: &mut Bencher, + msg_len: usize, + runtime: &mut Runtime, + server_addr: Multiaddr, + client_transport: T, +) -> impl Stream +where + T: Transport + 'static, + S: AsyncRead + AsyncWrite + Unpin + Send + 'static, +{ + // Client dials the server. Some of our transports have timeouts built in, + // which means the futures must be run on a tokio Runtime. + let client_socket = runtime + .block_on(client_transport.dial(server_addr).unwrap().boxed().compat()) + .unwrap(); + let mut client_stream = + Framed::new(client_socket.compat(), UviBytes::::default()).sink_compat(); + + // Benchmark client sending data to server. + bench_client_send(b, msg_len, &mut client_stream); + + // Return the stream so we can drop it after the bench completes + client_stream +} + +/// Setup and benchmark the client side for the muxer substream case (yamux). +fn bench_client_muxer_send( + b: &mut Bencher, + msg_len: usize, + runtime: &mut Runtime, + server_addr: Multiaddr, + client_transport: T, +) -> (M, impl Stream) +where + T: Transport + Send + 'static, + M: StreamMultiplexer + 'static, +{ + // Client dials the server. Some of our transports have timeouts built in, + // which means the futures must be run on a tokio Runtime. + let f_client = async move { + let client_muxer = client_transport.dial(server_addr).unwrap().await.unwrap(); + let client_substream = client_muxer.open_outbound().await.unwrap(); + (client_muxer, client_substream) + }; + let (client_muxer, client_substream) = runtime + .block_on(f_client.boxed().unit_error().compat()) + .unwrap(); + let mut client_stream = + Framed::new(client_substream.compat(), UviBytes::::default()).sink_compat(); + + // Benchmark client sending data to server. + bench_client_send(b, msg_len, &mut client_stream); + + // Return the muxer and stream so we can drop them after the bench completes + (client_muxer, client_stream) +} + +/// Benchmark the throughput of sending messages of size `msg_len` over an +/// in-memory socket. +fn bench_memsocket_send(b: &mut Bencher, msg_len: &usize) { + let mut runtime = Runtime::new().unwrap(); + let executor = runtime.executor(); + + let client_transport = MemoryTransport::default(); + let server_transport = MemoryTransport::default(); + let (server_listener, server_addr) = server_transport + .listen_on("/memory/0".parse().unwrap()) + .unwrap(); + + // Server waits for client connection then reads all messages. + let f_server = spawn_with_handle(&executor, server_stream_handler(server_listener)); + + // Benchmark sending some data to the server. + let _client_stream = + bench_client_stream_send(b, *msg_len, &mut runtime, server_addr, client_transport); + + // Wait for server task to finish. + let _server_stream = block_on(f_server).unwrap(); +} + +/// Benchmark the throughput of sending messages of size `msg_len` over an +/// in-memory socket with Noise encryption. +fn bench_memsocket_noise_send(b: &mut Bencher, msg_len: &usize) { + let mut runtime = Runtime::new().unwrap(); + let executor = runtime.executor(); + + let client_transport = build_memsocket_noise_transport(); + let server_transport = build_memsocket_noise_transport(); + let (server_listener, server_addr) = server_transport + .listen_on("/memory/0".parse().unwrap()) + .unwrap(); + + // Server waits for client connection then reads all messages. + let f_server = spawn_with_handle(&executor, server_stream_handler(server_listener)); + + // Benchmark sending some data to the server. + let _client_stream = + bench_client_stream_send(b, *msg_len, &mut runtime, server_addr, client_transport); + + // Wait for server task to finish. + let _server_stream = block_on(f_server).unwrap(); +} + +/// Benchmark the throughput of sending messages of size `msg_len` over a muxer +/// over an in-memory socket. +fn bench_memsocket_muxer_send(b: &mut Bencher, msg_len: &usize) { + let mut runtime = Runtime::new().unwrap(); + let executor = runtime.executor(); + + let client_transport = build_memsocket_muxer_transport(); + let server_transport = build_memsocket_muxer_transport(); + let (server_listener, server_addr) = server_transport + .listen_on("/memory/0".parse().unwrap()) + .unwrap(); + + // Server waits for client connection and substream, then reads all messages. + let f_server = spawn_with_handle(&executor, server_muxer_handler(server_listener)); + + // Benchmark sending some data to the server. + let (_client_muxer, _client_stream) = + bench_client_muxer_send(b, *msg_len, &mut runtime, server_addr, client_transport); + + // Wait for server task to finish. + let (_server_muxer, _server_stream) = block_on(f_server).unwrap(); +} + +/// Benchmark the throughput of sending messages of size`msg_len` over a muxer +/// over an in-memory socket with noise encryption +fn bench_memsocket_noise_muxer_send(b: &mut Bencher, msg_len: &usize) { + let mut runtime = Runtime::new().unwrap(); + let executor = runtime.executor(); + + let client_transport = build_memsocket_noise_muxer_transport(); + let server_transport = build_memsocket_noise_muxer_transport(); + let (server_listener, server_addr) = server_transport + .listen_on("/memory/0".parse().unwrap()) + .unwrap(); + + // Server waits for client connection and substream, then reads all messages. + let f_server = spawn_with_handle(&executor, server_muxer_handler(server_listener)); + + // Benchmark sending some data to the server. + let (_client_muxer, _client_stream) = + bench_client_muxer_send(b, *msg_len, &mut runtime, server_addr, client_transport); + + // Wait for server task to finish. + let (_server_muxer, _server_stream) = block_on(f_server).unwrap(); +} + +/// Benchmark the throughput of sending messages of size `msg_len` over tcp +/// loopback. +fn bench_tcp_send(b: &mut Bencher, msg_len: &usize) { + let mut runtime = Runtime::new().unwrap(); + let executor = runtime.executor(); + + let client_transport = TcpTransport::default(); + let server_transport = TcpTransport::default(); + let (server_listener, server_addr) = server_transport + .listen_on("/ip4/127.0.0.1/tcp/0".parse().unwrap()) + .unwrap(); + + // Server waits for client connection then reads all messages. + let f_server = spawn_with_handle(&executor, server_stream_handler(server_listener)); + + // Benchmark sending some data to the server. + let _client_stream = + bench_client_stream_send(b, *msg_len, &mut runtime, server_addr, client_transport); + + // Wait for server task to finish. + let _server_stream = block_on(f_server).unwrap(); +} + +/// Benchmark the throughput of sending messages of size `msg_len` over tcp +/// loopback with Noise encryption. +fn bench_tcp_noise_send(b: &mut Bencher, msg_len: &usize) { + let mut runtime = Runtime::new().unwrap(); + let executor = runtime.executor(); + + let client_transport = build_tcp_noise_transport(); + let server_transport = build_tcp_noise_transport(); + let (server_listener, server_addr) = server_transport + .listen_on("/ip4/127.0.0.1/tcp/0".parse().unwrap()) + .unwrap(); + + // Server waits for client connection then reads all messages. + let f_server = spawn_with_handle(&executor, server_stream_handler(server_listener)); + + // Benchmark sending some data to the server. + let _client_stream = + bench_client_stream_send(b, *msg_len, &mut runtime, server_addr, client_transport); + + // Wait for server task to finish. + let _server_stream = block_on(f_server).unwrap(); +} + +/// Benchmark the throughput of sending messages of size `msg_len` over a muxer +/// over tcp loopback. +fn bench_tcp_muxer_send(b: &mut Bencher, msg_len: &usize) { + let mut runtime = Runtime::new().unwrap(); + let executor = runtime.executor(); + + let client_transport = build_tcp_muxer_transport(); + let server_transport = build_tcp_muxer_transport(); + let (server_listener, server_addr) = server_transport + .listen_on("/ip4/127.0.0.1/tcp/0".parse().unwrap()) + .unwrap(); + + // Server waits for client connection and substream, then reads all messages. + let f_server = spawn_with_handle(&executor, server_muxer_handler(server_listener)); + + // Benchmark sending some data to the server. + let (_client_muxer, _client_stream) = + bench_client_muxer_send(b, *msg_len, &mut runtime, server_addr, client_transport); + + // Wait for server task to finish. + let (_server_muxer, _server_stream) = block_on(f_server).unwrap(); +} + +/// Benchmark the throughput of sending messages of size `msg_len` over a muxer over tcp lookback +/// with noise encryption. +fn bench_tcp_noise_muxer_send(b: &mut Bencher, msg_len: &usize) { + let mut runtime = Runtime::new().unwrap(); + let executor = runtime.executor(); + + let client_transport = build_tcp_noise_muxer_transport(); + let server_transport = build_tcp_noise_muxer_transport(); + let (server_listener, server_addr) = server_transport + .listen_on("/ip4/127.0.0.1/tcp/0".parse().unwrap()) + .unwrap(); + + // Server waits for client connection and substream, then reads all messages. + let f_server = spawn_with_handle(&executor, server_muxer_handler(server_listener)); + + // Benchmark sending some data to the server. + let (_client_muxer, _client_stream) = + bench_client_muxer_send(b, *msg_len, &mut runtime, server_addr, client_transport); + + // Wait for server task to finish. + let (_server_muxer, _server_stream) = block_on(f_server).unwrap(); +} + +/// Measure sending messages of varying sizes over: +/// 1. in-memory transport +/// 2. in-memory transport + noise encryption +/// 3. in-memory transport + yamux +/// 4. in-memory transport + noise encryption + yamux +/// 5. tcp transport +/// 6. tcp transport + noise encryption +/// 7. tcp transport + yamux +/// 8. tcp transport + noise encryption + yamux +/// +/// Important: +/// 1. Measures single-threaded send since only one sending task is used, so any +/// muxer lock contention is likely not measured. +/// 2. We use a `UviBytes` codec to frame the benchmark messages since this is +/// what we currently use in the codebase; however, this seems to add not +/// insignificant overhead and might change in the near future. +/// 3. TCP benchmarks are only over loopback. +/// 4. Socket buffer sizes and buffering strategies are not yet optimized. +fn socket_muxer_bench(c: &mut Criterion) { + ::logger::try_init_for_testing(); + + // Parameterize benchmarks over the message length. + let msg_lens = vec![32usize, 256, 1 * KiB, 4 * KiB, 64 * KiB, 256 * KiB, 1 * MiB]; + + c.bench( + "socket_muxer_send_throughput", + ParameterizedBenchmark::new("memsocket", bench_memsocket_send, msg_lens) + .with_function("memsocket+noise", bench_memsocket_noise_send) + .with_function("memsocket+muxer", bench_memsocket_muxer_send) + .with_function("memsocket+noise+muxer", bench_memsocket_noise_muxer_send) + .with_function("tcp", bench_tcp_send) + .with_function("tcp+noise", bench_tcp_noise_send) + .with_function("tcp+muxer", bench_tcp_muxer_send) + .with_function("tcp+noise+muxer", bench_tcp_noise_muxer_send) + .warm_up_time(Duration::from_secs(2)) + .measurement_time(Duration::from_secs(2)) + .sample_size(10) + .plot_config(PlotConfiguration::default().summary_scale(AxisScale::Logarithmic)) + .throughput(|msg_len| { + let msg_len = *msg_len as u32; + let num_msgs = SENDS_PER_ITER as u32; + Throughput::Bytes(msg_len * num_msgs) + }), + ); +} + +criterion_group!(network_benches, socket_muxer_bench); +criterion_main!(network_benches); diff --git a/network/build.rs b/network/build.rs new file mode 100644 index 0000000000000..5222adebd43dc --- /dev/null +++ b/network/build.rs @@ -0,0 +1,27 @@ +// Copyright (c) The Libra Core Contributors +// SPDX-License-Identifier: Apache-2.0 + +/// Builds the proto files needed for the network crate. +fn main() { + let proto_files = [ + "src/proto/network.proto", + "src/proto/mempool.proto", + "src/proto/consensus.proto", + ]; + + for file in &proto_files { + println!("cargo:rerun-if-changed={}", file); + } + + protoc_rust::run(protoc_rust::Args { + out_dir: "src/proto", + input: &proto_files, + includes: &["../types/src/proto", "src/proto"], + customize: protoc_rust::Customize { + carllerche_bytes_for_bytes: Some(true), + carllerche_bytes_for_string: Some(true), + ..Default::default() + }, + }) + .expect("protoc"); +} diff --git a/network/memsocket/Cargo.toml b/network/memsocket/Cargo.toml new file mode 100644 index 0000000000000..066a38e3e1e2f --- /dev/null +++ b/network/memsocket/Cargo.toml @@ -0,0 +1,12 @@ +[package] +name = "memsocket" +version = "0.1.0" +authors = ["Libra Association "] +license = "Apache-2.0" +edition = "2018" +publish = false + +[dependencies] +bytes = "0.4.12" +lazy_static = "1.3.0" +futures = { version = "=0.3.0-alpha.16", package = "futures-preview" } diff --git a/network/memsocket/src/lib.rs b/network/memsocket/src/lib.rs new file mode 100644 index 0000000000000..5db2d3f7a6758 --- /dev/null +++ b/network/memsocket/src/lib.rs @@ -0,0 +1,424 @@ +// Copyright (c) The Libra Core Contributors +// SPDX-License-Identifier: Apache-2.0 + +use bytes::{Buf, Bytes, IntoBuf}; +use futures::{ + channel::mpsc::{self, UnboundedReceiver, UnboundedSender}, + io::{AsyncRead, AsyncWrite, Error, ErrorKind, Result}, + ready, + stream::{FusedStream, Stream}, + task::{Context, Poll}, +}; +use lazy_static::lazy_static; +use std::{collections::HashMap, num::NonZeroU16, pin::Pin, sync::Mutex}; + +lazy_static! { + static ref SWITCHBOARD: Mutex = Mutex::new(SwitchBoard(HashMap::default(), 1)); +} + +struct SwitchBoard(HashMap>, u16); + +/// An in-memory socket server, listening for connections. +/// +/// After creating a `MemoryListener` by [`bind`]ing it to a socket address, it listens +/// for incoming connections. These can be accepted by awaiting elements from the +/// async stream of incoming connections, [`incoming`][`MemoryListener::incoming`]. +/// +/// The socket will be closed when the value is dropped. +/// +/// [`bind`]: #method.bind +/// [`MemoryListener::incoming`]: #method.incoming +/// +/// # Examples +/// +/// ```rust,no_run +/// #![feature(async_await)] +/// use std::io::Result; +/// +/// use memsocket::{MemoryListener, MemorySocket}; +/// use futures::prelude::*; +/// +/// async fn write_stormlight(mut stream: MemorySocket) -> Result<()> { +/// let msg = b"The most important step a person can take is always the next one."; +/// stream.write_all(msg).await?; +/// stream.flush().await +/// } +/// +/// async fn listen() -> Result<()> { +/// let mut listener = MemoryListener::bind(16)?; +/// let mut incoming = listener.incoming(); +/// +/// // accept connections and process them serially +/// while let Some(stream) = incoming.next().await { +/// write_stormlight(stream?).await?; +/// } +/// Ok(()) +/// } +/// ``` +#[derive(Debug)] +pub struct MemoryListener { + incoming: UnboundedReceiver, + port: NonZeroU16, +} + +impl Drop for MemoryListener { + fn drop(&mut self) { + let mut switchboard = (&*SWITCHBOARD).lock().unwrap(); + // Remove the Sending side of the channel in the switchboard when + // MemoryListener is dropped + switchboard.0.remove(&self.port); + } +} + +impl MemoryListener { + /// Creates a new `MemoryListener` which will be bound to the specified + /// port. + /// + /// The returned listener is ready for accepting connections. + /// + /// Binding with a port number of 0 will request that a port be assigned + /// to this listener. The port allocated can be queried via the + /// [`local_addr`] method. + /// + /// # Examples + /// Create a MemoryListener bound to port 16: + /// + /// ```rust,no_run + /// use memsocket::MemoryListener; + /// + /// # fn main () -> ::std::io::Result<()> { + /// let listener = MemoryListener::bind(16)?; + /// # Ok(())} + /// ``` + /// + /// [`local_addr`]: #method.local_addr + pub fn bind(port: u16) -> Result { + let mut switchboard = (&*SWITCHBOARD).lock().unwrap(); + + // Get the port we should bind to. If 0 was given, use a random port + let port = if let Some(port) = NonZeroU16::new(port) { + if switchboard.0.contains_key(&port) { + return Err(ErrorKind::AddrInUse.into()); + } + port + } else { + loop { + let port = match NonZeroU16::new(switchboard.1) { + Some(p) => p, + None => unreachable!(), + }; + switchboard.1 += 1; + if !switchboard.0.contains_key(&port) { + break port; + } + } + }; + + let (sender, receiver) = mpsc::unbounded(); + switchboard.0.insert(port, sender); + + Ok(Self { + incoming: receiver, + port, + }) + } + + /// Returns the local address that this listener is bound to. + /// + /// This can be useful, for example, when binding to port 0 to figure out + /// which port was actually bound. + /// + /// # Examples + /// + /// ```rust + /// use memsocket::MemoryListener; + /// + /// # fn main () -> ::std::io::Result<()> { + /// let listener = MemoryListener::bind(16)?; + /// + /// assert_eq!(listener.local_addr(), 16); + /// # Ok(())} + /// ``` + pub fn local_addr(&self) -> u16 { + self.port.get() + } + + /// Consumes this listener, returning a stream of the sockets this listener + /// accepts. + /// + /// This method returns an implementation of the `Stream` trait which + /// resolves to the sockets the are accepted on this listener. + /// + /// # Examples + /// + /// ```rust,no_run + /// #![feature(async_await)] + /// use futures::prelude::*; + /// use memsocket::MemoryListener; + /// + /// # async fn work () -> ::std::io::Result<()> { + /// let mut listener = MemoryListener::bind(16)?; + /// let mut incoming = listener.incoming(); + /// + /// // accept connections and process them serially + /// while let Some(stream) = incoming.next().await { + /// match stream { + /// Ok(stream) => { + /// println!("new connection!"); + /// }, + /// Err(e) => { /* connection failed */ } + /// } + /// } + /// # Ok(())} + /// ``` + pub fn incoming(&mut self) -> Incoming<'_> { + Incoming { inner: self } + } + + fn poll_accept(&mut self, context: &mut Context) -> Poll> { + match Pin::new(&mut self.incoming).poll_next(context) { + Poll::Ready(Some(socket)) => Poll::Ready(Ok(socket)), + Poll::Ready(None) => { + let err = Error::new(ErrorKind::Other, "MemoryListener unknown error"); + Poll::Ready(Err(err)) + } + Poll::Pending => Poll::Pending, + } + } +} + +/// Stream returned by the `MemoryListener::incoming` function representing the +/// stream of sockets received from a listener. +#[must_use = "streams do nothing unless polled"] +#[derive(Debug)] +pub struct Incoming<'a> { + inner: &'a mut MemoryListener, +} + +impl<'a> Stream for Incoming<'a> { + type Item = Result; + + fn poll_next(mut self: Pin<&mut Self>, context: &mut Context) -> Poll> { + let socket = ready!(self.inner.poll_accept(context)?); + Poll::Ready(Some(Ok(socket))) + } +} + +/// An in-memory stream between two local sockets. +/// +/// A `MemorySocket` can either be created by connecting to an endpoint, via the +/// [`connect`] method, or by [accepting] a connection from a [listener]. +/// It can be read or written to using the `AsyncRead`, `AsyncWrite`, and related +/// extension traits in `futures::io`. +/// +/// # Examples +/// +/// ```rust, no_run +/// #![feature(async_await)] +/// use futures::prelude::*; +/// use memsocket::MemorySocket; +/// +/// # async fn run() -> ::std::io::Result<()> { +/// let (mut socket_a, mut socket_b) = MemorySocket::new_pair(); +/// +/// socket_a.write_all(b"stormlight").await?; +/// socket_a.flush().await?; +/// +/// let mut buf = [0; 10]; +/// socket_b.read_exact(&mut buf).await?; +/// assert_eq!(&buf, b"stormlight"); +/// +/// # Ok(())} +/// ``` +/// +/// [`connect`]: struct.MemorySocket.html#method.connect +/// [accepting]: struct.MemoryListener.html#method.accept +/// [listener]: struct.MemoryListener.html +#[derive(Debug)] +pub struct MemorySocket { + incoming: UnboundedReceiver, + outgoing: UnboundedSender, + current_buffer: Option<::Buf>, + seen_eof: bool, +} + +impl MemorySocket { + /// Construct both sides of an in-memory socket. + /// + /// # Examples + /// + /// ```rust + /// use memsocket::MemorySocket; + /// + /// # fn main() { + /// let (socket_a, socket_b) = MemorySocket::new_pair(); + /// # } + /// ``` + pub fn new_pair() -> (Self, Self) { + let (a_tx, a_rx) = mpsc::unbounded(); + let (b_tx, b_rx) = mpsc::unbounded(); + let a = Self { + incoming: a_rx, + outgoing: b_tx, + current_buffer: None, + seen_eof: false, + }; + let b = Self { + incoming: b_rx, + outgoing: a_tx, + current_buffer: None, + seen_eof: false, + }; + + (a, b) + } + + /// Create a new in-memory Socket connected to the specified port. + /// + /// This function will create a new MemorySocket socket and attempt to connect it to + /// the `port` provided. + /// + /// # Examples + /// + /// ```rust,no_run + /// use memsocket::MemorySocket; + /// + /// # fn main () -> ::std::io::Result<()> { + /// let socket = MemorySocket::connect(16)?; + /// # Ok(())} + /// ``` + pub fn connect(port: u16) -> Result { + let mut switchboard = (&*SWITCHBOARD).lock().unwrap(); + + // Find port to connect to + let port = if let Some(port) = NonZeroU16::new(port) { + port + } else { + return Err(ErrorKind::AddrNotAvailable.into()); + }; + + let sender = if let Some(s) = switchboard.0.get_mut(&port) { + s + } else { + return Err(ErrorKind::AddrNotAvailable.into()); + }; + + let (socket_a, socket_b) = Self::new_pair(); + // Send the socket to the listener + if let Err(e) = sender.unbounded_send(socket_a) { + if e.is_disconnected() { + return Err(ErrorKind::AddrNotAvailable.into()); + } + + unreachable!(); + } + + Ok(socket_b) + } +} + +impl AsyncRead for MemorySocket { + /// Attempt to read from the `AsyncRead` into `buf`. + fn poll_read( + mut self: Pin<&mut Self>, + mut context: &mut Context, + buf: &mut [u8], + ) -> Poll> { + if self.incoming.is_terminated() { + if self.seen_eof { + return Poll::Ready(Err(ErrorKind::UnexpectedEof.into())); + } else { + self.seen_eof = true; + return Poll::Ready(Ok(0)); + } + } + + let mut bytes_read = 0; + + loop { + // If we're already filled up the buffer then we can return + if bytes_read == buf.len() { + return Poll::Ready(Ok(bytes_read)); + } + + match self.current_buffer { + // We have data to copy to buf + Some(ref mut current_buffer) if current_buffer.has_remaining() => { + let bytes_to_read = + ::std::cmp::min(buf.len() - bytes_read, current_buffer.remaining()); + debug_assert!(bytes_to_read > 0); + + current_buffer + .take(bytes_to_read) + .copy_to_slice(&mut buf[bytes_read..(bytes_read + bytes_to_read)]); + bytes_read += bytes_to_read; + } + + // Either we've exhausted our current buffer or don't have one + _ => { + self.current_buffer = { + match Pin::new(&mut self.incoming).poll_next(&mut context) { + Poll::Pending => { + // If we've read anything up to this point return the bytes read + if bytes_read > 0 { + return Poll::Ready(Ok(bytes_read)); + } else { + return Poll::Pending; + } + } + Poll::Ready(Some(buf)) => Some(buf.into_buf()), + Poll::Ready(None) => return Poll::Ready(Ok(bytes_read)), + } + }; + } + } + } + } +} + +impl AsyncWrite for MemorySocket { + /// Attempt to write bytes from `buf` into the outgoing channel. + fn poll_write( + mut self: Pin<&mut Self>, + context: &mut Context, + buf: &[u8], + ) -> Poll> { + let len = buf.len(); + + match self.outgoing.poll_ready(context) { + Poll::Ready(Ok(())) => { + if let Err(e) = self.outgoing.start_send(buf.into()) { + if e.is_disconnected() { + return Poll::Ready(Err(Error::new(ErrorKind::BrokenPipe, e))); + } + + // Unbounded channels should only ever have "Disconnected" errors + unreachable!(); + } + } + Poll::Ready(Err(e)) => { + if e.is_disconnected() { + return Poll::Ready(Err(Error::new(ErrorKind::BrokenPipe, e))); + } + + // Unbounded channels should only ever have "Disconnected" errors + unreachable!(); + } + Poll::Pending => return Poll::Pending, + } + + Poll::Ready(Ok(len)) + } + + /// Attempt to flush the channel. Cannot Fail. + fn poll_flush(self: Pin<&mut Self>, _context: &mut Context) -> Poll> { + Poll::Ready(Ok(())) + } + + /// Attempt to close the channel. Cannot Fail. + fn poll_close(self: Pin<&mut Self>, _context: &mut Context) -> Poll> { + self.outgoing.close_channel(); + + Poll::Ready(Ok(())) + } +} diff --git a/network/memsocket/tests/memory_listener.rs b/network/memsocket/tests/memory_listener.rs new file mode 100644 index 0000000000000..48102c5b0e824 --- /dev/null +++ b/network/memsocket/tests/memory_listener.rs @@ -0,0 +1,83 @@ +// Copyright (c) The Libra Core Contributors +// SPDX-License-Identifier: Apache-2.0 + +use futures::{ + executor::block_on, + io::{AsyncReadExt, AsyncWriteExt}, + stream::StreamExt, +}; +use memsocket::{MemoryListener, MemorySocket}; +use std::io::Result; + +#[test] +fn listener_bind() -> Result<()> { + let listener = MemoryListener::bind(42)?; + assert_eq!(listener.local_addr(), 42); + + Ok(()) +} + +#[test] +fn simple_connect() -> Result<()> { + let mut listener = MemoryListener::bind(10)?; + + let mut dialer = MemorySocket::connect(10)?; + let mut listener_socket = block_on(listener.incoming().next()).unwrap()?; + + block_on(dialer.write_all(b"foo"))?; + block_on(dialer.flush())?; + + let mut buf = [0; 3]; + block_on(listener_socket.read_exact(&mut buf))?; + assert_eq!(&buf, b"foo"); + + Ok(()) +} + +#[test] +fn listen_on_port_zero() -> Result<()> { + let mut listener = MemoryListener::bind(0)?; + let listener_addr = listener.local_addr(); + + let mut dialer = MemorySocket::connect(listener_addr)?; + let mut listener_socket = block_on(listener.incoming().next()).unwrap()?; + + block_on(dialer.write_all(b"foo"))?; + block_on(dialer.flush())?; + + let mut buf = [0; 3]; + block_on(listener_socket.read_exact(&mut buf))?; + assert_eq!(&buf, b"foo"); + + block_on(listener_socket.write_all(b"bar"))?; + block_on(listener_socket.flush())?; + + let mut buf = [0; 3]; + block_on(dialer.read_exact(&mut buf))?; + assert_eq!(&buf, b"bar"); + + Ok(()) +} + +#[test] +fn listener_correctly_frees_port_on_drop() -> Result<()> { + fn connect_on_port(port: u16) -> Result<()> { + let mut listener = MemoryListener::bind(port)?; + let mut dialer = MemorySocket::connect(port)?; + let mut listener_socket = block_on(listener.incoming().next()).unwrap()?; + + block_on(dialer.write_all(b"foo"))?; + block_on(dialer.flush())?; + + let mut buf = [0; 3]; + block_on(listener_socket.read_exact(&mut buf))?; + assert_eq!(&buf, b"foo"); + + Ok(()) + } + + connect_on_port(9)?; + connect_on_port(9)?; + + Ok(()) +} diff --git a/network/memsocket/tests/memory_socket.rs b/network/memsocket/tests/memory_socket.rs new file mode 100644 index 0000000000000..f4b62b938d89c --- /dev/null +++ b/network/memsocket/tests/memory_socket.rs @@ -0,0 +1,114 @@ +// Copyright (c) The Libra Core Contributors +// SPDX-License-Identifier: Apache-2.0 + +use futures::{ + executor::block_on, + io::{AsyncReadExt, AsyncWriteExt}, +}; +use memsocket::MemorySocket; +use std::io::Result; + +#[test] +fn simple_write_read() -> Result<()> { + let (mut a, mut b) = MemorySocket::new_pair(); + + block_on(a.write_all(b"hello world"))?; + block_on(a.flush())?; + drop(a); + + let mut v = Vec::new(); + block_on(b.read_to_end(&mut v))?; + assert_eq!(v, b"hello world"); + + Ok(()) +} + +#[test] +fn partial_read() -> Result<()> { + let (mut a, mut b) = MemorySocket::new_pair(); + + block_on(a.write_all(b"foobar"))?; + block_on(a.flush())?; + + let mut buf = [0; 3]; + block_on(b.read_exact(&mut buf))?; + assert_eq!(&buf, b"foo"); + block_on(b.read_exact(&mut buf))?; + assert_eq!(&buf, b"bar"); + + Ok(()) +} + +#[test] +fn partial_read_write_both_sides() -> Result<()> { + let (mut a, mut b) = MemorySocket::new_pair(); + + block_on(a.write_all(b"foobar"))?; + block_on(a.flush())?; + block_on(b.write_all(b"stormlight"))?; + block_on(b.flush())?; + + let mut buf_a = [0; 5]; + let mut buf_b = [0; 3]; + block_on(a.read_exact(&mut buf_a))?; + assert_eq!(&buf_a, b"storm"); + block_on(b.read_exact(&mut buf_b))?; + assert_eq!(&buf_b, b"foo"); + + block_on(a.read_exact(&mut buf_a))?; + assert_eq!(&buf_a, b"light"); + block_on(b.read_exact(&mut buf_b))?; + assert_eq!(&buf_b, b"bar"); + + Ok(()) +} + +#[test] +fn many_small_writes() -> Result<()> { + let (mut a, mut b) = MemorySocket::new_pair(); + + block_on(a.write_all(b"words"))?; + block_on(a.write_all(b" "))?; + block_on(a.write_all(b"of"))?; + block_on(a.write_all(b" "))?; + block_on(a.write_all(b"radiance"))?; + block_on(a.flush())?; + drop(a); + + let mut buf = [0; 17]; + block_on(b.read_exact(&mut buf))?; + assert_eq!(&buf, b"words of radiance"); + + Ok(()) +} + +#[test] +fn read_zero_bytes() -> Result<()> { + let (mut a, mut b) = MemorySocket::new_pair(); + + block_on(a.write_all(b"way of kings"))?; + block_on(a.flush())?; + + let mut buf = [0; 12]; + block_on(b.read_exact(&mut buf[0..0]))?; + assert_eq!(buf, [0; 12]); + block_on(b.read_exact(&mut buf))?; + assert_eq!(&buf, b"way of kings"); + + Ok(()) +} + +#[test] +fn read_bytes_with_large_buffer() -> Result<()> { + let (mut a, mut b) = MemorySocket::new_pair(); + + block_on(a.write_all(b"way of kings"))?; + block_on(a.flush())?; + + let mut buf = [0; 20]; + let bytes_read = block_on(b.read(&mut buf))?; + assert_eq!(bytes_read, 12); + assert_eq!(&buf[0..12], b"way of kings"); + + Ok(()) +} diff --git a/network/netcore/Cargo.toml b/network/netcore/Cargo.toml new file mode 100644 index 0000000000000..97c9394d5424e --- /dev/null +++ b/network/netcore/Cargo.toml @@ -0,0 +1,18 @@ +[package] +name = "netcore" +version = "0.1.0" +authors = ["Libra Association "] +license = "Apache-2.0" +publish = false +edition = "2018" + +[dependencies] +bytes = "0.4.12" +futures = { version = "=0.3.0-alpha.16", package = "futures-preview", features = ["io-compat", "compat"] } +futures_01 = { version = "0.1.25", package = "futures" } +pin-utils = "=0.1.0-alpha.4" +tokio = "0.1.16" +yamux = "0.2.1" +parity-multiaddr = "0.4.0" + +memsocket = { path = "../memsocket" } diff --git a/network/netcore/src/lib.rs b/network/netcore/src/lib.rs new file mode 100644 index 0000000000000..8b6e7a63673fc --- /dev/null +++ b/network/netcore/src/lib.rs @@ -0,0 +1,15 @@ +// Copyright (c) The Libra Core Contributors +// SPDX-License-Identifier: Apache-2.0 + +//! Core types and traits for building Peer to Peer networks. +//! +//! The `netcore` crate contains all of the core functionality needed to build a Peer to Peer +//! network from building `Transport`s and `StreamMultiplexer`s to negotiating protocols on a +//! socket. + +#![feature(async_await)] + +pub mod multiplexing; +pub mod negotiate; +pub mod transport; +mod utils; diff --git a/network/netcore/src/multiplexing/mod.rs b/network/netcore/src/multiplexing/mod.rs new file mode 100644 index 0000000000000..b029af7ae6373 --- /dev/null +++ b/network/netcore/src/multiplexing/mod.rs @@ -0,0 +1,61 @@ +// Copyright (c) The Libra Core Contributors +// SPDX-License-Identifier: Apache-2.0 + +//! Module responsible for defining and implementing Stream Multiplexing +//! +//! The main component of this module is the [`StreamMultiplexer`] trait, which +//! provides an interface for multiplexing multiple [`AsyncRead`]/[`AsyncWrite`] substreams over a +//! single underlying [`AsyncRead`]/[`AsyncWrite`] stream. [`Yamux`], an implementation of this +//! trait over [`TcpStream`], is also provided. +//! +//! [`StreamMultiplexer`]: crate::multiplexing::StreamMultiplexer +//! [`AsyncRead`]: futures::io::AsyncRead +//! [`AsyncWrite`]: futures::io::AsyncWrite +//! [`TcpStream`]: tokio::net::tcp::TcpStream +//! [`Yamux`]: crate::multiplexing::yamux::Yamux + +use futures::{ + future::Future, + io::{AsyncRead, AsyncWrite}, + stream::Stream, +}; +use std::{fmt::Debug, io}; + +pub mod yamux; + +/// A StreamMultiplexer is responsible for multiplexing multiple [`AsyncRead`]/[`AsyncWrite`] +/// streams over a single underlying [`AsyncRead`]/[`AsyncWrite`] stream. +/// +/// New substreams are opened either by [listening](StreamMultiplexer::listen_for_inbound) for +/// inbound substreams opened by the remote side or by [opening](StreamMultiplexer::open_outbound) +/// and outbound substream locally. +pub trait StreamMultiplexer: Debug + Send + Sync { + /// The type of substreams opened by this Multiplexer. + /// + /// Must implement both AsyncRead and AsyncWrite. + type Substream: AsyncRead + AsyncWrite + Send + Debug + Unpin; + + /// A stream of new [`Substreams`](StreamMultiplexer::Substream) opened by the remote side. + type Listener: Stream> + Send + Unpin; + + /// A pending [`Substream`](StreamMultiplexer::Substream) to be opened on the underlying + /// connection, obtained from [requesting a new substream](StreamMultiplexer::open_outbound). + type Outbound: Future> + Send; + + /// A pending request to shut down the underlying connection, obtained from + /// [closing](StreamMultiplexer::close). + type Close: Future> + Send; + + /// Returns a stream of new Substreams opened by the remote side. + fn listen_for_inbound(&self) -> Self::Listener; + + /// Requests that a new Substream be opened. + fn open_outbound(&self) -> Self::Outbound; + + /// Close and shutdown this [`StreamMultiplexer`]. + /// + /// After the returned future has resolved this multiplexer will be shutdown. All subsequent + /// reads or writes to any still existing handles to substreams opened through this multiplexer + /// must return EOF (in the case of a read), or an error. + fn close(&self) -> Self::Close; +} diff --git a/network/netcore/src/multiplexing/yamux.rs b/network/netcore/src/multiplexing/yamux.rs new file mode 100644 index 0000000000000..17f810e778207 --- /dev/null +++ b/network/netcore/src/multiplexing/yamux.rs @@ -0,0 +1,347 @@ +// Copyright (c) The Libra Core Contributors +// SPDX-License-Identifier: Apache-2.0 + +//! Implementation of [`StreamMultiplexer`] using the [`yamux`] protocol over TCP +//! +//! [`StreamMultiplexer`]: crate::multiplexing::StreamMultiplexer +//! [`yamux`]: https://github.com/hashicorp/yamux/blob/master/spec.md + +use crate::{ + multiplexing::StreamMultiplexer, + negotiate::{negotiate_inbound, negotiate_outbound_interactive}, + transport::ConnectionOrigin, +}; +use futures::{ + compat::{Compat, Compat01As03}, + future::{self, Future}, + io::{AsyncRead, AsyncWrite}, + stream::Stream, +}; +use futures_01::future::poll_fn as poll_fn_01; +use std::{ + fmt::Debug, + io, + pin::Pin, + task::{Context, Poll}, +}; +use yamux; + +/// Re-export `Mode` from the yamux crate +pub use yamux::Mode; + +/// The substream type produced by the `Yamux` multiplexer +pub type StreamHandle = Compat01As03>>; + +const YAMUX_PROTOCOL_NAME: &[u8] = b"/yamux/1.0.0"; + +#[derive(Debug)] +pub struct Yamux { + inner: yamux::Connection>, +} + +const MAX_BUFFER_SIZE: u32 = 8 * 1024 * 1024; // 8MB +const RECEIVE_WINDOW: u32 = 4 * 1024 * 1024; // 4MB + +impl Yamux +where + TSocket: AsyncRead + AsyncWrite + Send + Debug + Unpin, +{ + pub fn new(socket: TSocket, mode: Mode) -> Self { + let mut config = yamux::Config::default(); + // Use OnRead mode instead of OnReceive mode to provide back pressure to the sending side. + // Caveat: the OnRead mode has the risk of deadlock, where both sides send data larger than + // receive window and don't read before finishing writes. But it doesn't apply to our use + // cases. Some of our streams are unidirectional, where only one end writes data, e.g., + // Direct Send. Some of our streams are bidirectional, but only one end write data at a + // time, e.g., RPC. Some of our streams may write at the same time, but the frames are + // shorter than the receive window, e.g., protocol negotiation. + config.set_window_update_mode(yamux::WindowUpdateMode::OnRead); + // Because OnRead mode increases the RTT of window update, bigger buffer size and receive + // window size perform better. + config.set_max_buffer_size(MAX_BUFFER_SIZE as usize); + config + .set_receive_window(RECEIVE_WINDOW) + .expect("Invalid receive window size"); + let socket = Compat::new(socket); + Self { + inner: yamux::Connection::new(socket, config, mode), + } + } + + pub async fn upgrade_connection(socket: TSocket, origin: ConnectionOrigin) -> io::Result { + // Perform protocol negotiation + let (socket, proto) = match origin { + ConnectionOrigin::Inbound => negotiate_inbound(socket, [YAMUX_PROTOCOL_NAME]).await?, + ConnectionOrigin::Outbound => { + negotiate_outbound_interactive(socket, [YAMUX_PROTOCOL_NAME]).await? + } + }; + + assert_eq!(proto, YAMUX_PROTOCOL_NAME); + + let mode = match origin { + ConnectionOrigin::Inbound => Mode::Server, + ConnectionOrigin::Outbound => Mode::Client, + }; + + Ok(Yamux::new(socket, mode)) + } +} + +impl StreamMultiplexer for Yamux +where + TSocket: AsyncRead + AsyncWrite + Send + Debug + Unpin, +{ + type Substream = StreamHandle; + type Listener = Listener; + type Outbound = future::Ready>; + type Close = Close; + + fn listen_for_inbound(&self) -> Self::Listener { + Listener::new(self.inner.clone()) + } + + fn open_outbound(&self) -> Self::Outbound { + let output = match self.inner.open_stream() { + Ok(Some(substream)) => Ok(Compat01As03::new(substream)), + Ok(None) => Err(io::Error::new( + io::ErrorKind::Other, + "Unable to open substream", + )), + Err(e) => Err(io::Error::new(io::ErrorKind::Other, e)), + }; + + future::ready(output) + } + + fn close(&self) -> Self::Close { + Close::new(self.inner.clone()) + } +} + +#[derive(Debug)] +#[must_use = "futures do nothing unless polled"] +pub struct Close { + inner: yamux::Connection>, +} + +impl Close +where + TSocket: AsyncRead + AsyncWrite + Unpin, +{ + pub fn new(connection: yamux::Connection>) -> Self { + Self { inner: connection } + } +} + +impl Future for Close +where + TSocket: AsyncRead + AsyncWrite + Unpin, +{ + type Output = io::Result<()>; + + fn poll(self: Pin<&mut Self>, context: &mut Context) -> Poll { + let mut close_fut = Compat01As03::new(poll_fn_01(|| { + self.inner + .close() + .map_err(|e| io::Error::new(io::ErrorKind::Other, e)) + })); + Pin::new(&mut close_fut).poll(context) + } +} + +#[derive(Debug)] +#[must_use = "streams do nothing unless polled"] +pub struct Listener { + inner: Compat01As03>>, +} + +impl Listener +where + TSocket: AsyncRead + AsyncWrite + Unpin, +{ + pub fn new(connection: yamux::Connection>) -> Self { + Self { + inner: Compat01As03::new(connection), + } + } +} + +impl Stream for Listener +where + TSocket: AsyncRead + AsyncWrite + Unpin, +{ + type Item = io::Result>>>; + + fn poll_next(mut self: Pin<&mut Self>, context: &mut Context) -> Poll> { + match Pin::new(&mut self.inner).poll_next(context) { + Poll::Pending => Poll::Pending, + Poll::Ready(None) => Poll::Ready(None), + Poll::Ready(Some(Err(e))) => { + Poll::Ready(Some(Err(io::Error::new(io::ErrorKind::Other, e)))) + } + Poll::Ready(Some(Ok(substream))) => Poll::Ready(Some(Ok(Compat01As03::new(substream)))), + } + } +} + +#[cfg(test)] +mod test { + use crate::multiplexing::{ + yamux::{Mode, Yamux}, + StreamMultiplexer, + }; + use futures::{ + executor::block_on, + future::join, + io::{AsyncReadExt, AsyncWriteExt}, + stream::StreamExt, + }; + use memsocket::MemorySocket; + use std::io; + + #[test] + fn open_substream() -> io::Result<()> { + let (dialer, listener) = MemorySocket::new_pair(); + let msg = b"The Way of Kings"; + + let dialer = async move { + let muxer = Yamux::new(dialer, Mode::Client); + + let mut substream = muxer.open_outbound().await?; + + substream.write_all(msg).await?; + substream.flush().await?; + + // Force return type of the async block + let result: io::Result<()> = Ok(()); + result + }; + + let listener = async move { + let muxer = Yamux::new(listener, Mode::Server); + + let (maybe_substream, _listener) = muxer.listen_for_inbound().into_future().await; + let mut substream = maybe_substream + .ok_or_else(|| io::Error::new(io::ErrorKind::Other, "no substream"))??; + + let mut buf = Vec::new(); + substream.read_to_end(&mut buf).await?; + assert_eq!(buf, msg); + + // Force return type of the async block + let result: io::Result<()> = Ok(()); + result + }; + + let (dialer_result, listener_result) = block_on(join(dialer, listener)); + dialer_result?; + listener_result?; + Ok(()) + } + + #[test] + fn close() -> io::Result<()> { + let (dialer, listener) = MemorySocket::new_pair(); + let msg = b"Words of Radiance"; + + let dialer = async move { + let muxer = Yamux::new(dialer, Mode::Client); + + let mut substream = muxer.open_outbound().await?; + + substream.write_all(msg).await?; + substream.flush().await?; + + let mut buf = Vec::new(); + substream.read_to_end(&mut buf).await?; + assert_eq!(buf, b""); + + // Force return type of the async block + let result: io::Result<()> = Ok(()); + result + }; + + let listener = async move { + let muxer = Yamux::new(listener, Mode::Server); + + let (maybe_substream, _listener) = muxer.listen_for_inbound().into_future().await; + let mut substream = maybe_substream + .ok_or_else(|| io::Error::new(io::ErrorKind::Other, "no substream"))??; + + let mut buf = vec![0; msg.len()]; + substream.read_exact(&mut buf).await?; + assert_eq!(buf, msg); + + // Close the muxer and then try to write to it + muxer.close().await?; + + let result = substream.write_all(b"ignored message").await; + match result { + Ok(()) => panic!("Write should have failed"), + Err(e) => assert_eq!(e.kind(), io::ErrorKind::WriteZero), + } + + // Force return type of the async block + let result: io::Result<()> = Ok(()); + result + }; + + let (dialer_result, listener_result) = block_on(join(dialer, listener)); + dialer_result?; + listener_result?; + Ok(()) + } + + #[test] + fn send_big_message() -> io::Result<()> { + #[allow(non_snake_case)] + let MiB: usize = 1 << 20; + let msg_len = 16 * MiB; + + let (dialer, listener) = MemorySocket::new_pair(); + + let dialer = async move { + let muxer = Yamux::new(dialer, Mode::Client); + let mut substream = muxer.open_outbound().await?; + + let msg = vec![0x55u8; msg_len]; + substream.write_all(msg.as_slice()).await?; + + let mut buf = vec![0u8; msg_len]; + substream.read_exact(&mut buf).await?; + substream.close().await?; + + assert_eq!(buf.len(), msg_len); + assert_eq!(buf, vec![0xAAu8; msg_len]); + + // Force return type of the async block + let result: io::Result> = Ok(muxer); + result + }; + + let listener = async move { + let muxer = Yamux::new(listener, Mode::Server); + let (maybe_substream, _listener) = muxer.listen_for_inbound().into_future().await; + let mut substream = maybe_substream + .ok_or_else(|| io::Error::new(io::ErrorKind::Other, "no substream"))??; + + let mut buf = vec![0u8; msg_len]; + substream.read_exact(&mut buf).await?; + assert_eq!(buf, vec![0x55u8; msg_len]); + + let msg = vec![0xAAu8; msg_len]; + substream.write_all(msg.as_slice()).await?; + substream.close().await?; + + // Force return type of the async block + let result: io::Result> = Ok(muxer); + result + }; + + let (dialer_result, listener_result) = block_on(join(dialer, listener)); + let _ = dialer_result?; + let _ = listener_result?; + Ok(()) + } +} diff --git a/network/netcore/src/negotiate/framing.rs b/network/netcore/src/negotiate/framing.rs new file mode 100644 index 0000000000000..dce345e06f191 --- /dev/null +++ b/network/netcore/src/negotiate/framing.rs @@ -0,0 +1,173 @@ +// Copyright (c) The Libra Core Contributors +// SPDX-License-Identifier: Apache-2.0 + +use crate::utils::Captures; +use bytes::BytesMut; +use futures::{ + future::Future, + io::{AsyncRead, AsyncReadExt, AsyncWrite, AsyncWriteExt}, +}; +use std::{convert::TryInto, io::Result}; + +/// Read a u16 length prefixed frame from `Stream` into `buf`. +pub fn read_u16frame<'stream, 'buf, 'c, TSocket>( + mut stream: &'stream mut TSocket, + buf: &'buf mut BytesMut, +) -> impl Future> + Captures<'stream> + Captures<'buf> + 'c +where + 'stream: 'c, + 'buf: 'c, + TSocket: AsyncRead + AsyncWrite + Unpin, +{ + async move { + let len = read_u16frame_len(&mut stream).await?; + buf.resize(len as usize, 0); + stream.read_exact(buf.as_mut()).await?; + Ok(()) + } +} + +/// Read a u16 (encoded as BE bytes) from `Stream` and return the length. +async fn read_u16frame_len(stream: &mut TSocket) -> Result +where + TSocket: AsyncRead + AsyncWrite + Unpin, +{ + let mut len_buf = [0, 0]; + stream.read_exact(&mut len_buf).await?; + + Ok(u16::from_be_bytes(len_buf)) +} + +/// Write the contents of `buf` to `stream` prefixed with a u16 length. +/// The length of `buf` must be less than or equal to u16::max_value(). +/// +/// Caller is responsible for flushing the write to `stream`. +pub fn write_u16frame<'stream, 'buf, 'c, TSocket>( + mut stream: &'stream mut TSocket, + buf: &'buf [u8], +) -> impl Future> + Captures<'stream> + Captures<'buf> + 'c +where + 'stream: 'c, + 'buf: 'c, + TSocket: AsyncRead + AsyncWrite + Unpin, +{ + async move { + let len = buf + .len() + .try_into() + // TODO Maybe use our own Error Type? + .map_err(|_e| std::io::Error::new(std::io::ErrorKind::Other, "Too big"))?; + write_u16frame_len(&mut stream, len).await?; + stream.write_all(buf).await?; + + Ok(()) + } +} + +/// Write a u16 `len` as BE bytes to `stream`. +/// +/// Caller is responsible for flushing the write to `stream`. +async fn write_u16frame_len(stream: &mut TSocket, len: u16) -> Result<()> +where + TSocket: AsyncRead + AsyncWrite + Unpin, +{ + let len = u16::to_be_bytes(len); + stream.write_all(&len).await?; + + Ok(()) +} + +#[cfg(test)] +mod test { + use super::{read_u16frame, read_u16frame_len, write_u16frame, write_u16frame_len}; + use bytes::BytesMut; + use futures::{executor::block_on, io::AsyncWriteExt}; + use memsocket::MemorySocket; + use std::io::Result; + + #[test] + fn write_read_u16frame_len() -> Result<()> { + let (mut a, mut b) = MemorySocket::new_pair(); + + block_on(write_u16frame_len(&mut a, 17))?; + block_on(a.flush())?; + let len = block_on(read_u16frame_len(&mut b))?; + assert_eq!(len, 17); + + Ok(()) + } + + #[test] + fn read_u16frame_len_eof() -> Result<()> { + let (mut a, mut b) = MemorySocket::new_pair(); + + block_on(a.write_all(&[42]))?; + block_on(a.flush())?; + drop(a); + + let result = block_on(read_u16frame_len(&mut b)); + assert!(result.is_err(), true); + + Ok(()) + } + + #[test] + fn write_u16frame_len_eof() -> Result<()> { + let (mut a, b) = MemorySocket::new_pair(); + drop(b); + + let result = block_on(a.write_all(&[42])); + assert!(result.is_err(), true); + + Ok(()) + } + + #[test] + fn write_read_u16frame() -> Result<()> { + let (mut a, mut b) = MemorySocket::new_pair(); + + block_on(write_u16frame(&mut a, b"The Name of the Wind"))?; + block_on(a.flush())?; + + let mut buf = BytesMut::new(); + block_on(read_u16frame(&mut b, &mut buf))?; + + assert_eq!(buf.as_ref(), b"The Name of the Wind"); + + Ok(()) + } + + #[test] + fn write_read_multiple_u16frames() -> Result<()> { + let (mut a, mut b) = MemorySocket::new_pair(); + + block_on(write_u16frame(&mut a, b"The Name of the Wind"))?; + block_on(write_u16frame(&mut b, b"The Wise Man's Fear"))?; + block_on(b.flush())?; + block_on(write_u16frame(&mut a, b"The Doors of Stone"))?; + block_on(a.flush())?; + + let mut buf = BytesMut::new(); + block_on(read_u16frame(&mut b, &mut buf))?; + assert_eq!(buf.as_ref(), b"The Name of the Wind"); + block_on(read_u16frame(&mut b, &mut buf))?; + assert_eq!(buf.as_ref(), b"The Doors of Stone"); + block_on(read_u16frame(&mut a, &mut buf))?; + assert_eq!(buf.as_ref(), b"The Wise Man's Fear"); + + Ok(()) + } + + #[test] + fn write_large_u16frame() -> Result<()> { + let (mut a, _b) = MemorySocket::new_pair(); + + let mut buf = Vec::new(); + buf.resize((u16::max_value() as usize) * 2, 0); + + let result = block_on(write_u16frame(&mut a, &buf)); + assert!(result.is_err(), true); + + Ok(()) + } +} diff --git a/network/netcore/src/negotiate/inbound.rs b/network/netcore/src/negotiate/inbound.rs new file mode 100644 index 0000000000000..92b0cd56290a0 --- /dev/null +++ b/network/netcore/src/negotiate/inbound.rs @@ -0,0 +1,268 @@ +// Copyright (c) The Libra Core Contributors +// SPDX-License-Identifier: Apache-2.0 + +use crate::{ + negotiate::{ + framing::{read_u16frame, write_u16frame}, + PROTOCOL_INTERACTIVE, PROTOCOL_NOT_SUPPORTED, PROTOCOL_SELECT, + }, + utils::Captures, +}; +use bytes::BytesMut; +use futures::{ + future::Future, + io::{AsyncRead, AsyncWrite, AsyncWriteExt}, +}; +use std::io::Result; + +/// Perform protocol negotiation on an inbound `stream` attempting to match +/// against the provided `supported_protocols`. Protocol negotiation is done +/// using either an interactive or optimistic negotiation, selected by the +/// remote end. +pub async fn negotiate_inbound( + mut stream: TSocket, + supported_protocols: TProtocols, +) -> Result<(TSocket, TProto)> +where + TSocket: AsyncRead + AsyncWrite + Unpin, + TProto: AsRef<[u8]> + Clone, + TProtocols: AsRef<[TProto]>, +{ + let mut buf = BytesMut::new(); + read_u16frame(&mut stream, &mut buf).await?; + + if buf.as_ref() == PROTOCOL_INTERACTIVE { + let selected_proto = + negotiate_inbound_interactive(&mut stream, supported_protocols, buf).await?; + + Ok((stream, selected_proto)) + } else if buf.as_ref() == PROTOCOL_SELECT { + let selected_proto = + negotiate_inbound_select(&mut stream, supported_protocols, buf).await?; + + Ok((stream, selected_proto)) + } else { + // TODO Maybe have our own Error type here? + Err(std::io::Error::new( + std::io::ErrorKind::Other, + "Unable to negotiate protocol - unexpected inbound message", + )) + } +} + +fn negotiate_inbound_interactive<'stream, 'c, TSocket, TProto, TProtocols>( + mut stream: &'stream mut TSocket, + supported_protocols: TProtocols, + mut buf: BytesMut, +) -> impl Future> + Captures<'stream> + 'c +where + 'stream: 'c, + TSocket: AsyncRead + AsyncWrite + Unpin, + TProto: AsRef<[u8]> + Clone, + TProtocols: AsRef<[TProto]> + 'c, +{ + async move { + // ACK that we are speaking PROTOCOL_INTERACTIVE + write_u16frame(&mut stream, PROTOCOL_INTERACTIVE).await?; + stream.flush().await?; + + // We make upto 10 attempts to negotiate a protocol. + for _ in 0..10 { + // Read in the Protocol they want to speak and attempt to match + // it against our supported protocols + read_u16frame(&mut stream, &mut buf).await?; + for proto in supported_protocols.as_ref() { + // Found a match! + if buf.as_ref() == proto.as_ref() { + // Echo back the selected protocol + write_u16frame(&mut stream, proto.as_ref()).await?; + stream.flush().await?; + return Ok(proto.clone()); + } + } + // If the desired protocol doesn't match any of our supported + // ones then send PROTOCOL_NOT_SUPPORTED + write_u16frame(&mut stream, PROTOCOL_NOT_SUPPORTED).await?; + stream.flush().await?; + } + + Err(std::io::Error::new( + std::io::ErrorKind::Other, + "Unable to negotiate protocol - all attempts failed", + )) + } +} + +fn negotiate_inbound_select<'stream, 'c, TSocket, TProto, TProtocols>( + mut stream: &'stream mut TSocket, + supported_protocols: TProtocols, + mut buf: BytesMut, +) -> impl Future> + Captures<'stream> + 'c +where + 'stream: 'c, + TSocket: AsyncRead + AsyncWrite + Unpin, + TProto: AsRef<[u8]> + Clone, + TProtocols: AsRef<[TProto]> + 'c, +{ + async move { + // Read in the Protocol they want to speak and attempt to match + // it against our supported protocols + read_u16frame(&mut stream, &mut buf).await?; + for proto in supported_protocols.as_ref() { + // Found a match! + if buf.as_ref() == proto.as_ref() { + return Ok(proto.clone()); + } + } + + Err(std::io::Error::new( + std::io::ErrorKind::Other, + "Unable to negotiate Protocol - protocol not supported", + )) + } +} + +#[cfg(test)] +mod test { + use crate::negotiate::{ + framing::{read_u16frame, write_u16frame}, + inbound::{negotiate_inbound_interactive, negotiate_inbound_select}, + PROTOCOL_INTERACTIVE, PROTOCOL_NOT_SUPPORTED, + }; + use bytes::BytesMut; + use futures::{executor::block_on, future::join, io::AsyncWriteExt}; + use memsocket::MemorySocket; + use std::io::Result; + + #[test] + fn test_negotiate_inbound_interactive() -> Result<()> { + let (mut a, mut b) = MemorySocket::new_pair(); + let test_protocol = b"/hello/1.0.0"; + + let outbound = async move { + write_u16frame(&mut a, test_protocol).await?; + a.flush().await?; + + let mut buf = BytesMut::new(); + read_u16frame(&mut a, &mut buf).await?; + assert_eq!(buf.as_ref(), PROTOCOL_INTERACTIVE); + read_u16frame(&mut a, &mut buf).await?; + assert_eq!(buf.as_ref(), test_protocol); + + // Force return type of the async block + let result: Result<()> = Ok(()); + result + }; + + let buf = BytesMut::new(); + let inbound = negotiate_inbound_interactive(&mut b, [test_protocol], buf); + + let (result_outbound, result_inbound) = block_on(join(outbound, inbound)); + assert_eq!(result_outbound.is_ok(), true); + assert_eq!(result_inbound?, test_protocol); + + Ok(()) + } + + #[test] + fn test_negotiate_inbound_interactive_unsupported() -> Result<()> { + let (mut a, mut b) = MemorySocket::new_pair(); + let protocol_supported = b"/hello/1.0.0"; + let protocol_unsupported = b"/hello/2.0.0"; + + let outbound = async move { + write_u16frame(&mut a, protocol_unsupported).await?; + a.flush().await?; + + let mut buf = BytesMut::new(); + read_u16frame(&mut a, &mut buf).await?; + assert_eq!(buf.as_ref(), PROTOCOL_INTERACTIVE); + read_u16frame(&mut a, &mut buf).await?; + assert_eq!(buf.as_ref(), PROTOCOL_NOT_SUPPORTED); + + write_u16frame(&mut a, protocol_supported).await?; + a.flush().await?; + + read_u16frame(&mut a, &mut buf).await?; + assert_eq!(buf.as_ref(), protocol_supported); + + // Force return type of the async block + let result: Result<()> = Ok(()); + result + }; + + let buf = BytesMut::new(); + let inbound = negotiate_inbound_interactive(&mut b, [protocol_supported], buf); + + let (result_outbound, result_inbound) = block_on(join(outbound, inbound)); + assert_eq!(result_outbound.is_ok(), true); + assert_eq!(result_inbound?, protocol_supported); + + Ok(()) + } + + #[test] + fn test_negotiate_inbound_select() -> Result<()> { + let (mut a, mut b) = MemorySocket::new_pair(); + let test_protocol = b"/hello/1.0.0"; + let hello_request = b"Hello World!"; + + let outbound = async move { + write_u16frame(&mut a, test_protocol).await?; + a.flush().await?; + + let mut buf = BytesMut::new(); + read_u16frame(&mut a, &mut buf).await?; + assert_eq!(buf.as_ref(), hello_request); + + // Force return type of the async block + let result: Result<()> = Ok(()); + result + }; + + let inbound = async move { + let buf = BytesMut::new(); + let selected_proto = negotiate_inbound_select(&mut b, [test_protocol], buf).await?; + assert_eq!(selected_proto, test_protocol); + + write_u16frame(&mut b, hello_request).await?; + b.flush().await?; + + // Force return type of the async block + let result: Result<()> = Ok(()); + result + }; + + let (result_outbound, result_inbound) = block_on(join(outbound, inbound)); + assert_eq!(result_outbound.is_ok(), true); + assert_eq!(result_inbound.is_ok(), true); + + Ok(()) + } + + #[test] + fn test_negotiate_inbound_select_unsupported() -> Result<()> { + let (mut a, mut b) = MemorySocket::new_pair(); + let protocol_supported = b"/hello/1.0.0"; + let protocol_unsupported = b"/hello/2.0.0"; + + let outbound = async move { + write_u16frame(&mut a, protocol_unsupported).await?; + a.flush().await?; + + let mut buf = BytesMut::new(); + read_u16frame(&mut a, &mut buf).await + }; + + let inbound = async move { + let buf = BytesMut::new(); + negotiate_inbound_select(&mut b, [protocol_supported], buf).await + }; + + let (result_outbound, result_inbound) = block_on(join(outbound, inbound)); + assert_eq!(result_outbound.is_err(), true); + assert_eq!(result_inbound.is_err(), true); + + Ok(()) + } +} diff --git a/network/netcore/src/negotiate/mod.rs b/network/netcore/src/negotiate/mod.rs new file mode 100644 index 0000000000000..6889bae88b463 --- /dev/null +++ b/network/netcore/src/negotiate/mod.rs @@ -0,0 +1,22 @@ +// Copyright (c) The Libra Core Contributors +// SPDX-License-Identifier: Apache-2.0 + +//! Protocol negotiation on AsyncRead/AsyncWrite streams +//! +//! Upgrading a stream to a particular protocol can be done either using 'protocol-interactive' or +//! 'protocol-select', both of which use u16 length prefix framing. + +mod framing; +mod inbound; +mod outbound; +#[cfg(test)] +mod test; + +pub use self::{ + inbound::negotiate_inbound, + outbound::{negotiate_outbound_interactive, negotiate_outbound_select}, +}; + +static PROTOCOL_INTERACTIVE: &'static [u8] = b"/libra/protocol-interactive/1.0.0"; +static PROTOCOL_SELECT: &'static [u8] = b"/libra/protocol-select/1.0.0"; +static PROTOCOL_NOT_SUPPORTED: &'static [u8] = b"not supported"; diff --git a/network/netcore/src/negotiate/outbound.rs b/network/netcore/src/negotiate/outbound.rs new file mode 100644 index 0000000000000..bb84cd8260208 --- /dev/null +++ b/network/netcore/src/negotiate/outbound.rs @@ -0,0 +1,264 @@ +// Copyright (c) The Libra Core Contributors +// SPDX-License-Identifier: Apache-2.0 + +use crate::negotiate::{ + framing::{read_u16frame, write_u16frame}, + PROTOCOL_INTERACTIVE, PROTOCOL_NOT_SUPPORTED, PROTOCOL_SELECT, +}; +use bytes::BytesMut; +use futures::io::{AsyncRead, AsyncWrite, AsyncWriteExt}; +use std::io::Result; + +/// Perform protocol negotiation on an outbound `stream` using interactive +/// negotiation, waiting for the remote end to send back and ACK on the agreed +/// upon protocol. The protocols provided in `supported_protocols` are tried in +/// order, preferring protocols first in the provided list. +pub async fn negotiate_outbound_interactive( + mut stream: TSocket, + supported_protocols: TProtocols, +) -> Result<(TSocket, TProto)> +where + TSocket: AsyncRead + AsyncWrite + Unpin, + TProto: AsRef<[u8]> + Clone, + TProtocols: AsRef<[TProto]>, +{ + write_u16frame(&mut stream, PROTOCOL_INTERACTIVE).await?; + + let mut buf = BytesMut::new(); + let mut recieved_header_ack = false; + for proto in supported_protocols.as_ref() { + write_u16frame(&mut stream, proto.as_ref()).await?; + stream.flush().await?; + + // Read the ACK that we're speaking PROTOCOL_INTERACTIVE if we still haven't done so. + // Note that we do this after sending the first protocol id, allowing for the negotiation to + // happen in a single round trip in case the remote node indeed speaks our preferred + // protocol. + if !recieved_header_ack { + read_u16frame(&mut stream, &mut buf).await?; + if buf.as_ref() != PROTOCOL_INTERACTIVE { + // Remote side doesn't understand us, give up + return Err(std::io::Error::new( + std::io::ErrorKind::Other, + "Unable to negotiate protocol - PROTOCOL_INTERACTIVE not supported", + )); + } + + recieved_header_ack = true; + } + + read_u16frame(&mut stream, &mut buf).await?; + + if buf.as_ref() == proto.as_ref() { + // We received an ACK on the protocol! + return Ok((stream, proto.clone())); + } else if buf.as_ref() != PROTOCOL_NOT_SUPPORTED { + // We received an unexpected message from the remote, give up + return Err(std::io::Error::new( + std::io::ErrorKind::Other, + "Unable to negotiate protocol - unexpected interactive response", + )); + } + } + + // We weren't able to find a matching protocol, give up + Err(std::io::Error::new( + std::io::ErrorKind::Other, + "Unable to negotiate protocol - no matching protocol", + )) +} + +/// Perform an optimistic protocol negotiation on `stream` using the provided +/// `protocol`. +/// +/// The negotiation frames are only enqueued and not yet flushed (assuming the +/// underlying transport is buffered). It's up to the protocol that handles this +/// new outbound substream to decide when it should flush these frames. +pub async fn negotiate_outbound_select( + mut stream: TSocket, + protocol: TProto, +) -> Result +where + TSocket: AsyncRead + AsyncWrite + Unpin, + TProto: AsRef<[u8]> + Clone, +{ + write_u16frame(&mut stream, PROTOCOL_SELECT).await?; + // Write out the protocol we're optimistically selecting and return + write_u16frame(&mut stream, protocol.as_ref()).await?; + // We do not wait for any ACK from the listener. This is OK because in case the listener does + // not want to speak this protocol, it can simply close the stream leading to an upstream + // failure in the dialer when it tries to read/write to the stream. + Ok(stream) +} + +#[cfg(test)] +mod test { + use crate::negotiate::{ + framing::{read_u16frame, write_u16frame}, + outbound::{negotiate_outbound_interactive, negotiate_outbound_select}, + PROTOCOL_INTERACTIVE, PROTOCOL_NOT_SUPPORTED, PROTOCOL_SELECT, + }; + use bytes::BytesMut; + use futures::{executor::block_on, future::join, io::AsyncWriteExt}; + use memsocket::MemorySocket; + use std::io::Result; + + #[test] + fn test_negotiate_outbound_interactive() -> Result<()> { + let (a, mut b) = MemorySocket::new_pair(); + let test_protocol = b"/hello/1.0.0"; + + let outbound = async move { + let (_stream, proto) = negotiate_outbound_interactive(a, [test_protocol]).await?; + + assert_eq!(proto, test_protocol); + + // Force return type of the async block + let result: Result<()> = Ok(()); + result + }; + + let inbound = async move { + let mut buf = BytesMut::new(); + read_u16frame(&mut b, &mut buf).await?; + assert_eq!(buf.as_ref(), PROTOCOL_INTERACTIVE); + read_u16frame(&mut b, &mut buf).await?; + assert_eq!(buf.as_ref(), test_protocol); + + write_u16frame(&mut b, PROTOCOL_INTERACTIVE).await?; + write_u16frame(&mut b, test_protocol).await?; + b.flush().await?; + + // Force return type of the async block + let result: Result<()> = Ok(()); + result + }; + + let (result_outbound, result_inbound) = block_on(join(outbound, inbound)); + assert_eq!(result_outbound.is_ok(), true); + assert_eq!(result_inbound.is_ok(), true); + + Ok(()) + } + + #[test] + fn test_negotiate_outbound_interactive_unsupported() -> Result<()> { + let (a, mut b) = MemorySocket::new_pair(); + let protocol_supported = b"/hello/1.0.0"; + let protocol_unsupported = b"/hello/2.0.0"; + + let outbound = async move { + let (_stream, proto) = + negotiate_outbound_interactive(a, [protocol_unsupported, protocol_supported]) + .await?; + + assert_eq!(proto, protocol_supported); + + // Force return type of the async block + let result: Result<()> = Ok(()); + result + }; + + let inbound = async move { + let mut buf = BytesMut::new(); + read_u16frame(&mut b, &mut buf).await?; + assert_eq!(buf.as_ref(), PROTOCOL_INTERACTIVE); + read_u16frame(&mut b, &mut buf).await?; + assert_eq!(buf.as_ref(), protocol_unsupported); + + write_u16frame(&mut b, PROTOCOL_INTERACTIVE).await?; + write_u16frame(&mut b, PROTOCOL_NOT_SUPPORTED).await?; + b.flush().await?; + + read_u16frame(&mut b, &mut buf).await?; + assert_eq!(buf.as_ref(), protocol_supported); + + write_u16frame(&mut b, protocol_supported).await?; + b.flush().await?; + + // Force return type of the async block + let result: Result<()> = Ok(()); + result + }; + + let (result_outbound, result_inbound) = block_on(join(outbound, inbound)); + assert_eq!(result_outbound.is_ok(), true); + assert_eq!(result_inbound.is_ok(), true); + + Ok(()) + } + + #[test] + fn test_negotiate_outbound_select() -> Result<()> { + let (a, mut b) = MemorySocket::new_pair(); + let test_protocol = b"/hello/1.0.0"; + let hello_request = b"Hello World!"; + + let outbound = async move { + let mut stream = negotiate_outbound_select(a, test_protocol).await?; + + write_u16frame(&mut stream, hello_request).await?; + stream.flush().await?; + + // Force return type of the async block + let result: Result<()> = Ok(()); + result + }; + + let inbound = async move { + let mut buf = BytesMut::new(); + read_u16frame(&mut b, &mut buf).await?; + assert_eq!(buf.as_ref(), PROTOCOL_SELECT); + read_u16frame(&mut b, &mut buf).await?; + assert_eq!(buf.as_ref(), test_protocol); + + read_u16frame(&mut b, &mut buf).await?; + assert_eq!(buf.as_ref(), hello_request); + + // Force return type of the async block + let result: Result<()> = Ok(()); + result + }; + + let (result_outbound, result_inbound) = block_on(join(outbound, inbound)); + assert!(result_outbound.is_ok()); + assert!(result_inbound.is_ok()); + + Ok(()) + } + + #[test] + fn test_negotiate_outbound_select_unsupported() -> Result<()> { + let (a, mut b) = MemorySocket::new_pair(); + let protocol_unsupported = b"/hello/2.0.0"; + + let outbound = async move { + let mut stream = negotiate_outbound_select(a, protocol_unsupported).await?; + stream.flush().await?; + + let mut buf = BytesMut::new(); + read_u16frame(&mut stream, &mut buf).await + }; + + let inbound = async move { + let mut buf = BytesMut::new(); + read_u16frame(&mut b, &mut buf).await?; + assert_eq!(buf.as_ref(), PROTOCOL_SELECT); + read_u16frame(&mut b, &mut buf).await?; + assert_eq!(buf.as_ref(), protocol_unsupported); + + // Just drop b to signle that the upgrade failed + drop(b); + + // Force return type of the async block + let result: Result<()> = Ok(()); + result + }; + + let (result_outbound, result_inbound) = block_on(join(outbound, inbound)); + assert_eq!(result_outbound.is_err(), true); + assert_eq!(result_inbound.is_ok(), true); + + Ok(()) + } +} diff --git a/network/netcore/src/negotiate/test.rs b/network/netcore/src/negotiate/test.rs new file mode 100644 index 0000000000000..4ba934cc3dcec --- /dev/null +++ b/network/netcore/src/negotiate/test.rs @@ -0,0 +1,62 @@ +// Copyright (c) The Libra Core Contributors +// SPDX-License-Identifier: Apache-2.0 + +//! Integration tests for Protocol negotiation + +use crate::negotiate::{ + inbound::negotiate_inbound, + outbound::{negotiate_outbound_interactive, negotiate_outbound_select}, +}; +use futures::{executor::block_on, future::join}; +use memsocket::MemorySocket; +use std::io::Result; + +#[test] +fn interactive_negotiation() -> Result<()> { + let (a, b) = MemorySocket::new_pair(); + let test_protocol = b"/hello/1.0.0"; + + let outbound = async move { + let (_stream, proto) = negotiate_outbound_interactive(a, [test_protocol]).await?; + assert_eq!(proto, test_protocol); + // Force return type of the async block + let result: Result<()> = Ok(()); + result + }; + let inbound = async move { + let vec: Vec<&'static [u8]> = vec![b"some", b"stuff", test_protocol]; + let (_stream, proto) = negotiate_inbound(b, &vec).await?; + assert_eq!(proto, test_protocol); + // Force return type of the async block + let result: Result<()> = Ok(()); + result + }; + + let (result_outbound, result_inbound) = block_on(join(outbound, inbound)); + assert!(result_outbound.is_ok()); + assert!(result_inbound.is_ok()); + + Ok(()) +} + +#[test] +fn optimistic_negotiation() -> Result<()> { + let (a, b) = MemorySocket::new_pair(); + let test_protocol = b"/hello/1.0.0"; + + let outbound = negotiate_outbound_select(a, test_protocol); + let inbound = async move { + let vec: Vec<&'static [u8]> = vec![b"some", b"stuff", test_protocol]; + let (_stream, proto) = negotiate_inbound(b, &vec).await?; + assert_eq!(proto, test_protocol); + // Force return type of the async block + let result: Result<()> = Ok(()); + result + }; + + let (result_outbound, result_inbound) = block_on(join(outbound, inbound)); + assert!(result_outbound.is_ok()); + assert!(result_inbound.is_ok()); + + Ok(()) +} diff --git a/network/netcore/src/transport/and_then.rs b/network/netcore/src/transport/and_then.rs new file mode 100644 index 0000000000000..20438fc8249ff --- /dev/null +++ b/network/netcore/src/transport/and_then.rs @@ -0,0 +1,194 @@ +// Copyright (c) The Libra Core Contributors +// SPDX-License-Identifier: Apache-2.0 + +use crate::transport::{ConnectionOrigin, Transport}; +use futures::{future::Future, stream::Stream}; +use parity_multiaddr::Multiaddr; +use pin_utils::{unsafe_pinned, unsafe_unpinned}; +use std::{ + pin::Pin, + task::{Context, Poll}, +}; + +/// An [`AndThen`] is a transport which applies a closure (F) to all connections created by the +/// underlying transport. +pub struct AndThen { + transport: T, + function: F, +} + +impl AndThen { + pub(crate) fn new(transport: T, function: F) -> Self { + Self { + transport, + function, + } + } +} + +impl Transport for AndThen +where + T: Transport, + F: FnOnce(T::Output, ConnectionOrigin) -> Fut + Send + Unpin + Clone, + // Pin the error types to be the same for now + // TODO don't require the error types to be the same + Fut: Future> + Send, +{ + type Output = O; + type Error = T::Error; + type Listener = AndThenStream; + type Inbound = AndThenFuture; + type Outbound = AndThenFuture; + + fn listen_on(&self, addr: Multiaddr) -> Result<(Self::Listener, Multiaddr), Self::Error> { + let (listener, addr) = self.transport.listen_on(addr)?; + let listener = AndThenStream::new(listener, self.function.clone()); + + Ok((listener, addr)) + } + + fn dial(&self, addr: Multiaddr) -> Result { + let fut = self.transport.dial(addr)?; + let origin = ConnectionOrigin::Outbound; + let f = self.function.clone(); + + Ok(AndThenFuture::new(fut, f, origin)) + } +} + +/// Listener stream returned by [listen_on](Transport::listen_on) on an AndThen transport. +#[derive(Debug)] +#[must_use = "streams do nothing unless polled"] +pub struct AndThenStream { + stream: St, + f: F, +} + +impl AndThenStream +where + St: Stream>, + Fut1: Future>, + Fut2: Future>, + F: FnOnce(O1, ConnectionOrigin) -> Fut2 + Clone, + E: ::std::error::Error, +{ + // This use of `unsafe_pinned` is safe because: + // 1. This struct does not implement [`Drop`] + // 2. This struct does not implement [`Unpin`] + // 3. This struct is not `#[repr(packed)]` + unsafe_pinned!(stream: St); + + fn new(stream: St, f: F) -> Self { + Self { stream, f } + } +} + +impl Stream for AndThenStream +where + St: Stream>, + Fut1: Future>, + Fut2: Future>, + F: FnOnce(O1, ConnectionOrigin) -> Fut2 + Clone, + E: ::std::error::Error, +{ + type Item = Result<(AndThenFuture, Multiaddr), E>; + + fn poll_next(mut self: Pin<&mut Self>, context: &mut Context) -> Poll> { + match self.as_mut().stream().poll_next(context) { + Poll::Pending => Poll::Pending, + Poll::Ready(None) => Poll::Ready(None), + Poll::Ready(Some(Err(e))) => Poll::Ready(Some(Err(e))), + Poll::Ready(Some(Ok((fut1, addr)))) => Poll::Ready(Some(Ok(( + AndThenFuture::new(fut1, self.f.clone(), ConnectionOrigin::Inbound), + addr, + )))), + } + } +} + +#[derive(Debug)] +enum AndThenChain { + First(Fut1, Option<(F, ConnectionOrigin)>), + Second(Fut2), + Empty, +} + +/// Future generated by the [`AndThen`] transport. +/// +/// Takes a future (Fut1) generated from an underlying transport, runs it to completion and applies +/// a closure (F) to the result to create another future (Fut2) which is then run to completion. +#[derive(Debug)] +#[must_use = "futures do nothing unless polled"] +pub struct AndThenFuture { + chain: AndThenChain, +} + +impl AndThenFuture +where + Fut1: Future>, + Fut2: Future>, + F: FnOnce(O1, ConnectionOrigin) -> Fut2, + E: ::std::error::Error, +{ + // Ideally we'd want to use `unsafe_pinned` to get a pinned version of the `AndThenChain`, + // unfortunately a Pin<&mut AndThenChain> doesn't let us construct Pin<&mut Fut> pins for the + // interior futures stored in the enum variants; as such `unsafe_unpinned` is used instead with + // great caution: + // + // 1. This struct does not implement [`Drop`] + // 2. This struct does not implement [`Unpin`] + // 3. This struct is not `#[repr(packed)]` + // 4. We take care to never move `chain` or its interior Futures + // 5. When transitioning from First to Second state we first ensure that the `drop` method is + // called on the future stored in First prior to advancing to Second. + unsafe_unpinned!(chain: AndThenChain); + + fn new(fut1: Fut1, f: F, origin: ConnectionOrigin) -> Self { + Self { + chain: AndThenChain::First(fut1, Some((f, origin))), + } + } +} + +// Inspired by: https://github.com/rust-lang-nursery/futures-rs/blob/master/futures-util/src/future/chain.rs +impl Future for AndThenFuture +where + Fut1: Future>, + Fut2: Future>, + F: FnOnce(O1, ConnectionOrigin) -> Fut2, + E: ::std::error::Error, +{ + type Output = Result; + + fn poll(mut self: Pin<&mut Self>, mut context: &mut Context) -> Poll { + loop { + let (output, (f, origin)) = match self.as_mut().chain() { + // Step 1: Drive Fut1 to completion + AndThenChain::First(fut1, data) => { + // Safe to construct a Pin of the interior future because + // `self` is pinned (and therefor `chain` is pinned). + match unsafe { Pin::new_unchecked(fut1) }.poll(&mut context) { + Poll::Pending => return Poll::Pending, + Poll::Ready(Err(e)) => return Poll::Ready(Err(e)), + Poll::Ready(Ok(output)) => { + (output, data.take().expect("must be initialized")) + } + } + } + // Step 4: Drive Fut2 to completion + AndThenChain::Second(fut2) => { + // Safe to construct a Pin of the interior future because + // `self` is pinned (and therefor `chain` is pinned). + return unsafe { Pin::new_unchecked(fut2) }.poll(&mut context); + } + AndThenChain::Empty => unreachable!(), + }; + + // Step 2: Ensure that Fut1 is dropped + *self.as_mut().chain() = AndThenChain::Empty; + // Step 3: Run F on the output of Fut1 to create Fut2 + let fut2 = f(output, origin); + *self.as_mut().chain() = AndThenChain::Second(fut2) + } + } +} diff --git a/network/netcore/src/transport/boxed.rs b/network/netcore/src/transport/boxed.rs new file mode 100644 index 0000000000000..a4db26162d9f1 --- /dev/null +++ b/network/netcore/src/transport/boxed.rs @@ -0,0 +1,81 @@ +// Copyright (c) The Libra Core Contributors +// SPDX-License-Identifier: Apache-2.0 + +use crate::transport::Transport; +use futures::{ + future::{Future, FutureExt}, + stream::{Stream, StreamExt}, +}; +use parity_multiaddr::Multiaddr; +use std::pin::Pin; + +pub type Listener = Pin, Multiaddr), E>> + Send>>; +pub type Inbound = Pin> + Send>>; +pub type Outbound = Pin> + Send>>; + +trait AbstractBoxedTransport { + fn listen_on(&self, addr: Multiaddr) -> Result<(Listener, Multiaddr), E>; + fn dial(&self, addr: Multiaddr) -> Result, E>; +} + +impl AbstractBoxedTransport for T +where + T: Transport + Send + 'static, + T::Listener: Send + 'static, + T::Inbound: Send + 'static, + T::Outbound: Send + 'static, + E: ::std::error::Error + Send + Sync + 'static, +{ + fn listen_on(&self, addr: Multiaddr) -> Result<(Listener, Multiaddr), E> { + let (listener, addr) = self.listen_on(addr)?; + let listener = listener + .map(|result| result.map(|(incoming, addr)| (incoming.boxed() as Inbound, addr))); + Ok((listener.boxed() as Listener, addr)) + } + + fn dial(&self, addr: Multiaddr) -> Result, E> { + let outgoing = self.dial(addr)?; + Ok(outgoing.boxed() as Outbound) + } +} + +/// See the [boxed](crate::transport::TransportExt::boxed) method for more information. +pub struct BoxedTransport { + inner: Box + Send + 'static>, +} + +impl BoxedTransport +where + E: ::std::error::Error + Send + Sync + 'static, +{ + pub(crate) fn new(transport: T) -> Self + where + T: Transport + Send + 'static, + T::Listener: Send + 'static, + T::Inbound: Send + 'static, + T::Outbound: Send + 'static, + { + Self { + inner: Box::new(transport) as Box<_>, + } + } +} + +impl Transport for BoxedTransport +where + E: ::std::error::Error + Send + Sync + 'static, +{ + type Output = O; + type Error = E; + type Listener = Listener; + type Inbound = Inbound; + type Outbound = Outbound; + + fn listen_on(&self, addr: Multiaddr) -> Result<(Self::Listener, Multiaddr), Self::Error> { + self.inner.listen_on(addr) + } + + fn dial(&self, addr: Multiaddr) -> Result { + self.inner.dial(addr) + } +} diff --git a/network/netcore/src/transport/memory.rs b/network/netcore/src/transport/memory.rs new file mode 100644 index 0000000000000..7d781eef93404 --- /dev/null +++ b/network/netcore/src/transport/memory.rs @@ -0,0 +1,136 @@ +// Copyright (c) The Libra Core Contributors +// SPDX-License-Identifier: Apache-2.0 + +use crate::transport::Transport; +use futures::{future, stream::Stream}; +use memsocket::{MemoryListener, MemorySocket}; +use parity_multiaddr::{Multiaddr, Protocol}; +use std::{ + io, + pin::Pin, + task::{Context, Poll}, +}; + +/// Transport to build in-memory connections +#[derive(Debug, Default)] +pub struct MemoryTransport; + +impl Transport for MemoryTransport { + type Output = MemorySocket; + type Error = io::Error; + type Listener = Listener; + type Inbound = future::Ready>; + type Outbound = future::Ready>; + + fn listen_on(&self, addr: Multiaddr) -> Result<(Self::Listener, Multiaddr), Self::Error> { + let port = parse_addr(&addr)?; + let listener = MemoryListener::bind(port)?; + let actual_port = listener.local_addr(); + let mut actual_addr = Multiaddr::empty(); + actual_addr.push(Protocol::Memory(u64::from(actual_port))); + + Ok((Listener { inner: listener }, actual_addr)) + } + + fn dial(&self, addr: Multiaddr) -> Result { + let port = parse_addr(&addr)?; + let socket = MemorySocket::connect(port)?; + Ok(future::ready(Ok(socket))) + } +} + +fn parse_addr(addr: &Multiaddr) -> io::Result { + let mut iter = addr.iter(); + + let port = if let Some(Protocol::Memory(port)) = iter.next() { + port + } else { + return Err(io::Error::new( + io::ErrorKind::InvalidInput, + format!("Invalid Multiaddr '{:?}'", addr), + )); + }; + + if iter.next().is_some() { + return Err(io::Error::new( + io::ErrorKind::InvalidInput, + format!("Invalid Multiaddr '{:?}'", addr), + )); + } + + Ok(port as u16) +} + +#[must_use = "streams do nothing unless polled"] +#[derive(Debug)] +pub struct Listener { + inner: MemoryListener, +} + +impl Stream for Listener { + type Item = io::Result<(future::Ready>, Multiaddr)>; + + fn poll_next(mut self: Pin<&mut Self>, context: &mut Context) -> Poll> { + let mut incoming = self.inner.incoming(); + match Pin::new(&mut incoming).poll_next(context) { + Poll::Ready(Some(Ok(socket))) => { + // Dialer addresses for MemoryTransport don't make a ton of sense, + // so use port 0 to ensure they aren't used as an address to dial. + let dialer_addr = Protocol::Memory(0).into(); + Poll::Ready(Some(Ok((future::ready(Ok(socket)), dialer_addr)))) + } + Poll::Ready(Some(Err(e))) => Poll::Ready(Some(Err(e))), + Poll::Ready(None) => Poll::Ready(None), + Poll::Pending => Poll::Pending, + } + } +} + +#[cfg(test)] +mod test { + use crate::transport::{memory::MemoryTransport, Transport}; + use futures::{ + executor::block_on, + future::join, + io::{AsyncReadExt, AsyncWriteExt}, + stream::StreamExt, + }; + + #[test] + fn simple_listen_and_dial() -> Result<(), ::std::io::Error> { + let t = MemoryTransport::default(); + + let (listener, addr) = t.listen_on("/memory/0".parse().unwrap())?; + + let listener = async move { + let (item, _listener) = listener.into_future().await; + let (inbound, _addr) = item.unwrap().unwrap(); + let mut socket = inbound.await.unwrap(); + + let mut buf = Vec::new(); + socket.read_to_end(&mut buf).await.unwrap(); + assert_eq!(buf, b"hello world"); + }; + let outbound = t.dial(addr)?; + + let dialer = async move { + let mut socket = outbound.await.unwrap(); + socket.write_all(b"hello world").await.unwrap(); + socket.flush().await.unwrap(); + }; + + block_on(join(dialer, listener)); + Ok(()) + } + + #[test] + fn unsupported_multiaddrs() { + let t = MemoryTransport::default(); + + let result = t.listen_on("/ip4/127.0.0.1/tcp/0".parse().unwrap()); + assert!(result.is_err()); + + let result = t.dial("/ip4/127.0.0.1/tcp/22".parse().unwrap()); + assert!(result.is_err()); + } +} diff --git a/network/netcore/src/transport/mod.rs b/network/netcore/src/transport/mod.rs new file mode 100644 index 0000000000000..8949f408ae915 --- /dev/null +++ b/network/netcore/src/transport/mod.rs @@ -0,0 +1,143 @@ +// Copyright (c) The Libra Core Contributors +// SPDX-License-Identifier: Apache-2.0 + +//! Low-level module for establishing connections with peers +//! +//! The main component of this module is the [`Transport`] trait, which provides an interface for +//! establishing both inbound and outbound connections with remote peers. The [`TransportExt`] +//! trait contains a variety of combinators for modifying a transport allowing composability and +//! layering of additional transports or protocols. +//! +//! [`Transport`]: crate::transport::Transport +//! [`TransportExt`]: crate::transport::TransportExt + +use futures::{future::Future, stream::Stream}; +use parity_multiaddr::Multiaddr; +use std::time::Duration; + +pub mod and_then; +pub mod boxed; +pub mod memory; +pub mod tcp; +pub mod timeout; + +/// Origin of how a Connection was established. +#[derive(Clone, Copy, Debug, PartialEq, Eq)] +pub enum ConnectionOrigin { + /// `Inbound` indicates that we are the listener for this connection. + Inbound, + /// `Outbound` indicates that we are the dialer for this connection. + Outbound, +} + +/// A Transport is responsible for establishing connections with remote Peers. +/// +/// Connections are established either by [listening](Transport::listen_on) +/// or [dialing](Transport::dial) on a [`Transport`]. A peer that +/// obtains a connection by listening is often referred to as the *listener* and the +/// peer that initiated the connection through dialing as the *dialer*. +/// +/// Additional protocols can be layered on top of the connections established +/// by a [`Transport`] through utilizing the combinators in the [`TransportExt`] trait. +pub trait Transport { + /// The result of establishing a connection. + /// + /// Generally this would include a socket-like streams which allows for sending and receiving + /// of data through the connection. + type Output; + + /// The Error type of errors which can happen while establishing a connection. + type Error: ::std::error::Error + Send + Sync + 'static; + + /// A stream of [`Inbound`](Transport::Inbound) connections and the address of the dialer. + /// + /// An item should be produced whenever a connection is received at the lowest level of the + /// transport stack. Each item is an [`Inbound`](Transport::Inbound) future + /// that resolves to an [`Output`](Transport::Output) value once all protocol upgrades + /// have been applied. + type Listener: Stream> + Send + Unpin; + + /// A pending [`Output`](Transport::Output) for an inbound connection, + /// obtained from the [`Listener`](Transport::Listener) stream. + /// + /// After a connection has been accepted by the transport, it may need to go through + /// asynchronous post-processing (i.e. protocol upgrade negotiations). Such + /// post-processing should not block the `Listener` from producing the next + /// connection, hence further connection setup proceeds asynchronously. + /// Once a `Inbound` future resolves it yields the [`Output`](Transport::Output) + /// of the connection setup process. + type Inbound: Future> + Send; + + /// A pending [`Output`](Transport::Output) for an outbound connection, + /// obtained from [dialing](Transport::dial) stream. + type Outbound: Future> + Send; + + /// Listens on the given [`Multiaddr`], returning a stream of incoming connections. + /// + /// The returned [`Multiaddr`] is the actual listening address, this is done to take into + /// account OS-assigned port numbers (e.g. listening on port 0). + fn listen_on(&self, addr: Multiaddr) -> Result<(Self::Listener, Multiaddr), Self::Error> + where + Self: Sized; + + /// Dials the given [`Multiaddr`], returning a future for a pending outbound connection. + fn dial(&self, addr: Multiaddr) -> Result + where + Self: Sized; +} + +impl TransportExt for T where T: Transport {} + +/// An extension trait for [`Transport`]s that provides a variety of convenient +/// combinators. +/// +/// Additional protocols or functionality can be layered on top of an existing +/// [`Transport`] by using this extension trait. For example, one might want to +/// take a raw connection and upgrade it to a secure transport followed by a +/// stream multiplexer by chaining calls to [`and_then`](TransportExt::and_then). +/// Each method yields a new [`Transport`] whose connection setup incorporates +/// all earlier upgrades followed by the new upgrade, i.e. the order of the +/// upgrades is significant. +pub trait TransportExt: Transport { + /// Turns a [`Transport`] into an abstract boxed transport. + fn boxed(self) -> boxed::BoxedTransport + where + Self: Sized + Send + 'static, + Self::Listener: Send + 'static, + Self::Inbound: Send + 'static, + Self::Outbound: Send + 'static, + { + boxed::BoxedTransport::new(self) + } + + /// Applies a function producing an asynchronous result to every connection + /// created by this transport. + /// + /// This function can be used for ad-hoc protocol upgrades on a transport + /// or for processing or adapting the output of an earlier upgrade. The + /// provided function must take as input the output from the existing + /// transport and a [`ConnectionOrigin`] which can be used to identify the + /// origin of the connection (inbound vs outbound). + fn and_then(self, f: F) -> and_then::AndThen + where + Self: Sized, + F: FnOnce(Self::Output, ConnectionOrigin) -> Fut + Clone, + // Pin the error types to be the same for now + // TODO don't require the error types to be the same + Fut: Future>, + { + and_then::AndThen::new(self, f) + } + + /// Wraps a [`Transport`] with a timeout to the + /// [Inbound](Transport::Inbound) and [Outbound](Transport::Outbound) + /// connection futures. + /// + /// Note: The timeout does not apply to the [Listener](Transport::Listener) stream. + fn with_timeout(self, timeout: Duration) -> timeout::TimeoutTransport + where + Self: Sized, + { + timeout::TimeoutTransport::new(self, timeout) + } +} diff --git a/network/netcore/src/transport/tcp.rs b/network/netcore/src/transport/tcp.rs new file mode 100644 index 0000000000000..0ea0d8e8fe6ee --- /dev/null +++ b/network/netcore/src/transport/tcp.rs @@ -0,0 +1,308 @@ +// Copyright (c) The Libra Core Contributors +// SPDX-License-Identifier: Apache-2.0 + +//! TCP Transport +use crate::transport::Transport; +use futures::{ + compat::{Compat01As03, Future01CompatExt}, + future::{self, Future}, + io::{AsyncRead, AsyncWrite}, + ready, + stream::Stream, +}; +use parity_multiaddr::{Multiaddr, Protocol}; +use std::{ + io, + net::{Shutdown, SocketAddr}, + pin::Pin, + task::{Context, Poll}, + time::Duration, +}; +use tokio::net::tcp::{ConnectFuture, Incoming, TcpListener, TcpStream}; + +/// Transport to build TCP connections +#[derive(Debug, Clone, Default)] +pub struct TcpTransport { + /// Size of the recv buffer size to set for opened sockets, or `None` to keep default. + recv_buffer_size: Option, + /// Size of the send buffer size to set for opened sockets, or `None` to keep default. + send_buffer_size: Option, + /// TTL to set for opened sockets, or `None` to keep default. + ttl: Option, + /// Keep alive duration to set for opened sockets, or `None` to keep default. + #[allow(clippy::option_option)] + keepalive: Option>, + /// `TCP_NODELAY` to set for opened sockets, or `None` to keep default. + nodelay: Option, +} + +impl TcpTransport { + pub fn set_recv_buffer_size(mut self, size: usize) -> Self { + self.recv_buffer_size = Some(size); + self + } + + pub fn set_send_buffer_size(mut self, size: usize) -> Self { + self.send_buffer_size = Some(size); + self + } + + pub fn set_ttl(mut self, ttl: u32) -> Self { + self.ttl = Some(ttl); + self + } + + pub fn set_keepalive(mut self, keepalive: Option) -> Self { + self.keepalive = Some(keepalive); + self + } + + pub fn set_nodelay(mut self, nodelay: bool) -> Self { + self.nodelay = Some(nodelay); + self + } + + fn apply_config(&self, stream: &TcpStream) -> ::std::io::Result<()> { + if let Some(size) = self.recv_buffer_size { + stream.set_recv_buffer_size(size)?; + } + + if let Some(size) = self.send_buffer_size { + stream.set_send_buffer_size(size)?; + } + + if let Some(ttl) = self.ttl { + stream.set_ttl(ttl)?; + } + + if let Some(keepalive) = self.keepalive { + stream.set_keepalive(keepalive)?; + } + + if let Some(nodelay) = self.nodelay { + stream.set_nodelay(nodelay)?; + } + + Ok(()) + } +} + +impl Transport for TcpTransport { + type Output = TcpSocket; + type Error = ::std::io::Error; + type Listener = TcpListenerStream; + type Inbound = future::Ready>; + type Outbound = TcpOutbound; + + fn listen_on(&self, addr: Multiaddr) -> Result<(Self::Listener, Multiaddr), Self::Error> { + let socket_addr = multiaddr_to_socketaddr(&addr)?; + let config = self.clone(); + let listener = TcpListener::bind(&socket_addr)?; + let local_addr = socketaddr_to_multiaddr(listener.local_addr()?); + Ok(( + TcpListenerStream { + inner: Compat01As03::new(listener.incoming()), + config, + }, + local_addr, + )) + } + + fn dial(&self, addr: Multiaddr) -> Result { + let socket_addr = multiaddr_to_socketaddr(&addr)?; + let config = self.clone(); + let f = TcpStream::connect(&socket_addr).compat(); + Ok(TcpOutbound { inner: f, config }) + } +} + +#[derive(Debug)] +#[must_use = "streams do nothing unless polled"] +pub struct TcpListenerStream { + inner: Compat01As03, + config: TcpTransport, +} + +impl Stream for TcpListenerStream { + type Item = io::Result<(future::Ready>, Multiaddr)>; + + fn poll_next(mut self: Pin<&mut Self>, context: &mut Context) -> Poll> { + match Pin::new(&mut self.inner).poll_next(context) { + Poll::Ready(Some(Ok(socket))) => { + if let Err(e) = self.config.apply_config(&socket) { + return Poll::Ready(Some(Err(e))); + } + let dialer_addr = match socket.peer_addr() { + Ok(addr) => socketaddr_to_multiaddr(addr), + Err(e) => return Poll::Ready(Some(Err(e))), + }; + Poll::Ready(Some(Ok(( + future::ready(Ok(TcpSocket::new(socket))), + dialer_addr, + )))) + } + Poll::Ready(Some(Err(e))) => Poll::Ready(Some(Err(e))), + Poll::Ready(None) => Poll::Ready(None), + Poll::Pending => Poll::Pending, + } + } +} + +#[derive(Debug)] +#[must_use = "futures do nothing unless polled"] +pub struct TcpOutbound { + inner: Compat01As03, + config: TcpTransport, +} + +impl Future for TcpOutbound { + type Output = Result; + + fn poll(mut self: Pin<&mut Self>, context: &mut Context) -> Poll { + let socket = ready!(Pin::new(&mut self.inner).poll(context))?; + self.config.apply_config(&socket)?; + Poll::Ready(Ok(TcpSocket::new(socket))) + } +} + +/// A wrapper around a tokio TcpStream +/// +/// In order to properly implement the AsyncRead/AsyncWrite traits we need to wrap a TcpStream to +/// ensure that the "close" method actually closes the write half of the TcpStream. This is +/// because the "close" method on a TcpStream just performs a no-op instead of actually shutting +/// down the write side of the TcpStream. +//TODO Probably should add some tests for this +#[derive(Debug)] +pub struct TcpSocket { + inner: Compat01As03, +} + +impl TcpSocket { + fn new(socket: TcpStream) -> Self { + Self { + inner: Compat01As03::new(socket), + } + } +} + +impl AsyncRead for TcpSocket { + fn poll_read( + mut self: Pin<&mut Self>, + context: &mut Context, + buf: &mut [u8], + ) -> Poll> { + Pin::new(&mut self.inner).poll_read(context, buf) + } +} + +impl AsyncWrite for TcpSocket { + fn poll_write( + mut self: Pin<&mut Self>, + context: &mut Context, + buf: &[u8], + ) -> Poll> { + Pin::new(&mut self.inner).poll_write(context, buf) + } + + fn poll_flush(mut self: Pin<&mut Self>, context: &mut Context) -> Poll> { + Pin::new(&mut self.inner).poll_flush(context) + } + + fn poll_close(self: Pin<&mut Self>, _context: &mut Context) -> Poll> { + Poll::Ready(self.inner.get_ref().shutdown(Shutdown::Write)) + } +} + +fn socketaddr_to_multiaddr(socketaddr: SocketAddr) -> Multiaddr { + let ipaddr: Multiaddr = socketaddr.ip().into(); + ipaddr.with(Protocol::Tcp(socketaddr.port())) +} + +fn multiaddr_to_socketaddr(addr: &Multiaddr) -> ::std::io::Result { + let mut iter = addr.iter(); + let proto1 = iter.next().ok_or_else(|| { + io::Error::new( + io::ErrorKind::InvalidInput, + format!("Invalid Multiaddr '{:?}'", addr), + ) + })?; + let proto2 = iter.next().ok_or_else(|| { + io::Error::new( + io::ErrorKind::InvalidInput, + format!("Invalid Multiaddr '{:?}'", addr), + ) + })?; + + if iter.next().is_some() { + return Err(io::Error::new( + io::ErrorKind::InvalidInput, + format!("Invalid Multiaddr '{:?}'", addr), + )); + } + + match (proto1, proto2) { + (Protocol::Ip4(ip), Protocol::Tcp(port)) => Ok(SocketAddr::new(ip.into(), port)), + (Protocol::Ip6(ip), Protocol::Tcp(port)) => Ok(SocketAddr::new(ip.into(), port)), + _ => Err(io::Error::new( + io::ErrorKind::InvalidInput, + format!("Invalid Multiaddr '{:?}'", addr), + )), + } +} + +#[cfg(test)] +mod test { + use crate::transport::{tcp::TcpTransport, ConnectionOrigin, Transport, TransportExt}; + use futures::{ + executor::block_on, + future::{join, FutureExt}, + io::{AsyncReadExt, AsyncWriteExt}, + stream::StreamExt, + }; + + #[test] + fn simple_listen_and_dial() -> Result<(), ::std::io::Error> { + let t = TcpTransport::default().and_then(|mut out, connection| { + async move { + match connection { + ConnectionOrigin::Inbound => { + out.write_all(b"Earth").await?; + let mut buf = [0; 3]; + out.read_exact(&mut buf).await?; + assert_eq!(&buf, b"Air"); + } + ConnectionOrigin::Outbound => { + let mut buf = [0; 5]; + out.read_exact(&mut buf).await?; + assert_eq!(&buf, b"Earth"); + out.write_all(b"Air").await?; + } + } + Ok(()) + } + }); + + let (listener, addr) = t.listen_on("/ip4/127.0.0.1/tcp/0".parse().unwrap())?; + + let dial = t.dial(addr)?; + let listener = listener.into_future().then(|(maybe_result, _stream)| { + let (incoming, _addr) = maybe_result.unwrap().unwrap(); + incoming.map(Result::unwrap) + }); + + let (outgoing, _incoming) = block_on(join(dial, listener)); + assert!(outgoing.is_ok()); + Ok(()) + } + + #[test] + fn unsupported_multiaddrs() { + let t = TcpTransport::default(); + + let result = t.listen_on("/memory/0".parse().unwrap()); + assert!(result.is_err()); + + let result = t.dial("/memory/22".parse().unwrap()); + assert!(result.is_err()); + } +} diff --git a/network/netcore/src/transport/timeout.rs b/network/netcore/src/transport/timeout.rs new file mode 100644 index 0000000000000..97ee08456b693 --- /dev/null +++ b/network/netcore/src/transport/timeout.rs @@ -0,0 +1,208 @@ +// Copyright (c) The Libra Core Contributors +// SPDX-License-Identifier: Apache-2.0 + +// Timeout Transport + +use crate::transport::Transport; +use futures::{compat::Compat01As03, future::Future, stream::Stream}; +use parity_multiaddr::Multiaddr; +use pin_utils::{unsafe_pinned, unsafe_unpinned}; +use std::{ + pin::Pin, + task::{Context, Poll}, + time::{Duration, Instant}, +}; +use tokio::{executor::Executor, timer::Delay}; + +/// A [`TimeoutTransport`] is a transport which wraps another transport with a timeout on all +/// inbound and outbound connection setup. +/// +/// Note: The [Listener](Transport::Listener) stream is not subject to the provided timeout. +#[derive(Debug)] +pub struct TimeoutTransport { + transport: T, + timeout: Duration, +} + +impl TimeoutTransport { + /// Wraps around a [`Transport`] and adds timeouts to all inbound and outbound connections + /// created by it. + pub(crate) fn new(transport: T, timeout: Duration) -> Self { + Self { transport, timeout } + } +} + +impl Transport for TimeoutTransport +where + T: Transport, + T::Error: 'static, +{ + type Output = T::Output; + type Error = TimeoutTransportError; + type Listener = TimeoutStream; + type Inbound = TimeoutFuture; + type Outbound = TimeoutFuture; + + fn listen_on(&self, addr: Multiaddr) -> Result<(Self::Listener, Multiaddr), Self::Error> { + let (listener, addr) = self.transport.listen_on(addr)?; + let listener = TimeoutStream::new(listener, self.timeout); + + Ok((listener, addr)) + } + + fn dial(&self, addr: Multiaddr) -> Result { + let fut = self.transport.dial(addr)?; + + Ok(TimeoutFuture::new(fut, self.timeout)) + } +} + +/// Listener stream returned by [listen_on](Transport::listen_on) on a TimeoutTransport. +#[derive(Debug)] +#[must_use = "streams do nothing unless polled"] +pub struct TimeoutStream { + inner: St, + timeout: Duration, +} + +impl TimeoutStream +where + St: Stream, +{ + // This use of `unsafe_pinned` is safe because: + // 1. This struct does not implement [`Drop`] + // 2. This struct does not implement [`Unpin`] + // 3. This struct is not `#[repr(packed)]` + unsafe_pinned!(inner: St); + + fn new(stream: St, timeout: Duration) -> Self { + Self { + inner: stream, + timeout, + } + } +} + +impl Stream for TimeoutStream +where + St: Stream>, + Fut: Future>, + E: ::std::error::Error, +{ + type Item = Result<(TimeoutFuture, Multiaddr), TimeoutTransportError>; + + fn poll_next(mut self: Pin<&mut Self>, context: &mut Context) -> Poll> { + match self.as_mut().inner().poll_next(context) { + Poll::Pending => Poll::Pending, + Poll::Ready(None) => Poll::Ready(None), + Poll::Ready(Some(Err(e))) => { + Poll::Ready(Some(Err(TimeoutTransportError::TransportError(e)))) + } + Poll::Ready(Some(Ok((fut, addr)))) => { + let fut = TimeoutFuture::new(fut, self.timeout); + Poll::Ready(Some(Ok((fut, addr)))) + } + } + } +} + +/// Future which wraps an inner Future with a timeout. +#[derive(Debug)] +#[must_use = "futures do nothing unless polled"] +pub struct TimeoutFuture { + future: F, + timeout: Compat01As03, +} + +impl TimeoutFuture +where + F: Future, +{ + // This use of `unsafe_pinned` is safe because: + // 1. This struct does not implement [`Drop`] + // 2. This struct does not implement [`Unpin`] + // 3. This struct is not `#[repr(packed)]` + unsafe_pinned!(future: F); + + // This use of `unsafe_unpinned` is safe because: + // 1. `timeout` implements `Unpin` + // 2. We only use the generated `timeout()` getter to construct a Pin with Pin::new. + unsafe_unpinned!(timeout: Compat01As03); + + fn new(future: F, timeout: Duration) -> Self { + let deadline = Instant::now() + timeout; + Self { + future, + timeout: Compat01As03::new(Delay::new(deadline)), + } + } +} + +impl Future for TimeoutFuture +where + F: Future>, + E: ::std::error::Error, +{ + type Output = Result>; + + fn poll(mut self: Pin<&mut Self>, mut context: &mut Context) -> Poll { + // Make sure we're inside of a Tokio Runtime since Tokio Timers + // don't work outside of a Tokio context. + assert!(tokio::executor::DefaultExecutor::current().status().is_ok()); + + // Try polling the inner future first + match self.as_mut().future().poll(&mut context) { + Poll::Pending => {} + Poll::Ready(Err(e)) => { + return Poll::Ready(Err(TimeoutTransportError::TransportError(e))) + } + Poll::Ready(Ok(output)) => return Poll::Ready(Ok(output)), + } + + // Now check to see if we've overshot the timeout + match Pin::new(self.as_mut().timeout()).poll(&mut context) { + Poll::Pending => Poll::Pending, + Poll::Ready(Err(err)) => Poll::Ready(Err(TimeoutTransportError::TimerError(err))), + Poll::Ready(Ok(())) => Poll::Ready(Err(TimeoutTransportError::Timeout)), + } + } +} + +#[derive(Debug)] +pub enum TimeoutTransportError { + Timeout, + TimerError(::tokio::timer::Error), + TransportError(E), +} + +impl ::std::convert::From for TimeoutTransportError { + fn from(error: E) -> Self { + TimeoutTransportError::TransportError(error) + } +} + +impl ::std::fmt::Display for TimeoutTransportError +where + E: ::std::fmt::Display, +{ + fn fmt(&self, f: &mut ::std::fmt::Formatter<'_>) -> ::std::fmt::Result { + match self { + TimeoutTransportError::Timeout => write!(f, "Timeout has been reached"), + TimeoutTransportError::TimerError(err) => write!(f, "Error in the timer: '{}'", err), + TimeoutTransportError::TransportError(err) => write!(f, "{}", err), + } + } +} + +impl ::std::error::Error for TimeoutTransportError +where + E: ::std::error::Error + 'static, +{ + fn source(&self) -> Option<&(dyn ::std::error::Error + 'static)> { + match self { + TimeoutTransportError::Timeout => None, + TimeoutTransportError::TimerError(err) => Some(err), + TimeoutTransportError::TransportError(err) => Some(err), + } + } +} diff --git a/network/netcore/src/utils.rs b/network/netcore/src/utils.rs new file mode 100644 index 0000000000000..082ee84c68900 --- /dev/null +++ b/network/netcore/src/utils.rs @@ -0,0 +1,6 @@ +// Copyright (c) The Libra Core Contributors +// SPDX-License-Identifier: Apache-2.0 + +pub trait Captures<'a> {} + +impl<'a, T> Captures<'a> for T {} diff --git a/network/noise/Cargo.toml b/network/noise/Cargo.toml new file mode 100644 index 0000000000000..ac85d8bc4bcf6 --- /dev/null +++ b/network/noise/Cargo.toml @@ -0,0 +1,18 @@ +[package] +name = "noise" +version = "0.1.0" +authors = ["Libra Association "] +license = "Apache-2.0" +publish = false +edition = "2018" + +[dependencies] +futures = { version = "=0.3.0-alpha.16", package = "futures-preview" } +snow = { version = "0.5.2", features=["ring-accelerated"]} + +crypto = { path = "../../crypto/legacy_crypto" } +netcore = { path = "../netcore" } +logger = { path = "../../common/logger" } + +[dev-dependencies] +memsocket = { path = "../memsocket" } diff --git a/network/noise/src/lib.rs b/network/noise/src/lib.rs new file mode 100644 index 0000000000000..bbd9bedc532a4 --- /dev/null +++ b/network/noise/src/lib.rs @@ -0,0 +1,105 @@ +// Copyright (c) The Libra Core Contributors +// SPDX-License-Identifier: Apache-2.0 + +#![feature(async_await)] + +//! [Noise protocol framework][noise] support for use in Libra. +//! +//! The main feature of this module is [`NoiseSocket`](crate::socket::NoiseSocket) which +//! provides wire-framing for noise payloads. Currently the only handshake pattern supported is IX. +//! +//! [noise]: http://noiseprotocol.org/ + +use crypto::x25519::{X25519PrivateKey, X25519PublicKey}; +use futures::io::{AsyncRead, AsyncWrite}; +use netcore::{ + negotiate::{negotiate_inbound, negotiate_outbound_interactive}, + transport::ConnectionOrigin, +}; +use snow::{self, params::NoiseParams, Keypair}; +use std::io; + +mod socket; + +pub use self::socket::NoiseSocket; + +const NOISE_IX_25519_AESGCM_SHA256_PROTOCOL_NAME: &[u8] = b"/noise_ix_25519_aesgcm_sha256/1.0.0"; +const NOISE_IX_PARAMETER: &str = "Noise_IX_25519_AESGCM_SHA256"; + +/// The Noise protocol configuration to be used to perform a protocol upgrade on an underlying +/// socket. +pub struct NoiseConfig { + keypair: Keypair, + parameters: NoiseParams, +} + +impl NoiseConfig { + /// Create a new NoiseConfig with the provided keypair + pub fn new(keypair: (X25519PrivateKey, X25519PublicKey)) -> Self { + let parameters: NoiseParams = NOISE_IX_PARAMETER.parse().expect("Invalid protocol name"); + let keypair = Keypair { + private: keypair.0.to_bytes().to_vec(), + public: keypair.1.as_bytes().to_vec(), + }; + Self { + keypair, + parameters, + } + } + + /// Create a new NoiseConfig with an ephemeral static key. + pub fn new_random() -> Self { + let parameters: NoiseParams = NOISE_IX_PARAMETER.parse().expect("Invalid protocol name"); + let keypair = snow::Builder::new(parameters.clone()) + .generate_keypair() + .unwrap(); + Self { + keypair, + parameters, + } + } + + /// Perform a protocol upgrade on an underlying connection. In addition perform the noise IX + /// handshake to establish a noise session and exchange static public keys. Upon success, + /// returns the static public key of the remote as well as a NoiseSocket. + pub async fn upgrade_connection( + &self, + socket: TSocket, + origin: ConnectionOrigin, + ) -> io::Result<(Vec, NoiseSocket)> + where + TSocket: AsyncRead + AsyncWrite + Unpin, + { + // Perform protocol negotiation + let (socket, proto) = match origin { + ConnectionOrigin::Inbound => { + negotiate_inbound(socket, [NOISE_IX_25519_AESGCM_SHA256_PROTOCOL_NAME]).await? + } + ConnectionOrigin::Outbound => { + negotiate_outbound_interactive(socket, [NOISE_IX_25519_AESGCM_SHA256_PROTOCOL_NAME]) + .await? + } + }; + + assert_eq!(proto, NOISE_IX_25519_AESGCM_SHA256_PROTOCOL_NAME); + + // Instantiate the snow session + // Note: We need to scope the Builder struct so that the compiler doesn't over eagerly + // capture it into the Async State-machine. + let session = { + let builder = snow::Builder::new(self.parameters.clone()) + .local_private_key(&self.keypair.private); + match origin { + ConnectionOrigin::Inbound => builder.build_responder(), + ConnectionOrigin::Outbound => builder.build_initiator(), + } + .map_err(|e| io::Error::new(io::ErrorKind::Other, format!("{}", e)))? + }; + + let handshake = socket::Handshake::new(socket, session); + + let socket = handshake.handshake_1rt().await?; + let remote_static_key = socket.get_remote_static().unwrap().to_owned(); + Ok((remote_static_key, socket)) + } +} diff --git a/network/noise/src/socket.rs b/network/noise/src/socket.rs new file mode 100644 index 0000000000000..feca2dceb70f8 --- /dev/null +++ b/network/noise/src/socket.rs @@ -0,0 +1,712 @@ +// Copyright (c) The Libra Core Contributors +// SPDX-License-Identifier: Apache-2.0 + +//! Noise Socket + +use futures::{ + future::poll_fn, + io::{AsyncRead, AsyncWrite}, + ready, try_ready, +}; +use logger::prelude::*; +use std::{ + convert::TryInto, + io, + pin::Pin, + task::{Context, Poll}, +}; + +const MAX_PAYLOAD_LENGTH: usize = u16::max_value() as usize; // 65535 + +// The maximum number of bytes that we can buffer is 16 bytes less than u16::max_value() because +// encrypted messages include a tag along with the payload. +const MAX_WRITE_BUFFER_LENGTH: usize = u16::max_value() as usize - 16; // 65519 + +/// Collection of buffers used for buffering data during the various read/write states of a +/// NoiseSocket +struct NoiseBuffers { + /// Encrypted frame read from the wire + read_encrypted: [u8; MAX_PAYLOAD_LENGTH], + /// Decrypted data read from the wire (produced by having snow decrypt the `read_encrypted` + /// buffer) + read_decrypted: [u8; MAX_PAYLOAD_LENGTH], + /// Unencrypted data intended to be written to the wire + write_decrypted: [u8; MAX_WRITE_BUFFER_LENGTH], + /// Encrypted data to write to the wire (produced by having snow encrypt the `write_decrypted` + /// buffer) + write_encrypted: [u8; MAX_PAYLOAD_LENGTH], +} + +impl NoiseBuffers { + fn new() -> Self { + Self { + read_encrypted: [0; MAX_PAYLOAD_LENGTH], + read_decrypted: [0; MAX_PAYLOAD_LENGTH], + write_decrypted: [0; MAX_WRITE_BUFFER_LENGTH], + write_encrypted: [0; MAX_PAYLOAD_LENGTH], + } + } +} + +/// Hand written Debug implementation in order to omit the printing of huge buffers of data +impl ::std::fmt::Debug for NoiseBuffers { + fn fmt(&self, f: &mut ::std::fmt::Formatter) -> ::std::fmt::Result { + f.debug_struct("NoiseBuffers").finish() + } +} + +/// Possible read states for a [NoiseSocket] +#[derive(Debug)] +enum ReadState { + /// Initial State + Init, + /// Read frame length + ReadFrameLen { buf: [u8; 2], offset: usize }, + /// Read encrypted frame + ReadFrame { frame_len: u16, offset: usize }, + /// Copy decrypted frame to provided buffer + CopyDecryptedFrame { decrypted_len: usize, offset: usize }, + /// End of file reached, result indicated if EOF was expected or not + Eof(Result<(), ()>), + /// Decryption Error + DecryptionError(snow::SnowError), +} + +/// Possible write states for a [NoiseSocket] +#[derive(Debug)] +enum WriteState { + /// Initial State + Init, + /// Buffer provided data + BufferData { offset: usize }, + /// Write frame length to the wire + WriteFrameLen { + frame_len: u16, + buf: [u8; 2], + offset: usize, + }, + /// Write encrypted frame to the wire + WriteEncryptedFrame { frame_len: u16, offset: usize }, + /// Flush the underlying socket + Flush, + /// End of file reached + Eof, + /// Encryption Error + EncryptionError(snow::SnowError), +} + +/// A Noise session with a remote +/// +/// Encrypts data to be written to and decrypts data that is read from the underlying socket using +/// the noise protocol. This is done by wrapping noise payloads in u16 (big endian) length prefix +/// frames. +#[derive(Debug)] +pub struct NoiseSocket { + socket: TSocket, + session: snow::Session, + buffers: Box, + read_state: ReadState, + write_state: WriteState, +} + +impl NoiseSocket { + fn new(socket: TSocket, session: snow::Session) -> Self { + Self { + socket, + session, + buffers: Box::new(NoiseBuffers::new()), + read_state: ReadState::Init, + write_state: WriteState::Init, + } + } + + /// Pull out the static public key of the remote + pub fn get_remote_static(&self) -> Option<&[u8]> { + self.session.get_remote_static() + } +} + +fn poll_write_all( + mut context: &mut Context, + mut socket: Pin<&mut TSocket>, + buf: &[u8], + offset: &mut usize, +) -> Poll> +where + TSocket: AsyncWrite, +{ + loop { + let n = ready!(socket.as_mut().poll_write(&mut context, &buf[*offset..]))?; + trace!("poll_write_all: wrote {}/{} bytes", *offset + n, buf.len()); + if n == 0 { + return Poll::Ready(Err(io::ErrorKind::WriteZero.into())); + } + *offset += n; + assert!(*offset <= buf.len()); + + if *offset == buf.len() { + return Poll::Ready(Ok(())); + } + } +} + +/// Read a u16 frame length from `socket`. +/// +/// Can result in the following output: +/// 1) Ok(None) => EOF; remote graceful shutdown +/// 2) Err(UnexpectedEOF) => read 1 byte then hit EOF; remote died +/// 3) Ok(Some(n)) => new frame of length n +fn poll_read_u16frame_len( + context: &mut Context, + socket: Pin<&mut TSocket>, + buf: &mut [u8; 2], + offset: &mut usize, +) -> Poll>> +where + TSocket: AsyncRead, +{ + match ready!(poll_read_exact(context, socket, buf, offset)) { + Ok(()) => Poll::Ready(Ok(Some(u16::from_be_bytes(*buf)))), + Err(e) => { + if *offset == 0 && e.kind() == io::ErrorKind::UnexpectedEof { + return Poll::Ready(Ok(None)); + } + Poll::Ready(Err(e)) + } + } +} + +fn poll_read_exact( + mut context: &mut Context, + mut socket: Pin<&mut TSocket>, + buf: &mut [u8], + offset: &mut usize, +) -> Poll> +where + TSocket: AsyncRead, +{ + loop { + let n = ready!(socket.as_mut().poll_read(&mut context, &mut buf[*offset..]))?; + trace!("poll_read_exact: read {}/{} bytes", *offset + n, buf.len()); + if n == 0 { + return Poll::Ready(Err(io::ErrorKind::UnexpectedEof.into())); + } + *offset += n; + assert!(*offset <= buf.len()); + + if *offset == buf.len() { + return Poll::Ready(Ok(())); + } + } +} + +impl NoiseSocket +where + TSocket: AsyncRead + Unpin, +{ + fn poll_read(&mut self, mut context: &mut Context, buf: &mut [u8]) -> Poll> { + loop { + trace!("NoiseSocket ReadState::{:?}", self.read_state); + match self.read_state { + ReadState::Init => { + self.read_state = ReadState::ReadFrameLen { + buf: [0, 0], + offset: 0, + }; + } + ReadState::ReadFrameLen { + ref mut buf, + ref mut offset, + } => { + match ready!(poll_read_u16frame_len( + &mut context, + Pin::new(&mut self.socket), + buf, + offset + )) { + Ok(Some(frame_len)) => { + // Empty Frame + if frame_len == 0 { + self.read_state = ReadState::Init; + } else { + self.read_state = ReadState::ReadFrame { + frame_len, + offset: 0, + }; + } + } + Ok(None) => { + self.read_state = ReadState::Eof(Ok(())); + } + Err(e) => { + if e.kind() == io::ErrorKind::UnexpectedEof { + self.read_state = ReadState::Eof(Err(())); + } + return Poll::Ready(Err(e)); + } + } + } + ReadState::ReadFrame { + frame_len, + ref mut offset, + } => { + match ready!(poll_read_exact( + &mut context, + Pin::new(&mut self.socket), + &mut self.buffers.read_encrypted[..(frame_len as usize)], + offset + )) { + Ok(()) => { + match self.session.read_message( + &self.buffers.read_encrypted[..(frame_len as usize)], + &mut self.buffers.read_decrypted, + ) { + Ok(decrypted_len) => { + self.read_state = ReadState::CopyDecryptedFrame { + decrypted_len, + offset: 0, + }; + } + Err(e) => { + error!("Decryption Error: {}", e); + self.read_state = ReadState::DecryptionError(e); + } + } + } + Err(e) => { + if e.kind() == io::ErrorKind::UnexpectedEof { + self.read_state = ReadState::Eof(Err(())); + } + return Poll::Ready(Err(e)); + } + } + } + ReadState::CopyDecryptedFrame { + decrypted_len, + ref mut offset, + } => { + let bytes_to_copy = + ::std::cmp::min(decrypted_len as usize - *offset, buf.len()); + buf[..bytes_to_copy].copy_from_slice( + &self.buffers.read_decrypted[*offset..(*offset + bytes_to_copy)], + ); + trace!( + "CopyDecryptedFrame: copied {}/{} bytes", + *offset + bytes_to_copy, + decrypted_len + ); + *offset += bytes_to_copy; + if *offset == decrypted_len as usize { + self.read_state = ReadState::Init; + } + return Poll::Ready(Ok(bytes_to_copy)); + } + ReadState::Eof(Ok(())) => return Poll::Ready(Ok(0)), + ReadState::Eof(Err(())) => { + return Poll::Ready(Err(io::ErrorKind::UnexpectedEof.into())) + } + ReadState::DecryptionError(ref e) => { + return Poll::Ready(Err(io::Error::new( + io::ErrorKind::InvalidData, + format!("DecryptionError: {}", e), + ))) + } + } + } + } +} + +impl AsyncRead for NoiseSocket +where + TSocket: AsyncRead + Unpin, +{ + fn poll_read( + self: Pin<&mut Self>, + context: &mut Context, + buf: &mut [u8], + ) -> Poll> { + self.get_mut().poll_read(context, buf) + } +} + +impl NoiseSocket +where + TSocket: AsyncWrite + Unpin, +{ + fn poll_write_or_flush( + &mut self, + mut context: &mut Context, + buf: Option<&[u8]>, + ) -> Poll>> { + loop { + trace!( + "NoiseSocket {} WriteState::{:?}", + if buf.is_some() { + "poll_write" + } else { + "poll_flush" + }, + self.write_state, + ); + match self.write_state { + WriteState::Init => { + if buf.is_some() { + self.write_state = WriteState::BufferData { offset: 0 }; + } else { + return Poll::Ready(Ok(None)); + } + } + WriteState::BufferData { ref mut offset } => { + let bytes_buffered = if let Some(buf) = buf { + let bytes_to_copy = + ::std::cmp::min(MAX_WRITE_BUFFER_LENGTH - *offset, buf.len()); + self.buffers.write_decrypted[*offset..(*offset + bytes_to_copy)] + .copy_from_slice(&buf[..bytes_to_copy]); + trace!("BufferData: buffered {}/{} bytes", bytes_to_copy, buf.len()); + *offset += bytes_to_copy; + Some(bytes_to_copy) + } else { + None + }; + + if buf.is_none() || *offset == MAX_WRITE_BUFFER_LENGTH { + match self.session.write_message( + &self.buffers.write_decrypted[..*offset], + &mut self.buffers.write_encrypted, + ) { + Ok(encrypted_len) => { + let frame_len = encrypted_len + .try_into() + .expect("offset should be able to fit in u16"); + self.write_state = WriteState::WriteFrameLen { + frame_len, + buf: u16::to_be_bytes(frame_len), + offset: 0, + }; + } + Err(e) => { + error!("Encryption Error: {}", e); + let err = io::Error::new( + io::ErrorKind::InvalidData, + format!("EncryptionError: {}", e), + ); + self.write_state = WriteState::EncryptionError(e); + return Poll::Ready(Err(err)); + } + } + } + + if let Some(bytes_buffered) = bytes_buffered { + return Poll::Ready(Ok(Some(bytes_buffered))); + } + } + WriteState::WriteFrameLen { + frame_len, + ref buf, + ref mut offset, + } => { + match ready!(poll_write_all( + &mut context, + Pin::new(&mut self.socket), + buf, + offset + )) { + Ok(()) => { + self.write_state = WriteState::WriteEncryptedFrame { + frame_len, + offset: 0, + }; + } + Err(e) => { + if e.kind() == io::ErrorKind::WriteZero { + self.write_state = WriteState::Eof; + } + return Poll::Ready(Err(e)); + } + } + } + WriteState::WriteEncryptedFrame { + frame_len, + ref mut offset, + } => { + match ready!(poll_write_all( + &mut context, + Pin::new(&mut self.socket), + &self.buffers.write_encrypted[..(frame_len as usize)], + offset + )) { + Ok(()) => { + self.write_state = WriteState::Flush; + } + Err(e) => { + if e.kind() == io::ErrorKind::WriteZero { + self.write_state = WriteState::Eof; + } + return Poll::Ready(Err(e)); + } + } + } + WriteState::Flush => { + try_ready!(Pin::new(&mut self.socket).poll_flush(&mut context)); + self.write_state = WriteState::Init; + } + WriteState::Eof => return Poll::Ready(Err(io::ErrorKind::WriteZero.into())), + WriteState::EncryptionError(ref e) => { + return Poll::Ready(Err(io::Error::new( + io::ErrorKind::InvalidData, + format!("EncryptionError: {}", e), + ))) + } + } + } + } + + fn poll_write(&mut self, context: &mut Context, buf: &[u8]) -> Poll> { + if let Some(bytes_written) = try_ready!(self.poll_write_or_flush(context, Some(buf))) { + Poll::Ready(Ok(bytes_written)) + } else { + unreachable!(); + } + } + + fn poll_flush(&mut self, context: &mut Context) -> Poll> { + if try_ready!(self.poll_write_or_flush(context, None)).is_none() { + Poll::Ready(Ok(())) + } else { + unreachable!(); + } + } +} + +impl AsyncWrite for NoiseSocket +where + TSocket: AsyncWrite + Unpin, +{ + fn poll_write( + self: Pin<&mut Self>, + context: &mut Context, + buf: &[u8], + ) -> Poll> { + self.get_mut().poll_write(context, buf) + } + + fn poll_flush(self: Pin<&mut Self>, context: &mut Context) -> Poll> { + self.get_mut().poll_flush(context) + } + + fn poll_close(mut self: Pin<&mut Self>, context: &mut Context) -> Poll> { + Pin::new(&mut self.socket).poll_close(context) + } +} + +/// Represents a noise session which still needs to have a handshake performed. +pub(super) struct Handshake(NoiseSocket); + +impl Handshake { + /// Build a new `Handshake` struct given a socket and a new snow Session + pub fn new(socket: TSocket, session: snow::Session) -> Self { + let noise_socket = NoiseSocket::new(socket, session); + Self(noise_socket) + } +} + +impl Handshake +where + TSocket: AsyncRead + AsyncWrite + Unpin, +{ + /// Perform a Single Round-Trip noise IX handshake returning the underlying [NoiseSocket] + /// (switched to transport mode) upon success. + pub async fn handshake_1rt(mut self) -> io::Result> { + // The Dialer + if self.0.session.is_initiator() { + // -> e, s + self.send().await?; + self.flush().await?; + + // <- e, ee, se, s, es + self.receive().await?; + } else { + // -> e, s + self.receive().await?; + + // <- e, ee, se, s, es + self.send().await?; + self.flush().await?; + } + + self.finish() + } + + /// Send handshake message to remote. + async fn send(&mut self) -> io::Result<()> { + poll_fn(|context| self.0.poll_write(context, &[])) + .await + .map(|_| ()) + } + + /// Flush handshake message to remote. + async fn flush(&mut self) -> io::Result<()> { + poll_fn(|context| self.0.poll_flush(context)).await + } + + /// Receive handshake message from remote. + async fn receive(&mut self) -> io::Result<()> { + poll_fn(|context| self.0.poll_read(context, &mut [])) + .await + .map(|_| ()) + } + + /// Finish the handshake. + /// + /// Converts the noise session into transport mode and returns the NoiseSocket. + fn finish(self) -> io::Result> { + let session = self + .0 + .session + .into_transport_mode() + .map_err(|e| io::Error::new(io::ErrorKind::Other, format!("Noise error: {}", e)))?; + Ok(NoiseSocket { session, ..self.0 }) + } +} + +#[cfg(test)] +mod test { + use crate::{ + socket::{Handshake, NoiseSocket, MAX_PAYLOAD_LENGTH}, + NOISE_IX_PARAMETER, + }; + use futures::{ + executor::block_on, + future::join, + io::{AsyncReadExt, AsyncWriteExt}, + }; + use memsocket::MemorySocket; + use snow::{params::NoiseParams, Builder, Keypair, SnowError}; + use std::io; + + fn build_test_connection() -> Result< + ( + (Keypair, Handshake), + (Keypair, Handshake), + ), + SnowError, + > { + let parameters: NoiseParams = NOISE_IX_PARAMETER.parse().expect("Invalid protocol name"); + + let dialer_keypair = Builder::new(parameters.clone()).generate_keypair()?; + let listener_keypair = Builder::new(parameters.clone()).generate_keypair()?; + + let dialer_session = Builder::new(parameters.clone()) + .local_private_key(&dialer_keypair.private) + .build_initiator()?; + let listener_session = Builder::new(parameters.clone()) + .local_private_key(&listener_keypair.private) + .build_responder()?; + + let (dialer_socket, listener_socket) = MemorySocket::new_pair(); + let (dialer, listener) = ( + NoiseSocket::new(dialer_socket, dialer_session), + NoiseSocket::new(listener_socket, listener_session), + ); + + Ok(( + (dialer_keypair, Handshake(dialer)), + (listener_keypair, Handshake(listener)), + )) + } + + fn perform_handshake( + dialer: Handshake, + listener: Handshake, + ) -> io::Result<(NoiseSocket, NoiseSocket)> { + let (dialer_result, listener_result) = + block_on(join(dialer.handshake_1rt(), listener.handshake_1rt())); + + Ok((dialer_result?, listener_result?)) + } + + #[test] + fn test_handshake() { + let ((dialer_keypair, dialer), (listener_keypair, listener)) = + build_test_connection().unwrap(); + + let (dialer_socket, listener_socket) = perform_handshake(dialer, listener).unwrap(); + + assert_eq!( + dialer_socket.get_remote_static(), + Some(listener_keypair.public.as_ref()) + ); + assert_eq!( + listener_socket.get_remote_static(), + Some(dialer_keypair.public.as_ref()) + ); + } + + #[test] + fn simple_test() -> io::Result<()> { + let ((_dialer_keypair, dialer), (_listener_keypair, listener)) = + build_test_connection().unwrap(); + + let (mut dialer_socket, mut listener_socket) = perform_handshake(dialer, listener)?; + + block_on(dialer_socket.write_all(b"stormlight"))?; + block_on(dialer_socket.write_all(b" "))?; + block_on(dialer_socket.write_all(b"archive"))?; + block_on(dialer_socket.flush())?; + block_on(dialer_socket.close())?; + + let mut buf = Vec::new(); + block_on(listener_socket.read_to_end(&mut buf))?; + + assert_eq!(buf, b"stormlight archive"); + + Ok(()) + } + + #[test] + fn interleaved_writes() -> io::Result<()> { + let ((_dialer_keypair, dialer), (_listener_keypair, listener)) = + build_test_connection().unwrap(); + + let (mut a, mut b) = perform_handshake(dialer, listener)?; + + block_on(a.write_all(b"The Name of the Wind"))?; + block_on(a.flush())?; + block_on(a.write_all(b"The Wise Man's Fear"))?; + block_on(a.flush())?; + + block_on(b.write_all(b"The Doors of Stone"))?; + block_on(b.flush())?; + + let mut buf = [0; 20]; + block_on(b.read_exact(&mut buf))?; + assert_eq!(&buf, b"The Name of the Wind"); + let mut buf = [0; 19]; + block_on(b.read_exact(&mut buf))?; + assert_eq!(&buf, b"The Wise Man's Fear"); + + let mut buf = [0; 18]; + block_on(a.read_exact(&mut buf))?; + assert_eq!(&buf, b"The Doors of Stone"); + + Ok(()) + } + + #[test] + fn u16_max_writes() -> io::Result<()> { + let ((_dialer_keypair, dialer), (_listener_keypair, listener)) = + build_test_connection().unwrap(); + + let (mut a, mut b) = perform_handshake(dialer, listener)?; + + let buf_send = [1; MAX_PAYLOAD_LENGTH]; + block_on(a.write_all(&buf_send))?; + block_on(a.flush())?; + + let mut buf_receive = [0; MAX_PAYLOAD_LENGTH]; + block_on(b.read_exact(&mut buf_receive))?; + assert_eq!(&buf_receive[..], &buf_send[..]); + + Ok(()) + } +} diff --git a/network/src/common.rs b/network/src/common.rs new file mode 100644 index 0000000000000..22e052980b6c8 --- /dev/null +++ b/network/src/common.rs @@ -0,0 +1,34 @@ +// Copyright (c) The Libra Core Contributors +// SPDX-License-Identifier: Apache-2.0 + +use crate::ProtocolId; +use crypto::{x25519::X25519PublicKey, PublicKey}; +use std::fmt; + +/// A Negotiated substream encapsulates a protocol and a substream for which that protocol has been +/// negotiated. +pub struct NegotiatedSubstream { + /// Protocol we have negotiated to use on the substream. + pub protocol: ProtocolId, + /// Opened substream. + pub substream: TSubstream, +} + +impl fmt::Debug for NegotiatedSubstream { + fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { + write!( + f, + "NegotiatedSubstream {{ protocol: {:?}, substream: ... }}", + self.protocol, + ) + } +} + +/// Public keys used at the network layer +#[derive(Debug, Clone)] +pub struct NetworkPublicKeys { + /// This key can validate signed messages at the network layer. + pub signing_public_key: PublicKey, + /// This key establishes a node's identity in the p2p network. + pub identity_public_key: X25519PublicKey, +} diff --git a/network/src/connectivity_manager/mod.rs b/network/src/connectivity_manager/mod.rs new file mode 100644 index 0000000000000..ecde2bae118e1 --- /dev/null +++ b/network/src/connectivity_manager/mod.rs @@ -0,0 +1,219 @@ +// Copyright (c) The Libra Core Contributors +// SPDX-License-Identifier: Apache-2.0 + +//! The ConnectivityManager actor is responsible for ensuring that we are connected to a node +//! if and only if it is an eligible node. +//! A list of eligible nodes is received at initialization, and updates are received on changes +//! to system membership. +//! +//! In our current system design, the Consensus actor informs the ConnectivityManager of +//! eligible nodes, and the Discovery actor infroms it about updates to addresses of eligible +//! nodes. +use crate::{ + common::NetworkPublicKeys, + peer_manager::{PeerManagerError, PeerManagerNotification, PeerManagerRequestSender}, +}; +use channel; +use futures::stream::{FusedStream, Stream, StreamExt}; +use logger::prelude::*; +use parity_multiaddr::Multiaddr; +use std::{ + collections::HashMap, + fmt::Debug, + sync::{Arc, RwLock}, +}; +use types::PeerId; + +#[cfg(test)] +mod test; + +/// The ConnectivityManager actor. +pub struct ConnectivityManager { + /// Nodes which are eligible to join the network. + eligible: Arc>>, + /// Nodes we are connected to, and the address we are connected at. + connected: HashMap, + /// Addresses of peers received from Discovery module. + peer_addresses: HashMap>, + /// Ticker to trigger connectivity checks to provide the guarantees stated above. + ticker: TTicker, + /// Channel to send requests to PeerManager. + peer_mgr_reqs_tx: PeerManagerRequestSender, + /// Channel to receive notifications from PeerManager. + peer_mgr_notifs_rx: channel::Receiver>, + /// Channel over which we receive requests from other actors. + requests_rx: channel::Receiver, + /// A local counter incremented on receiving an incoming message. Printing this in debugging + /// allows for easy debugging. + event_id: u32, +} + +/// Requests received by the [`ConnectivityManager`] manager actor from upstream modules. +#[derive(Debug)] +pub enum ConnectivityRequest { + /// Request to update known addresses of peer with id `PeerId` to given list. + UpdateAddresses(PeerId, Vec), + /// Update set of nodes eligible to join the network. + UpdateEligibleNodes(HashMap), +} + +impl ConnectivityManager +where + TTicker: Stream + FusedStream + Unpin + 'static, + TSubstream: Debug, +{ + /// Creates a new instance of the [`ConnectivityManager`] actor. + pub fn new( + eligible: Arc>>, + ticker: TTicker, + peer_mgr_reqs_tx: PeerManagerRequestSender, + peer_mgr_notifs_rx: channel::Receiver>, + requests_rx: channel::Receiver, + ) -> Self { + Self { + eligible, + connected: HashMap::new(), + peer_addresses: HashMap::new(), + ticker, + peer_mgr_reqs_tx, + peer_mgr_notifs_rx, + requests_rx, + event_id: 0, + } + } + + /// Starts the [`ConnectivityManager`] actor. + pub async fn start(mut self) { + // The ConnectivityManager actor is interested in 3 kinds of events: + // 1. Ticks to trigger connecitvity check. These are implemented using a clock based + // trigger in production. + // 2. Incoming requests to connect or disconnect with a peer. + // 3. Notifications from PeerManager when we establish a new connection or lose an existing + // connection with a peer. + loop { + self.event_id += 1; + ::futures::select! { + _ = self.ticker.select_next_some() => { + trace!("Event Id: {}, type: Tick", self.event_id); + self.check_connectivity().await; + } + req = self.requests_rx.select_next_some() => { + trace!("Event Id: {}, type: ConnectivityRequest, req: {:?}", self.event_id, req); + self.handle_request(req); + } + notif = self.peer_mgr_notifs_rx.select_next_some() => { + trace!("Event Id: {}, type: PeerManagerNotification, notif: {:?}", self.event_id, notif); + self.handle_peer_mgr_notification(notif); + } + complete => { + crit!("Connectivity manager actor terminated"); + break; + } + } + } + } + + // Note: We do not check that the connections to older incarnations of a node are broken, and + // instead rely on the node moving to a new epoch to break connections made from older + // incarnations. + async fn check_connectivity(&mut self) { + // Ensure we are only connected to eligible peers. + let stale_connections: Vec<_> = self + .connected + .keys() + .filter(|peer_id| !self.eligible.read().unwrap().contains_key(peer_id)) + .cloned() + .collect(); + for p in stale_connections.into_iter() { + info!("Should no longer be connected to peer: {}", p.short_str()); + match self.peer_mgr_reqs_tx.disconnect_peer(p).await { + // If we disconnect succesfully, or if the connection has already broken, we eagerly + // remove the peer from list of connected peers, without waiting for the LostPeer + // notification from PeerManager. + Ok(_) | Err(PeerManagerError::NotConnected(_)) => { + self.connected.remove(&p); + } + e => { + info!( + "Failed to disconnect from peer: {}. Error: {:?}", + p.short_str(), + e + ); + } + } + } + // If we have the address of an eligible peer, but are not connected to it, we dial it. + let pending_connections: Vec<_> = self + .peer_addresses + .iter() + .filter(|(peer_id, _)| { + self.eligible.read().unwrap().contains_key(peer_id) + && self.connected.get(peer_id).is_none() + }) + .collect(); + for (p, addrs) in pending_connections.into_iter() { + info!( + "Should be connected to peer: {} at addr(s): {:?}", + p.short_str(), + addrs, + ); + match self.peer_mgr_reqs_tx.dial_peer(*p, addrs[0].clone()).await { + // If the dial succeeded, or if are somehow already connected to the peer by the + // time we make the dial request, we eagerly add the peer to list of connected + // peers, without waiting for the NewPeer notification from PeerManager. + Ok(_) => { + self.connected.insert(*p, addrs[0].clone()); + } + Err(PeerManagerError::AlreadyConnected(a)) => { + // We ignore whether `a` is actually the address we dialed. + self.connected.insert(*p, a); + } + e => { + info!( + "Failed to connect to peer: {} at address: {}. Error: {:?}", + p.short_str(), + addrs[0].clone(), + e + ); + } + } + } + } + + fn handle_request(&mut self, req: ConnectivityRequest) { + match req { + ConnectivityRequest::UpdateAddresses(peer_id, addrs) => { + self.peer_addresses.insert(peer_id, addrs); + } + ConnectivityRequest::UpdateEligibleNodes(nodes) => { + *self.eligible.write().unwrap() = nodes; + } + } + } + + fn handle_peer_mgr_notification(&mut self, notif: PeerManagerNotification) { + match notif { + PeerManagerNotification::NewPeer(peer_id, addr) => { + self.connected.insert(peer_id, addr); + } + PeerManagerNotification::LostPeer(peer_id, addr) => { + match self.connected.get(&peer_id) { + Some(curr_addr) if *curr_addr == addr => { + // Remove node from connected peers list. + self.connected.remove(&peer_id); + } + _ => { + debug!( + "Ignoring stale lost peer event for peer: {}, addr: {}", + peer_id.short_str(), + addr + ); + } + } + } + _ => { + panic!("Received unexpected notification from peer manager"); + } + } + } +} diff --git a/network/src/connectivity_manager/test.rs b/network/src/connectivity_manager/test.rs new file mode 100644 index 0000000000000..9e633c4ce9cbe --- /dev/null +++ b/network/src/connectivity_manager/test.rs @@ -0,0 +1,461 @@ +// Copyright (c) The Libra Core Contributors +// SPDX-License-Identifier: Apache-2.0 + +use super::*; +use crate::peer_manager::PeerManagerRequest; +use core::str::FromStr; +use crypto::{signing, x25519}; +use futures::{FutureExt, SinkExt, TryFutureExt}; +use memsocket::MemorySocket; +use std::io; +use tokio::runtime::Runtime; + +fn setup_conn_mgr( + rt: &mut Runtime, + seed_peer_id: PeerId, +) -> ( + channel::Receiver>, + channel::Sender>, + channel::Sender, + channel::Sender<()>, +) { + let (peer_mgr_reqs_tx, peer_mgr_reqs_rx): ( + channel::Sender>, + _, + ) = channel::new_test(0); + let (peer_mgr_notifs_tx, peer_mgr_notifs_rx) = channel::new_test(0); + let (conn_mgr_reqs_tx, conn_mgr_reqs_rx) = channel::new_test(0); + let (ticker_tx, ticker_rx) = channel::new_test(0); + let (_, signing_public_key) = signing::generate_keypair(); + let (_, identity_public_key) = x25519::generate_keypair(); + let conn_mgr = { + ConnectivityManager::new( + Arc::new(RwLock::new( + vec![( + seed_peer_id, + NetworkPublicKeys { + identity_public_key, + signing_public_key, + }, + )] + .into_iter() + .collect(), + )), + ticker_rx, + PeerManagerRequestSender::new(peer_mgr_reqs_tx), + peer_mgr_notifs_rx, + conn_mgr_reqs_rx, + ) + }; + rt.spawn(conn_mgr.start().boxed().unit_error().compat()); + ( + peer_mgr_reqs_rx, + peer_mgr_notifs_tx, + conn_mgr_reqs_tx, + ticker_tx, + ) +} + +async fn expect_disconnect_request( + peer_mgr_reqs_rx: &mut channel::Receiver>, + peer: PeerId, + result: Result<(), PeerManagerError>, +) { + match peer_mgr_reqs_rx.next().await.unwrap() { + PeerManagerRequest::DisconnectPeer(p, error_tx) => { + assert_eq!(peer, p); + error_tx.send(result).unwrap(); + } + _ => { + panic!("unexpected request to peer manager"); + } + } +} + +async fn expect_dial_request( + peer_mgr_reqs_rx: &mut channel::Receiver>, + peer: PeerId, + address: Multiaddr, + result: Result<(), PeerManagerError>, +) { + match peer_mgr_reqs_rx.next().await.unwrap() { + PeerManagerRequest::DialPeer(p, addr, error_tx) => { + assert_eq!(peer, p); + assert_eq!(address, addr); + error_tx.send(result).unwrap(); + } + _ => { + panic!("unexpected request to peer manager"); + } + } +} + +#[test] +fn addr_change() { + ::logger::try_init_for_testing(); + let mut rt = Runtime::new().unwrap(); + let seed_peer_id = PeerId::random(); + let (mut peer_mgr_reqs_rx, mut peer_mgr_notifs_tx, mut conn_mgr_reqs_tx, mut ticker_tx) = + setup_conn_mgr(&mut rt, seed_peer_id); + + // Fake peer manager and discovery. + let f_peer_mgr = async move { + let seed_address = Multiaddr::from_str("/ip4/127.0.0.1/tcp/9090").unwrap(); + + // Send request to connect to seed peer. + info!("Sending request to connect to seed peer"); + conn_mgr_reqs_tx + .send(ConnectivityRequest::UpdateAddresses( + seed_peer_id, + vec![seed_address.clone()], + )) + .await + .unwrap(); + + // Trigger connectivity check. + info!("Sending tick to trigger connectivity check"); + ticker_tx.send(()).await.unwrap(); + + // Peer manager receives a request to connect to the seed peer. + info!("Waiting to receive dial request"); + expect_dial_request( + &mut peer_mgr_reqs_rx, + seed_peer_id, + seed_address.clone(), + Ok(()), + ) + .await; + + // Send request to connect to seed peer at old address. ConnectivityManager should not + // dial, since we are already connected at the new address. The absence of another dial + // attempt is hard to test explicitly. It will get implicitly tested if the dial + // attempt arrives in place of some other expected message in the future. + info!("Sending redundant request to connect to seed peer."); + conn_mgr_reqs_tx + .send(ConnectivityRequest::UpdateAddresses( + seed_peer_id, + vec![seed_address.clone()], + )) + .await + .unwrap(); + + // Trigger connectivity check. + info!("Sending tick to trigger connectivity check"); + ticker_tx.send(()).await.unwrap(); + + let seed_address_new = Multiaddr::from_str("/ip4/127.0.1.1/tcp/8080").unwrap(); + // Send request to connect to seed peer at new address. + info!("Sending request to connect to seed peer at new address."); + conn_mgr_reqs_tx + .send(ConnectivityRequest::UpdateAddresses( + seed_peer_id, + vec![seed_address_new.clone()], + )) + .await + .unwrap(); + + // Trigger connectivity check. + info!("Sending tick to trigger connectivity check"); + ticker_tx.send(()).await.unwrap(); + + // We expect the peer which changed its address to also disconnect. + info!("Sending lost peer notification for seed peer at old address"); + peer_mgr_notifs_tx + .send(PeerManagerNotification::LostPeer( + seed_peer_id, + seed_address, + )) + .await + .unwrap(); + + // Trigger connectivity check. + info!("Sending tick to trigger connectivity check"); + ticker_tx.send(()).await.unwrap(); + + // Peer manager then receives a request to connect to the seed peer at new address. + info!("Waiting to receive dial request to seed peer at new address"); + expect_dial_request( + &mut peer_mgr_reqs_rx, + seed_peer_id, + seed_address_new, + Ok(()), + ) + .await; + }; + rt.block_on(f_peer_mgr.boxed().unit_error().compat()) + .unwrap(); +} + +#[test] +fn lost_connection() { + ::logger::try_init_for_testing(); + let mut rt = Runtime::new().unwrap(); + let seed_peer_id = PeerId::random(); + let (mut peer_mgr_reqs_rx, mut peer_mgr_notifs_tx, mut conn_mgr_reqs_tx, mut ticker_tx) = + setup_conn_mgr(&mut rt, seed_peer_id); + + // Fake peer manager and discovery. + let f_peer_mgr = async move { + let seed_address = Multiaddr::from_str("/ip4/127.0.0.1/tcp/9090").unwrap(); + + // Send request to connect to seed peer. + info!("Sending request to connect to seed peer"); + conn_mgr_reqs_tx + .send(ConnectivityRequest::UpdateAddresses( + seed_peer_id, + vec![seed_address.clone()], + )) + .await + .unwrap(); + + // Trigger connectivity check. + info!("Sending tick to trigger connectivity check"); + ticker_tx.send(()).await.unwrap(); + + // Peer manager receives a request to connect to the seed peer. + info!("Waiting to receive dial request"); + expect_dial_request( + &mut peer_mgr_reqs_rx, + seed_peer_id, + seed_address.clone(), + Ok(()), + ) + .await; + + // Notify connectivity actor of loss of connection to seed_peer. + info!("Sending LostPeer event to signal connection loss"); + peer_mgr_notifs_tx + .send(PeerManagerNotification::LostPeer( + seed_peer_id, + seed_address.clone(), + )) + .await + .unwrap(); + + // Trigger connectivity check. + info!("Sending tick to trigger connectivity check"); + ticker_tx.send(()).await.unwrap(); + + // Peer manager receives a request to connect to the seed peer after loss of + // connection. + info!("Waiting to receive dial request"); + expect_dial_request( + &mut peer_mgr_reqs_rx, + seed_peer_id, + seed_address.clone(), + Ok(()), + ) + .await; + }; + rt.block_on(f_peer_mgr.boxed().unit_error().compat()) + .unwrap(); +} + +#[test] +fn disconnect() { + ::logger::try_init_for_testing(); + let mut rt = Runtime::new().unwrap(); + let seed_peer_id = PeerId::random(); + let (mut peer_mgr_reqs_rx, _, mut conn_mgr_reqs_tx, mut ticker_tx) = + setup_conn_mgr(&mut rt, seed_peer_id); + + let events_f = async move { + let seed_address = Multiaddr::from_str("/ip4/127.0.0.1/tcp/9090").unwrap(); + + // Send request to connect to seed peer. + info!("Sending request to connect to seed peer"); + conn_mgr_reqs_tx + .send(ConnectivityRequest::UpdateAddresses( + seed_peer_id, + vec![seed_address.clone()], + )) + .await + .unwrap(); + + // Trigger connectivity check. + info!("Sending tick to trigger connectivity check"); + ticker_tx.send(()).await.unwrap(); + + // Peer manager receives a request to connect to the seed peer. + info!("Waiting to receive dial request"); + expect_dial_request( + &mut peer_mgr_reqs_rx, + seed_peer_id, + seed_address.clone(), + Ok(()), + ) + .await; + + // Send request to make seed peer ineligible. + info!("Sending request to make seed peer ineligible"); + conn_mgr_reqs_tx + .send(ConnectivityRequest::UpdateEligibleNodes(HashMap::new())) + .await + .unwrap(); + + // Trigger connectivity check. + info!("Sending tick to trigger connectivity check"); + ticker_tx.send(()).await.unwrap(); + + // Peer manager receives a request to connect to the seed peer. + info!("Waiting to receive disconnect request"); + expect_disconnect_request(&mut peer_mgr_reqs_rx, seed_peer_id, Ok(())).await; + }; + rt.block_on(events_f.boxed().unit_error().compat()).unwrap(); +} + +// Tests that connectivity manager retries dials and disconnects on failure. +#[test] +fn retry_on_failure() { + ::logger::try_init_for_testing(); + let mut rt = Runtime::new().unwrap(); + let seed_peer_id = PeerId::random(); + let (mut peer_mgr_reqs_rx, _, mut conn_mgr_reqs_tx, mut ticker_tx) = + setup_conn_mgr(&mut rt, seed_peer_id); + + let events_f = async move { + let seed_address = Multiaddr::from_str("/ip4/127.0.0.1/tcp/9090").unwrap(); + + // Send request to connect to seed peer. + info!("Sending request to connect to seed peer"); + conn_mgr_reqs_tx + .send(ConnectivityRequest::UpdateAddresses( + seed_peer_id, + vec![seed_address.clone()], + )) + .await + .unwrap(); + + // Trigger connectivity check. + info!("Sending tick to trigger connectivity check"); + ticker_tx.send(()).await.unwrap(); + + // Peer manager receives a request to connect to the seed peer. + info!("Waiting to receive dial request"); + expect_dial_request( + &mut peer_mgr_reqs_rx, + seed_peer_id, + seed_address.clone(), + Err(PeerManagerError::IoError(io::Error::from( + io::ErrorKind::ConnectionRefused, + ))), + ) + .await; + + // Trigger connectivity check. + info!("Sending tick to trigger connectivity check"); + ticker_tx.send(()).await.unwrap(); + + // Peer manager again receives a request to connect to the seed peer. + info!("Waiting to receive dial request"); + expect_dial_request( + &mut peer_mgr_reqs_rx, + seed_peer_id, + seed_address.clone(), + Ok(()), + ) + .await; + + // Send request to make seed peer ineligible. + info!("Sending request to make seed peer ineligible"); + conn_mgr_reqs_tx + .send(ConnectivityRequest::UpdateEligibleNodes(HashMap::new())) + .await + .unwrap(); + + // Trigger connectivity check. + info!("Sending tick to trigger connectivity check"); + ticker_tx.send(()).await.unwrap(); + + // Peer manager receives a request to disconnect from the seed peer, which fails. + info!("Waiting to receive disconnect request"); + expect_disconnect_request( + &mut peer_mgr_reqs_rx, + seed_peer_id, + Err(PeerManagerError::IoError(io::Error::from( + io::ErrorKind::Interrupted, + ))), + ) + .await; + + // Trigger connectivity check again. + info!("Sending tick to trigger connectivity check"); + ticker_tx.send(()).await.unwrap(); + + // Peer manager receives another request to disconnect from the seed peer, which now + // succeeds. + info!("Waiting to receive disconnect request"); + expect_disconnect_request(&mut peer_mgr_reqs_rx, seed_peer_id, Ok(())).await; + }; + rt.block_on(events_f.boxed().unit_error().compat()).unwrap(); +} + +#[test] +// Tests that if we dial a an already connected peer or disconnect from an already disconnected +// peer, connectivity manager does not send any additional dial or disconnect requests. +fn no_op_requests() { + ::logger::try_init_for_testing(); + let mut rt = Runtime::new().unwrap(); + let seed_peer_id = PeerId::random(); + let (mut peer_mgr_reqs_rx, _, mut conn_mgr_reqs_tx, mut ticker_tx) = + setup_conn_mgr(&mut rt, seed_peer_id); + + let events_f = async move { + let seed_address = Multiaddr::from_str("/ip4/127.0.0.1/tcp/9090").unwrap(); + + // Send request to connect to seed peer. + info!("Sending request to connect to seed peer"); + conn_mgr_reqs_tx + .send(ConnectivityRequest::UpdateAddresses( + seed_peer_id, + vec![seed_address.clone()], + )) + .await + .unwrap(); + + // Trigger connectivity check. + info!("Sending tick to trigger connectivity check"); + ticker_tx.send(()).await.unwrap(); + + // Peer manager receives a request to connect to the seed peer. + info!("Waiting to receive dial request"); + expect_dial_request( + &mut peer_mgr_reqs_rx, + seed_peer_id, + seed_address.clone(), + Err(PeerManagerError::AlreadyConnected(seed_address.clone())), + ) + .await; + + // Trigger connectivity check. + info!("Sending tick to trigger connectivity check"); + ticker_tx.send(()).await.unwrap(); + + // Send request to make seed peer ineligible. + info!("Sending request to make seed peer ineligible"); + conn_mgr_reqs_tx + .send(ConnectivityRequest::UpdateEligibleNodes(HashMap::new())) + .await + .unwrap(); + + // Trigger connectivity check. + info!("Sending tick to trigger connectivity check"); + ticker_tx.send(()).await.unwrap(); + + // Peer manager receives a request to disconnect from the seed peer, which fails. + info!("Waiting to receive disconnect request"); + expect_disconnect_request( + &mut peer_mgr_reqs_rx, + seed_peer_id, + Err(PeerManagerError::NotConnected(seed_peer_id)), + ) + .await; + + // Trigger connectivity check again. We don't expect connectivity manager to do + // anything - if it does, the task should panic. That may not fail the test (right + // now), but will be easily spotted by someone running the tests locallly. + info!("Sending tick to trigger connectivity check"); + ticker_tx.send(()).await.unwrap(); + }; + rt.block_on(events_f.boxed().unit_error().compat()).unwrap(); +} diff --git a/network/src/counters.rs b/network/src/counters.rs new file mode 100644 index 0000000000000..30953ec5e0f7a --- /dev/null +++ b/network/src/counters.rs @@ -0,0 +1,117 @@ +// Copyright (c) The Libra Core Contributors +// SPDX-License-Identifier: Apache-2.0 + +use lazy_static; +use metrics::{Histogram, IntCounter, IntGauge, OpMetrics}; + +lazy_static::lazy_static! { + pub static ref OP_COUNTERS: OpMetrics = OpMetrics::new_and_registered("network"); +} + +lazy_static::lazy_static! { + /// Counter of currently connected peers + pub static ref CONNECTED_PEERS: IntGauge = OP_COUNTERS.gauge("connected_peers"); + + /// Counter of rpc requests sent + pub static ref RPC_REQUESTS_SENT: IntCounter = OP_COUNTERS.counter("rpc_requests_sent"); + + /// Counter of rpc request bytes sent + pub static ref RPC_REQUEST_BYTES_SENT: IntCounter = OP_COUNTERS.counter("rpc_request_bytes_sent"); + + /// Counter of rpc requests failed + pub static ref RPC_REQUESTS_FAILED: IntCounter = OP_COUNTERS.counter("rpc_requests_failed"); + + /// Counter of rpc requests cancelled + pub static ref RPC_REQUESTS_CANCELLED: IntCounter = OP_COUNTERS.counter("rpc_requests_cancelled"); + + /// Counter of rpc requests received + pub static ref RPC_REQUESTS_RECEIVED: IntCounter = OP_COUNTERS.counter("rpc_requests_received"); + + /// Counter of rpc responses sent + pub static ref RPC_RESPONSES_SENT: IntCounter = OP_COUNTERS.counter("rpc_responses_sent"); + + /// Counter of rpc response bytes sent + pub static ref RPC_RESPONSE_BYTES_SENT: IntCounter = OP_COUNTERS.counter("rpc_response_bytes_sent"); + + /// Counter of rpc responses failed + pub static ref RPC_RESPONSES_FAILED: IntCounter = OP_COUNTERS.counter("rpc_responses_failed"); + + /// Histogram of rpc latency + pub static ref RPC_LATENCY: Histogram = OP_COUNTERS.histogram("rpc_latency"); + + /// Counter of messages sent via the direct send protocol + pub static ref DIRECT_SEND_MESSAGES_SENT: IntCounter = OP_COUNTERS.counter("direct_send_messages_sent"); + + /// Counter of bytes sent via the direct send protocol + pub static ref DIRECT_SEND_BYTES_SENT: IntCounter = OP_COUNTERS.counter("direct_send_bytes_sent"); + + /// Counter of messages dropped via the direct send protocol + pub static ref DIRECT_SEND_MESSAGES_DROPPED: IntCounter = OP_COUNTERS.counter("direct_send_messages_dropped"); + + /// Counter of messages received via the direct send protocol + pub static ref DIRECT_SEND_MESSAGES_RECEIVED: IntCounter = OP_COUNTERS.counter("direct_send_messages_received"); + + /// Counter of bytes received via the direct send protocol + pub static ref DIRECT_SEND_BYTES_RECEIVED: IntCounter = OP_COUNTERS.counter("direct_send_bytes_received"); + + /// + /// Channel Counters + /// + + /// Counter of pending requests in Network Provider + pub static ref PENDING_NETWORK_REQUESTS: IntGauge = OP_COUNTERS.gauge("pending_network_requests"); + + /// Counter of pending network events to Mempool + pub static ref PENDING_MEMPOOL_NETWORK_EVENTS: IntGauge = OP_COUNTERS.gauge("pending_mempool_network_events"); + + /// Counter of pending network events to Consensus + pub static ref PENDING_CONSENSUS_NETWORK_EVENTS: IntGauge = OP_COUNTERS.gauge("pending_consensus_network_events"); + + /// Counter of pending requests in Peer Manager + pub static ref PENDING_PEER_MANAGER_REQUESTS: IntGauge = OP_COUNTERS.gauge("pending_peer_manager_requests"); + + /// Counter of pending Peer Manager notifications in Network Provider + pub static ref PENDING_PEER_MANAGER_NET_NOTIFICATIONS: IntGauge = OP_COUNTERS.gauge("pending_peer_manager_net_notifications"); + + /// Counter of pending requests in Direct Send + pub static ref PENDING_DIRECT_SEND_REQUESTS: IntGauge = OP_COUNTERS.gauge("pending_direct_send_requests"); + + /// Counter of pending Direct Send notifications to Network Provider + pub static ref PENDING_DIRECT_SEND_NOTIFICATIONS: IntGauge = OP_COUNTERS.gauge("pending_direct_send_notifications"); + + /// Counter of pending requests in Connectivity Manager + pub static ref PENDING_CONNECTIVITY_MANAGER_REQUESTS: IntGauge = OP_COUNTERS.gauge("pending_connectivity_manager_requests"); + + /// Counter of pending requests in RPC + pub static ref PENDING_RPC_REQUESTS: IntGauge = OP_COUNTERS.gauge("pending_rpc_requests"); + + /// Counter of pending RPC notifications to Network Provider + pub static ref PENDING_RPC_NOTIFICATIONS: IntGauge = OP_COUNTERS.gauge("pending_rpc_notifications"); + + /// Counter of pending Peer Manager notifications to Direct Send + pub static ref PENDING_PEER_MANAGER_DIRECT_SEND_NOTIFICATIONS: IntGauge = OP_COUNTERS.gauge("pending_peer_manager_direct_send_notifications"); + + /// Counter of pending Peer Manager notifications to RPC + pub static ref PENDING_PEER_MANAGER_RPC_NOTIFICATIONS: IntGauge = OP_COUNTERS.gauge("pending_peer_manager_rpc_notifications"); + + /// Counter of pending Peer Manager notifications to Discovery + pub static ref PENDING_PEER_MANAGER_DISCOVERY_NOTIFICATIONS: IntGauge = OP_COUNTERS.gauge("pending_peer_manager_discovery_notifications"); + + /// Counter of pending Peer Manager notifications to Ping + pub static ref PENDING_PEER_MANAGER_PING_NOTIFICATIONS: IntGauge = OP_COUNTERS.gauge("pending_peer_manager_ping_notifications"); + + /// Counter of pending Peer Manager notifications to Connectivity Manager + pub static ref PENDING_PEER_MANAGER_CONNECTIVITY_MANAGER_NOTIFICATIONS: IntGauge = OP_COUNTERS.gauge("pending_peer_manager_connectivity_manager_notifications"); + + /// Counter of pending internal events in Peer Manager + pub static ref PENDING_PEER_MANAGER_INTERNAL_EVENTS: IntGauge = OP_COUNTERS.gauge("pending_peer_manager_internal_events"); + + /// Counter of pending dial requests in Peer Manager + pub static ref PENDING_PEER_MANAGER_DIAL_REQUESTS: IntGauge = OP_COUNTERS.gauge("pending_peer_manager_dial_requests"); + + /// Counter of pending requests for each remote peer + pub static ref PENDING_PEER_REQUESTS: &'static str = "pending_peer_requests"; + + /// Counter of pending outbound messages in Direct Send for each remote peer + pub static ref PENDING_DIRECT_SEND_OUTBOUND_MESSAGES: &'static str = "pending_direct_send_outbound_messages"; +} diff --git a/network/src/error.rs b/network/src/error.rs new file mode 100644 index 0000000000000..7b9bc3bc27a29 --- /dev/null +++ b/network/src/error.rs @@ -0,0 +1,178 @@ +// Copyright (c) The Libra Core Contributors +// SPDX-License-Identifier: Apache-2.0 + +use crate::peer_manager::PeerManagerError; +use failure::{Backtrace, Context, Fail}; +use futures::channel::mpsc; +use protobuf::error::ProtobufError; +use std::{ + fmt::{self, Display}, + io, +}; +use tokio::timer; +use types::validator_verifier::VerifyError; + +/// Errors propagated from the network module. +#[derive(Debug)] +pub struct NetworkError { + inner: Context, +} + +#[derive(Copy, Clone, Eq, PartialEq, Debug, Fail)] +pub enum NetworkErrorKind { + #[fail(display = "IO error")] + IoError, + + #[fail(display = "Error parsing protobuf message")] + ProtobufParseError, + + #[fail(display = "Invalid signature error")] + SignatureError, + + #[fail(display = "Failed to parse multiaddrs")] + MultiaddrError, + + #[fail(display = "Error sending on mpsc channel")] + MpscSendError, + + #[fail(display = "Error setting timeout")] + TimerError, + + #[fail(display = "Operation timed out")] + TimedOut, + + #[fail(display = "PeerManager error")] + PeerManagerError, + + #[fail(display = "Parsing error")] + ParsingError, + + #[fail(display = "Peer disconnected")] + NotConnected, +} + +impl Fail for NetworkError { + fn cause(&self) -> Option<&Fail> { + self.inner.cause() + } + + fn backtrace(&self) -> Option<&Backtrace> { + self.inner.backtrace() + } +} + +impl Display for NetworkError { + fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result { + write!(f, "{}", &self.inner) + } +} + +impl NetworkError { + pub fn kind(&self) -> NetworkErrorKind { + *self.inner.get_context() + } +} + +impl From for NetworkError { + fn from(kind: NetworkErrorKind) -> NetworkError { + NetworkError { + inner: Context::new(kind), + } + } +} + +impl From> for NetworkError { + fn from(inner: Context) -> NetworkError { + NetworkError { inner } + } +} + +impl From for NetworkError { + fn from(err: io::Error) -> NetworkError { + err.context(NetworkErrorKind::IoError).into() + } +} + +impl From for NetworkError { + fn from(err: VerifyError) -> NetworkError { + err.context(NetworkErrorKind::SignatureError).into() + } +} + +impl From for NetworkError { + fn from(err: ProtobufError) -> NetworkError { + err.context(NetworkErrorKind::ProtobufParseError).into() + } +} + +impl From for NetworkError { + fn from(err: parity_multiaddr::Error) -> NetworkError { + err.context(NetworkErrorKind::MultiaddrError).into() + } +} + +impl From for NetworkError { + fn from(err: mpsc::SendError) -> NetworkError { + err.context(NetworkErrorKind::MpscSendError).into() + } +} + +impl From for NetworkError { + fn from(err: PeerManagerError) -> NetworkError { + match err { + PeerManagerError::IoError(_) => err.context(NetworkErrorKind::IoError).into(), + PeerManagerError::NotConnected(_) => err.context(NetworkErrorKind::NotConnected).into(), + err => err.context(NetworkErrorKind::PeerManagerError).into(), + } + } +} + +impl From> for NetworkError { + fn from(err: timer::timeout::Error) -> NetworkError { + if err.is_elapsed() { + Context::new(NetworkErrorKind::TimedOut).into() + } else if err.is_timer() { + err.into_timer() + .unwrap() + .context(NetworkErrorKind::TimerError) + .into() + } else if err.is_inner() { + err.into_inner().unwrap() + } else { + unreachable!("Unrecognized timer error: {}", err) + } + } +} + +#[cfg(test)] +mod test { + + use super::*; + use failure::AsFail; + + // This test demos a causal error chain that can be created using the `context` method of `Fail` + // types. + #[test] + fn causal_chain() { + let base_error = ::failure::err_msg("First error"); + let first_level_error = base_error.context(NetworkErrorKind::TimedOut); + let second_level_error = first_level_error.context(NetworkErrorKind::PeerManagerError); + let network_error: NetworkError = second_level_error.into(); + // When called without RUST_BACKTRACE=1, the debug mode should print the following: + // NetworkError { inner: ErrorMessage { msg: "First error" } + // Operation timed out + // PeerManager error } + eprintln!("{:?}", network_error); + // The display mode output is just the outermost error: + // PeerManager error + eprintln!("{}", network_error); + // Alternatively, we can iterate over the individual failures in the causal chain to get + // the following output: + // Error: PeerManager error + // Error: Operation timed out + // Error: First error + for e in network_error.as_fail().iter_chain() { + eprintln!("Error: {}", e); + } + } +} diff --git a/network/src/interface/mod.rs b/network/src/interface/mod.rs new file mode 100644 index 0000000000000..4143ffaeabb1d --- /dev/null +++ b/network/src/interface/mod.rs @@ -0,0 +1,261 @@ +// Copyright (c) The Libra Core Contributors +// SPDX-License-Identifier: Apache-2.0 + +//! Module exposing generic network API +//! +//! Unlike the [`validator_network`](crate::validator_network) module, which exposes async function +//! call and Stream API specific for consensus and mempool modules, the `interface` module +//! exposes generic network API by receiving requests over a channel for outbound requests and +//! sending notifications to upstream clients for inbound requests and other events. For example, +//! clients wishing to send an RPC need to send a +//! [`NetworkRequest::SendRpc`](crate::interface::NetworkRequest::SendRpc) message to the +//! [`NetworkProvider`] actor. Inbound RPC requests are forwarded to the appropriate +//! handler, determined using the protocol negotiated on the RPC substream. +use crate::{ + common::NetworkPublicKeys, + connectivity_manager::ConnectivityRequest, + counters, + peer_manager::PeerManagerNotification, + protocols::{ + direct_send::{DirectSendNotification, DirectSendRequest, Message}, + rpc::{InboundRpcRequest, OutboundRpcRequest, RpcNotification, RpcRequest}, + }, + ProtocolId, +}; +use channel; +use futures::{FutureExt, SinkExt, StreamExt}; +use logger::prelude::*; +use std::{collections::HashMap, fmt::Debug}; +use types::PeerId; + +/// Requests [`NetworkProvider`] receives from the network interface. +#[derive(Debug)] +pub enum NetworkRequest { + /// Send an RPC request to a remote peer. + SendRpc(PeerId, OutboundRpcRequest), + /// Fire-and-forget style message send to a remote peer. + SendMessage(PeerId, Message), + /// Update set of nodes eligible to join the network. + UpdateEligibleNodes(HashMap), +} + +/// Notifications that [`NetworkProvider`] sends to consumers of its API. The +/// [`NetworkProvider`] in turn receives these notifications from the PeerManager and other +/// [`protocols`](crate::protocols). +#[derive(Debug)] +pub enum NetworkNotification { + /// Connection with a new peer has been established. + NewPeer(PeerId), + /// Connection to a peer has been terminated. This could have been triggered from either end. + LostPeer(PeerId), + /// A new RPC request has been received from a remote peer. + RecvRpc(PeerId, InboundRpcRequest), + /// A new message has been received from a remote peer. + RecvMessage(PeerId, Message), +} + +pub struct NetworkProvider { + /// Map from protocol to upstream handlers for events of that protocol type. + upstream_handlers: HashMap>, + /// Channel over which we receive notifications from PeerManager. + peer_mgr_notifs_rx: channel::Receiver>, + /// Channel over which we send requets to RPC actor. + rpc_reqs_tx: channel::Sender, + /// Channel over which we receive notifications from RPC actor. + rpc_notifs_rx: channel::Receiver, + /// Channel over which we send requests to DirectSend actor. + ds_reqs_tx: channel::Sender, + /// Channel over which we receive notifications from DirectSend actor. + ds_notifs_rx: channel::Receiver, + /// Channel over which we send requests to the ConnectivityManager actor. + conn_mgr_reqs_tx: channel::Sender, + /// Channel to receive requests from other actors. + requests_rx: channel::Receiver, + /// The maximum number of concurrent NetworkRequests that can be handled. + /// Back-pressure takes effect via bounded mpsc channel beyond the limit. + max_concurrent_reqs: u32, + /// The maximum number of concurrent Notifications from Peer Manager, + /// RPC and Direct Send that can be handled. + /// Back-pressure takes effect via bounded mpsc channel beyond the limit. + max_concurrent_notifs: u32, +} + +impl NetworkProvider +where + TSubstream: Debug + Send, +{ + pub fn new( + peer_mgr_notifs_rx: channel::Receiver>, + rpc_reqs_tx: channel::Sender, + rpc_notifs_rx: channel::Receiver, + ds_reqs_tx: channel::Sender, + ds_notifs_rx: channel::Receiver, + conn_mgr_reqs_tx: channel::Sender, + requests_rx: channel::Receiver, + upstream_handlers: HashMap>, + max_concurrent_reqs: u32, + max_concurrent_notifs: u32, + ) -> Self { + Self { + upstream_handlers, + peer_mgr_notifs_rx, + rpc_reqs_tx, + rpc_notifs_rx, + ds_reqs_tx, + ds_notifs_rx, + conn_mgr_reqs_tx, + requests_rx, + max_concurrent_reqs, + max_concurrent_notifs, + } + } + + async fn handle_network_request( + req: NetworkRequest, + mut rpc_reqs_tx: channel::Sender, + mut ds_reqs_tx: channel::Sender, + mut conn_mgr_reqs_tx: channel::Sender, + ) { + trace!("NetworkRequest::{:?}", req); + match req { + NetworkRequest::SendRpc(peer_id, req) => { + rpc_reqs_tx + .send(RpcRequest::SendRpc(peer_id, req)) + .await + .unwrap(); + } + NetworkRequest::SendMessage(peer_id, msg) => { + counters::DIRECT_SEND_MESSAGES_SENT.inc(); + counters::DIRECT_SEND_BYTES_SENT.inc_by(msg.mdata.len() as i64); + ds_reqs_tx + .send(DirectSendRequest::SendMessage(peer_id, msg)) + .await + .unwrap(); + } + NetworkRequest::UpdateEligibleNodes(nodes) => { + conn_mgr_reqs_tx + .send(ConnectivityRequest::UpdateEligibleNodes(nodes)) + .await + .unwrap(); + } + } + } + + async fn handle_peer_mgr_notification( + notif: PeerManagerNotification, + mut upstream_handlers: HashMap>, + ) { + trace!("PeerManagerNotification::{:?}", notif); + match notif { + PeerManagerNotification::NewPeer(peer_id, _addr) => { + counters::CONNECTED_PEERS.inc(); + for ch in upstream_handlers.values_mut() { + ch.send(NetworkNotification::NewPeer(peer_id)) + .await + .unwrap(); + } + } + PeerManagerNotification::LostPeer(peer_id, _addr) => { + counters::CONNECTED_PEERS.dec(); + for ch in upstream_handlers.values_mut() { + ch.send(NetworkNotification::LostPeer(peer_id)) + .await + .unwrap(); + } + } + _ => { + unreachable!("Received unexpected event from PeerManager"); + } + } + } + + async fn handle_rpc_notification( + notif: RpcNotification, + mut upstream_handlers: HashMap>, + ) { + trace!("RpcNotification::{:?}", notif); + match notif { + RpcNotification::RecvRpc(peer_id, req) => { + if let Some(ch) = upstream_handlers.get_mut(&req.protocol) { + ch.send(NetworkNotification::RecvRpc(peer_id, req)) + .await + .unwrap(); + } else { + unreachable!(); + } + } + } + } + + async fn handle_ds_notification( + mut upstream_handlers: HashMap>, + notif: DirectSendNotification, + ) { + trace!("DirectSendNotification::{:?}", notif); + match notif { + DirectSendNotification::RecvMessage(peer_id, msg) => { + counters::DIRECT_SEND_MESSAGES_RECEIVED.inc(); + counters::DIRECT_SEND_BYTES_RECEIVED.inc_by(msg.mdata.len() as i64); + let ch = upstream_handlers + .get_mut(&msg.protocol) + .expect("DirectSend protocol not registered"); + ch.send(NetworkNotification::RecvMessage(peer_id, msg)) + .await + .unwrap(); + } + } + } + + pub async fn start(self) { + let rpc_reqs_tx = self.rpc_reqs_tx.clone(); + let ds_reqs_tx = self.ds_reqs_tx.clone(); + let conn_mgr_reqs_tx = self.conn_mgr_reqs_tx.clone(); + let mut reqs = self + .requests_rx + .map(move |req| { + Self::handle_network_request( + req, + rpc_reqs_tx.clone(), + ds_reqs_tx.clone(), + conn_mgr_reqs_tx.clone(), + ) + .boxed() + }) + .buffer_unordered(self.max_concurrent_reqs as usize); + + let upstream_handlers = self.upstream_handlers.clone(); + let mut peer_mgr_notifs = self + .peer_mgr_notifs_rx + .map(move |notif| { + Self::handle_peer_mgr_notification(notif, upstream_handlers.clone()).boxed() + }) + .buffer_unordered(self.max_concurrent_notifs as usize); + + let upstream_handlers = self.upstream_handlers.clone(); + let mut rpc_notifs = self + .rpc_notifs_rx + .map(move |notif| { + Self::handle_rpc_notification(notif, upstream_handlers.clone()).boxed() + }) + .buffer_unordered(self.max_concurrent_notifs as usize); + + let upstream_handlers = self.upstream_handlers.clone(); + let mut ds_notifs = self + .ds_notifs_rx + .map(|notif| Self::handle_ds_notification(upstream_handlers.clone(), notif).boxed()) + .buffer_unordered(self.max_concurrent_notifs as usize); + + loop { + futures::select! { + _ = reqs.select_next_some() => {}, + _ = peer_mgr_notifs.select_next_some() => {}, + _ = rpc_notifs.select_next_some() => {}, + _ = ds_notifs.select_next_some() => {} + complete => { + crit!("Network provider actor terminated"); + break; + } + } + } + } +} diff --git a/network/src/lib.rs b/network/src/lib.rs new file mode 100644 index 0000000000000..328f83831fb8d --- /dev/null +++ b/network/src/lib.rs @@ -0,0 +1,31 @@ +// Copyright (c) The Libra Core Contributors +// SPDX-License-Identifier: Apache-2.0 + +// The nightly features that are commonly needed with async / await +// Lets turn these on so that we can experiment a little bit +#![feature(async_await)] +// +// Increase recursion limit to allow for use of select! macro. +#![recursion_limit = "1024"] +// + +// Public exports +pub use common::NetworkPublicKeys; +pub use interface::NetworkProvider; + +pub mod interface; +pub mod proto; +pub mod protocols; +pub mod validator_network; + +mod common; +mod connectivity_manager; +mod counters; +mod error; +mod peer_manager; +mod sink; +mod transport; +mod utils; + +/// Type for unique identifier associated with each network protocol +pub type ProtocolId = bytes::Bytes; diff --git a/network/src/peer_manager/error.rs b/network/src/peer_manager/error.rs new file mode 100644 index 0000000000000..e9f27b8247864 --- /dev/null +++ b/network/src/peer_manager/error.rs @@ -0,0 +1,48 @@ +// Copyright (c) The Libra Core Contributors +// SPDX-License-Identifier: Apache-2.0 + +//! Errors that originate from the PeerManager module + +use failure::Fail; +use futures::channel::oneshot; +use parity_multiaddr::Multiaddr; +use types::PeerId; + +#[derive(Debug, Fail)] +pub enum PeerManagerError { + #[fail(display = "IO error: {}", _0)] + IoError(#[fail(cause)] ::std::io::Error), + + #[fail(display = "Transport error: {}", _0)] + TransportError(#[fail(cause)] ::failure::Error), + + #[fail(display = "Shutting down Peer")] + ShuttingDownPeer, + + #[fail(display = "Not connected with Peer {}", _0)] + NotConnected(PeerId), + + #[fail(display = "Already connected at {}", _0)] + AlreadyConnected(Multiaddr), + + #[fail(display = "Sending end of oneshot dropped")] + OneshotSenderDropped, +} + +impl PeerManagerError { + pub fn from_transport_error>(error: E) -> Self { + PeerManagerError::TransportError(error.into()) + } +} + +impl From<::std::io::Error> for PeerManagerError { + fn from(error: ::std::io::Error) -> Self { + PeerManagerError::IoError(error) + } +} + +impl From for PeerManagerError { + fn from(_: oneshot::Canceled) -> Self { + PeerManagerError::OneshotSenderDropped + } +} diff --git a/network/src/peer_manager/mod.rs b/network/src/peer_manager/mod.rs new file mode 100644 index 0000000000000..49954198cec3f --- /dev/null +++ b/network/src/peer_manager/mod.rs @@ -0,0 +1,1047 @@ +// Copyright (c) The Libra Core Contributors +// SPDX-License-Identifier: Apache-2.0 + +//! The PeerManager module is responsible for establishing connections between Peers and for +//! opening/receiving new substreams on those connections. +//! +//! ## Implementation +//! +//! The PeerManager is implemented as a number of actors: +//! * A main event loop actor which is responsible for handling requests and sending +//! notification about new/lost Peers to the rest of the network stack. +//! * An actor responsible for dialing and listening for new connections. +//! * An actor per Peer which owns the underlying connection and is responsible for listening for +//! and opening substreams as well as negotiating particular protocols on those substreams. +use crate::{common::NegotiatedSubstream, counters, protocols::identity::Identity, ProtocolId}; +use channel; +use futures::{ + channel::oneshot, + future::{BoxFuture, FutureExt, TryFutureExt}, + sink::SinkExt, + stream::{Fuse, FuturesUnordered, StreamExt}, +}; +use logger::prelude::*; +use netcore::{ + multiplexing::StreamMultiplexer, + negotiate::{negotiate_inbound, negotiate_outbound_interactive, negotiate_outbound_select}, + transport::{ConnectionOrigin, Transport}, +}; +use parity_multiaddr::Multiaddr; +use std::{collections::HashMap, marker::PhantomData}; +use tokio::runtime::TaskExecutor; +use types::PeerId; + +mod error; +#[cfg(test)] +mod tests; + +pub use self::error::PeerManagerError; + +/// Notifications about new/lost peers. +#[derive(Debug)] +pub enum PeerManagerNotification { + NewPeer(PeerId, Multiaddr), + LostPeer(PeerId, Multiaddr), + NewInboundSubstream(PeerId, NegotiatedSubstream), +} + +/// Request received by PeerManager from upstream actors. +#[derive(Debug)] +pub enum PeerManagerRequest { + DialPeer( + PeerId, + Multiaddr, + oneshot::Sender>, + ), + DisconnectPeer(PeerId, oneshot::Sender>), + OpenSubstream( + PeerId, + ProtocolId, + oneshot::Sender>, + ), +} + +/// Convenience wrapper around a `channel::Sender` which makes it easy to issue +/// requests and await the responses from PeerManager +pub struct PeerManagerRequestSender { + inner: channel::Sender>, +} + +impl Clone for PeerManagerRequestSender { + fn clone(&self) -> Self { + Self::new(self.inner.clone()) + } +} + +impl PeerManagerRequestSender { + /// Construct a new PeerManagerRequestSender with a raw channel::Sender + pub fn new(sender: channel::Sender>) -> Self { + Self { inner: sender } + } + + /// Request that a given Peer be dialed at the provided `Multiaddr` and synchronously wait for + /// the request to be performed. + pub async fn dial_peer( + &mut self, + peer_id: PeerId, + addr: Multiaddr, + ) -> Result<(), PeerManagerError> { + let (oneshot_tx, oneshot_rx) = oneshot::channel(); + let request = PeerManagerRequest::DialPeer(peer_id, addr, oneshot_tx); + self.inner.send(request).await.unwrap(); + oneshot_rx.await? + } + + /// Request that a given Peer be disconnected and synchronously wait for the request to be + /// performed. + pub async fn disconnect_peer(&mut self, peer_id: PeerId) -> Result<(), PeerManagerError> { + let (oneshot_tx, oneshot_rx) = oneshot::channel(); + let request = PeerManagerRequest::DisconnectPeer(peer_id, oneshot_tx); + self.inner.send(request).await.unwrap(); + oneshot_rx.await? + } + + /// Request that a new substream be opened with the given Peer and that the provided `protocol` + /// be negotiated on that substream and synchronously wait for the request to be performed. + pub async fn open_substream( + &mut self, + peer_id: PeerId, + protocol: ProtocolId, + ) -> Result { + let (oneshot_tx, oneshot_rx) = oneshot::channel(); + let request = PeerManagerRequest::OpenSubstream(peer_id, protocol, oneshot_tx); + self.inner.send(request).await.unwrap(); + // TODO(philiphayes): If this error changes, also change rpc errors to + // handle appropriate cases. + oneshot_rx + .await + .map_err(|_| PeerManagerError::NotConnected(peer_id))? + } +} + +#[derive(Debug, PartialEq, Eq)] +enum DisconnectReason { + Requested, + ConnectionLost, +} + +#[derive(Debug)] +enum InternalEvent +where + TMuxer: StreamMultiplexer, +{ + NewConnection(Identity, Multiaddr, ConnectionOrigin, TMuxer), + NewSubstream(PeerId, NegotiatedSubstream), + PeerDisconnected(PeerId, ConnectionOrigin, DisconnectReason), +} + +/// Responsible for handling and maintaining connections to other Peers +pub struct PeerManager +where + TTransport: Transport, + TMuxer: StreamMultiplexer, +{ + /// A handle to a tokio executor. + executor: TaskExecutor, + /// PeerId of "self". + own_peer_id: PeerId, + /// Address to listen on for incoming connections. + listen_addr: Multiaddr, + /// Connection Listener, listening on `listen_addr` + connection_handler: Option>, + /// Map from PeerId to corresponding Peer object. + active_peers: HashMap>, + /// Channel to receive requests from other actors. + requests_rx: channel::Receiver>, + /// Map from protocol to handler for substreams which want to "speak" that protocol. + protocol_handlers: + HashMap>>, + /// Channel to send NewPeer/LostPeer notifications to other actors. + /// Note: NewInboundSubstream notifications are not sent via these channels. + peer_event_handlers: Vec>>, + /// Channel used to send Dial requests to the ConnectionHandler actor + dial_request_tx: channel::Sender, + /// Internal event Receiver + internal_event_rx: channel::Receiver>, + /// Internal event Sender + internal_event_tx: channel::Sender>, + /// A map of outstanding disconnect requests + outstanding_disconnect_requests: HashMap>>, + /// Pin the transport type corresponding to this PeerManager instance + phantom_transport: PhantomData, +} + +impl PeerManager +where + TTransport: Transport + Send + 'static, + TTransport::Listener: 'static, + TTransport::Inbound: 'static, + TMuxer: StreamMultiplexer + 'static, +{ + /// Construct a new PeerManager actor + pub fn new( + transport: TTransport, + executor: TaskExecutor, + own_peer_id: PeerId, + listen_addr: Multiaddr, + requests_rx: channel::Receiver>, + protocol_handlers: HashMap< + ProtocolId, + channel::Sender>, + >, + peer_event_handlers: Vec>>, + ) -> Self { + let (internal_event_tx, internal_event_rx) = + channel::new(1024, &counters::PENDING_PEER_MANAGER_INTERNAL_EVENTS); + let (dial_request_tx, dial_request_rx) = + channel::new(1024, &counters::PENDING_PEER_MANAGER_DIAL_REQUESTS); + let (connection_handler, listen_addr) = ConnectionHandler::new( + transport, + listen_addr, + dial_request_rx, + internal_event_tx.clone(), + ); + + Self { + executor, + own_peer_id, + listen_addr, + connection_handler: Some(connection_handler), + active_peers: HashMap::new(), + requests_rx, + protocol_handlers, + peer_event_handlers, + dial_request_tx, + internal_event_tx, + internal_event_rx, + outstanding_disconnect_requests: HashMap::new(), + phantom_transport: PhantomData, + } + } + + /// Get the [`Multiaddr`] we're listening for incoming connections on + pub fn listen_addr(&self) -> &Multiaddr { + &self.listen_addr + } + + /// Start listening on the set address and return a future which runs PeerManager + pub async fn start(mut self) { + // Start listening for connections. + self.start_connection_listener(); + loop { + ::futures::select! { + maybe_internal_event = self.internal_event_rx.next() => { + if let Some(event) = maybe_internal_event { + self.handle_internal_event(event).await; + } + } + maybe_request = self.requests_rx.next() => { + if let Some(request) = maybe_request { + self.handle_request(request).await; + } + } + complete => { + crit!("Peer manager actor terminated"); + break; + } + } + } + } + + async fn handle_internal_event(&mut self, event: InternalEvent) { + trace!("InternalEvent::{:?}", event); + match event { + InternalEvent::NewConnection(identity, addr, origin, conn) => { + self.add_peer(identity, addr, origin, conn).await; + } + InternalEvent::NewSubstream(peer_id, substream) => { + let ch = self + .protocol_handlers + .get_mut(&substream.protocol) + .expect("Received substream for unknown protocol"); + let event = PeerManagerNotification::NewInboundSubstream(peer_id, substream); + ch.send(event).await.unwrap(); + } + InternalEvent::PeerDisconnected(peer_id, origin, _reason) => { + let peer = self + .active_peers + .remove(&peer_id) + .expect("Should have a handle to Peer"); + + // If we receive a PeerDisconnected event and the connection origin isn't the same + // as the one we have stored in PeerManager this particular event is from a Peer + // actor which is being shutdown due to simultaneous dial tie-breaking and we don't + // need to send a LostPeer notification to all subscribers. + if peer.origin != origin { + self.active_peers.insert(peer_id, peer); + return; + } + info!("Disconnected from peer: {}", peer_id.short_str()); + if let Some(oneshot_tx) = self.outstanding_disconnect_requests.remove(&peer_id) { + if oneshot_tx.send(Ok(())).is_err() { + error!("oneshot channel receiver dropped"); + } + } + // Send LostPeer notifications to subscribers + for ch in &mut self.peer_event_handlers { + ch.send(PeerManagerNotification::LostPeer( + peer_id, + peer.address().clone(), + )) + .await + .unwrap(); + } + } + } + } + + async fn handle_request(&mut self, request: PeerManagerRequest) { + trace!("PeerManagerRequest::{:?}", request); + match request { + PeerManagerRequest::DialPeer(requested_peer_id, addr, response_tx) => { + // Only dial peers which we aren't already connected with + if let Some(peer) = self.active_peers.get(&requested_peer_id) { + let error = if peer.is_shutting_down() { + PeerManagerError::ShuttingDownPeer + } else { + PeerManagerError::AlreadyConnected(peer.address().to_owned()) + }; + debug!( + "Already connected with Peer {} at address {}, not dialing address {}", + peer.peer_id().short_str(), + peer.address(), + addr + ); + + if response_tx.send(Err(error)).is_err() { + warn!( + "Receiver for DialPeer {} dropped", + requested_peer_id.short_str() + ); + } + } else { + self.dial_peer(requested_peer_id, addr, response_tx).await; + }; + } + PeerManagerRequest::DisconnectPeer(peer_id, response_tx) => { + self.disconnect_peer(peer_id, response_tx).await; + } + PeerManagerRequest::OpenSubstream(peer_id, protocol, request_tx) => { + match self.active_peers.get_mut(&peer_id) { + Some(ref mut peer) if !peer.is_shutting_down() => { + peer.open_substream(protocol, request_tx).await; + } + _ => { + // If we don't have a connection open with this peer, or if the connection + // is currently undergoing shutdown we should return an error to the + // requester + if request_tx + .send(Err(PeerManagerError::NotConnected(peer_id))) + .is_err() + { + warn!( + "Request for substream to peer {} failed, but receiver dropped too", + peer_id.short_str() + ); + } + } + } + } + } + } + + fn start_connection_listener(&mut self) { + let connection_handler = self.connection_handler.take().unwrap(); + self.executor + .spawn(connection_handler.listen().boxed().unit_error().compat()); + } + + /// In the event two peers simultaneously dial each other we need to be able to do + /// tie-breaking to determine which connection to keep and which to drop in a deterministic + /// way. One simple way is to compare our local PeerId with that of the remote's PeerId and + /// keep the connection where the peer with the greater PeerId is the dialer. + /// + /// Returns `true` if the existing connection should be dropped and `false` if the new + /// connection should be dropped. + fn simultaneous_dial_tie_breaking( + own_peer_id: PeerId, + remote_peer_id: PeerId, + existing_origin: ConnectionOrigin, + new_origin: ConnectionOrigin, + ) -> bool { + match (existing_origin, new_origin) { + // The remote dialed us twice for some reason, drop the new incoming connection + (ConnectionOrigin::Inbound, ConnectionOrigin::Inbound) => false, + (ConnectionOrigin::Inbound, ConnectionOrigin::Outbound) => remote_peer_id < own_peer_id, + (ConnectionOrigin::Outbound, ConnectionOrigin::Inbound) => own_peer_id < remote_peer_id, + // We should never dial the same peer twice, but if we do drop the new connection + (ConnectionOrigin::Outbound, ConnectionOrigin::Outbound) => false, + } + } + + async fn add_peer( + &mut self, + identity: Identity, + address: Multiaddr, + origin: ConnectionOrigin, + connection: TMuxer, + ) { + let peer_id = identity.peer_id(); + assert!(self.own_peer_id != peer_id); + + let mut send_new_peer_notification = true; + + // Check for and handle simultaneous dialing + if let Some(mut peer) = self.active_peers.remove(&peer_id) { + if Self::simultaneous_dial_tie_breaking( + self.own_peer_id, + peer.peer_id(), + peer.origin(), + origin, + ) { + // Drop the existing connection and replace it with the new connection + peer.disconnect().await; + info!( + "Closing existing connection with Peer {} to mitigate simultaneous dial", + peer_id.short_str() + ); + send_new_peer_notification = false; + } else { + // Drop the new connection and keep the one already stored in active_peers + connection.close().await.unwrap_or_else(|e| { + error!( + "Closing connection with Peer {} failed with error: {}", + peer_id.short_str(), + e + ) + }); + info!( + "Closing incoming connection with Peer {} to mitigate simultaneous dial", + peer_id.short_str() + ); + // Put the existing connection back + self.active_peers.insert(peer.peer_id(), peer); + return; + } + } + + let (peer_req_tx, peer_req_rx) = channel::new( + 1024, + &counters::OP_COUNTERS + .peer_gauge(&counters::PENDING_PEER_REQUESTS, &peer_id.short_str()), + ); + let peer = Peer::new( + identity, + connection, + origin, + self.protocol_handlers.keys().cloned().collect(), + self.internal_event_tx.clone(), + peer_req_rx, + ); + let peer_handle = PeerHandle::new(peer_id, address.clone(), origin, peer_req_tx); + info!( + "{:?} connection with peer {} established", + origin, + peer_id.short_str() + ); + self.active_peers.insert(peer_id, peer_handle); + self.executor + .spawn(peer.start().boxed().unit_error().compat()); + + if send_new_peer_notification { + for ch in &mut self.peer_event_handlers { + ch.send(PeerManagerNotification::NewPeer(peer_id, address.clone())) + .await + .unwrap(); + } + } + } + + async fn dial_peer( + &mut self, + peer_id: PeerId, + address: Multiaddr, + response_tx: oneshot::Sender>, + ) { + let request = ConnectionHandlerRequest::DialPeer(peer_id, address, response_tx); + self.dial_request_tx.send(request).await.unwrap(); + } + + // Send a Disconnect request to the Peer actor corresponding with `peer_id`. + async fn disconnect_peer( + &mut self, + peer_id: PeerId, + response_tx: oneshot::Sender>, + ) { + if let Some(peer) = self.active_peers.get_mut(&peer_id) { + peer.disconnect().await; + self.outstanding_disconnect_requests + .insert(peer_id, response_tx); + } else if response_tx + .send(Err(PeerManagerError::NotConnected(peer_id))) + .is_err() + { + info!( + "Failed to disconnect from peer {}, but result receiver dropped", + peer_id.short_str() + ); + } + } +} + +enum ConnectionHandlerRequest { + DialPeer( + PeerId, + Multiaddr, + oneshot::Sender>, + ), +} + +/// Responsible for listening for new incoming connections +struct ConnectionHandler +where + TTransport: Transport, + TMuxer: StreamMultiplexer, +{ + /// [`Transport`] that is used to establish connections + transport: TTransport, + listener: Fuse, + dial_request_rx: channel::Receiver, + internal_event_tx: channel::Sender>, +} + +impl ConnectionHandler +where + TTransport: Transport, + TTransport::Listener: 'static, + TTransport::Inbound: 'static, + TTransport::Outbound: 'static, + TMuxer: StreamMultiplexer + 'static, +{ + fn new( + transport: TTransport, + listen_addr: Multiaddr, + dial_request_rx: channel::Receiver, + internal_event_tx: channel::Sender>, + ) -> (Self, Multiaddr) { + let (listener, listen_addr) = transport.listen_on(listen_addr).unwrap(); + debug!("listening on {:?}", listen_addr); + + ( + Self { + transport, + listener: listener.fuse(), + dial_request_rx, + internal_event_tx, + }, + listen_addr, + ) + } + + async fn listen(mut self) { + let mut pending_inbound_connections = FuturesUnordered::new(); + let mut pending_outbound_connections = FuturesUnordered::new(); + + debug!("Incoming connections listener Task started"); + + loop { + futures::select! { + dial_request = self.dial_request_rx.select_next_some() => { + if let Some(fut) = self.dial_peer(dial_request) { + pending_outbound_connections.push(fut); + } + }, + incoming_connection = self.listener.select_next_some() => { + match incoming_connection { + Ok((upgrade, addr)) => { + debug!("Incoming connection from {}", addr); + pending_inbound_connections.push(upgrade.map(|out| (out, addr))); + } + Err(e) => { + warn!("Incoming connection error {}", e); + } + } + }, + (upgrade, addr, peer_id, response_tx) = pending_outbound_connections.select_next_some() => { + self.handle_completed_outbound_upgrade(upgrade, addr, peer_id, response_tx).await; + }, + (upgrade, addr) = pending_inbound_connections.select_next_some() => { + self.handle_completed_inbound_upgrade(upgrade, addr).await; + }, + complete => break, + } + } + + error!("Incoming connections listener Task ended"); + } + + fn dial_peer( + &self, + dial_peer_request: ConnectionHandlerRequest, + ) -> Option< + BoxFuture< + 'static, + ( + Result<(Identity, TMuxer), TTransport::Error>, + Multiaddr, + PeerId, + oneshot::Sender>, + ), + >, + > { + match dial_peer_request { + ConnectionHandlerRequest::DialPeer(peer_id, address, response_tx) => { + match self.transport.dial(address.clone()) { + Ok(upgrade) => Some( + upgrade + .map(move |out| (out, address, peer_id, response_tx)) + .boxed(), + ), + Err(error) => { + if response_tx + .send(Err(PeerManagerError::from_transport_error(error))) + .is_err() + { + warn!( + "Receiver for DialPeer {} request dropped", + peer_id.short_str() + ); + } + None + } + } + } + } + } + + async fn handle_completed_outbound_upgrade( + &mut self, + upgrade: Result<(Identity, TMuxer), TTransport::Error>, + addr: Multiaddr, + peer_id: PeerId, + response_tx: oneshot::Sender>, + ) { + match upgrade { + Ok((identity, connection)) => { + let response = if identity.peer_id() == peer_id { + debug!( + "Peer '{}' successfully dialed at '{}'", + peer_id.short_str(), + addr + ); + let event = InternalEvent::NewConnection( + identity, + addr, + ConnectionOrigin::Outbound, + connection, + ); + // Send the new connection to PeerManager + self.internal_event_tx.send(event).await.unwrap(); + Ok(()) + } else { + let e = ::failure::format_err!( + "Dialed PeerId ({}) differs from expected PeerId ({})", + identity.peer_id().short_str(), + peer_id.short_str() + ); + + warn!("{}", e); + + Err(PeerManagerError::from_transport_error(e)) + }; + + if response_tx.send(response).is_err() { + warn!( + "Receiver for DialPeer {} request dropped", + peer_id.short_str() + ); + } + } + Err(error) => { + error!("Error dialing Peer {} at {}", peer_id.short_str(), addr); + + if response_tx + .send(Err(PeerManagerError::from_transport_error(error))) + .is_err() + { + warn!( + "Receiver for DialPeer {} request dropped", + peer_id.short_str() + ); + } + } + } + } + + async fn handle_completed_inbound_upgrade( + &mut self, + upgrade: Result<(Identity, TMuxer), TTransport::Error>, + addr: Multiaddr, + ) { + match upgrade { + Ok((identity, connection)) => { + debug!("Connection from {} successfully upgraded", addr); + let event = InternalEvent::NewConnection( + identity, + addr, + ConnectionOrigin::Inbound, + connection, + ); + // Send the new connection to PeerManager + self.internal_event_tx.send(event).await.unwrap(); + } + Err(e) => { + warn!("Connection from {} failed to upgrade {}", addr, e); + } + } + } +} + +struct PeerHandle { + peer_id: PeerId, + sender: channel::Sender>, + origin: ConnectionOrigin, + address: Multiaddr, + is_shutting_down: bool, +} + +impl PeerHandle { + pub fn new( + peer_id: PeerId, + address: Multiaddr, + origin: ConnectionOrigin, + sender: channel::Sender>, + ) -> Self { + Self { + peer_id, + address, + origin, + sender, + is_shutting_down: false, + } + } + + pub fn is_shutting_down(&self) -> bool { + self.is_shutting_down + } + + pub fn address(&self) -> &Multiaddr { + &self.address + } + + pub fn peer_id(&self) -> PeerId { + self.peer_id + } + + pub fn origin(&self) -> ConnectionOrigin { + self.origin + } + + pub async fn open_substream( + &mut self, + protocol: ProtocolId, + response_tx: oneshot::Sender>, + ) { + // If we fail to send the request to the Peer, then it must have already been shutdown. + if self + .sender + .send(PeerRequest::OpenSubstream(protocol, response_tx)) + .await + .is_err() + { + error!( + "Sending OpenSubstream request to Peer {} \ + failed because it has already been shutdown.", + self.peer_id.short_str() + ); + } + } + + pub async fn disconnect(&mut self) { + // If we fail to send the request to the Peer, then it must have already been shutdown. + if self + .sender + .send(PeerRequest::CloseConnection) + .await + .is_err() + { + error!( + "Sending CloseConnection request to Peer {} \ + failed because it has already been shutdown.", + self.peer_id.short_str() + ); + } + self.is_shutting_down = true; + } +} + +#[derive(Debug)] +enum PeerRequest { + OpenSubstream( + ProtocolId, + oneshot::Sender>, + ), + CloseConnection, +} + +struct Peer +where + TMuxer: StreamMultiplexer, +{ + /// Identity of the remote peer + identity: Identity, + connection: TMuxer, + own_supported_protocols: Vec, + internal_event_tx: channel::Sender>, + requests_rx: channel::Receiver>, + origin: ConnectionOrigin, + shutdown: bool, +} + +impl Peer +where + TMuxer: StreamMultiplexer + 'static, + TMuxer::Substream: 'static, + TMuxer::Outbound: 'static, +{ + fn new( + identity: Identity, + connection: TMuxer, + origin: ConnectionOrigin, + own_supported_protocols: Vec, + internal_event_tx: channel::Sender>, + requests_rx: channel::Receiver>, + ) -> Self { + Self { + identity, + connection, + origin, + own_supported_protocols, + internal_event_tx, + requests_rx, + shutdown: false, + } + } + + async fn start(mut self) { + let mut substream_rx = self.connection.listen_for_inbound().fuse(); + let mut pending_outbound_substreams = FuturesUnordered::new(); + let mut pending_inbound_substreams = FuturesUnordered::new(); + + loop { + futures::select! { + maybe_req = self.requests_rx.next() => { + if let Some(request) = maybe_req { + self.handle_request(&mut pending_outbound_substreams, request).await; + } else { + // This branch will only be taken if the PeerRequest sender for this Peer + // gets dropped. This should never happen because PeerManager should also + // issue a shutdown request before dropping the sender + unreachable!( + "Peer {} PeerRequest sender gets dropped", + self.identity.peer_id().short_str() + ); + } + }, + maybe_substream = substream_rx.next() => { + match maybe_substream { + Some(Ok(substream)) => { + self.handle_inbound_substream(&mut pending_inbound_substreams, substream); + } + Some(Err(e)) => { + warn!("Inbound substream error {:?} with peer {}", + e, self.identity.peer_id().short_str()); + self.close_connection(DisconnectReason::ConnectionLost).await; + } + None => { + warn!("Inbound substreams exhausted with peer {}", + self.identity.peer_id().short_str()); + self.close_connection(DisconnectReason::ConnectionLost).await; + } + } + }, + inbound_substream = pending_inbound_substreams.select_next_some() => { + match inbound_substream { + Ok(negotiated_substream) => { + let event = InternalEvent::NewSubstream( + self.identity.peer_id(), + negotiated_substream, + ); + self.internal_event_tx.send(event).await.unwrap(); + } + Err(e) => { + error!( + "Inbound substream negotiation for peer {} failed: {}", + self.identity.peer_id().short_str(), e + ); + } + } + }, + _ = pending_outbound_substreams.select_next_some() => { + // Do nothing since these futures have an output of "()" + }, + complete => unreachable!(), + } + + if self.shutdown { + break; + } + } + debug!( + "Peer actor '{}' shutdown", + self.identity.peer_id().short_str() + ); + } + + async fn handle_request<'a>( + &'a mut self, + pending: &'a mut FuturesUnordered>, + request: PeerRequest, + ) { + trace!( + "Peer {} PeerRequest::{:?}", + self.identity.peer_id().short_str(), + request + ); + match request { + PeerRequest::OpenSubstream(protocol, channel) => { + pending.push(self.handle_open_outbound_substream_request(protocol, channel)); + } + PeerRequest::CloseConnection => { + self.close_connection(DisconnectReason::Requested).await; + } + } + } + + fn handle_open_outbound_substream_request( + &self, + protocol: ProtocolId, + channel: oneshot::Sender>, + ) -> BoxFuture<'static, ()> { + let outbound = self.connection.open_outbound(); + let optimistic_negotiation = self.identity.is_protocol_supported(&protocol); + let negotiate = Self::negotiate_outbound_substream( + self.identity.peer_id(), + outbound, + protocol, + optimistic_negotiation, + channel, + ); + + negotiate.boxed() + } + + async fn negotiate_outbound_substream( + peer_id: PeerId, + outbound_fut: TMuxer::Outbound, + protocol: ProtocolId, + optimistic_negotiation: bool, + channel: oneshot::Sender>, + ) { + let response = match outbound_fut.await { + Ok(substream) => { + // TODO(bmwill) Evaluate if we should still try to open and negotiate an outbound + // substream even though we know for a fact that the Identity struct of this Peer + // doesn't include the protocol we're interested in. + if optimistic_negotiation { + negotiate_outbound_select(substream, &protocol).await + } else { + warn!( + "Negotiating outbound substream interactively: Protocol({:?}) PeerId({})", + protocol, + peer_id.short_str() + ); + negotiate_outbound_interactive(substream, [&protocol]) + .await + .map(|(substream, _protocol)| substream) + } + } + Err(e) => Err(e), + } + .map_err(Into::into); + + match response { + Ok(_) => debug!( + "Successfully negotiated outbound substream '{:?}' with Peer {}", + protocol, + peer_id.short_str() + ), + Err(ref e) => debug!( + "Unable to negotiated outbound substream '{:?}' with Peer {}: {}", + protocol, + peer_id.short_str(), + e + ), + } + + if channel.send(response).is_err() { + warn!( + "oneshot channel receiver dropped for new substream with peer {} for protocol {:?}", + peer_id.short_str(), + protocol + ); + } + } + + fn handle_inbound_substream<'a>( + &'a mut self, + pending: &'a mut FuturesUnordered< + BoxFuture<'static, Result, PeerManagerError>>, + >, + substream: TMuxer::Substream, + ) { + trace!( + "New inbound substream from peer '{}'", + self.identity.peer_id().short_str() + ); + + let negotiate = + Self::negotiate_inbound_substream(substream, self.own_supported_protocols.clone()); + pending.push(negotiate.boxed()); + } + + async fn negotiate_inbound_substream( + substream: TMuxer::Substream, + own_supported_protocols: Vec, + ) -> Result, PeerManagerError> { + let (substream, protocol) = negotiate_inbound(substream, own_supported_protocols).await?; + Ok(NegotiatedSubstream { + protocol, + substream, + }) + } + + async fn close_connection(&mut self, reason: DisconnectReason) { + match self.connection.close().await { + Err(e) => { + error!( + "Failed to gracefully close connection with peer: {}; error: {}", + self.identity.peer_id().short_str(), + e + ); + } + Ok(_) => { + info!( + "Closed connection with peer: {}, reason: {:?}", + self.identity.peer_id().short_str(), + reason + ); + } + } + // If the graceful shutdown above fails, the connection will be forcefull terminated once + // the connection struct is dropped. Setting the `shutdown` flag to true ensures that the + // peer actor will terminate and close the connection in the process. + self.shutdown = true; + // We send a PeerDisconnected event to peer manager as a result (or in case of a failure + // above, in anticipation of) closing the connection. + + self.internal_event_tx + .send(InternalEvent::PeerDisconnected( + self.identity.peer_id(), + self.origin, + reason, + )) + .await + .unwrap(); + } +} diff --git a/network/src/peer_manager/tests.rs b/network/src/peer_manager/tests.rs new file mode 100644 index 0000000000000..2406284224e12 --- /dev/null +++ b/network/src/peer_manager/tests.rs @@ -0,0 +1,729 @@ +// Copyright (c) The Libra Core Contributors +// SPDX-License-Identifier: Apache-2.0 + +use crate::{ + peer_manager::{ + DisconnectReason, InternalEvent, Peer, PeerHandle, PeerManager, PeerManagerNotification, + PeerManagerRequest, + }, + protocols::{ + identity::{exchange_identity, Identity}, + peer_id_exchange::PeerIdExchange, + }, + ProtocolId, +}; +use channel; +use futures::{ + channel::oneshot, + compat::Compat01As03, + executor::block_on, + future::{join, FutureExt, TryFutureExt}, + io::{AsyncRead, AsyncReadExt, AsyncWrite, AsyncWriteExt}, + stream::StreamExt, +}; +use memsocket::MemorySocket; +use netcore::{ + multiplexing::{ + yamux::{Mode, StreamHandle, Yamux}, + StreamMultiplexer, + }, + negotiate::{negotiate_inbound, negotiate_outbound_interactive}, + transport::{boxed::BoxedTransport, memory::MemoryTransport, ConnectionOrigin, TransportExt}, +}; +use parity_multiaddr::Multiaddr; +use std::{collections::HashMap, io, time::Duration}; +use tokio::{runtime::TaskExecutor, timer::Timeout}; +use types::PeerId; + +const HELLO_PROTOCOL: &[u8] = b"/hello-world/1.0.0"; + +// Builds a concrete typed transport (instead of using impl Trait) for testing PeerManager. +// Specificly this transport is compatible with the `build_test_connection` test helper making +// it easy to build connections without going through the whole transport pipeline. +fn build_test_transport( + own_peer_id: PeerId, +) -> BoxedTransport<(Identity, Yamux), std::io::Error> { + let memory_transport = MemoryTransport::default(); + let peer_id_exchange_config = PeerIdExchange::new(own_peer_id); + let own_identity = Identity::new(own_peer_id, Vec::new()); + + memory_transport + .and_then(move |socket, origin| peer_id_exchange_config.exchange_peer_id(socket, origin)) + .and_then(|(peer_id, socket), origin| { + async move { + let muxer = Yamux::upgrade_connection(socket, origin).await?; + Ok((peer_id, muxer)) + } + }) + .and_then(move |(peer_id, muxer), origin| { + async move { + let (identity, muxer) = exchange_identity(&own_identity, muxer, origin).await?; + assert_eq!(identity.peer_id(), peer_id); + + Ok((identity, muxer)) + } + }) + .boxed() +} + +fn build_test_connection() -> (Yamux, Yamux) { + let (dialer, listener) = MemorySocket::new_pair(); + + ( + Yamux::new(dialer, Mode::Client), + Yamux::new(listener, Mode::Server), + ) +} + +fn build_test_identity(peer_id: PeerId) -> Identity { + Identity::new(peer_id, Vec::new()) +} + +fn build_test_peer( + origin: ConnectionOrigin, +) -> ( + Peer>, + PeerHandle>, + Yamux, + channel::Receiver>>, +) { + let (a, b) = build_test_connection(); + let identity = build_test_identity(PeerId::random()); + let peer_id = identity.peer_id(); + let (internal_event_tx, internal_event_rx) = channel::new_test(1); + let (peer_req_tx, peer_req_rx) = channel::new_test(0); + + let peer = Peer::new( + identity, + a, + origin, + vec![ProtocolId::from_static(HELLO_PROTOCOL)], + internal_event_tx, + peer_req_rx, + ); + let peer_handle = PeerHandle::new(peer_id, Multiaddr::empty(), origin, peer_req_tx); + + (peer, peer_handle, b, internal_event_rx) +} + +fn build_test_connected_peers() -> ( + ( + Peer>, + PeerHandle>, + channel::Receiver>>, + ), + ( + Peer>, + PeerHandle>, + channel::Receiver>>, + ), +) { + let (peer_a, peer_handle_a, connection_a, internal_event_rx_a) = + build_test_peer(ConnectionOrigin::Inbound); + let (mut peer_b, peer_handle_b, _connection_b, internal_event_rx_b) = + build_test_peer(ConnectionOrigin::Outbound); + // Make sure both peers are connected + peer_b.connection = connection_a; + + ( + (peer_a, peer_handle_a, internal_event_rx_a), + (peer_b, peer_handle_b, internal_event_rx_b), + ) +} + +#[test] +fn peer_open_substream() { + let (peer, _peer_handle, connection, _internal_event_rx) = + build_test_peer(ConnectionOrigin::Inbound); + + let server = async move { + let substream_listener = connection.listen_for_inbound(); + let (substream, _substream_listener) = substream_listener.into_future().await; + let (mut substream, _protocol) = + negotiate_inbound(substream.unwrap().unwrap(), [HELLO_PROTOCOL]) + .await + .unwrap(); + substream.write_all(b"hello world").await.unwrap(); + substream.flush().await.unwrap(); + substream.close().await.unwrap(); + // Wait to read EOF from the other side in order to hold open the connection for + // the remote side to read + let mut buf = Vec::new(); + substream.read_to_end(&mut buf).await.unwrap(); + assert_eq!(buf.len(), 0); + }; + + let client = async move { + let (substream_tx, substream_rx) = oneshot::channel(); + peer.handle_open_outbound_substream_request( + ProtocolId::from_static(HELLO_PROTOCOL), + substream_tx, + ) + .await; + let mut substream = substream_rx.await.unwrap().unwrap(); + let mut buf = Vec::new(); + substream.read_to_end(&mut buf).await.unwrap(); + substream.close().await.unwrap(); + assert_eq!(buf, b"hello world"); + }; + + block_on(join(server, client)); +} + +// Test that if two peers request to open a substream with each other simultaneously that +// we won't deadlock. +#[test] +fn peer_open_substream_simultaneous() { + let mut runtime = ::tokio::runtime::Runtime::new().unwrap(); + let ( + (peer_a, mut peer_handle_a, mut internal_event_rx_a), + (peer_b, mut peer_handle_b, mut internal_event_rx_b), + ) = build_test_connected_peers(); + + let test = async move { + let (substream_tx_a, substream_rx_a) = oneshot::channel(); + let (substream_tx_b, substream_rx_b) = oneshot::channel(); + + // Send open substream requests to both peer_a and peer_b + peer_handle_a + .open_substream(ProtocolId::from_static(HELLO_PROTOCOL), substream_tx_a) + .await; + peer_handle_b + .open_substream(ProtocolId::from_static(HELLO_PROTOCOL), substream_tx_b) + .await; + + // These both should complete, but in the event they deadlock wrap them in a timeout + let timeout_a = Compat01As03::new(Timeout::new( + substream_rx_a.boxed().compat(), + Duration::from_secs(10), + )); + let timeout_b = Compat01As03::new(Timeout::new( + substream_rx_b.boxed().compat(), + Duration::from_secs(10), + )); + let _ = timeout_a.await.unwrap().unwrap(); + let _ = timeout_b.await.unwrap().unwrap(); + + // Check that we recieved the new inbound substream for both peers + assert_new_substream_event(peer_handle_a.peer_id, &mut internal_event_rx_a).await; + assert_new_substream_event(peer_handle_b.peer_id, &mut internal_event_rx_b).await; + + // Shut one peers and the other should shutdown due to ConnectionLost + peer_handle_a.disconnect().await; + + // Check that we recieved both shutdown events + assert_peer_disconnected_event( + peer_handle_a.peer_id, + DisconnectReason::Requested, + &mut internal_event_rx_a, + ) + .await; + assert_peer_disconnected_event( + peer_handle_b.peer_id, + DisconnectReason::ConnectionLost, + &mut internal_event_rx_b, + ) + .await; + }; + + runtime.spawn(peer_a.start().boxed().unit_error().compat()); + runtime.spawn(peer_b.start().boxed().unit_error().compat()); + + runtime + .block_on_all(test.boxed().unit_error().compat()) + .unwrap(); +} + +#[test] +fn peer_disconnect_request() { + let (peer, mut peer_handle, _connection, mut internal_event_rx) = + build_test_peer(ConnectionOrigin::Inbound); + + let test = async move { + peer_handle.disconnect().await; + assert_peer_disconnected_event( + peer_handle.peer_id, + DisconnectReason::Requested, + &mut internal_event_rx, + ) + .await; + }; + + block_on(join(test, peer.start())); +} + +#[test] +fn peer_disconnect_connection_lost() { + let (peer, peer_handle, connection, mut internal_event_rx) = + build_test_peer(ConnectionOrigin::Inbound); + + let test = async move { + connection.close().await.unwrap(); + assert_peer_disconnected_event( + peer_handle.peer_id, + DisconnectReason::ConnectionLost, + &mut internal_event_rx, + ) + .await; + }; + + block_on(join(test, peer.start())); +} + +#[test] +#[should_panic] +fn peer_panics_when_request_tx_has_dropped() { + let (peer, peer_handle, _conn, _event_rx) = build_test_peer(ConnectionOrigin::Inbound); + + drop(peer_handle); + block_on(peer.start()); +} + +// +// Simultaneous Dial Tests +// + +fn ordered_peer_ids(num: usize) -> Vec { + let mut ids = Vec::new(); + for _ in 0..num { + ids.push(PeerId::random()); + } + ids.sort(); + ids +} + +fn build_test_peer_manager( + executor: TaskExecutor, + peer_id: PeerId, +) -> ( + PeerManager< + BoxedTransport<(Identity, Yamux), std::io::Error>, + Yamux, + >, + channel::Sender>, + channel::Receiver>, +) { + let protocol = ProtocolId::from_static(HELLO_PROTOCOL); + let (peer_manager_request_tx, peer_manager_request_rx) = channel::new_test(0); + let (hello_tx, hello_rx) = channel::new_test(0); + let mut protocol_handlers = HashMap::new(); + protocol_handlers.insert(protocol.clone(), hello_tx); + + let peer_manager = PeerManager::new( + build_test_transport(peer_id), + executor.clone(), + peer_id, + "/memory/0".parse().unwrap(), + peer_manager_request_rx, + protocol_handlers, + Vec::new(), + ); + + (peer_manager, peer_manager_request_tx, hello_rx) +} + +async fn open_hello_substream(connection: &T) -> io::Result<()> { + let outbound = connection.open_outbound().await?; + let (_, _) = negotiate_outbound_interactive(outbound, [HELLO_PROTOCOL]).await?; + Ok(()) +} + +async fn assert_new_substream_event( + peer_id: PeerId, + internal_event_rx: &mut channel::Receiver>, +) { + match internal_event_rx.next().await { + Some(InternalEvent::NewSubstream(actual_peer_id, _)) => { + assert_eq!(actual_peer_id, peer_id); + } + event => { + panic!("Expected a NewSubstream, received: {:?}", event); + } + } +} + +async fn assert_peer_disconnected_event( + peer_id: PeerId, + reason: DisconnectReason, + internal_event_rx: &mut channel::Receiver>, +) { + match internal_event_rx.next().await { + Some(InternalEvent::PeerDisconnected(actual_peer_id, _origin, actual_reason)) => { + assert_eq!(actual_peer_id, peer_id); + assert_eq!(actual_reason, reason); + } + event => { + panic!( + "Expected a Requested PeerDisconnected, received: {:?}", + event + ); + } + } +} + +// This helper function is used to help identify that the expected connection was dropped due +// to simultaneous dial tie-breaking. It also checks the correct events were sent from the +// Peer actors to PeerManager's internal_event_rx. +async fn check_correct_connection_is_live( + live_connection: TMuxer, + dropped_connection: TMuxer, + expected_peer_id: PeerId, + requested_shutdown: bool, + mut internal_event_rx: &mut channel::Receiver>, +) { + // If PeerManager needed to kill the existing connection we'll see a Requested shutdown + // event + if requested_shutdown { + assert_peer_disconnected_event( + expected_peer_id, + DisconnectReason::Requested, + &mut internal_event_rx, + ) + .await; + } + + assert!(open_hello_substream(&dropped_connection).await.is_err()); + assert!(open_hello_substream(&live_connection).await.is_ok()); + + // Make sure we get the incoming substream and shutdown events + assert_new_substream_event(expected_peer_id, &mut internal_event_rx).await; + + live_connection.close().await.unwrap(); + + assert_peer_disconnected_event( + expected_peer_id, + DisconnectReason::ConnectionLost, + &mut internal_event_rx, + ) + .await; +} + +#[test] +fn peer_manager_simultaneous_dial_two_inbound() { + let mut runtime = ::tokio::runtime::Runtime::new().unwrap(); + + // Create a list of ordered PeerIds so we can ensure how PeerIds will be compared. + let ids = ordered_peer_ids(2); + let (mut peer_manager, _request_tx, _hello_rx) = + build_test_peer_manager(runtime.executor(), ids[1]); + + let test = async move { + // + // Two inbound connections + // + let (outbound1, inbound1) = build_test_connection(); + peer_manager + .add_peer( + build_test_identity(ids[0]), + Multiaddr::empty(), + ConnectionOrigin::Inbound, + inbound1, + ) + .await; + let (outbound2, inbound2) = build_test_connection(); + peer_manager + .add_peer( + build_test_identity(ids[0]), + Multiaddr::empty(), + ConnectionOrigin::Inbound, + inbound2, + ) + .await; + + // outbound2 should have been dropped since it was the second inbound connection + check_correct_connection_is_live( + outbound1, + outbound2, + ids[0], + false, + &mut peer_manager.internal_event_rx, + ) + .await; + }; + + runtime + .block_on(test.boxed().unit_error().compat()) + .unwrap(); +} + +#[test] +fn peer_manager_simultaneous_dial_inbound_outbout_remote_id_larger() { + let mut runtime = ::tokio::runtime::Runtime::new().unwrap(); + + // Create a list of ordered PeerIds so we can ensure how PeerIds will be compared. + let ids = ordered_peer_ids(2); + let (mut peer_manager, _request_tx, _hello_rx) = + build_test_peer_manager(runtime.executor(), ids[0]); + + let test = async move { + // + // Inbound first, outbound second with own_peer_id < remote_peer_id + // + let (outbound1, inbound1) = build_test_connection(); + peer_manager + .add_peer( + build_test_identity(ids[1]), + Multiaddr::empty(), + ConnectionOrigin::Inbound, + inbound1, + ) + .await; + let (outbound2, inbound2) = build_test_connection(); + peer_manager + .add_peer( + build_test_identity(ids[1]), + Multiaddr::empty(), + ConnectionOrigin::Outbound, + outbound2, + ) + .await; + + // inbound2 should be dropped because for outbound1 the remote peer has a greater + // PeerId and is the "dialer" + check_correct_connection_is_live( + outbound1, + inbound2, + ids[1], + false, + &mut peer_manager.internal_event_rx, + ) + .await; + }; + + runtime + .block_on(test.boxed().unit_error().compat()) + .unwrap(); +} + +#[test] +fn peer_manager_simultaneous_dial_inbound_outbout_own_id_larger() { + let mut runtime = ::tokio::runtime::Runtime::new().unwrap(); + + // Create a list of ordered PeerIds so we can ensure how PeerIds will be compared. + let ids = ordered_peer_ids(2); + let (mut peer_manager, _request_tx, _hello_rx) = + build_test_peer_manager(runtime.executor(), ids[1]); + + let test = async move { + // + // Inbound first, outbound second with remote_peer_id < own_peer_id + // + let (outbound1, inbound1) = build_test_connection(); + peer_manager + .add_peer( + build_test_identity(ids[0]), + Multiaddr::empty(), + ConnectionOrigin::Inbound, + inbound1, + ) + .await; + let (outbound2, inbound2) = build_test_connection(); + peer_manager + .add_peer( + build_test_identity(ids[0]), + Multiaddr::empty(), + ConnectionOrigin::Outbound, + outbound2, + ) + .await; + + // outbound1 should be dropped because for inbound2 PeerManager's PeerId is greater and + // is the "dialer" + check_correct_connection_is_live( + inbound2, + outbound1, + ids[0], + true, + &mut peer_manager.internal_event_rx, + ) + .await; + }; + + runtime + .block_on(test.boxed().unit_error().compat()) + .unwrap(); +} + +#[test] +fn peer_manager_simultaneous_dial_outbound_inbound_remote_id_larger() { + let mut runtime = ::tokio::runtime::Runtime::new().unwrap(); + + // Create a list of ordered PeerIds so we can ensure how PeerIds will be compared. + let ids = ordered_peer_ids(2); + let (mut peer_manager, _request_tx, _hello_rx) = + build_test_peer_manager(runtime.executor(), ids[0]); + + let test = async move { + // + // Outbound first, inbound second with own_peer_id < remote_peer_id + // + let (outbound1, inbound1) = build_test_connection(); + peer_manager + .add_peer( + build_test_identity(ids[1]), + Multiaddr::empty(), + ConnectionOrigin::Outbound, + outbound1, + ) + .await; + let (outbound2, inbound2) = build_test_connection(); + peer_manager + .add_peer( + build_test_identity(ids[1]), + Multiaddr::empty(), + ConnectionOrigin::Inbound, + inbound2, + ) + .await; + + // inbound1 should be dropped because for outbound2 the remote peer has a greater + // PeerID and is the "dialer" + check_correct_connection_is_live( + outbound2, + inbound1, + ids[1], + true, + &mut peer_manager.internal_event_rx, + ) + .await; + }; + + runtime + .block_on(test.boxed().unit_error().compat()) + .unwrap(); +} + +#[test] +fn peer_manager_simultaneous_dial_outbound_inbound_own_id_larger() { + let mut runtime = ::tokio::runtime::Runtime::new().unwrap(); + + // Create a list of ordered PeerIds so we can ensure how PeerIds will be compared. + let ids = ordered_peer_ids(2); + let (mut peer_manager, _request_tx, _hello_rx) = + build_test_peer_manager(runtime.executor(), ids[1]); + + let test = async move { + // + // Outbound first, inbound second with remote_peer_id < own_peer_id + // + let (outbound1, inbound1) = build_test_connection(); + peer_manager + .add_peer( + build_test_identity(ids[0]), + Multiaddr::empty(), + ConnectionOrigin::Outbound, + outbound1, + ) + .await; + let (outbound2, inbound2) = build_test_connection(); + peer_manager + .add_peer( + build_test_identity(ids[0]), + Multiaddr::empty(), + ConnectionOrigin::Inbound, + inbound2, + ) + .await; + + // outbound2 should be dropped because for inbound1 PeerManager's PeerId is greater and + // is the "dialer" + check_correct_connection_is_live( + inbound1, + outbound2, + ids[0], + false, + &mut peer_manager.internal_event_rx, + ) + .await; + }; + + runtime + .block_on(test.boxed().unit_error().compat()) + .unwrap(); +} + +#[test] +fn peer_manager_simultaneous_dial_two_outbound() { + let mut runtime = ::tokio::runtime::Runtime::new().unwrap(); + + // Create a list of ordered PeerIds so we can ensure how PeerIds will be compared. + let ids = ordered_peer_ids(2); + let (mut peer_manager, _request_tx, _hello_rx) = + build_test_peer_manager(runtime.executor(), ids[1]); + + let test = async move { + // + // Two Outbound connections + // + let (outbound1, inbound1) = build_test_connection(); + peer_manager + .add_peer( + build_test_identity(ids[0]), + Multiaddr::empty(), + ConnectionOrigin::Outbound, + outbound1, + ) + .await; + let (outbound2, inbound2) = build_test_connection(); + peer_manager + .add_peer( + build_test_identity(ids[0]), + Multiaddr::empty(), + ConnectionOrigin::Outbound, + outbound2, + ) + .await; + + // inbound2 should have been dropped since it was the second outbound connection + check_correct_connection_is_live( + inbound1, + inbound2, + ids[0], + false, + &mut peer_manager.internal_event_rx, + ) + .await; + }; + + runtime + .block_on(test.boxed().unit_error().compat()) + .unwrap(); +} + +#[test] +fn peer_manager_simultaneous_dial_disconnect_event() { + let mut runtime = ::tokio::runtime::Runtime::new().unwrap(); + + // Create a list of ordered PeerIds so we can ensure how PeerIds will be compared. + let ids = ordered_peer_ids(2); + let (mut peer_manager, _request_tx, _hello_rx) = + build_test_peer_manager(runtime.executor(), ids[1]); + + let test = async move { + let (outbound, _inbound) = build_test_connection(); + peer_manager + .add_peer( + build_test_identity(ids[0]), + Multiaddr::empty(), + ConnectionOrigin::Outbound, + outbound, + ) + .await; + + // Create a PeerDisconnect event with the opposite origin of the one stored in + // PeerManager to ensure that handling the event wont cause the PeerHandle to be + // removed from PeerManager + let event = InternalEvent::PeerDisconnected( + ids[0], + ConnectionOrigin::Inbound, + DisconnectReason::ConnectionLost, + ); + peer_manager.handle_internal_event(event).await; + + assert!(peer_manager.active_peers.contains_key(&ids[0])); + }; + + runtime + .block_on(test.boxed().unit_error().compat()) + .unwrap(); +} diff --git a/network/src/proto/consensus.proto b/network/src/proto/consensus.proto new file mode 100644 index 0000000000000..2e6af492d826d --- /dev/null +++ b/network/src/proto/consensus.proto @@ -0,0 +1,146 @@ +// Copyright (c) The Libra Core Contributors +// SPDX-License-Identifier: Apache-2.0 + +syntax = "proto3"; + +package network; + +import "ledger_info.proto"; +import "transaction.proto"; + +message ConsensusMsg { + oneof message { + Proposal proposal = 1; + Vote vote = 2; + RequestBlock request_block = 3; + RespondBlock respond_block = 4; + NewRound new_round = 5; + RequestChunk request_chunk = 6; + RespondChunk respond_chunk = 7; + } +} + +message Proposal { + // The proposed block + Block proposed_block = 1; + // Author of the proposal + bytes proposer = 2; + // Optional timeout quorum certificate if this proposal is generated by + // timeout + PacemakerTimeoutCertificate timeout_quorum_cert = 3; + // The highest ledger info + QuorumCert highest_ledger_info = 4; +} + +message PacemakerTimeout { + // Round that has timed out (e.g. we propose to switch to round + 1) + uint64 round = 1; + // Author of timeout + bytes author = 2; + // Signature that this timeout was authored by owner + bytes signature = 3; +} + +message NewRound { + // Highest quorum certificate known after a timeout ouf a round. + QuorumCert highest_quorum_cert = 1; + // Timeout + PacemakerTimeout pacemaker_timeout = 2; + // Author of new round message + bytes author = 3; + // Signature that this timeout was authored by owner + bytes signature = 4; + // The highest ledger info + QuorumCert highest_ledger_info = 5; +} + +message PacemakerTimeoutCertificate { + // Round for which this certificate was created + uint64 round = 1; + // List of certified timeouts + repeated PacemakerTimeout timeouts = 2; +} + +message Block { + // This block's id as a hash value + bytes id = 1; + // Parent block id of this block as a hash value (all zeros to indicate the + // genesis block) + bytes parent_id = 2; + // Payload of the block (e.g. one or more transaction(s) + bytes payload = 3; + // The round of the block (internal monotonically increasing counter). + uint64 round = 4; + // The height of the block (position in the chain). + uint64 height = 5; + // The approximate physical microseconds since the epoch when the block was proposed + uint64 timestamp_usecs = 6; + // Contains the quorum certified ancestor and whether the quorum certified + // ancestor was voted on successfully + QuorumCert quorum_cert = 7; + // Author of the block that can be validated by the author's public key and + // the signature + bytes author = 8; + // Signature that the hash of this block has been authored by the owner of the + // private key + bytes signature = 9; +} + +message QuorumCert { + // Ancestor of this block (could be a parent) + bytes block_id = 1; + /// The execution state id of the corresponding block + bytes state_id = 2; + uint64 version = 3; + /// The round of a certified block. + uint64 round = 4; + // LedgerInfo with at least 2f+1 signatures. The LedgerInfo's consensus data + // hash is a digest that covers ancestor_id, state_id and round. + types.LedgerInfoWithSignatures signed_ledger_info = 5; +} + +message Vote { + // The id of the proposed block. + bytes proposed_block_id = 1; + // The id of the state generated by the StateExecutor after executing the + // proposed block. + bytes executed_state_id = 2; + uint64 version = 3; + uint64 round = 4; + // Author of the vote. + bytes author = 5; + // The ledger info carried with the vote (corresponding to the block of a + // potentially committed txn). + types.LedgerInfo ledger_info = 6; + // Signature of the ledger info. + bytes signature = 7; +} + +message RequestBlock { + // The id of the requested block. + bytes block_id = 1; + uint64 num_blocks = 2; +} + +enum BlockRetrievalStatus { + // Successfully fill in the request. + SUCCEEDED = 0; + // Can not find the block corresponding to block_id. + ID_NOT_FOUND = 1; + // Can not find enough blocks but find some. + NOT_ENOUGH_BLOCKS = 2; +} + +message RespondBlock { + BlockRetrievalStatus status = 1; + // The responded block. + repeated Block blocks = 2; +} + +message RequestChunk { + uint64 start_version = 1; + QuorumCert target = 2; + uint64 batch_size = 3; +} + +message RespondChunk { types.TransactionListWithProof txn_list_with_proof = 1; } diff --git a/network/src/proto/mempool.proto b/network/src/proto/mempool.proto new file mode 100644 index 0000000000000..578e5734b5325 --- /dev/null +++ b/network/src/proto/mempool.proto @@ -0,0 +1,18 @@ +// Copyright (c) The Libra Core Contributors +// SPDX-License-Identifier: Apache-2.0 + +syntax = "proto3"; + +package network; + +import "transaction.proto"; + +/* MempoolSyncMsg represents the messages exchanging between validators to keep + * transactions in sync. The proto definition provides the spec on the wire so + * that others can implement their mempool service in various languages. + * Mempool service is responsible for sending and receiving MempoolSyncMsg + * across validators. */ +message MempoolSyncMsg { + bytes peer_id = 1; + repeated types.SignedTransaction transactions = 2; +} diff --git a/network/src/proto/mod.rs b/network/src/proto/mod.rs new file mode 100644 index 0000000000000..3be4275933ac9 --- /dev/null +++ b/network/src/proto/mod.rs @@ -0,0 +1,20 @@ +// Copyright (c) The Libra Core Contributors +// SPDX-License-Identifier: Apache-2.0 + +//! Protobuf definitions for data structures sent over the network +mod consensus; +mod mempool; +mod network; + +use types::proto::{ledger_info, transaction}; + +pub use self::{ + consensus::{ + Block, BlockRetrievalStatus, ConsensusMsg, NewRound, PacemakerTimeout, + PacemakerTimeoutCertificate, Proposal, QuorumCert, RequestBlock, RequestChunk, + RespondBlock, RespondChunk, Vote, + }, + mempool::MempoolSyncMsg, + network::{DiscoveryMsg, IdentityMsg, Note, PeerInfo, Ping, Pong}, +}; +pub use transaction::SignedTransaction; diff --git a/network/src/proto/network.proto b/network/src/proto/network.proto new file mode 100644 index 0000000000000..f4f09932f0fcc --- /dev/null +++ b/network/src/proto/network.proto @@ -0,0 +1,45 @@ +// Copyright (c) The Libra Core Contributors +// SPDX-License-Identifier: Apache-2.0 + +syntax = "proto3"; + +package network; + +// A `PeerInfo` represents the network address(es) of a Peer at some epoch. +message PeerInfo { + // Addresses this peer can be reached at. + // An address is a byte array in the + // [multiaddr](https://multiformats.io/multiaddr/) format. + repeated bytes addrs = 1; + // Monotonically increasing incarnation number. This is usually a timestamp. + uint64 epoch = 2; +} + +// A `Note` represents a signed PeerInfo. The signature should be of the peer +// whose info is being sent. +message Note { + // Id of the peer. + bytes peer_id = 1; + // Serialized PeerInfo. + bytes peer_info = 2; + // Each peer signs its serialized PeerInfo and includes both the PeerInfo and + // the sign in a note it sends to another peer. + bytes signature = 3; +} + +// Discovery message exchanged as part of the discovery protocol. +// The discovery message sent by a peer consists of notes for all the peers the +// sending peer knows about. +message DiscoveryMsg { repeated Note notes = 1; } + +// Identity message exchanged as part of the Identity protocol. +message IdentityMsg { + bytes peer_id = 1; + repeated bytes supported_protocols = 2; +} + +// Ping message sent as liveness probe. +message Ping {} + +// Pong message sent as reponse to liveness probe. +message Pong {} diff --git a/network/src/protocols/direct_send/mod.rs b/network/src/protocols/direct_send/mod.rs new file mode 100644 index 0000000000000..e3b2cc1e3f266 --- /dev/null +++ b/network/src/protocols/direct_send/mod.rs @@ -0,0 +1,316 @@ +// Copyright (c) The Libra Core Contributors +// SPDX-License-Identifier: Apache-2.0 + +//! Protocol for fire-and-forget style message delivery to a peer +//! +//! DirectSend protocol takes advantage of [muxers] and [substream negotiation] to build a simple +//! best effort message delivery protocol. Concretely, +//! +//! 1. Every message runs in its own ephemeral substream. The substream is directional in the way +//! that only the dialer sends a message to the listener, but no messages or acknowledgements +//! sending back on the other direction. So the message delivery is best effort and not +//! guranteed. Because the substreams are independent, there is no gurantee on the ordering +//! of the message delivery either. +//! 2. An DirectSend call negotiates which protocol to speak using [`protocol-select`]. This +//! allows simple versioning of message delivery and negotiation of which message types are +//! supported. In the future, we can potentially support multiple backwards-incompatible +//! versions of any messages. +//! 3. The actual structure of the wire messages is left for higher layers to specify. The +//! DirectSend protocol is only concerned with shipping around opaque blobs. Current libra +//! DirectSend clients (consensus, mempool) mostly send protobuf enums around over a single +//! DirectSend protocol, e.g., `/libra/consensus/direct_send/0.1.0`. +//! +//! ## Wire Protocol (dialer): +//! +//! To send a message to a remote peer, the dialer +//! +//! 1. Requests a new outbound substream from the muxer. +//! 2. Negotiates the substream using [`protocol-select`] to the protocol they +//! wish to speak, e.g., `/libra/mempool/direct_send/0.1.0`. +//! 3. Sends the serialized message on the newly negotiated substream. +//! 4. Drops the substream. +//! +//! ## Wire Protocol (listener): +//! +//! To receive a message from remote peers, the listener +//! +//! 1. Polls for new inbound substreams on the muxer. +//! 2. Negotiates inbound substreams using [`protocol-select`]. The negotiation +//! must only succeed if the requested protocol is actually supported. +//! 3. Awaits the serialized message on the newly negotiated substream. +//! 4. Drops the substream. +//! +//! Note: negotiated substreams are currently framed with the +//! [muiltiformats unsigned varint length-prefix](https://github.com/multiformats/unsigned-varint) +//! +//! [muxers]: ../../../netcore/multiplexing/index.html +//! [substream negotiation]: ../../../netcore/negotiate/index.html +//! [`protocol-select`]: ../../../netcore/negotiate/index.html +use crate::{ + counters, + error::NetworkError, + peer_manager::{PeerManagerNotification, PeerManagerRequestSender}, + ProtocolId, +}; +use bytes::Bytes; +use channel; +use futures::{ + compat::Sink01CompatExt, + future::{FutureExt, TryFutureExt}, + io::{AsyncRead, AsyncReadExt, AsyncWrite}, + sink::SinkExt, + stream::StreamExt, +}; +use logger::prelude::*; +use std::{ + collections::{hash_map::Entry, HashMap}, + fmt::Debug, +}; +use tokio::{codec::Framed, runtime::TaskExecutor}; +use types::PeerId; +use unsigned_varint::codec::UviBytes; + +#[cfg(test)] +mod test; + +#[derive(Clone, Debug, Eq, PartialEq)] +pub enum DirectSendRequest { + /// A request to send out a message. + SendMessage(PeerId, Message), +} + +#[derive(Clone, Debug, Eq, PartialEq)] +pub enum DirectSendNotification { + /// A notification that a DirectSend message is received. + RecvMessage(PeerId, Message), +} + +#[derive(Clone, Eq, PartialEq)] +pub struct Message { + /// Message type. + pub protocol: ProtocolId, + /// Serialized message data. + pub mdata: Bytes, +} + +impl Debug for Message { + fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + let mdata_str = if self.mdata.len() <= 10 { + format!("{:?}", self.mdata) + } else { + format!("{:?}...", self.mdata.slice_to(10)) + }; + write!( + f, + "Message {{ protocol: {:?}, mdata: {} }}", + self.protocol, mdata_str + ) + } +} + +/// The DirectSend actor. +pub struct DirectSend { + /// A handle to a tokio executor. + executor: TaskExecutor, + /// Channel to receive requests from other upstream actors. + ds_requests_rx: channel::Receiver, + /// Channels to send notifictions to upstream actors. + ds_notifs_tx: channel::Sender, + /// Channel to receive notifications from PeerManager. + peer_mgr_notifs_rx: channel::Receiver>, + /// Channel to send requests to PeerManager. + peer_mgr_reqs_tx: PeerManagerRequestSender, + /// Outbound message queues for each (PeerId, ProtocolId) pair. + message_queues: HashMap<(PeerId, ProtocolId), channel::Sender>, +} + +impl DirectSend +where + TSubstream: AsyncRead + AsyncWrite + Send + Unpin + Debug + 'static, +{ + pub fn new( + executor: TaskExecutor, + ds_requests_rx: channel::Receiver, + ds_notifs_tx: channel::Sender, + peer_mgr_notifs_rx: channel::Receiver>, + peer_mgr_reqs_tx: PeerManagerRequestSender, + ) -> Self { + Self { + executor, + ds_requests_rx, + ds_notifs_tx, + peer_mgr_notifs_rx, + peer_mgr_reqs_tx, + message_queues: HashMap::new(), + } + } + + pub async fn start(mut self) { + loop { + futures::select! { + req = self.ds_requests_rx.select_next_some() => { + self.handle_direct_send_request(req).await; + } + notif = self.peer_mgr_notifs_rx.select_next_some() => { + self.handle_peer_mgr_notification(notif); + } + complete => { + crit!("Direct send actor terminated"); + break; + } + } + } + } + + // Handle PeerManagerNotification, which can only be NewInboundSubstream for now. + fn handle_peer_mgr_notification(&self, notif: PeerManagerNotification) { + trace!("PeerManagerNotification::{:?}", notif); + match notif { + PeerManagerNotification::NewInboundSubstream(peer_id, substream) => { + self.executor.spawn( + Self::handle_inbound_substream( + peer_id, + substream.protocol, + substream.substream, + self.ds_notifs_tx.clone(), + ) + .boxed() + .unit_error() + .compat(), + ); + } + _ => unreachable!("Unexpected PeerManagerNotification"), + } + } + + // Handle a new inbound substream. Keep forwarding the messages to the NetworkProvider. + async fn handle_inbound_substream( + peer_id: PeerId, + protocol: ProtocolId, + substream: TSubstream, + mut ds_notifs_tx: channel::Sender, + ) { + let mut substream = + Framed::new(substream.compat(), UviBytes::::default()).sink_compat(); + while let Some(item) = substream.next().await { + match item { + Ok(data) => { + let notif = DirectSendNotification::RecvMessage( + peer_id, + Message { + protocol: protocol.clone(), + mdata: data.freeze(), + }, + ); + ds_notifs_tx + .send(notif) + .await + .expect("DirectSendNotification send error"); + } + Err(e) => { + warn!( + "DirectSend substream with peer {} receives error {}", + peer_id.short_str(), + e + ); + break; + } + } + } + warn!( + "DirectSend inbound substream with peer {} closed", + peer_id.short_str() + ); + } + + // Create a new message queue and spawn a task to forward the messages from the queue to the + // corresponding substream. + async fn start_message_queue_handler( + executor: TaskExecutor, + mut peer_mgr_reqs_tx: PeerManagerRequestSender, + peer_id: PeerId, + protocol: ProtocolId, + ) -> Result, NetworkError> { + // Create a channel for the (PeerId, ProtocolId) pair. + let (msg_tx, msg_rx) = channel::new::( + 1024, + &counters::OP_COUNTERS.peer_gauge( + &counters::PENDING_DIRECT_SEND_OUTBOUND_MESSAGES, + &peer_id.short_str(), + ), + ); + + // Open a new substream for the (PeerId, ProtocolId) pair + let raw_substream = peer_mgr_reqs_tx.open_substream(peer_id, protocol).await?; + let substream = + Framed::new(raw_substream.compat(), UviBytes::::default()).sink_compat(); + + // Spawn a task to forward the messages from the queue to the substream. + let f_substream = async move { + if let Err(e) = msg_rx.map(Ok).forward(substream).await { + warn!( + "Forward messages to peer {} error {:?}", + peer_id.short_str(), + e + ); + } + // The messages in queue will be dropped + counters::DIRECT_SEND_MESSAGES_DROPPED.inc_by( + counters::OP_COUNTERS + .peer_gauge( + &counters::PENDING_DIRECT_SEND_OUTBOUND_MESSAGES, + &peer_id.short_str(), + ) + .get(), + ); + }; + executor.spawn(f_substream.boxed().unit_error().compat()); + + Ok(msg_tx) + } + + // Try to send a message to the message queue. + async fn try_send_msg( + &mut self, + peer_id: PeerId, + msg: Message, + peer_mgr_reqs_tx: PeerManagerRequestSender, + ) -> Result<(), NetworkError> { + let protocol = msg.protocol.clone(); + + let substream_queue_tx = match self.message_queues.entry((peer_id, protocol.clone())) { + Entry::Occupied(entry) => entry.into_mut(), + Entry::Vacant(entry) => { + let msg_tx = Self::start_message_queue_handler( + self.executor.clone(), + peer_mgr_reqs_tx, + peer_id, + protocol.clone(), + ) + .await?; + entry.insert(msg_tx) + } + }; + + substream_queue_tx.send(msg.mdata).await.map_err(|e| { + self.message_queues.remove(&(peer_id, protocol)); + e.into() + }) + } + + // Handle DirectSendRequest, which can only be SendMessage request for now. + async fn handle_direct_send_request(&mut self, req: DirectSendRequest) { + trace!("DirectSendRequest::{:?}", req); + match req { + DirectSendRequest::SendMessage(peer_id, msg) => { + if let Err(e) = self + .try_send_msg(peer_id, msg.clone(), self.peer_mgr_reqs_tx.clone()) + .await + { + counters::DIRECT_SEND_MESSAGES_DROPPED.inc(); + warn!("DirectSend to peer {} failed: {}", peer_id.short_str(), e); + } + } + } + } +} diff --git a/network/src/protocols/direct_send/test.rs b/network/src/protocols/direct_send/test.rs new file mode 100644 index 0000000000000..babcee81772d4 --- /dev/null +++ b/network/src/protocols/direct_send/test.rs @@ -0,0 +1,520 @@ +// Copyright (c) The Libra Core Contributors +// SPDX-License-Identifier: Apache-2.0 + +use crate::{ + common::NegotiatedSubstream, + peer_manager::{ + PeerManagerError, PeerManagerNotification, PeerManagerRequest, PeerManagerRequestSender, + }, + protocols::direct_send::{DirectSend, DirectSendNotification, DirectSendRequest, Message}, + ProtocolId, +}; +use bytes::Bytes; +use channel; +use futures::{ + compat::Sink01CompatExt, + future::{FutureExt, TryFutureExt}, + io::AsyncReadExt, + sink::SinkExt, + stream::StreamExt, +}; +use memsocket::MemorySocket; +use tokio::{ + codec::Framed, + runtime::{Runtime, TaskExecutor}, +}; +use types::PeerId; +use unsigned_varint::codec::UviBytes; + +const PROTOCOL_1: &[u8] = b"/direct_send/1.0.0"; +const PROTOCOL_2: &[u8] = b"/direct_send/2.0.0"; +const MESSAGE_1: &[u8] = b"Direct Send 1"; +const MESSAGE_2: &[u8] = b"Direct Send 2"; +const MESSAGE_3: &[u8] = b"Direct Send 3"; + +fn start_direct_send_actor( + executor: TaskExecutor, +) -> ( + channel::Sender, + channel::Receiver, + channel::Sender>, + channel::Receiver>, +) { + let (ds_requests_tx, ds_requests_rx) = channel::new_test(8); + let (ds_notifs_tx, ds_notifs_rx) = channel::new_test(8); + let (peer_mgr_notifs_tx, peer_mgr_notifs_rx) = channel::new_test(8); + let (peer_mgr_reqs_tx, peer_mgr_reqs_rx) = channel::new_test(8); + let direct_send = DirectSend::new( + executor.clone(), + ds_requests_rx, + ds_notifs_tx, + peer_mgr_notifs_rx, + PeerManagerRequestSender::new(peer_mgr_reqs_tx), + ); + executor.spawn(direct_send.start().boxed().unit_error().compat()); + + ( + ds_requests_tx, + ds_notifs_rx, + peer_mgr_notifs_tx, + peer_mgr_reqs_rx, + ) +} + +async fn expect_network_provider_recv_message( + ds_notifs_rx: &mut channel::Receiver, + expected_peer_id: PeerId, + expected_protocol: &'static [u8], + expected_message: &'static [u8], +) { + match ds_notifs_rx.next().await.unwrap() { + DirectSendNotification::RecvMessage(peer_id, msg) => { + assert_eq!(peer_id, expected_peer_id); + assert_eq!(msg.protocol.as_ref(), expected_protocol); + assert_eq!(msg.mdata, Bytes::from_static(expected_message)); + } + } +} + +async fn expect_open_substream_request( + peer_mgr_reqs_rx: &mut channel::Receiver>, + expected_peer_id: PeerId, + expected_protocol: &'static [u8], + response: Result, +) where + TSubstream: std::fmt::Debug, +{ + match peer_mgr_reqs_rx.next().await.unwrap() { + PeerManagerRequest::OpenSubstream(peer_id, protocol, substream_tx) => { + assert_eq!(peer_id, expected_peer_id); + assert_eq!(protocol.as_ref(), expected_protocol); + substream_tx.send(response).unwrap(); + } + _ => panic!("Unexpected event"), + } +} + +#[test] +fn test_inbound_substream() { + let mut rt = Runtime::new().unwrap(); + + let (_ds_requests_tx, mut ds_notifs_rx, mut peer_mgr_notifs_tx, _peer_mgr_reqs_rx) = + start_direct_send_actor(rt.executor()); + + let peer_id = PeerId::random(); + let (dialer_substream, listener_substream) = MemorySocket::new_pair(); + + // The dialer sends two messages to the listener. + let f_substream = async move { + let mut dialer_substream = + Framed::new(dialer_substream.compat(), UviBytes::default()).sink_compat(); + dialer_substream + .send(Bytes::from_static(MESSAGE_1)) + .await + .unwrap(); + dialer_substream + .send(Bytes::from_static(MESSAGE_2)) + .await + .unwrap(); + }; + + // Fake the listener NetworkProvider to notify DirectSend of the inbound substream. + let f_network_provider = async move { + peer_mgr_notifs_tx + .send(PeerManagerNotification::NewInboundSubstream( + peer_id, + NegotiatedSubstream { + protocol: ProtocolId::from_static(&PROTOCOL_1[..]), + substream: listener_substream, + }, + )) + .await + .unwrap(); + + // The listener should receive these two messages + expect_network_provider_recv_message(&mut ds_notifs_rx, peer_id, PROTOCOL_1, MESSAGE_1) + .await; + expect_network_provider_recv_message(&mut ds_notifs_rx, peer_id, PROTOCOL_1, MESSAGE_2) + .await; + }; + + rt.spawn(f_substream.boxed().unit_error().compat()); + rt.block_on(f_network_provider.boxed().unit_error().compat()) + .unwrap(); +} + +#[test] +fn test_inbound_substream_closed() { + let mut rt = Runtime::new().unwrap(); + + let (_ds_requests_tx, mut ds_notifs_rx, mut peer_mgr_notifs_tx, _peer_mgr_reqs_rx) = + start_direct_send_actor(rt.executor()); + + let peer_id = PeerId::random(); + let (dialer_substream, listener_substream) = MemorySocket::new_pair(); + + // The dialer sends a message to the listener. + let f_substream = async move { + let mut dialer_substream = + Framed::new(dialer_substream.compat(), UviBytes::default()).sink_compat(); + dialer_substream + .send(Bytes::from_static(MESSAGE_1)) + .await + .unwrap(); + // close the substream on the dialer side + drop(dialer_substream); + }; + + // Fake the listener NetworkProvider + let f_network_provider = async move { + peer_mgr_notifs_tx + .send(PeerManagerNotification::NewInboundSubstream( + peer_id, + NegotiatedSubstream { + protocol: ProtocolId::from_static(&PROTOCOL_1[..]), + substream: listener_substream, + }, + )) + .await + .unwrap(); + + expect_network_provider_recv_message(&mut ds_notifs_rx, peer_id, PROTOCOL_1, MESSAGE_1) + .await; + }; + + rt.spawn(f_substream.boxed().unit_error().compat()); + rt.block_on(f_network_provider.boxed().unit_error().compat()) + .unwrap(); +} + +#[test] +fn test_outbound_single_protocol() { + let mut rt = Runtime::new().unwrap(); + + let (mut ds_requests_tx, _ds_notifs_rx, _peer_mgr_notifs_tx, mut peer_mgr_reqs_rx) = + start_direct_send_actor(rt.executor()); + + let peer_id = PeerId::random(); + let (dialer_substream, listener_substream) = MemorySocket::new_pair(); + + // Fake the dialer NetworkProvider + let f_network_provider = async move { + // Send 2 messages with the same protocol + ds_requests_tx + .send(DirectSendRequest::SendMessage( + peer_id, + Message { + protocol: Bytes::from_static(&PROTOCOL_1[..]), + mdata: Bytes::from_static(MESSAGE_1), + }, + )) + .await + .unwrap(); + ds_requests_tx + .send(DirectSendRequest::SendMessage( + peer_id, + Message { + protocol: Bytes::from_static(&PROTOCOL_1[..]), + mdata: Bytes::from_static(MESSAGE_2), + }, + )) + .await + .unwrap(); + + // DirectSend actor should request to open a substream with the same protocol + expect_open_substream_request( + &mut peer_mgr_reqs_rx, + peer_id, + PROTOCOL_1, + Ok(dialer_substream), + ) + .await; + }; + + // The listener should receive these two messages. + let f_substream = async move { + let mut listener_substream = + Framed::new(listener_substream.compat(), UviBytes::::default()).sink_compat(); + let msg = listener_substream.next().await.unwrap().unwrap(); + assert_eq!(msg.as_ref(), MESSAGE_1); + let msg = listener_substream.next().await.unwrap().unwrap(); + assert_eq!(msg.as_ref(), MESSAGE_2); + }; + + rt.spawn(f_network_provider.boxed().unit_error().compat()); + rt.block_on(f_substream.boxed().unit_error().compat()) + .unwrap(); +} + +#[test] +fn test_outbound_multiple_protocols() { + let mut rt = Runtime::new().unwrap(); + + let (mut ds_requests_tx, _ds_notifs_rx, _peer_mgr_notifs_tx, mut peer_mgr_reqs_rx) = + start_direct_send_actor(rt.executor()); + + let peer_id = PeerId::random(); + let (dialer_substream_1, listener_substream_1) = MemorySocket::new_pair(); + let (dialer_substream_2, listener_substream_2) = MemorySocket::new_pair(); + + // Fake the dialer NetworkProvider + let f_network_provider = async move { + // Send 2 messages with different protocols to the same peer + ds_requests_tx + .send(DirectSendRequest::SendMessage( + peer_id, + Message { + protocol: Bytes::from_static(&PROTOCOL_1[..]), + mdata: Bytes::from_static(MESSAGE_1), + }, + )) + .await + .unwrap(); + ds_requests_tx + .send(DirectSendRequest::SendMessage( + peer_id, + Message { + protocol: Bytes::from_static(&PROTOCOL_2[..]), + mdata: Bytes::from_static(MESSAGE_2), + }, + )) + .await + .unwrap(); + + // DirectSend actor should request to open 2 different substreams. + expect_open_substream_request( + &mut peer_mgr_reqs_rx, + peer_id, + PROTOCOL_1, + Ok(dialer_substream_1), + ) + .await; + expect_open_substream_request( + &mut peer_mgr_reqs_rx, + peer_id, + PROTOCOL_2, + Ok(dialer_substream_2), + ) + .await; + }; + + // The listener should receive 1 message on each substream. + let f_substream = async move { + let mut listener_substream_1 = + Framed::new(listener_substream_1.compat(), UviBytes::::default()).sink_compat(); + let msg = listener_substream_1.next().await.unwrap().unwrap(); + assert_eq!(msg.as_ref(), MESSAGE_1); + let mut listener_substream_2 = + Framed::new(listener_substream_2.compat(), UviBytes::::default()).sink_compat(); + let msg = listener_substream_2.next().await.unwrap().unwrap(); + assert_eq!(msg.as_ref(), MESSAGE_2); + }; + + rt.spawn(f_network_provider.boxed().unit_error().compat()); + rt.block_on(f_substream.boxed().unit_error().compat()) + .unwrap(); +} + +#[test] +fn test_outbound_not_connected() { + ::logger::try_init_for_testing(); + let mut rt = Runtime::new().unwrap(); + + let (mut ds_requests_tx, _ds_notifs_rx, _peer_mgr_notifs_tx, mut peer_mgr_reqs_rx) = + start_direct_send_actor(rt.executor()); + + let peer_id = PeerId::random(); + let (dialer_substream, listener_substream) = MemorySocket::new_pair(); + + // Fake the dialer NetworkProvider + let f_network_provider = async move { + // Request DirectSend to send the first message + ds_requests_tx + .send(DirectSendRequest::SendMessage( + peer_id, + Message { + protocol: Bytes::from_static(&PROTOCOL_1[..]), + mdata: Bytes::from_static(MESSAGE_1), + }, + )) + .await + .unwrap(); + + // PeerManager returns the NotConnected error + expect_open_substream_request( + &mut peer_mgr_reqs_rx, + peer_id, + PROTOCOL_1, + Err(PeerManagerError::NotConnected(peer_id)), + ) + .await; + + // Request DirectSend to send the second message + ds_requests_tx + .send(DirectSendRequest::SendMessage( + peer_id, + Message { + protocol: Bytes::from_static(&PROTOCOL_1[..]), + mdata: Bytes::from_static(MESSAGE_2), + }, + )) + .await + .unwrap(); + + // PeerManager returns the substream + expect_open_substream_request( + &mut peer_mgr_reqs_rx, + peer_id, + PROTOCOL_1, + Ok(dialer_substream), + ) + .await; + }; + + // The listener should receive the message. + let f_substream = async move { + let mut listener_substream = + Framed::new(listener_substream.compat(), UviBytes::::default()).sink_compat(); + let msg = listener_substream.next().await.unwrap().unwrap(); + // Only the second message should be received, because when the first message is sent, + // the peer isn't connected. + assert_eq!(msg.as_ref(), MESSAGE_2); + }; + + rt.spawn(f_network_provider.boxed().unit_error().compat()); + rt.block_on(f_substream.boxed().unit_error().compat()) + .unwrap(); +} + +#[test] +fn test_outbound_connection_closed() { + ::logger::try_init_for_testing(); + let mut rt = Runtime::new().unwrap(); + + let (mut ds_requests_tx, _ds_notifs_rx, _peer_mgr_notifs_tx, mut peer_mgr_reqs_rx) = + start_direct_send_actor(rt.executor()); + + let peer_id = PeerId::random(); + let (dialer_substream_1, listener_substream_1) = MemorySocket::new_pair(); + let (dialer_substream_2, listener_substream_2) = MemorySocket::new_pair(); + + // Send the first message and open the first substream + let f_first_message = async move { + // Request DirectSend to send the first message + ds_requests_tx + .send(DirectSendRequest::SendMessage( + peer_id, + Message { + protocol: Bytes::from_static(&PROTOCOL_1[..]), + mdata: Bytes::from_static(MESSAGE_1), + }, + )) + .await + .unwrap(); + + // PeerManager returns the first substream + expect_open_substream_request( + &mut peer_mgr_reqs_rx, + peer_id, + PROTOCOL_1, + Ok(dialer_substream_1), + ) + .await; + + (ds_requests_tx, peer_mgr_reqs_rx) + }; + let (mut ds_requests_tx, mut peer_mgr_reqs_rx) = rt + .block_on(f_first_message.boxed().unit_error().compat()) + .unwrap(); + + // Receive the first message and close the first substream + let f_close_first_substream = async move { + let mut listener_substream = + Framed::new(listener_substream_1.compat(), UviBytes::::default()).sink_compat(); + let msg = listener_substream.next().await.unwrap().unwrap(); + // The listener should receive the first message. + assert_eq!(msg.as_ref(), MESSAGE_1); + // Close the substream by dropping it on the listener side + drop(listener_substream); + }; + rt.block_on(f_close_first_substream.boxed().unit_error().compat()) + .unwrap(); + + // Send the second message while the connection is still lost. + let f_second_message = async move { + // Request DirectSend to send the second message + ds_requests_tx + .send(DirectSendRequest::SendMessage( + peer_id, + Message { + protocol: Bytes::from_static(&PROTOCOL_1[..]), + mdata: Bytes::from_static(MESSAGE_2), + }, + )) + .await + .unwrap(); + + ds_requests_tx + }; + let mut ds_requests_tx = rt + .block_on(f_second_message.boxed().unit_error().compat()) + .unwrap(); + + // Keep sending the third message and open the second substream + let f_third_message = async move { + // Request DirectSend to send the third message + loop { + ds_requests_tx + .send(DirectSendRequest::SendMessage( + peer_id, + Message { + protocol: Bytes::from_static(&PROTOCOL_1[..]), + mdata: Bytes::from_static(MESSAGE_3), + }, + )) + .await + .unwrap(); + } + }; + rt.spawn(f_third_message.boxed().unit_error().compat()); + + let f_open_second_substream = async move { + // PeerManager returns the second substream + expect_open_substream_request( + &mut peer_mgr_reqs_rx, + peer_id, + PROTOCOL_1, + Ok(dialer_substream_2), + ) + .await; + + peer_mgr_reqs_rx + }; + let mut peer_mgr_reqs_rx = rt + .block_on(f_open_second_substream.boxed().unit_error().compat()) + .unwrap(); + + // Fake peer manager to keep the PeerManagerRequest receiver + let f_peer_manager = async move { + loop { + expect_open_substream_request( + &mut peer_mgr_reqs_rx, + peer_id, + PROTOCOL_1, + Err(PeerManagerError::NotConnected(peer_id)), + ) + .await; + } + }; + rt.spawn(f_peer_manager.boxed().unit_error().compat()); + + // The listener should only receive the third message through the second substream. + let f_second_substream = async move { + let mut listener_substream = + Framed::new(listener_substream_2.compat(), UviBytes::::default()).sink_compat(); + let msg = listener_substream.next().await.unwrap().unwrap(); + assert_eq!(msg.as_ref(), MESSAGE_3); + }; + rt.block_on(f_second_substream.boxed().unit_error().compat()) + .unwrap(); +} diff --git a/network/src/protocols/discovery/mod.rs b/network/src/protocols/discovery/mod.rs new file mode 100644 index 0000000000000..deabb75bf90c5 --- /dev/null +++ b/network/src/protocols/discovery/mod.rs @@ -0,0 +1,522 @@ +// Copyright (c) The Libra Core Contributors +// SPDX-License-Identifier: Apache-2.0 + +//! Protocol to discover network addresses of other peers on the Libra network +//! +//! ## Implementation +//! +//! The discovery module is implemented as a stand-alone actor in the Network sub-system of the +//! Libra stack. The actor participates in discovery by periodically sending its observed state of +//! the network to a randomly chosen peer. Other peers are also expected to be running the same +//! protocol. Therefore, in expectation, every peer expects to hear from 1 other peer in each +//! round. On hearing from the remote peer, the local discovery module tries to reconcile its state +//! to reflect any changes. In addition to updating its state, it also passes on new infromation to +//! the [`ConnectivityManager`] module. +//! +//! For the initial bootstrap of a node, it sends the discovery message to a randomly chosen seed +//! peer in each round. The message only contains the identity of this peer unless it learns more +//! about the network membership from another peer. +//! +//! Currently we do not use this mechanism to detect peer failures - instead, we simply connect to +//! all the peers in the network, and hope to learn about their failure on connection errors. +//! +//! TODO: We need to handle to case of peers who may no longer be a part of the network. +//! +//! ## Future work +//! +//! - Currently, we do not try to detect/punish nodes which are just lurking (without contributing +//! to the protocol), or actively trying to spread misinformation in the network. In the future, we +//! plan to remedy this by introducing a module dedicated to detecting byzantine behavior, and by +//! making the discovery protocol itself tolerant to byzantine faults. +//! - As an optimization, instead of creating a new substream to the chosen peer in each round, we +//! could maintain a cache of open substreams which could be re-used across numerous rounds. +//! +//! [`ConnectivityManager`]: ../../connectivity_manager +use crate::{ + common::NegotiatedSubstream, + connectivity_manager::ConnectivityRequest, + error::{NetworkError, NetworkErrorKind}, + peer_manager::{PeerManagerNotification, PeerManagerRequestSender}, + proto::{DiscoveryMsg, Note, PeerInfo}, + utils, NetworkPublicKeys, ProtocolId, +}; +use bytes::Bytes; +use channel; +use crypto::{ + hash::{CryptoHasher, DiscoveryMsgHasher}, + HashValue, Signature, +}; +use futures::{ + compat::{Future01CompatExt, Sink01CompatExt}, + future::{Future, FutureExt, TryFutureExt}, + io::{AsyncRead, AsyncReadExt, AsyncWrite}, + sink::SinkExt, + stream::{FusedStream, FuturesUnordered, Stream, StreamExt}, +}; +use logger::prelude::*; +use parity_multiaddr::Multiaddr; +use protobuf::{self, Message}; +use rand::{rngs::SmallRng, FromEntropy, Rng}; +use std::{ + collections::HashMap, + convert::TryFrom, + fmt::Debug, + pin::Pin, + sync::{Arc, RwLock}, + time::{Duration, SystemTime}, +}; +use tokio::{codec::Framed, prelude::FutureExt as _}; +use types::{ + validator_signer::ValidatorSigner as Signer, + validator_verifier::ValidatorVerifier as SignatureValidator, PeerId, +}; +use unsigned_varint::codec::UviBytes; + +#[cfg(test)] +mod test; + +pub const DISCOVERY_PROTOCOL_NAME: &[u8] = b"/libra/discovery/0.1.0"; + +/// The actor running the discovery protocol. +pub struct Discovery { + /// Note for self. + self_note: Note, + /// Validator for verifying signatures on messages. + trusted_peers: Arc>>, + /// Current state, maintaining the most recent Note for each peer, alongside parsed PeerInfo. + known_peers: HashMap, + /// Info for seed peers. + seed_peers: HashMap, + /// Currently connected peers. + connected_peers: HashMap, + /// Ticker to trigger state send to a random peer. In production, the ticker is likely to be + /// fixed duration interval timer. + ticker: TTicker, + /// Channel to send requests to PeerManager. + peer_mgr_reqs_tx: PeerManagerRequestSender, + /// Channel to receive notifications from PeerManager. + peer_mgr_notifs_rx: channel::Receiver>, + /// Channel to send requests to ConnectivityManager. + conn_mgr_reqs_tx: channel::Sender, + /// Message timeout duration. + msg_timeout: Duration, + /// Random-number generator. + rng: SmallRng, +} + +impl Discovery +where + TTicker: Stream + FusedStream + Unpin, + TSubstream: AsyncRead + AsyncWrite + Send + Unpin + Debug + 'static, +{ + pub fn new( + self_peer_id: PeerId, + self_addrs: Vec, + signer: Signer, + seed_peers: HashMap, + trusted_peers: Arc>>, + ticker: TTicker, + peer_mgr_reqs_tx: PeerManagerRequestSender, + peer_mgr_notifs_rx: channel::Receiver>, + conn_mgr_reqs_tx: channel::Sender, + msg_timeout: Duration, + ) -> Self { + let self_peer_info = create_peer_info(self_addrs); + let self_note = create_note(&signer, self_peer_id, self_peer_info.clone()); + let known_peers = vec![(self_peer_id, (self_peer_info, self_note.clone()))] + .into_iter() + .collect(); + Self { + self_note, + seed_peers, + trusted_peers, + known_peers, + connected_peers: HashMap::new(), + ticker, + peer_mgr_reqs_tx, + peer_mgr_notifs_rx, + conn_mgr_reqs_tx, + msg_timeout, + rng: SmallRng::from_entropy(), + } + } + + // Connect with all the seed peers. If current node is also a seed peer, remove it from the + // list. + async fn connect_to_seed_peers(&mut self) { + let self_peer_id = PeerId::try_from(self.self_note.get_peer_id()).unwrap(); + for (peer_id, peer_info) in self + .seed_peers + .iter() + .filter(|(peer_id, _)| **peer_id != self_peer_id) + { + self.conn_mgr_reqs_tx + .send(ConnectivityRequest::UpdateAddresses( + *peer_id, + peer_info + .get_addrs() + .iter() + .map(|addr| Multiaddr::try_from(addr.clone()).unwrap()) + .collect(), + )) + .await + .unwrap(); + } + } + + // Starts the main event loop for the discovery actor. We bootstrap by first dialing all the + // seed peers, and then entering the event handling loop. Messages are received from: + // - a ticker to trigger discovery message send to a random connected peer + // - an incoming substream from a peer wishing to send its state + // - an internal task once it has processed incoming messages from a peer, and wishes for + // discovery actor to update its state. + pub async fn start(mut self) { + // Bootstrap by connecting to seed peers. + self.connect_to_seed_peers().await; + let mut unprocessed_inbound = FuturesUnordered::new(); + let mut unprocessed_outbound = FuturesUnordered::new(); + loop { + futures::select! { + _ = self.ticker.select_next_some() => { + self.handle_tick(&mut unprocessed_outbound); + } + notif = self.peer_mgr_notifs_rx.select_next_some() => { + self.handle_peer_mgr_notification(notif, &mut unprocessed_inbound); + }, + (peer_id, stream_result) = unprocessed_inbound.select_next_some() => { + match stream_result { + Ok(remote_notes) => { + self.reconcile(peer_id, remote_notes).await; + } + Err(e) => { + warn!("Failure in processing stream from peer: {}. Error: {:?}", + peer_id.short_str(), e); + } + } + }, + _ = unprocessed_outbound.select_next_some() => {} + complete => { + crit!("Discovery actor terminated"); + break; + } + } + } + } + + // Handles a clock "tick" by: + // 1. Selecting a random peer to send state to. + // 2. Compose the msg to send. + // 3. Spawn off a new task to push the msg to the peer. + fn handle_tick<'a>( + &'a mut self, + unprocessed_outbound: &'a mut FuturesUnordered + Send>>>, + ) { + // On each tick, we choose a random neighbor and push our state to it. + if let Some(peer) = self.choose_random_neighbor() { + // We clone `peer_mgr_reqs_tx` member of Self, since using `self` inside fut below + // triggers some lifetime errors. + let sender = self.peer_mgr_reqs_tx.clone(); + // Compose discovery msg to send. + let msg = self.compose_discovery_msg(); + let timeout = self.msg_timeout; + let fut = async move { + if let Err(err) = push_state_to_peer(sender, peer, msg) + .boxed() + .compat() + .timeout(timeout) + .compat() + .await + { + warn!( + "Failed to send discovery msg to {}; error: {:?}", + peer.short_str(), + err + ); + } + }; + unprocessed_outbound.push(fut.boxed()); + } + } + + fn handle_peer_mgr_notification<'a>( + &'a mut self, + notif: PeerManagerNotification, + unprocessed_inbound: &'a mut FuturesUnordered< + Pin, NetworkError>)> + Send>>, + >, + ) { + trace!("PeerManagerNotification::{:?}", notif); + match notif { + PeerManagerNotification::NewPeer(peer_id, addr) => { + // Add peer to connected peer list. + self.connected_peers.insert(peer_id, addr); + } + PeerManagerNotification::LostPeer(peer_id, addr) => { + match self.connected_peers.get(&peer_id) { + Some(curr_addr) if *curr_addr == addr => { + // Remove node from connected peers list. + self.connected_peers.remove(&peer_id); + } + _ => { + debug!( + "Received redundant lost peer notfication for {}", + peer_id.short_str() + ); + } + } + } + PeerManagerNotification::NewInboundSubstream(peer_id, substream) => { + // We should not receive substreams from peer manager for any other protocol. + assert_eq!(substream.protocol, DISCOVERY_PROTOCOL_NAME); + // Add future to handle new inbound substream. + unprocessed_inbound.push( + handle_inbound_substream( + self.trusted_peers.clone(), + peer_id, + substream, + self.msg_timeout, + ) + .boxed(), + ); + } + } + } + + // Chooses a random connected neighbour. + fn choose_random_neighbor(&mut self) -> Option { + if !self.connected_peers.is_empty() { + let peers: Vec<_> = self.connected_peers.keys().collect(); + let idx = self.rng.gen_range(0, peers.len()); + Some(*peers[idx]) + } else { + None + } + } + + // Creates DiscoveryMsg to be sent to some remote peer. + fn compose_discovery_msg(&self) -> DiscoveryMsg { + let mut msg = DiscoveryMsg::new(); + let notes = msg.mut_notes(); + for (_, note) in self.known_peers.values() { + notes.push(note.clone()); + } + msg + } + + // Updates local state by reconciling with notes received from some remote peer. + // Assumption: `remote_notes` have already been verified for signature validity and content. + async fn reconcile(&mut self, remote_peer: PeerId, remote_notes: Vec) { + // If a peer is previously unknown, or has a newer epoch number, we update its + // corresponding entry in the map. + let self_peer_id = PeerId::try_from(self.self_note.get_peer_id()).unwrap(); + for note in remote_notes { + let peer_id = PeerId::try_from(note.get_peer_id()).unwrap(); + let peer_info: PeerInfo = protobuf::parse_from_bytes(note.get_peer_info()).unwrap(); + match self.known_peers.get_mut(&peer_id) { + // If we know about this peer, and receive the same or an older epoch, we do + // nothing. + Some((ref curr_peer_info, _)) + if peer_info.get_epoch() <= curr_peer_info.get_epoch() => + { + if peer_info.get_epoch() < curr_peer_info.get_epoch() { + debug!( + "Received stale note for peer: {} from peer: {}", + peer_id.short_str(), + remote_peer + ); + } + continue; + } + _ => { + info!( + "Received updated note for peer: {} from peer: {}", + peer_id.short_str(), + remote_peer.short_str() + ); + // We can never receive a note with a higher epoch number on us than what we + // ourselves have broadcasted. + assert_ne!(peer_id, self_peer_id); + // Update internal state of the peer with new Note. + self.known_peers.insert(peer_id, (peer_info.clone(), note)); + self.conn_mgr_reqs_tx + .send(ConnectivityRequest::UpdateAddresses( + peer_id, + peer_info + .get_addrs() + .iter() + .map(|addr| Multiaddr::try_from(addr.clone()).unwrap()) + .collect(), + )) + .await + .unwrap(); + } + } + } + } +} + +// Creates a PeerInfo combining the given addresses with the current unix timestamp as epoch. +fn create_peer_info(addrs: Vec) -> PeerInfo { + let mut peer_info = PeerInfo::new(); + // TODO: Currently, SystemTime::now() in Rust is not guaranteed to use a monotonic clock. + // At the moment, it's unclear how to do this in a platform-agnostic way. For Linux, we + // could use something like the [timerfd trait](https://docs.rs/crate/timerfd/1.0.0). + let time_since_epoch = SystemTime::now() + .duration_since(SystemTime::UNIX_EPOCH) + .expect("System clock reset to before unix epoch") + .as_millis() as u64; + peer_info.set_epoch(time_since_epoch); + peer_info.set_addrs(addrs.into_iter().map(|addr| addr.as_ref().into()).collect()); + peer_info +} + +// Creates a note by signing the given peer info, and combining the signature, peer_info and +// peer_id into a note. +fn create_note(signer: &Signer, peer_id: PeerId, peer_info: PeerInfo) -> Note { + let raw_info = peer_info.write_to_bytes().unwrap(); + // Now that we have the serialized peer info, we sign it. + let signature = sign(&signer, &raw_info); + let mut note = Note::default(); + note.set_peer_id(peer_id.into()); + note.set_peer_info(raw_info.into()); + note.set_signature(signature.into()); + note +} + +// Handles an inbound substream from a remote peer as follows: +// 1. Reads the DiscoveryMsg sent by the remote. +// 2. Verifies signatures on all notes contained in the message. +// 3. Sends a message to the discovery peer with the notes received from the remote. +async fn handle_inbound_substream( + trusted_peers: Arc>>, + peer_id: PeerId, + substream: NegotiatedSubstream, + timeout: Duration, +) -> (PeerId, Result, NetworkError>) +where + TSubstream: AsyncRead + AsyncWrite + Send + Unpin + 'static, +{ + // Read msg from stream. + let result = recv_msg(substream.substream) + .boxed() + .compat() + .timeout(timeout) + .compat() + .map_err(Into::::into) + .await + .and_then(|mut msg| { + let _ = msg + .get_notes() + .iter() + .map(|note| { + is_valid(¬e.to_owned(), trusted_peers.clone()).map_err(|e| { + security_log(SecurityEvent::InvalidNetworkPeer) + .error(&e) + .data(¬e) + .data(&trusted_peers) + .log(); + e + }) + }) + .collect::, NetworkError>>()?; + Ok(msg.take_notes().into_vec()) + }); + (peer_id, result) +} + +// Verifies validity of notes. Following conditions should be met for validity: +// 1. We should be able to correctly parse the peer id in each note. +// 2. The signature should be verified to be of the given peer for the serialized peer info. +// 3. The address(es) in the PeerInfo should be correctly parsable as Multiaddrs. +fn is_valid( + note: &Note, + trusted_peers: Arc>>, +) -> Result<(), NetworkError> { + let peer_id = PeerId::try_from(note.get_peer_id()) + .map_err(|err| err.context(NetworkErrorKind::ParsingError))?; + verify_signatures( + trusted_peers, + peer_id, + note.get_signature(), + note.get_peer_info(), + )?; + let peer_info: PeerInfo = protobuf::parse_from_bytes(note.get_peer_info())?; + for addr in peer_info.get_addrs() { + let _: Multiaddr = Multiaddr::try_from(addr.clone())?; + } + Ok(()) +} + +fn get_hash(msg: &[u8]) -> HashValue { + let mut hasher = DiscoveryMsgHasher::default(); + hasher.write(msg); + hasher.finish() +} + +fn verify_signatures( + trusted_peers: Arc>>, + signer: PeerId, + signature: &[u8], + msg: &[u8], +) -> Result<(), NetworkError> { + let verifier = SignatureValidator::new( + trusted_peers + .read() + .unwrap() + .iter() + .map(|(peer_id, network_public_keys)| { + (*peer_id, network_public_keys.signing_public_key) + }) + .collect(), + 1, /* quorum size */ + ); + let signature = Signature::from_compact(signature) + .map_err(|err| err.context(NetworkErrorKind::SignatureError))?; + verifier.verify_signature(signer, get_hash(msg), &signature)?; + Ok(()) +} + +fn sign(signer: &Signer, msg: &[u8]) -> Vec { + signer + .sign_message(get_hash(msg)) + .unwrap() + .to_compact() + .to_vec() +} + +async fn push_state_to_peer( + mut sender: PeerManagerRequestSender, + peer_id: PeerId, + msg: DiscoveryMsg, +) -> Result<(), NetworkError> +where + TSubstream: AsyncRead + AsyncWrite + Send + Unpin + 'static, +{ + trace!( + "Push discovery message to peer {} msg: {:?}", + peer_id.short_str(), + msg + ); + // Request a new substream to peer. + let substream = sender + .open_substream(peer_id, ProtocolId::from_static(DISCOVERY_PROTOCOL_NAME)) + .await?; + // Messages are length-prefixed. Wrap in a framed stream. + let mut substream = Framed::new(substream.compat(), UviBytes::default()).sink_compat(); + // Send serialized message to peer. + let bytes = msg + .write_to_bytes() + .expect("writing protobuf failed; should never happen"); + substream.send(Bytes::from(bytes)).await?; + Ok(()) +} + +async fn recv_msg(substream: TSubstream) -> Result +where + TSubstream: AsyncRead + AsyncWrite + Unpin, +{ + // Messages are length-prefixed. Wrap in a framed stream. + let mut substream = Framed::new(substream.compat(), UviBytes::::default()).sink_compat(); + // Read the message. + utils::read_proto(&mut substream).await +} diff --git a/network/src/protocols/discovery/test.rs b/network/src/protocols/discovery/test.rs new file mode 100644 index 0000000000000..dc34c0e840dab --- /dev/null +++ b/network/src/protocols/discovery/test.rs @@ -0,0 +1,301 @@ +// Copyright (c) The Libra Core Contributors +// SPDX-License-Identifier: Apache-2.0 + +use super::*; +use crate::{peer_manager::PeerManagerRequest, proto::DiscoveryMsg}; +use core::str::FromStr; +use crypto::{signing, x25519}; +use futures::future::{FutureExt, TryFutureExt}; +use memsocket::MemorySocket; +use tokio::runtime::Runtime; + +fn get_random_seed() -> PeerInfo { + let mut peer_info = PeerInfo::new(); + peer_info.set_epoch(1); + peer_info.mut_addrs().push( + Multiaddr::from_str("/ip4/127.0.0.1/tcp/9090") + .unwrap() + .as_ref() + .into(), + ); + peer_info +} + +fn setup_discovery( + rt: &mut Runtime, + peer_id: PeerId, + address: Multiaddr, + seed_peer_id: PeerId, + seed_peer_info: PeerInfo, + signer: Signer, + trusted_peers: Arc>>, +) -> ( + channel::Receiver>, + channel::Receiver, + channel::Sender>, + channel::Sender<()>, +) { + let (peer_mgr_reqs_tx, peer_mgr_reqs_rx) = channel::new_test(0); + let (conn_mgr_reqs_tx, conn_mgr_reqs_rx) = channel::new_test(1); + let (peer_mgr_notifs_tx, peer_mgr_notifs_rx) = channel::new_test(0); + let (ticker_tx, ticker_rx) = channel::new_test(0); + let discovery = { + Discovery::new( + peer_id, + vec![address], + signer, + vec![(seed_peer_id, seed_peer_info)].into_iter().collect(), + trusted_peers, + ticker_rx, + PeerManagerRequestSender::new(peer_mgr_reqs_tx), + peer_mgr_notifs_rx, + conn_mgr_reqs_tx, + Duration::from_secs(180), + ) + }; + rt.spawn(discovery.start().boxed().unit_error().compat()); + ( + peer_mgr_reqs_rx, + conn_mgr_reqs_rx, + peer_mgr_notifs_tx, + ticker_tx, + ) +} + +fn get_addrs(note: &Note) -> Vec { + let peer_info: PeerInfo = protobuf::parse_from_bytes(note.get_peer_info()).unwrap(); + let mut addrs = vec![]; + for addr in peer_info.get_addrs() { + addrs.push(Multiaddr::try_from(addr.clone()).unwrap()); + } + addrs +} + +async fn expect_address_update( + conn_mgr_reqs_rx: &mut channel::Receiver, + peer_id: PeerId, + addr: Multiaddr, +) { + match conn_mgr_reqs_rx.next().await.unwrap() { + ConnectivityRequest::UpdateAddresses(p, addrs) => { + assert_eq!(peer_id, p); + assert_eq!(1, addrs.len()); + assert_eq!(addr, addrs[0]); + } + _ => { + panic!("unexpected request to connectivity manager"); + } + } +} + +fn generate_network_pub_keys_and_signer(peer_id: PeerId) -> (NetworkPublicKeys, Signer) { + let (signing_priv_key, signing_pub_key) = signing::generate_keypair(); + let (_, identity_pub_key) = x25519::generate_keypair(); + ( + NetworkPublicKeys { + signing_public_key: signing_pub_key, + identity_public_key: identity_pub_key, + }, + Signer::new(peer_id, signing_pub_key, signing_priv_key), + ) +} + +#[test] +// Test behavior on receipt of an inbound DiscoveryMsg. +fn inbound() { + ::logger::try_init_for_testing(); + let mut rt = Runtime::new().unwrap(); + // Setup self. + let peer_id = PeerId::random(); + let address = Multiaddr::from_str("/ip4/127.0.0.1/tcp/9090").unwrap(); + let (self_pub_keys, self_signer) = generate_network_pub_keys_and_signer(peer_id); + // Setup seed. + let mut seed_peer_info = get_random_seed(); + let seed_peer_id = PeerId::random(); + let (seed_pub_keys, seed_signer) = generate_network_pub_keys_and_signer(seed_peer_id); + let trusted_peers = Arc::new(RwLock::new( + vec![(seed_peer_id, seed_pub_keys), (peer_id, self_pub_keys)] + .into_iter() + .collect(), + )); + // Setup discovery. + let (_, mut conn_mgr_reqs_rx, mut peer_mgr_notifs_tx, _) = setup_discovery( + &mut rt, + peer_id, + address.clone(), + seed_peer_id, + seed_peer_info.clone(), + self_signer, + trusted_peers.clone(), + ); + + // Fake connectivity manager and dialer. + let f_peer_mgr = async move { + let seed_peer_address = Multiaddr::try_from(seed_peer_info.get_addrs()[0].clone()).unwrap(); + // Connectivity manager receives addresses of the seed peer during bootstrap. + expect_address_update( + &mut conn_mgr_reqs_rx, + seed_peer_id, + seed_peer_address.clone(), + ) + .await; + + let (dialer_substream, listener_substream) = MemorySocket::new_pair(); + // Notify discovery actor of inbound substream. + + peer_mgr_notifs_tx + .send(PeerManagerNotification::NewInboundSubstream( + seed_peer_id, + NegotiatedSubstream { + protocol: ProtocolId::from_static(DISCOVERY_PROTOCOL_NAME), + substream: listener_substream, + }, + )) + .await + .unwrap(); + // Wrap dialer substream in a framed substream. + let mut dialer_substream = + Framed::new(dialer_substream.compat(), UviBytes::::default()).sink_compat(); + + // Send DiscoveryMsg consisting of 2 notes to the discovery actor - one note for the + // seed peer and one for another peer. The discovery actor should send addresses of the new + // peer to the connectivity manager. + let peer_id_other = PeerId::random(); + let address_other = Multiaddr::from_str("/ip4/172.29.52.192/tcp/8080").unwrap(); + let seed_note = create_note(&seed_signer, seed_peer_id, seed_peer_info.clone()); + let (pub_keys_other, signer_other) = generate_network_pub_keys_and_signer(peer_id_other); + trusted_peers + .write() + .unwrap() + .insert(peer_id_other, pub_keys_other); + let note_other = { + let mut peer_info = PeerInfo::new(); + let addrs = peer_info.mut_addrs(); + addrs.clear(); + addrs.push(address_other.as_ref().into()); + create_note(&signer_other, peer_id_other, peer_info) + }; + let mut msg = DiscoveryMsg::new(); + msg.mut_notes().push(note_other.clone()); + msg.mut_notes().push(seed_note.clone()); + dialer_substream + .send(msg.write_to_bytes().unwrap().into()) + .await + .unwrap(); + + // Connectivity manager receives address of new peer. + expect_address_update(&mut conn_mgr_reqs_rx, peer_id_other, address_other).await; + + // Connectivity manager receives a connect to the seed peer at the same address. + expect_address_update(&mut conn_mgr_reqs_rx, seed_peer_id, seed_peer_address).await; + + let (dialer_substream, listener_substream) = MemorySocket::new_pair(); + // Notify discovery actor of inbound substream. + + peer_mgr_notifs_tx + .send(PeerManagerNotification::NewInboundSubstream( + peer_id_other, + NegotiatedSubstream { + protocol: ProtocolId::from_static(DISCOVERY_PROTOCOL_NAME), + substream: listener_substream, + }, + )) + .await + .unwrap(); + // Wrap dialer substream in a framed substream. + let mut dialer_substream = + Framed::new(dialer_substream.compat(), UviBytes::::default()).sink_compat(); + // Compose new msg. + let mut msg = DiscoveryMsg::new(); + msg.mut_notes().push(note_other); + let new_seed_addr = Multiaddr::from_str("/ip4/127.0.0.1/tcp/8098").unwrap(); + { + seed_peer_info.set_epoch(3000); + seed_peer_info.mut_addrs().clear(); + seed_peer_info + .mut_addrs() + .push(new_seed_addr.as_ref().into()); + let seed_note = create_note(&seed_signer, seed_peer_id, seed_peer_info); + msg.mut_notes().push(seed_note); + dialer_substream + .send(msg.write_to_bytes().unwrap().into()) + .await + .unwrap(); + } + + // Connectivity manager receives new address of seed peer. + expect_address_update(&mut conn_mgr_reqs_rx, seed_peer_id, new_seed_addr).await; + }; + rt.block_on(f_peer_mgr.boxed().unit_error().compat()) + .unwrap(); +} + +#[test] +// Test that discovery actor sends a DiscoveryMsg to a neighbor on receiving a clock tick. +fn outbound() { + ::logger::try_init_for_testing(); + let mut rt = Runtime::new().unwrap(); + // Setup self. + let peer_id = PeerId::random(); + let address = Multiaddr::from_str("/ip4/127.0.0.1/tcp/9090").unwrap(); + let (self_pub_keys, self_signer) = generate_network_pub_keys_and_signer(peer_id); + // Setup seed. + let seed_peer_id = PeerId::random(); + let seed_peer_info = get_random_seed(); + let (seed_pub_keys, _) = generate_network_pub_keys_and_signer(seed_peer_id); + let trusted_peers = Arc::new(RwLock::new( + vec![(seed_peer_id, seed_pub_keys), (peer_id, self_pub_keys)] + .into_iter() + .collect(), + )); + // Setup discovery. + let (mut peer_mgr_reqs_rx, _conn_mgr_req_rx, mut peer_mgr_notifs_tx, mut ticker_tx) = + setup_discovery( + &mut rt, + peer_id, + address.clone(), + seed_peer_id, + seed_peer_info.clone(), + self_signer, + trusted_peers.clone(), + ); + + // Fake connectivity manager and dialer. + let f_peer_mgr = async move { + let (dialer_substream, listener_substream) = MemorySocket::new_pair(); + let seed_peer_address = Multiaddr::try_from(seed_peer_info.get_addrs()[0].clone()).unwrap(); + // Notify discovery actor of connection to seed peer. + peer_mgr_notifs_tx + .send(PeerManagerNotification::NewPeer( + seed_peer_id, + seed_peer_address, + )) + .await + .unwrap(); + + // Trigger outbound msg. + ticker_tx.send(()).await.unwrap(); + + // Request outgoing substream from PeerManager. + match peer_mgr_reqs_rx.next().await.unwrap() { + PeerManagerRequest::OpenSubstream(peer, protocol, ch) => { + assert_eq!(peer, seed_peer_id); + assert_eq!(protocol, DISCOVERY_PROTOCOL_NAME); + ch.send(Ok(dialer_substream)).unwrap(); + } + _ => { + panic!("unexpected request to peer manager"); + } + } + + // Receive DiscoveryMsg from actor. The message should contain only a note for the + // sending peer since it doesn't yet have the note for the seed peer. + let msg = recv_msg(listener_substream).await.unwrap(); + assert_eq!(1, msg.get_notes().len()); + assert_eq!(Vec::from(peer_id), msg.get_notes()[0].get_peer_id()); + assert_eq!(address, get_addrs(&msg.get_notes()[0])[0]); + }; + + rt.block_on(f_peer_mgr.boxed().unit_error().compat()) + .unwrap(); +} diff --git a/network/src/protocols/health_checker/mod.rs b/network/src/protocols/health_checker/mod.rs new file mode 100644 index 0000000000000..d15c6b8960d9b --- /dev/null +++ b/network/src/protocols/health_checker/mod.rs @@ -0,0 +1,284 @@ +// Copyright (c) The Libra Core Contributors +// SPDX-License-Identifier: Apache-2.0 + +//! Protocol used to ensure peer liveness +//! +//! The HealthChecker is responsible for ensuring liveness of all peers of a node. +//! It does so by periodically selecting a random connected peer and sending a Ping probe. A +//! healthy peer is expected to respond with a corresponding Pong message. +//! +//! If a certain number of successive liveness probes for a peer fail, the HealthChecker initiates a +//! disconnect from the peer. It relies on ConnectivityManager or the remote peer to re-establish +//! the connection. +//! +//! Future Work +//! ----------- +//! We can make a few other improvements to the health checker. These are: +//! - Make the policy for interpreting ping failures pluggable +//! - Use successful inbound pings as a sign of remote note being healthy +//! - Ping a peer only in periods of no application-level communication with the peer +use crate::{ + error::NetworkError, + peer_manager::{PeerManagerNotification, PeerManagerRequestSender}, + proto::{Ping, Pong}, + utils::read_proto, + ProtocolId, +}; +use bytes::Bytes; +use channel; +use futures::{ + compat::{Future01CompatExt, Sink01CompatExt}, + future::{FutureExt, TryFutureExt}, + io::{AsyncRead, AsyncReadExt, AsyncWrite}, + sink::SinkExt, + stream::{FusedStream, FuturesUnordered, Stream, StreamExt}, +}; +use logger::prelude::*; +use protobuf::{self, Message}; +use rand::{rngs::SmallRng, seq::SliceRandom, FromEntropy}; +use std::{collections::HashMap, fmt::Debug, time::Duration}; +use tokio::{codec::Framed, prelude::FutureExt as _}; +use types::PeerId; +use unsigned_varint::codec::UviBytes; + +#[cfg(test)] +mod test; + +/// Protocol name for Ping. +pub const PING_PROTOCOL_NAME: &[u8] = b"/libra/ping/0.1.0"; + +/// The actor performing health checks by running the Ping protocol +pub struct HealthChecker { + /// Ticker to trigger ping to a random peer. In production, the ticker is likely to be + /// fixed duration interval timer. + ticker: TTicker, + /// Channel to send requests to PeerManager. + peer_mgr_reqs_tx: PeerManagerRequestSender, + /// Channel to receive notifications from PeerManager about new/lost connections. + peer_mgr_notifs_rx: channel::Receiver>, + /// Map from connected peer to last round of successful ping, and number of failures since + /// then. + connected: HashMap, + /// Random-number generator. + rng: SmallRng, + /// Ping timmeout duration. + ping_timeout: Duration, + /// Number of successive ping failures we tolerate before declaring a node as unhealthy and + /// disconnecting from it. In the future, this can be replaced with a more general failure + /// detection policy. + ping_failures_tolerated: u64, + /// Counter incremented in each round of health checks + round: u64, +} + +impl HealthChecker +where + TTicker: Stream + FusedStream + Unpin, + TSubstream: AsyncRead + AsyncWrite + Send + Unpin + Debug + 'static, +{ + /// Create new intance of the [`HealthChecker`] actor. + pub fn new( + ticker: TTicker, + peer_mgr_reqs_tx: PeerManagerRequestSender, + peer_mgr_notifs_rx: channel::Receiver>, + ping_timeout: Duration, + ping_failures_tolerated: u64, + ) -> Self { + HealthChecker { + ticker, + peer_mgr_reqs_tx, + peer_mgr_notifs_rx, + connected: HashMap::new(), + rng: SmallRng::from_entropy(), + ping_timeout, + ping_failures_tolerated, + round: 0, + } + } + + pub async fn start(mut self) { + let mut tick_handlers = FuturesUnordered::new(); + let mut ping_handlers = FuturesUnordered::new(); + loop { + futures::select! { + notif = self.peer_mgr_notifs_rx.select_next_some() => { + match notif { + PeerManagerNotification::NewPeer(peer_id, _) => { + self.connected.insert(peer_id, (self.round, 0)); + } + PeerManagerNotification::LostPeer(peer_id, _) => { + self.connected.remove(&peer_id); + } + PeerManagerNotification::NewInboundSubstream(peer_id, substream) => { + assert_eq!(substream.protocol, PING_PROTOCOL_NAME); + ping_handlers.push(Self::handle_ping(peer_id, substream.substream)); + } + } + } + _ = self.ticker.select_next_some() => { + self.round += 1; + debug!("Round number: {}", self.round); + match self.get_random_peer() { + Some(peer_id) => { + debug!("Will ping: {}", peer_id.short_str()); + tick_handlers.push( + Self::ping_peer( + peer_id, + self.round, + self.peer_mgr_reqs_tx.clone(), + self.ping_timeout.clone())); + } + None => { + debug!("No connected peer to ping"); + } + } + } + res = tick_handlers.select_next_some() => { + let (peer_id, round, ping_result) = res; + self.handle_ping_result(peer_id, round, ping_result).await; + } + _ = ping_handlers.select_next_some() => {} + complete => { + crit!("Health checker actor terminated"); + break; + } + } + } + } + + async fn handle_ping_result( + &mut self, + peer_id: PeerId, + round: u64, + ping_result: Result<(), NetworkError>, + ) { + debug!("Got result for ping round: {}", round); + match ping_result { + Ok(_) => { + debug!("Ping successful for peer: {}", peer_id.short_str()); + // Update last successful ping to current round. + self.connected + .entry(peer_id) + .and_modify(|(ref mut r, ref mut count)| { + if round > *r { + *r = round; + *count = 0; + } + }); + } + Err(err) => { + warn!( + "Ping failed for peer: {} with error: {:?}", + peer_id.short_str(), + err + ); + match self.connected.get_mut(&peer_id) { + None => { + // If we are no longer connected to the peer, we ignore ping + // failure. + } + Some((ref mut prev, ref mut failures)) => { + // If this is the result of an older ping, we ignore it. + if *prev > round { + return; + } + // Increment num of failures. If the ping failures are now more than + // `self.ping_failures_tolerated`, we disconnect from the node. + // The HealthChecker only performs the disconnect. It relies on + // ConnectivityManager or the remote peer to re-establish the connection. + *failures += 1; + if *failures > self.ping_failures_tolerated { + info!("Disonnecting from peer: {}", peer_id.short_str()); + if let Err(err) = self.peer_mgr_reqs_tx.disconnect_peer(peer_id).await { + warn!( + "Failed to disconnect from peer: {} with error: {:?}", + peer_id.short_str(), + err + ); + } + } + } + } + } + } + } + + async fn ping_peer( + peer_id: PeerId, + round: u64, + peer_mgr_reqs_tx: PeerManagerRequestSender, + ping_timeout: Duration, + ) -> (PeerId, u64, Result<(), NetworkError>) { + let ping_result = async move |mut peer_mgr_reqs_tx: PeerManagerRequestSender< + TSubstream, + >| + -> Result<(), NetworkError> { + // Request a new substream to peer. + debug!( + "Opening a new substream with peer: {} for Ping", + peer_id.short_str() + ); + let substream = peer_mgr_reqs_tx + .open_substream(peer_id, ProtocolId::from_static(PING_PROTOCOL_NAME)) + .await?; + // Messages are length-prefixed. Wrap in a framed stream. + let mut substream = Framed::new(substream.compat(), UviBytes::default()).sink_compat(); + // Send Ping. + debug!("Sending Ping to peer: {}", peer_id.short_str()); + substream + .send(Bytes::from(Ping::new().write_to_bytes().unwrap())) + .await?; + // Read Pong. + debug!("Waiting for Pong from peer: {}", peer_id.short_str()); + let _: Pong = read_proto(&mut substream).await?; + // Return success. + Ok(()) + }; + ( + peer_id, + round, + ping_result(peer_mgr_reqs_tx.clone()) + .boxed() + .compat() + .timeout(ping_timeout) + .compat() + .map_err(Into::::into) + .await, + ) + } + + async fn handle_ping(peer_id: PeerId, substream: TSubstream) { + // Messages are length-prefixed. Wrap in a framed stream. + let mut substream = + Framed::new(substream.compat(), UviBytes::::default()).sink_compat(); + // Read ping. + trace!("Waiting for Ping on new substream"); + let maybe_ping: Result = read_proto(&mut substream).await; + if let Err(err) = maybe_ping { + warn!( + "Failed to read ping from peer: {}. Error: {:?}", + peer_id.short_str(), + err + ); + return; + } + // Send Pong. + trace!("Sending Pong back"); + if let Err(err) = substream + .send(Bytes::from(Pong::new().write_to_bytes().unwrap())) + .await + { + warn!( + "Failed to send pong to peer: {}. Error: {:?}", + peer_id.short_str(), + err + ); + return; + } + } + + fn get_random_peer(&mut self) -> Option { + let peers: Vec<_> = self.connected.keys().cloned().collect(); + peers.choose(&mut self.rng).cloned() + } +} diff --git a/network/src/protocols/health_checker/test.rs b/network/src/protocols/health_checker/test.rs new file mode 100644 index 0000000000000..87bac4937d9af --- /dev/null +++ b/network/src/protocols/health_checker/test.rs @@ -0,0 +1,376 @@ +// Copyright (c) The Libra Core Contributors +// SPDX-License-Identifier: Apache-2.0 + +use super::*; +use crate::{common::NegotiatedSubstream, peer_manager::PeerManagerRequest}; +use futures::future::{FutureExt, TryFutureExt}; +use memsocket::MemorySocket; +use parity_multiaddr::Multiaddr; +use std::str::FromStr; +use tokio::runtime::Runtime; + +const PING_TIMEOUT: Duration = Duration::from_millis(500); + +fn setup_permissive_health_checker( + rt: &mut Runtime, + ping_failures_tolerated: u64, +) -> ( + channel::Receiver>, + channel::Sender>, + channel::Sender<()>, +) { + let (ticker_tx, ticker_rx) = channel::new_test(0); + let (peer_mgr_reqs_tx, peer_mgr_reqs_rx) = channel::new_test(0); + let (peer_mgr_notifs_tx, peer_mgr_notifs_rx) = channel::new_test(0); + let health_checker = HealthChecker::new( + ticker_rx, + PeerManagerRequestSender::new(peer_mgr_reqs_tx), + peer_mgr_notifs_rx, + PING_TIMEOUT, + ping_failures_tolerated, + ); + rt.spawn(health_checker.start().boxed().unit_error().compat()); + (peer_mgr_reqs_rx, peer_mgr_notifs_tx, ticker_tx) +} + +fn setup_default_health_checker( + rt: &mut Runtime, +) -> ( + channel::Receiver>, + channel::Sender>, + channel::Sender<()>, +) { + let (ticker_tx, ticker_rx) = channel::new_test(0); + let (peer_mgr_reqs_tx, peer_mgr_reqs_rx) = channel::new_test(0); + let (peer_mgr_notifs_tx, peer_mgr_notifs_rx) = channel::new_test(0); + let health_checker = HealthChecker::new( + ticker_rx, + PeerManagerRequestSender::new(peer_mgr_reqs_tx), + peer_mgr_notifs_rx, + PING_TIMEOUT, + 0, + ); + rt.spawn(health_checker.start().boxed().unit_error().compat()); + (peer_mgr_reqs_rx, peer_mgr_notifs_tx, ticker_tx) +} + +async fn send_ping_expect_pong(substream: MemorySocket) { + // Messages are length-prefixed. Wrap in a framed stream. + let mut substream = Framed::new(substream.compat(), UviBytes::::default()).sink_compat(); + // Send ping. + substream + .send(Bytes::from(Ping::new().write_to_bytes().unwrap())) + .await + .unwrap(); + // Expect Pong. + let _: Pong = read_proto(&mut substream).await.unwrap(); +} + +async fn expect_ping_send_ok(substream: MemorySocket) { + // Messages are length-prefixed. Wrap in a framed stream. + let mut substream = Framed::new(substream.compat(), UviBytes::::default()).sink_compat(); + // Read ping. + let _: Ping = read_proto(&mut substream).await.unwrap(); + // Send Pong. + substream + .send(Bytes::from(Pong::new().write_to_bytes().unwrap())) + .await + .unwrap(); +} + +async fn expect_ping_send_notok(substream: MemorySocket) { + // Messages are length-prefixed. Wrap in a framed stream. + let mut substream = Framed::new(substream.compat(), UviBytes::::default()).sink_compat(); + // Read ping. + let _: Ping = read_proto(&mut substream).await.unwrap(); + substream.close().await.unwrap(); +} + +async fn expect_ping_timeout(substream: MemorySocket) { + // Messages are length-prefixed. Wrap in a framed stream. + let mut substream = Framed::new(substream.compat(), UviBytes::::default()).sink_compat(); + // Read ping. + let _: Ping = read_proto(&mut substream).await.unwrap(); + // Sleep for ping timeout plus a little bit. + std::thread::sleep(PING_TIMEOUT + Duration::from_millis(100)); +} + +async fn open_substream_and_notify( + peer_id: PeerId, + peer_mgr_notifs_tx: &mut channel::Sender>, +) -> MemorySocket { + let (dialer_substream, listener_substream) = MemorySocket::new_pair(); + + peer_mgr_notifs_tx + .send(PeerManagerNotification::NewInboundSubstream( + peer_id, + NegotiatedSubstream { + protocol: ProtocolId::from_static(PING_PROTOCOL_NAME), + substream: listener_substream, + }, + )) + .await + .unwrap(); + dialer_substream +} + +async fn expect_disconnect( + peer_id: PeerId, + peer_mgr_reqs_rx: &mut channel::Receiver>, +) { + match peer_mgr_reqs_rx.next().await.unwrap() { + PeerManagerRequest::DisconnectPeer(peer, ch) => { + assert_eq!(peer, peer_id); + ch.send(Ok(())).unwrap(); + } + _ => { + panic!("unexpected request to peer manager"); + } + } +} + +async fn expect_open_substream( + peer_id: PeerId, + peer_mgr_reqs_rx: &mut channel::Receiver>, +) -> MemorySocket { + let (dialer_substream, listener_substream) = MemorySocket::new_pair(); + match peer_mgr_reqs_rx.next().await.unwrap() { + PeerManagerRequest::OpenSubstream(peer, protocol, ch) => { + assert_eq!(peer, peer_id); + assert_eq!(protocol, PING_PROTOCOL_NAME); + ch.send(Ok(dialer_substream)).unwrap(); + } + _ => { + panic!("unexpected request to peer manager"); + } + } + listener_substream +} + +#[test] +fn outbound() { + ::logger::try_init_for_testing(); + let mut rt = Runtime::new().unwrap(); + let (mut peer_mgr_reqs_rx, mut peer_mgr_notifs_tx, mut ticker_tx) = + setup_default_health_checker(&mut rt); + + let events_f = async move { + // Trigger ping to a peer. This should do nothing. + ticker_tx.send(()).await.unwrap(); + + // Notify HealthChecker of new connected node. + let peer_id = PeerId::random(); + let peer_address = Multiaddr::from_str("/ip4/127.0.0.1/tcp/9090").unwrap(); + + peer_mgr_notifs_tx + .send(PeerManagerNotification::NewPeer( + peer_id, + peer_address.clone(), + )) + .await + .unwrap(); + + // Trigger ping to a peer. This should ping the newly added peer. + ticker_tx.send(()).await.unwrap(); + + // Health checker should request for a new substream. + let listener_substream = expect_open_substream(peer_id, &mut peer_mgr_reqs_rx).await; + + // Health checker should send a ping request. + expect_ping_send_ok(listener_substream).await; + }; + rt.block_on(events_f.boxed().unit_error().compat()).unwrap(); +} + +#[test] +fn inbound() { + ::logger::try_init_for_testing(); + let mut rt = Runtime::new().unwrap(); + let (_, mut peer_mgr_notifs_tx, _) = setup_default_health_checker(&mut rt); + + let events_f = async move { + let peer_id = PeerId::random(); + + // Send notification about incoming Ping substream. + let dialer_substream = open_substream_and_notify(peer_id, &mut peer_mgr_notifs_tx).await; + + // Send ping and expect pong in return. + send_ping_expect_pong(dialer_substream).await; + }; + rt.block_on(events_f.boxed().unit_error().compat()).unwrap(); +} + +#[test] +fn outbound_failure_permissive() { + ::logger::try_init_for_testing(); + let mut rt = Runtime::new().unwrap(); + let ping_failures_tolerated = 10; + let (mut peer_mgr_reqs_rx, mut peer_mgr_notifs_tx, mut ticker_tx) = + setup_permissive_health_checker(&mut rt, ping_failures_tolerated); + + let events_f = async move { + // Trigger ping to a peer. This should do nothing. + ticker_tx.send(()).await.unwrap(); + + // Notify HealthChecker of new connected node. + let peer_id = PeerId::random(); + let peer_address = Multiaddr::from_str("/ip4/127.0.0.1/tcp/9090").unwrap(); + + peer_mgr_notifs_tx + .send(PeerManagerNotification::NewPeer( + peer_id, + peer_address.clone(), + )) + .await + .unwrap(); + + // Trigger pings to a peer. These should ping the newly added peer, but not disconnect from + // it. + for _ in 0..=ping_failures_tolerated { + ticker_tx.send(()).await.unwrap(); + // Health checker should request for a new substream. + let listener_substream = expect_open_substream(peer_id, &mut peer_mgr_reqs_rx).await; + // Health checker should send a ping request which fails. + expect_ping_send_notok(listener_substream).await; + } + // Health checker should disconnect from peer after tolerated number of failures + expect_disconnect(peer_id, &mut peer_mgr_reqs_rx).await; + }; + rt.block_on(events_f.boxed().unit_error().compat()).unwrap(); +} + +#[test] +fn ping_success_resets_fail_counter() { + ::logger::try_init_for_testing(); + let mut rt = Runtime::new().unwrap(); + let failures_triggered = 10; + let ping_failures_tolerated = 2 * 10; + let (mut peer_mgr_reqs_rx, mut peer_mgr_notifs_tx, mut ticker_tx) = + setup_permissive_health_checker(&mut rt, ping_failures_tolerated); + + let events_f = async move { + // Trigger ping to a peer. This should do nothing. + ticker_tx.send(()).await.unwrap(); + // Notify HealthChecker of new connected node. + let peer_id = PeerId::random(); + let peer_address = Multiaddr::from_str("/ip4/127.0.0.1/tcp/9090").unwrap(); + peer_mgr_notifs_tx + .send(PeerManagerNotification::NewPeer( + peer_id, + peer_address.clone(), + )) + .await + .unwrap(); + // Trigger pings to a peer. These should ping the newly added peer, but not disconnect from + // it. + { + for _ in 0..failures_triggered { + ticker_tx.send(()).await.unwrap(); + // Health checker should request for a new substream. + let listener_substream = + expect_open_substream(peer_id, &mut peer_mgr_reqs_rx).await; + // Health checker should send a ping request which fails. + expect_ping_send_notok(listener_substream).await; + } + } + // Trigger successful ping. This should reset the counter of ping failures. + { + ticker_tx.send(()).await.unwrap(); + // Health checker should request for a new substream. + let listener_substream = expect_open_substream(peer_id, &mut peer_mgr_reqs_rx).await; + // Health checker should send a ping request which succeeds + expect_ping_send_ok(listener_substream).await; + } + // We would then need to fail for more than `ping_failures_tolerated` times before + // triggering disconnect. + { + for _ in 0..=ping_failures_tolerated { + ticker_tx.send(()).await.unwrap(); + // Health checker should request for a new substream. + let listener_substream = + expect_open_substream(peer_id, &mut peer_mgr_reqs_rx).await; + // Health checker should send a ping request which fails. + expect_ping_send_notok(listener_substream).await; + } + } + // Health checker should disconnect from peer after tolerated number of failures + expect_disconnect(peer_id, &mut peer_mgr_reqs_rx).await; + }; + rt.block_on(events_f.boxed().unit_error().compat()).unwrap(); +} + +#[test] +fn outbound_failure_strict() { + ::logger::try_init_for_testing(); + let mut rt = Runtime::new().unwrap(); + let (mut peer_mgr_reqs_rx, mut peer_mgr_notifs_tx, mut ticker_tx) = + setup_default_health_checker(&mut rt); + + let events_f = async move { + // Trigger ping to a peer. This should do nothing. + ticker_tx.send(()).await.unwrap(); + + // Notify HealthChecker of new connected node. + let peer_id = PeerId::random(); + let peer_address = Multiaddr::from_str("/ip4/127.0.0.1/tcp/9090").unwrap(); + + peer_mgr_notifs_tx + .send(PeerManagerNotification::NewPeer( + peer_id, + peer_address.clone(), + )) + .await + .unwrap(); + + // Trigger ping to a peer. This should ping the newly added peer. + ticker_tx.send(()).await.unwrap(); + + // Health checker should request for a new substream. + let listener_substream = expect_open_substream(peer_id, &mut peer_mgr_reqs_rx).await; + + // Health checker should send a ping request which fails. + expect_ping_send_notok(listener_substream).await; + + // Health checker should disconnect from peer. + expect_disconnect(peer_id, &mut peer_mgr_reqs_rx).await; + }; + rt.block_on(events_f.boxed().unit_error().compat()).unwrap(); +} + +#[test] +fn ping_timeout() { + ::logger::try_init_for_testing(); + let mut rt = Runtime::new().unwrap(); + let (mut peer_mgr_reqs_rx, mut peer_mgr_notifs_tx, mut ticker_tx) = + setup_default_health_checker(&mut rt); + + let events_f = async move { + // Trigger ping to a peer. This should do nothing. + ticker_tx.send(()).await.unwrap(); + + // Notify HealthChecker of new connected node. + let peer_id = PeerId::random(); + let peer_address = Multiaddr::from_str("/ip4/127.0.0.1/tcp/9090").unwrap(); + + peer_mgr_notifs_tx + .send(PeerManagerNotification::NewPeer( + peer_id, + peer_address.clone(), + )) + .await + .unwrap(); + + // Trigger ping to a peer. This should ping the newly added peer. + ticker_tx.send(()).await.unwrap(); + + // Health checker should request for a new substream. + let listener_substream = expect_open_substream(peer_id, &mut peer_mgr_reqs_rx).await; + + // Health checker should send a ping request which fails. + expect_ping_timeout(listener_substream).await; + + // Health checker should disconnect from peer. + expect_disconnect(peer_id, &mut peer_mgr_reqs_rx).await; + }; + rt.block_on(events_f.boxed().unit_error().compat()).unwrap(); +} diff --git a/network/src/protocols/identity.rs b/network/src/protocols/identity.rs new file mode 100644 index 0000000000000..414fc0b070c09 --- /dev/null +++ b/network/src/protocols/identity.rs @@ -0,0 +1,191 @@ +// Copyright (c) The Libra Core Contributors +// SPDX-License-Identifier: Apache-2.0 + +//! Protocol used to identify key information about a remote +//! +//! Currently, the information shared as part of this protocol includes the peer identity and a +//! list of protocols supported by the peer. +use crate::{proto::IdentityMsg, ProtocolId}; +use bytes::Bytes; +use futures::{ + compat::{Compat, Sink01CompatExt}, + sink::SinkExt, + stream::StreamExt, +}; +use netcore::{ + multiplexing::StreamMultiplexer, + negotiate::{negotiate_inbound, negotiate_outbound_interactive}, + transport::ConnectionOrigin, +}; +use protobuf::{self, Message}; +use std::{convert::TryFrom, io}; +use tokio::codec::Framed; +use types::PeerId; +use unsigned_varint::codec::UviBytes; + +const IDENTITY_PROTOCOL_NAME: &[u8] = b"/identity/0.1.0"; + +/// The Identity of a node +#[derive(Clone, Debug, PartialEq, Eq)] +pub struct Identity { + peer_id: PeerId, + supported_protocols: Vec, +} + +impl Identity { + pub fn new(peer_id: PeerId, supported_protocols: Vec) -> Self { + Self { + peer_id, + supported_protocols, + } + } + + pub fn peer_id(&self) -> PeerId { + self.peer_id + } + + pub fn is_protocol_supported(&self, protocol: &ProtocolId) -> bool { + self.supported_protocols + .iter() + .any(|proto| proto == protocol) + } + + pub fn supported_protocols(&self) -> &[ProtocolId] { + &self.supported_protocols + } +} + +/// The Identity exchange protocol +pub async fn exchange_identity( + own_identity: &Identity, + connection: TMuxer, + origin: ConnectionOrigin, +) -> io::Result<(Identity, TMuxer)> +where + TMuxer: StreamMultiplexer, +{ + // Perform protocol negotiation on a substream on the connection. The dialer is responsible + // for opening the substream, while the listener is responsible for listening for that + // incoming substream. + let (substream, proto) = match origin { + ConnectionOrigin::Inbound => { + let mut listener = connection.listen_for_inbound(); + let substream = listener.next().await.ok_or_else(|| { + io::Error::new( + io::ErrorKind::ConnectionAborted, + "Connection closed by remote", + ) + })??; + negotiate_inbound(substream, [IDENTITY_PROTOCOL_NAME]).await? + } + ConnectionOrigin::Outbound => { + let substream = connection.open_outbound().await?; + negotiate_outbound_interactive(substream, [IDENTITY_PROTOCOL_NAME]).await? + } + }; + + assert_eq!(proto, IDENTITY_PROTOCOL_NAME); + + // Create the Framed Sink/Stream + let mut framed_substream = + Framed::new(Compat::new(substream), UviBytes::default()).sink_compat(); + + // Build Identity Message + let mut msg = IdentityMsg::new(); + msg.set_supported_protocols(own_identity.supported_protocols().to_vec()); + msg.set_peer_id(own_identity.peer_id().into()); + + // Send serialized message to peer. + let bytes = msg + .write_to_bytes() + .expect("writing protobuf failed; should never happen"); + framed_substream.send(Bytes::from(bytes)).await?; + framed_substream.close().await?; + + // Read an IdentityMsg from the Remote + let response = framed_substream.next().await.ok_or_else(|| { + io::Error::new( + io::ErrorKind::ConnectionAborted, + "Connection closed by remote", + ) + })??; + let mut response = ::protobuf::parse_from_bytes::(&response).map_err(|e| { + io::Error::new( + io::ErrorKind::InvalidData, + format!("Failed to parse identity msg: {}", e), + ) + })?; + let peer_id = PeerId::try_from(response.take_peer_id()).expect("Invalid PeerId"); + let identity = Identity::new(peer_id, response.take_supported_protocols()); + + Ok((identity, connection)) +} + +#[cfg(test)] +mod tests { + use crate::{ + protocols::identity::{exchange_identity, Identity}, + ProtocolId, + }; + use futures::{executor::block_on, future::join}; + use memsocket::MemorySocket; + use netcore::{ + multiplexing::yamux::{Mode, Yamux}, + transport::ConnectionOrigin, + }; + use types::PeerId; + + fn build_test_connection() -> (Yamux, Yamux) { + let (dialer, listener) = MemorySocket::new_pair(); + + ( + Yamux::new(dialer, Mode::Client), + Yamux::new(listener, Mode::Server), + ) + } + + #[test] + fn simple_identify() { + let (outbound, inbound) = build_test_connection(); + let server_identity = Identity::new( + PeerId::random(), + vec![ + ProtocolId::from_static(b"/proto/1.0.0"), + ProtocolId::from_static(b"/proto/2.0.0"), + ], + ); + let client_identity = Identity::new( + PeerId::random(), + vec![ + ProtocolId::from_static(b"/proto/1.0.0"), + ProtocolId::from_static(b"/proto/2.0.0"), + ProtocolId::from_static(b"/proto/3.0.0"), + ], + ); + let server_identity_config = server_identity.clone(); + let client_identity_config = client_identity.clone(); + + let server = async move { + let (identity, _connection) = + exchange_identity(&server_identity_config, inbound, ConnectionOrigin::Inbound) + .await + .unwrap(); + + assert_eq!(identity, client_identity); + }; + + let client = async move { + let (identity, _connection) = exchange_identity( + &client_identity_config, + outbound, + ConnectionOrigin::Outbound, + ) + .await + .unwrap(); + + assert_eq!(identity, server_identity); + }; + + block_on(join(server, client)); + } +} diff --git a/network/src/protocols/mod.rs b/network/src/protocols/mod.rs new file mode 100644 index 0000000000000..ee9cd44a62aae --- /dev/null +++ b/network/src/protocols/mod.rs @@ -0,0 +1,16 @@ +// Copyright (c) The Libra Core Contributors +// SPDX-License-Identifier: Apache-2.0 + +//! Protocols used by network module for external APIs and internal functionality +//! +//! Each protocol corresponds to a certain order of messages +pub mod direct_send; +pub mod rpc; + +pub(crate) mod discovery; +pub(crate) mod health_checker; +pub(crate) mod identity; + +// Keep this module for now to be used in PeerManager tests +#[cfg(test)] +pub(crate) mod peer_id_exchange; diff --git a/network/src/protocols/peer_id_exchange.rs b/network/src/protocols/peer_id_exchange.rs new file mode 100644 index 0000000000000..f06ca6c00a22e --- /dev/null +++ b/network/src/protocols/peer_id_exchange.rs @@ -0,0 +1,132 @@ +// Copyright (c) The Libra Core Contributors +// SPDX-License-Identifier: Apache-2.0 + +// Simple transport used to identify the PeerId of a remote +// +use futures::io::{AsyncRead, AsyncReadExt, AsyncWrite, AsyncWriteExt}; +use netcore::{ + negotiate::{negotiate_inbound, negotiate_outbound_interactive}, + transport::ConnectionOrigin, +}; +use std::{convert::TryInto, io}; +use types::PeerId; + +const PEER_ID_EXCHANGE_PROTOCOL_NAME: &[u8] = b"/peer_id_exchange/0.1.0"; + +/// The protocol upgrade configuration. +#[derive(Clone)] +pub struct PeerIdExchange(PeerId); + +/// A simple PeerID exchange protocol +/// +/// A PeerID is sent from each side in the form of: +/// <[u8] PeerId> +// TODO change to using u16 framing +impl PeerIdExchange { + pub fn new(peer_id: PeerId) -> Self { + Self(peer_id) + } + + pub async fn exchange_peer_id( + self, + socket: TSocket, + origin: ConnectionOrigin, + ) -> io::Result<(PeerId, TSocket)> + where + TSocket: AsyncRead + AsyncWrite + Unpin, + { + // Perform protocol negotiation + let (mut socket, proto) = match origin { + ConnectionOrigin::Inbound => { + negotiate_inbound(socket, [PEER_ID_EXCHANGE_PROTOCOL_NAME]).await? + } + ConnectionOrigin::Outbound => { + negotiate_outbound_interactive(socket, [PEER_ID_EXCHANGE_PROTOCOL_NAME]).await? + } + }; + + assert_eq!(proto, PEER_ID_EXCHANGE_PROTOCOL_NAME); + + // Now exchange your PeerIds + let mut buf: Vec = self.0.into(); + let buf_len = buf.len(); + assert!(buf_len == 32); + let buf_len = buf_len as u8; + buf.insert(0, buf_len); + + socket.write_all(&buf).await?; + socket.flush().await?; + socket.read_exact(&mut buf[0..1]).await?; + let len = buf[0] as usize; + buf.resize(len, 0); + + socket.read_exact(&mut buf).await?; + + Ok((buf.try_into().expect("Invalid PeerId"), socket)) + } +} + +#[cfg(test)] +mod tests { + use super::PeerIdExchange; + use futures::{ + executor::block_on, + future::join, + io::{AsyncReadExt, AsyncWriteExt}, + stream::StreamExt, + }; + use memsocket::MemorySocket; + use netcore::transport::{ + boxed::BoxedTransport, memory::MemoryTransport, Transport, TransportExt, + }; + use types::PeerId; + + // Build an unsecure transport + fn test_transport(peer_id: PeerId) -> BoxedTransport<(PeerId, MemorySocket), ::std::io::Error> { + let transport = MemoryTransport::default(); + let peer_identifier_config = PeerIdExchange::new(peer_id); + + transport + .and_then(move |socket, origin| peer_identifier_config.exchange_peer_id(socket, origin)) + .boxed() + } + + #[test] + fn peer_identifier() { + let peer_a = PeerId::random(); + let peer_b = PeerId::random(); + + // Create the listener + let (mut listener, address) = test_transport(peer_a) + .listen_on("/memory/0".parse().unwrap()) + .unwrap(); + + let server = async move { + if let Some(result) = listener.next().await { + let (upgrade, _addr) = result.unwrap(); + let (peer_id, mut socket) = upgrade.await.unwrap(); + + assert_eq!(peer_b, peer_id); + + socket.write_all(b"hello world").await.unwrap(); + socket.flush().await.unwrap(); + socket.close().await.unwrap(); + } + }; + + let client = async move { + let (peer_id, mut socket) = + test_transport(peer_b).dial(address).unwrap().await.unwrap(); + + assert_eq!(peer_a, peer_id); + + let mut buf = Vec::new(); + socket.read_to_end(&mut buf).await.unwrap(); + socket.close().await.unwrap(); + + assert_eq!(buf, b"hello world"); + }; + + block_on(join(server, client)); + } +} diff --git a/network/src/protocols/rpc/error.rs b/network/src/protocols/rpc/error.rs new file mode 100644 index 0000000000000..ea5b4ff847a06 --- /dev/null +++ b/network/src/protocols/rpc/error.rs @@ -0,0 +1,95 @@ +// Copyright (c) The Libra Core Contributors +// SPDX-License-Identifier: Apache-2.0 + +//! Rpc protocol errors + +use crate::peer_manager::PeerManagerError; +use failure::{self, Fail}; +use futures::channel::{mpsc, oneshot}; +use protobuf::error::ProtobufError; +use std::io; +use tokio::timer; +use types::PeerId; + +#[derive(Debug, Fail)] +pub enum RpcError { + #[fail(display = "IO error: {}", _0)] + IoError(#[fail(cause)] io::Error), + + #[fail(display = "Failed to open substream, not connected with peer: {}", _0)] + NotConnected(PeerId), + + #[fail(display = "Error parsing protobuf message: {:?}", _0)] + ProtobufParseError(#[fail(cause)] ProtobufError), + + #[fail(display = "Received invalid rpc response message")] + InvalidRpcResponse, + + #[fail(display = "Received unexpected rpc response message; expected remote to half-close.")] + UnexpectedRpcResponse, + + #[fail(display = "Received unexpected rpc request message; expected remote to half-close.")] + UnexpectedRpcRequest, + + #[fail(display = "Application layer unexpectedly dropped response channel")] + UnexpectedResponseChannelCancel, + + #[fail(display = "Error in application layer handling rpc request: {:?}", _0)] + ApplicationError(#[fail(cause)] failure::Error), + + #[fail(display = "Error sending on mpsc channel: {:?}", _0)] + MpscSendError(#[fail(cause)] mpsc::SendError), + + #[fail(display = "Rpc timed out")] + TimedOut, + + #[fail(display = "Error setting timeout: {:?}", _0)] + TimerError(#[fail(cause)] timer::Error), +} + +impl From for RpcError { + fn from(err: io::Error) -> Self { + RpcError::IoError(err) + } +} + +impl From for RpcError { + fn from(err: PeerManagerError) -> Self { + match err { + PeerManagerError::NotConnected(peer_id) => RpcError::NotConnected(peer_id), + _ => unreachable!("open_substream only returns NotConnected errors"), + } + } +} + +impl From for RpcError { + fn from(err: ProtobufError) -> RpcError { + RpcError::ProtobufParseError(err) + } +} + +impl From for RpcError { + fn from(_: oneshot::Canceled) -> Self { + RpcError::UnexpectedResponseChannelCancel + } +} + +impl From for RpcError { + fn from(err: mpsc::SendError) -> RpcError { + RpcError::MpscSendError(err) + } +} + +impl From> for RpcError { + fn from(err: timer::timeout::Error) -> RpcError { + if err.is_elapsed() { + RpcError::TimedOut + } else if err.is_timer() { + RpcError::TimerError(err.into_timer().unwrap()) + } else if err.is_inner() { + err.into_inner().unwrap() + } else { + unreachable!("tokio timeout Error only has 3 cases; the above if cases are therefore exhaustive.") + } + } +} diff --git a/network/src/protocols/rpc/mod.rs b/network/src/protocols/rpc/mod.rs new file mode 100644 index 0000000000000..0d64753e15146 --- /dev/null +++ b/network/src/protocols/rpc/mod.rs @@ -0,0 +1,429 @@ +// Copyright (c) The Libra Core Contributors +// SPDX-License-Identifier: Apache-2.0 + +//! Protocol for making and handling Remote Procedure Calls +//! +//! # SLURP: Simple Libra Unary Rpc Protocol +//! +//! SLURP takes advantage of [muxers] and [substream negotiation] to build a +//! simple rpc protocol. Concretely, +//! +//! 1. Every rpc call runs in its own substream. Instead of managing a completion +//! queue of message ids, we instead delegate this handling to the muxer so +//! that the underlying substream controls the lifetime of the rpc call. +//! Additionally, on certain transports (e.g., QUIC) we avoid head-of-line +//! blocking as the substreams are independent. +//! 2. An rpc method call negotiates which method to call using [`protocol-select`]. +//! This allows simple versioning of rpc methods and negotiation of which +//! methods are supported. In the future, we can potentially support multiple +//! backwards-incompatible versions of any rpc method. +//! 3. The actual structure of the request/response wire messages is left for +//! higher layers to specify. The rpc protocol is only concerned with shipping +//! around opaque blobs. Current libra rpc clients (consensus, mempool) mostly +//! send protobuf enums around over a single rpc protocol, +//! e.g., `/libra/consensus/rpc/0.1.0`. +//! +//! ## Wire Protocol (dialer): +//! +//! To make an rpc request to a remote peer, the dialer +//! +//! 1. Requests a new outbound substream from the muxer. +//! 2. Negotiates the substream using [`protocol-select`] to the rpc method they +//! wish to call, e.g., `/libra/mempool/rpc/0.1.0`. +//! 3. Sends the serialized request arguments on the newly negotiated substream. +//! 4. Half-closes their output side. +//! 5. Awaits the serialized response message from remote. +//! 6. Awaits the listener's half-close to complete the substream close. +//! +//! ## Wire Protocol (listener): +//! +//! To handle new rpc requests from remote peers, the listener +//! +//! 1. Polls for new inbound substreams on the muxer. +//! 2. Negotiates inbound substreams using [`protocol-select`]. The negotiation +//! must only succeed if the requested rpc method is actually supported. +//! 3. Awaits the serialized request arguments on the newly negotiated substream. +//! 4. Awaits the dialer's half-close. +//! 5. Handles the request by sending it up through the +//! [`NetworkProvider`](crate::interface::NetworkProvider) +//! actor to a higher layer rpc client like consensus or mempool, who then +//! sends the serialed rpc response back down to the rpc layer. +//! 6. Sends the serialized response message to the dialer. +//! 7. Half-closes their output side to complete the substream close. +//! +//! Note: negotiated substreams are currently framed with the +//! [muiltiformats unsigned varint length-prefix](https://github.com/multiformats/unsigned-varint) +//! +//! [muxers]: ../../../netcore/multiplexing/index.html +//! [substream negotiation]: ../../../netcore/negotiate/index.html +//! [`protocol-select`]: ../../../netcore/negotiate/index.html + +use crate::{ + counters, + peer_manager::{PeerManagerNotification, PeerManagerRequestSender}, + sink::NetworkSinkExt, + ProtocolId, +}; +use bytes::Bytes; +use channel; +use error::RpcError; +use futures::{ + channel::oneshot, + compat::{Future01CompatExt, Sink01CompatExt}, + future::{self, FutureExt, TryFutureExt}, + io::{AsyncRead, AsyncReadExt, AsyncWrite}, + sink::SinkExt, + stream::{select, StreamExt}, + task::Context, +}; +use logger::prelude::*; +use std::{fmt::Debug, io, time::Duration}; +use tokio::{codec::Framed, prelude::FutureExt as Future01Ext}; +use types::PeerId; +use unsigned_varint::codec::UviBytes; + +pub mod error; + +#[cfg(test)] +mod test; + +/// A wrapper struct for an inbound rpc request and its associated context. +#[derive(Debug)] +pub struct InboundRpcRequest { + /// Rpc method identifier, e.g., `/libra/consensus/rpc/0.1.0`. This is used + /// to dispath the request to the corresponding client handler. + pub protocol: ProtocolId, + /// The serialized request data received from the sender. + pub data: Bytes, + /// Channel over which the rpc response is sent from the upper client layer + /// to the rpc layer. + /// + /// The rpc actor holds onto the receiving end of this channel, awaiting the + /// response from the upper layer. If there is an error in, e.g., + /// deserializing the request, the upper layer should send an [`RpcError`] + /// down the channel to signify that there was an error while handling this + /// rpc request. Currently, we just log these errors and drop the substream; + /// in the future, we will send an error response to the peer and/or log any + /// malicious behaviour. + /// + /// The upper client layer should be prepared for `res_tx` to be potentially + /// disconnected when trying to send their response, as the rpc call might + /// have timed out while handling the request. + pub res_tx: oneshot::Sender>, +} + +/// A wrapper struct for an outbound rpc request and its associated context. +#[derive(Debug)] +pub struct OutboundRpcRequest { + /// Rpc method identifier, e.g., `/libra/consensus/rpc/0.1.0`. This is the + /// protocol we will negotiate our outbound substream to. + pub protocol: ProtocolId, + /// The serialized request data to be sent to the receiver. + pub data: Bytes, + /// Channel over which the rpc response is sent from the rpc layer to the + /// upper client layer. + /// + /// If there is an error while performing the rpc protocol, e.g., the remote + /// peer drops the connection, we will send an [`RpcError`] over the channel. + pub res_tx: oneshot::Sender>, + /// The timeout duration for the entire rpc call. If the timeout elapses, the + /// rpc layer will send an [`RpcError::TimedOut`] error over the + /// `res_tx` channel to the upper client layer. + pub timeout: Duration, +} + +/// Events sent from the [`NetworkProvider`](crate::interface::NetworkProvider) +/// actor to the [`Rpc`] actor. +#[derive(Debug)] +pub enum RpcRequest { + /// Send an outbound rpc request to a remote peer. + SendRpc(PeerId, OutboundRpcRequest), +} + +/// Events sent from the [`Rpc`] actor to the +/// [`NetworkProvider`](crate::interface::NetworkProvider) actor. +#[derive(Debug)] +pub enum RpcNotification { + /// A new inbound rpc request has been received from a remote peer. + RecvRpc(PeerId, InboundRpcRequest), +} + +/// The rpc actor. +pub struct Rpc { + /// Channel to receive requests from other upstream actors. + requests_rx: channel::Receiver, + /// Channel to receive notifications from [`PeerManager`](crate::peer_manager::PeerManager). + peer_mgr_notifs_rx: channel::Receiver>, + /// Channel to send requests to [`PeerManager`](crate::peer_manager::PeerManager). + peer_mgr_reqs_tx: PeerManagerRequestSender, + /// Channels to send notifictions to upstream actors. + rpc_handler_tx: channel::Sender, + /// The timeout duration for inbound rpc calls. + inbound_rpc_timeout: Duration, + /// The maximum number of concurrent outbound rpc requests that we will + /// service before back-pressure kicks in. + max_concurrent_outbound_rpcs: u32, + /// The maximum number of concurrent inbound rpc requests that we will + /// service before back-pressure kicks in. + // TODO(philiphayes): partition inbound queue by peer to prevent one peer + // from starving other peers' rpcs? + max_concurrent_inbound_rpcs: u32, +} + +impl Rpc +where + TSubstream: AsyncRead + AsyncWrite + Send + Unpin + Debug + 'static, +{ + /// Create a new instance of the [`Rpc`] protocol actor. + pub fn new( + requests_rx: channel::Receiver, + peer_mgr_notifs_rx: channel::Receiver>, + peer_mgr_reqs_tx: PeerManagerRequestSender, + rpc_handler_tx: channel::Sender, + inbound_rpc_timeout: Duration, + max_concurrent_outbound_rpcs: u32, + max_concurrent_inbound_rpcs: u32, + ) -> Self { + Self { + requests_rx, + peer_mgr_notifs_rx, + peer_mgr_reqs_tx, + rpc_handler_tx, + inbound_rpc_timeout, + max_concurrent_outbound_rpcs, + max_concurrent_inbound_rpcs, + } + } + + /// Start the [`Rpc`] actor's event loop. + pub async fn start(self) { + // unpack self to satisfy borrow checker + let requests_rx = self.requests_rx; + let peer_mgr_notifs_rx = self.peer_mgr_notifs_rx; + let peer_mgr_reqs_tx = self.peer_mgr_reqs_tx; + let rpc_handler_tx = self.rpc_handler_tx; + let inbound_rpc_timeout = self.inbound_rpc_timeout; + let max_concurrent_outbound_rpcs = self.max_concurrent_outbound_rpcs; + let max_concurrent_inbound_rpcs = self.max_concurrent_inbound_rpcs; + + // inbound and outbound requests are buffered separately + + let outbound_reqs = requests_rx + .map(move |req| handle_outbound_rpc(peer_mgr_reqs_tx.clone(), req)) + .buffer_unordered(max_concurrent_outbound_rpcs as usize); + + let inbound_notifs = peer_mgr_notifs_rx + .map(move |notif| { + handle_inbound_substream(rpc_handler_tx.clone(), notif, inbound_rpc_timeout) + }) + .buffer_unordered(max_concurrent_inbound_rpcs as usize); + + // drive all inbound and outbound futures to completion + let mut rpc_futures = select(outbound_reqs, inbound_notifs); + while let Some(_) = rpc_futures.next().await {} + + crit!("Rpc actor terminated"); + } +} + +/// Handle an outbound rpc request event. Open a new substream then run the +/// outbound rpc protocol over the substream. +/// +/// The request results (including errors) are propagated up to the rpc client +/// through the [`req.res_tx`] oneshot channel. Cancellation is done by the client +/// dropping the receiver side of the [`req.res_tx`] oneshot channel. If the +/// request is canceled, the substream will be dropped and a RST frame will be +/// sent over the muxer closing the substream. +/// +/// [`req.res_tx`]: OutboundRpcRequest::res_tx +async fn handle_outbound_rpc( + peer_mgr_tx: PeerManagerRequestSender, + req: RpcRequest, +) where + TSubstream: AsyncRead + AsyncWrite + Send + Unpin, +{ + match req { + RpcRequest::SendRpc(peer_id, req) => { + let protocol = req.protocol; + let req_data = req.data; + let mut res_tx = req.res_tx; + let timeout = req.timeout; + + // Future to run the actual outbound rpc protocol and get the results. + let mut f_rpc_res = handle_outbound_rpc_inner(peer_mgr_tx, peer_id, protocol, req_data) + .boxed() + .compat() + .timeout(timeout) + .compat() + // Convert tokio timeout::Error to RpcError + .map_err(Into::::into); + + // If the rpc client drops their oneshot receiver, this future should + // cancel the request. + let mut f_rpc_cancel = + future::poll_fn(|cx: &mut Context| res_tx.poll_cancel(cx)).fuse(); + + futures::select! { + res = f_rpc_res => { + // Log any errors. + if let Err(err) = &res { + counters::RPC_REQUESTS_FAILED.inc(); + warn!( + "Error making outbound rpc request to {}: {:?}", + peer_id.short_str(), err + ); + } + + // Propagate the results to the rpc client layer. + if res_tx.send(res).is_err() { + counters::RPC_REQUESTS_CANCELLED.inc(); + debug!("Rpc client canceled outbound rpc call to {}", peer_id.short_str()); + } + }, + // The rpc client canceled the request + cancel = f_rpc_cancel => { + counters::RPC_REQUESTS_CANCELLED.inc(); + debug!("Rpc client canceled outbound rpc call to {}", peer_id.short_str()); + }, + } + } + } +} + +async fn handle_outbound_rpc_inner( + mut peer_mgr_tx: PeerManagerRequestSender, + peer_id: PeerId, + protocol: ProtocolId, + req_data: Bytes, +) -> Result +where + TSubstream: AsyncRead + AsyncWrite + Send + Unpin, +{ + let _timer = counters::RPC_LATENCY.start_timer(); + // Request a new substream with the peer. + let substream = peer_mgr_tx.open_substream(peer_id, protocol).await?; + // Rpc messages are length-prefixed. + let mut substream = Framed::new(substream.compat(), UviBytes::default()).sink_compat(); + // Send the rpc request data. + let req_len = req_data.len(); + substream.buffered_send(req_data).await?; + // We won't send anything else on this substream, so we can half-close our + // output side. + substream.close().await?; + counters::RPC_REQUESTS_SENT.inc(); + counters::RPC_REQUEST_BYTES_SENT.inc_by(req_len as i64); + + // Wait for listener's response. + let res_data = match substream.next().await { + Some(res_data) => res_data?.freeze(), + None => Err(io::Error::from(io::ErrorKind::UnexpectedEof))?, + }; + + // Wait for listener to half-close their side. + match substream.next().await { + // Remote should never send more than one response; we'll consider this + // a protocol violation and ignore their response. + Some(_) => Err(RpcError::UnexpectedRpcResponse), + None => Ok(res_data), + } +} + +/// Handle an new inbound substream. Run the inbound rpc protocol over the +/// substream. +async fn handle_inbound_substream( + notification_tx: channel::Sender, + notif: PeerManagerNotification, + timeout: Duration, +) where + TSubstream: AsyncRead + AsyncWrite + Debug + Send + Unpin, +{ + match notif { + PeerManagerNotification::NewInboundSubstream(peer_id, substream) => { + // Run the actual inbound rpc protocol. + let res = handle_inbound_substream_inner( + notification_tx, + peer_id, + substream.protocol, + substream.substream, + ) + .boxed() + .compat() + .timeout(timeout) + .compat() + .await; + + // Convert tokio timeout::Error to RpcError + let res = res.map_err(Into::::into); + + // Log any errors. + if let Err(err) = res { + counters::RPC_RESPONSES_FAILED.inc(); + warn!( + "Error handling inbound rpc request from {}: {:?}", + peer_id.short_str(), + err + ); + } + } + notif => unreachable!( + "Received unexpected event from PeerManager: {:?}, expected NewInboundSubstream", + notif + ), + } +} + +async fn handle_inbound_substream_inner( + mut notification_tx: channel::Sender, + peer_id: PeerId, + protocol: ProtocolId, + substream: TSubstream, +) -> Result<(), RpcError> +where + TSubstream: AsyncRead + AsyncWrite + Send + Unpin, +{ + // Rpc messages are length-prefixed. + let mut substream = Framed::new(substream.compat(), UviBytes::default()).sink_compat(); + // Read the rpc request data. + let req_data = match substream.next().await { + Some(req_data) => req_data?.freeze(), + None => Err(io::Error::from(io::ErrorKind::UnexpectedEof))?, + }; + counters::RPC_REQUESTS_RECEIVED.inc(); + + // Wait for dialer to half-close their side. + if substream.next().await.is_some() { + // Remote should never send more than one request; we'll consider this + // a protocol violation and ignore their request. + return Err(RpcError::UnexpectedRpcRequest); + }; + + // Build the event and context we push up to upper layers for handling. + let (res_tx, res_rx) = oneshot::channel(); + let notification = RpcNotification::RecvRpc( + peer_id, + InboundRpcRequest { + protocol, + data: req_data, + res_tx, + }, + ); + // TODO(philiphayes): impl correct shutdown process so this never panics + // Forward request to upper layer. + notification_tx.send(notification).await.unwrap(); + + // Wait for response from upper layer. + let res_data = res_rx.await??; + let res_len = res_data.len(); + + // Send the response to remote + substream.buffered_send(res_data).await?; + + // We won't send anything else on this substream, so we can half-close + // our output. The initiator will have also half-closed their side before + // this, so this should gracefully shutdown the socket. + substream.close().await?; + counters::RPC_RESPONSES_SENT.inc(); + counters::RPC_RESPONSE_BYTES_SENT.inc_by(res_len as i64); + + Ok(()) +} diff --git a/network/src/protocols/rpc/test.rs b/network/src/protocols/rpc/test.rs new file mode 100644 index 0000000000000..ce855d6ea835b --- /dev/null +++ b/network/src/protocols/rpc/test.rs @@ -0,0 +1,713 @@ +// Copyright (c) The Libra Core Contributors +// SPDX-License-Identifier: Apache-2.0 + +use super::{error::RpcError, *}; +use crate::{ + common::NegotiatedSubstream, + peer_manager::{PeerManagerNotification, PeerManagerRequest}, +}; +use futures::future::{join, join3, join4}; +use memsocket::MemorySocket; +use tokio::runtime::Runtime; + +async fn do_outbound_rpc_req( + peer_mgr_tx: PeerManagerRequestSender, + recipient: PeerId, + protocol: ProtocolId, + data: Bytes, + timeout: Duration, +) -> Result +where + TSubstream: AsyncRead + AsyncWrite + Send + Unpin, +{ + let (res_tx, res_rx) = oneshot::channel(); + let outbound_req = OutboundRpcRequest { + protocol, + data, + res_tx, + timeout, + }; + let rpc_req = RpcRequest::SendRpc(recipient, outbound_req); + handle_outbound_rpc(peer_mgr_tx, rpc_req).await; + res_rx.await.unwrap() +} + +// On the next OpenSubstream event, return the given substream. +async fn mock_peer_manager( + mut peer_mgr_rx: channel::Receiver>, + substream: TSubstream, +) { + // Return a mocked substream on the next OpenSubstream request + match peer_mgr_rx.next().await.unwrap() { + PeerManagerRequest::OpenSubstream(_peer_id, _protocol, substream_tx) => { + substream_tx.send(Ok(substream)).unwrap(); + } + req => panic!( + "Unexpected PeerManagerRequest: {:?}, expected OpenSubstream", + req + ), + } +} + +// Test the rpc substream upgrades. +#[test] +fn upgrades() { + ::logger::try_init_for_testing(); + + let listener_peer_id = PeerId::random(); + let dialer_peer_id = PeerId::random(); + let protocol_id = b"/get_blocks/1.0.0"; + let req_data = b"hello"; + let res_data = b"goodbye"; + + let (dialer_substream, listener_substream) = MemorySocket::new_pair(); + + // Fake the dialer NetworkProvider + let (dialer_peer_mgr_reqs_tx, dialer_peer_mgr_reqs_rx) = channel::new_test(8); + let dialer_peer_mgr_reqs_tx = PeerManagerRequestSender::new(dialer_peer_mgr_reqs_tx); + let f_dialer_peer_mgr = mock_peer_manager(dialer_peer_mgr_reqs_rx, dialer_substream); + + // Fake the listener NetworkProvider + let (listener_rpc_notifs_tx, mut listener_rpc_notifs_rx) = channel::new_test(8); + let f_listener_network = async move { + // Handle the inbound rpc request + match listener_rpc_notifs_rx.next().await.unwrap() { + RpcNotification::RecvRpc(peer_id, req) => { + assert_eq!(peer_id, dialer_peer_id); + assert_eq!(req.protocol.as_ref(), protocol_id); + assert_eq!(req.data.as_ref(), req_data); + req.res_tx.send(Ok(Bytes::from_static(res_data))).unwrap(); + } + } + }; + + let substream = NegotiatedSubstream { + protocol: ProtocolId::from_static(protocol_id), + substream: listener_substream, + }; + let inbound_notif = PeerManagerNotification::NewInboundSubstream(dialer_peer_id, substream); + + // Handle the inbound substream + let f_listener_upgrade = handle_inbound_substream( + listener_rpc_notifs_tx, + inbound_notif, + Duration::from_millis(500), + ); + + // Make an outbound substream request + let f_dialer_upgrade = async move { + let res = do_outbound_rpc_req( + dialer_peer_mgr_reqs_tx, + listener_peer_id, + ProtocolId::from_static(protocol_id), + Bytes::from_static(req_data), + Duration::from_secs(1), + ) + .await; + + // Check the rpc response data + let data = res.unwrap(); + assert_eq!(data.as_ref(), res_data); + }; + + let f = join4( + f_dialer_peer_mgr, + f_dialer_upgrade, + f_listener_network, + f_listener_upgrade, + ); + Runtime::new() + .unwrap() + .block_on(f.boxed().unit_error().compat()) + .unwrap(); +} + +// An outbound rpc request should fail if the listener drops the connection after +// receiving the request. +#[test] +fn listener_close_before_response() { + ::logger::try_init_for_testing(); + + let listener_peer_id = PeerId::random(); + let protocol_id = b"/get_blocks/1.0.0"; + let req_data = b"hello"; + + let (dialer_substream, listener_substream) = MemorySocket::new_pair(); + + // Fake the dialer NetworkProvider + let (dialer_peer_mgr_reqs_tx, dialer_peer_mgr_reqs_rx) = channel::new_test(8); + let dialer_peer_mgr_reqs_tx = PeerManagerRequestSender::new(dialer_peer_mgr_reqs_tx); + let f_dialer_peer_mgr = mock_peer_manager(dialer_peer_mgr_reqs_rx, dialer_substream); + + // Make an outbound rpc request + let f_dialer_upgrade = async move { + let res = do_outbound_rpc_req( + dialer_peer_mgr_reqs_tx, + listener_peer_id, + ProtocolId::from_static(protocol_id), + Bytes::from_static(req_data), + Duration::from_secs(1), + ) + .await; + + // Check the error + let err = res.expect_err("Dialer's rpc request should fail"); + match err { + RpcError::IoError(err) => assert_eq!(err.kind(), io::ErrorKind::UnexpectedEof), + err => panic!("Unexpected error: {:?}, expected IoError", err), + }; + }; + + // Listener reads the request but then drops the connection + let f_listener = async move { + // rpc messages are length-prefixed + let mut substream = + Framed::new(listener_substream.compat(), UviBytes::::default()).sink_compat(); + // read the rpc request data + let data = match substream.next().await { + Some(data) => data.unwrap().freeze(), + None => panic!("listener: expected rpc request from dialer"), + }; + assert_eq!(data.as_ref(), req_data); + + // Listener then suddenly drops the connection + substream.close().await.unwrap(); + }; + + let f = join3(f_dialer_peer_mgr, f_dialer_upgrade, f_listener); + Runtime::new() + .unwrap() + .block_on(f.boxed().unit_error().compat()) + .unwrap(); +} + +// An outbound rpc request should fail if the listener drops the connection after +// negotiation but before the dialer sends their request. +#[test] +fn listener_close_before_dialer_send() { + ::logger::try_init_for_testing(); + + let listener_peer_id = PeerId::random(); + let protocol_id = b"/get_blocks/1.0.0"; + let req_data = b"hello"; + + let (dialer_substream, listener_substream) = MemorySocket::new_pair(); + + // Listener immediately drops connection + drop(listener_substream); + + // Fake the dialer NetworkProvider + let (dialer_peer_mgr_reqs_tx, dialer_peer_mgr_reqs_rx) = channel::new_test(8); + let dialer_peer_mgr_reqs_tx = PeerManagerRequestSender::new(dialer_peer_mgr_reqs_tx); + let f_dialer_peer_mgr = mock_peer_manager(dialer_peer_mgr_reqs_rx, dialer_substream); + + // Make an outbound substream request + let f_dialer_upgrade = async move { + let res = do_outbound_rpc_req( + dialer_peer_mgr_reqs_tx, + listener_peer_id, + ProtocolId::from_static(protocol_id), + Bytes::from_static(req_data), + Duration::from_secs(1), + ) + .await; + + // Check the error + let err = res.expect_err("Dialer's rpc request should fail"); + match err { + RpcError::IoError(err) => assert_eq!(err.kind(), io::ErrorKind::BrokenPipe), + err => panic!("Unexpected error: {:?}, expected IoError", err), + }; + }; + + let f = join(f_dialer_peer_mgr, f_dialer_upgrade); + Runtime::new() + .unwrap() + .block_on(f.boxed().unit_error().compat()) + .unwrap(); +} + +// An inbound rpc request should fail if the dialer drops the connection after +// negotiation but before sending their request. +#[test] +fn dialer_close_before_listener_recv() { + ::logger::try_init_for_testing(); + + let dialer_peer_id = PeerId::random(); + let protocol_id = b"/get_blocks/1.0.0"; + + let (dialer_substream, listener_substream) = MemorySocket::new_pair(); + + // Dialer immediately drops connection after negotiation + drop(dialer_substream); + + // Listener handles the inbound substream, but should get an EOF error + let f_listener_upgrade = async move { + let (notification_tx, _notification_rx) = channel::new_test(8); + // use inner to get Result + let res = handle_inbound_substream_inner( + notification_tx, + dialer_peer_id, + ProtocolId::from_static(protocol_id), + listener_substream, + ) + .await; + + // Check the error + let err = res.expect_err("Listener's rpc handler should fail"); + match err { + RpcError::IoError(err) => assert_eq!(err.kind(), io::ErrorKind::UnexpectedEof), + err => panic!("Unexpected error: {:?}, expected IoError", err), + }; + }; + + Runtime::new() + .unwrap() + .block_on(f_listener_upgrade.boxed().unit_error().compat()) + .unwrap(); +} + +// An inbound rpc request should fail if the dialer drops the connection before +// reading out the response. +#[test] +fn dialer_close_before_listener_send() { + ::logger::try_init_for_testing(); + + let dialer_peer_id = PeerId::random(); + let protocol_id = b"/get_blocks/1.0.0"; + let req_data = b"hello"; + let res_data = b"goodbye"; + + let (dialer_substream, listener_substream) = MemorySocket::new_pair(); + + // Fake the listener NetworkProvider + let (listener_rpc_notifs_tx, mut listener_rpc_notifs_rx) = channel::new_test(8); + let f_listener_network = async move { + // Handle the inbound rpc request + match listener_rpc_notifs_rx.next().await.unwrap() { + RpcNotification::RecvRpc(peer_id, req) => { + assert_eq!(peer_id, dialer_peer_id); + assert_eq!(req.protocol.as_ref(), protocol_id); + assert_eq!(req.data.as_ref(), req_data); + req.res_tx.send(Ok(Bytes::from_static(res_data))).unwrap(); + } + } + }; + + // Listener handles the inbound substream, but should get a broken pipe error + let f_listener_upgrade = async move { + // use inner to get Result + let res = handle_inbound_substream_inner( + listener_rpc_notifs_tx, + dialer_peer_id, + ProtocolId::from_static(protocol_id), + listener_substream, + ) + .await; + + // Check the error + let err = res.expect_err("Listener's rpc handler should fail"); + match err { + RpcError::IoError(err) => assert_eq!(err.kind(), io::ErrorKind::BrokenPipe), + err => panic!("Unexpected error: {:?}, expected IoError", err), + }; + }; + + let f_dialer_upgrade = async move { + // Rpc messages are length-prefixed. + let mut substream = + Framed::new(dialer_substream.compat(), UviBytes::default()).sink_compat(); + // Send the rpc request data. + substream + .buffered_send(Bytes::from_static(req_data)) + .await + .unwrap(); + // Dialer then suddenly drops the connection + substream.close().await.unwrap(); + }; + + let f = join3(f_listener_network, f_listener_upgrade, f_dialer_upgrade); + Runtime::new() + .unwrap() + .block_on(f.boxed().unit_error().compat()) + .unwrap(); +} + +// Sending two requests should fail +#[test] +fn dialer_sends_two_requests_err() { + ::logger::try_init_for_testing(); + + let dialer_peer_id = PeerId::random(); + let protocol_id = b"/get_blocks/1.0.0"; + let req_data = b"hello"; + + let (dialer_substream, listener_substream) = MemorySocket::new_pair(); + + // Listener handles the inbound substream, but should get an EOF error + let f_listener_upgrade = async move { + let (notification_tx, _notification_rx) = channel::new_test(8); + // use inner to get Result + let res = handle_inbound_substream_inner( + notification_tx, + dialer_peer_id, + ProtocolId::from_static(protocol_id), + listener_substream, + ) + .await; + + // Check the error + let err = res.expect_err("Listener's rpc handler should fail"); + match err { + RpcError::UnexpectedRpcRequest => {} + err => panic!("Unexpected error: {:?}, expected UnexpectedRpcRequest", err), + }; + }; + + let f_dialer_upgrade = async move { + // Rpc messages are length-prefixed. + let mut substream = + Framed::new(dialer_substream.compat(), UviBytes::default()).sink_compat(); + // Send the rpc request data. + substream + .buffered_send(Bytes::from_static(req_data)) + .await + .unwrap(); + // ERROR: Send _another_ rpc request data in the same substream. + substream + .buffered_send(Bytes::from_static(req_data)) + .await + .unwrap(); + // Dialer half-closes + substream.close().await.unwrap(); + // Listener should RST substream + if let Some(res) = substream.next().await { + panic!("Unexpected response; expected None: {:?}", res); + } + }; + + let f = join(f_listener_upgrade, f_dialer_upgrade); + + Runtime::new() + .unwrap() + .block_on(f.boxed().unit_error().compat()) + .unwrap(); +} + +// Test that outbound rpc calls will timeout. +#[test] +fn outbound_rpc_timeout() { + ::logger::try_init_for_testing(); + + let listener_peer_id = PeerId::random(); + let protocol_id = b"/get_blocks/1.0.0"; + let req_data = b"hello"; + + // Listener hangs after negotiation + let (dialer_substream, _listener_substream) = MemorySocket::new_pair(); + + // Fake the dialer NetworkProvider + let (dialer_peer_mgr_reqs_tx, dialer_peer_mgr_reqs_rx) = channel::new_test(8); + let dialer_peer_mgr_reqs_tx = PeerManagerRequestSender::new(dialer_peer_mgr_reqs_tx); + let f_dialer_peer_mgr = mock_peer_manager(dialer_peer_mgr_reqs_rx, dialer_substream); + + // Make an outbound substream request; listener hangs so this should timeout. + let f_dialer_upgrade = async move { + let res = do_outbound_rpc_req( + dialer_peer_mgr_reqs_tx, + listener_peer_id, + ProtocolId::from_static(protocol_id), + Bytes::from_static(req_data), + Duration::from_millis(100), + ) + .await; + + // Check error is timeout error + let err = res.expect_err("Dialer's rpc request should fail"); + match err { + RpcError::TimedOut => {} + err => panic!("Unexpected error: {:?}, expected TimedOut", err), + }; + }; + + let f = join(f_dialer_peer_mgr, f_dialer_upgrade); + Runtime::new() + .unwrap() + .block_on(f.boxed().unit_error().compat()) + .unwrap(); +} + +// Test that inbound rpc calls will timeout. +#[test] +fn inbound_rpc_timeout() { + ::logger::try_init_for_testing(); + + let dialer_peer_id = PeerId::random(); + let protocol_id = b"/get_blocks/1.0.0"; + + // Dialer hangs after negotiation + let (_dialer_substream, listener_substream) = MemorySocket::new_pair(); + let (listener_rpc_notifs_tx, _listener_rpc_notifs_rx) = channel::new_test(8); + + // Handle the inbound substream + let substream = NegotiatedSubstream { + protocol: ProtocolId::from_static(protocol_id), + substream: listener_substream, + }; + let inbound_notif = PeerManagerNotification::NewInboundSubstream(dialer_peer_id, substream); + let f_listener_upgrade = handle_inbound_substream( + listener_rpc_notifs_tx, + inbound_notif, + Duration::from_millis(100), + ); + + // The listener future should complete (with a timeout) despite the dialer + // hanging. + Runtime::new() + .unwrap() + .block_on(f_listener_upgrade.boxed().unit_error().compat()) + .unwrap(); +} + +// Test that outbound rpcs can be canceled before sending +#[test] +fn outbound_cancellation_before_send() { + ::logger::try_init_for_testing(); + + let listener_peer_id = PeerId::random(); + let protocol_id = b"/get_blocks/1.0.0"; + let req_data = b"hello"; + + // Fake the dialer NetworkProvider channels + let (dialer_peer_mgr_reqs_tx, _dialer_peer_mgr_reqs_rx) = channel::new_test(8); + let dialer_peer_mgr_reqs_tx = + PeerManagerRequestSender::::new(dialer_peer_mgr_reqs_tx); + + // build the rpc request future + let (res_tx, res_rx) = oneshot::channel(); + let outbound_req = OutboundRpcRequest { + protocol: ProtocolId::from_static(protocol_id), + data: Bytes::from_static(req_data), + res_tx, + timeout: Duration::from_secs(1), + }; + let rpc_req = RpcRequest::SendRpc(listener_peer_id, outbound_req); + let f_rpc = handle_outbound_rpc(dialer_peer_mgr_reqs_tx, rpc_req); + + // drop res_rx to cancel the rpc request + drop(res_rx); + + // the rpc request should finish (from the cancellation) even though there is + // no remote peer + Runtime::new() + .unwrap() + .block_on(f_rpc.boxed().unit_error().compat()) + .unwrap(); +} + +// Test that outbound rpcs can be canceled while receiving response data. +#[test] +fn outbound_cancellation_recv() { + ::logger::try_init_for_testing(); + + let mut rt = Runtime::new().unwrap(); + let executor = rt.executor(); + + let listener_peer_id = PeerId::random(); + let protocol_id = b"/get_blocks/1.0.0"; + let req_data = b"hello"; + let res_data = b"goodbye"; + + let (dialer_substream, listener_substream) = MemorySocket::new_pair(); + + // Fake the dialer NetworkProvider + let (dialer_peer_mgr_reqs_tx, dialer_peer_mgr_reqs_rx) = channel::new_test(8); + let dialer_peer_mgr_reqs_tx = PeerManagerRequestSender::new(dialer_peer_mgr_reqs_tx); + let f_dialer_peer_mgr = mock_peer_manager(dialer_peer_mgr_reqs_rx, dialer_substream); + + // triggered when listener finishes reading response to notify dialer to cancel + let (cancel_tx, cancel_rx) = oneshot::channel::<()>(); + // triggered when dialer finishes canceling the request to notify listener to + // try sending. + let (cancel_done_tx, cancel_done_rx) = oneshot::channel::<()>(); + + // Make an outbound rpc request but then cancel it after sending + let f_dialer_upgrade = async move { + let (res_tx, res_rx) = oneshot::channel(); + let mut res_rx = res_rx.fuse(); + + let outbound_req = OutboundRpcRequest { + protocol: ProtocolId::from_static(protocol_id), + data: Bytes::from_static(req_data), + res_tx, + timeout: Duration::from_secs(1), + }; + let rpc_req = RpcRequest::SendRpc(listener_peer_id, outbound_req); + let (f_rpc, f_rpc_done) = + handle_outbound_rpc(dialer_peer_mgr_reqs_tx, rpc_req).remote_handle(); + executor.spawn(f_rpc.unit_error().boxed().compat()); + + futures::select! { + res = res_rx => panic!("dialer: expected cancellation signal, rpc call finished unexpectedly: {:?}", res), + _ = cancel_rx.fuse() => { + // drop res_rx to cancel rpc call + drop(res_rx); + + // wait for rpc to finish cancellation + f_rpc_done.await; + + // notify listener that cancel is finished so it can try sending + cancel_done_tx.send(()).unwrap(); + } + } + }; + + // Listener reads the request but then fails to send because the dialer canceled + let f_listener = async move { + // rpc messages are length-prefixed + let mut substream = + Framed::new(listener_substream.compat(), UviBytes::::default()).sink_compat(); + // read the rpc request data + let data = match substream.next().await { + Some(data) => data.unwrap().freeze(), + None => panic!("listener: Expected rpc request from dialer"), + }; + assert_eq!(data.as_ref(), req_data); + // wait for dialer's half-close + match substream.next().await { + None => {} + res => panic!("listener: Expected half-close: {:?}", res), + } + + // trigger dialer cancel + drop(cancel_tx); + + // wait for dialer to finish cancelling + cancel_done_rx.await.unwrap(); + + // should get an error when trying to send + match substream.send(Bytes::from_static(res_data)).await { + Err(err) => assert_eq!(io::ErrorKind::BrokenPipe, err.kind()), + res => panic!("listener: Unexpected result: {:?}", res), + } + }; + + let f = join3(f_dialer_peer_mgr, f_dialer_upgrade, f_listener); + rt.block_on(f.boxed().unit_error().compat()).unwrap(); +} + +// Test the full rpc protocol actor. +#[test] +fn rpc_protocol() { + ::logger::try_init_for_testing(); + + let listener_peer_id = PeerId::random(); + let dialer_peer_id = PeerId::random(); + let protocol_id = b"/get_blocks/1.0.0"; + let req_data = b"hello"; + let res_data = b"goodbye"; + + let mut rt = Runtime::new().unwrap(); + + let (dialer_substream, listener_substream) = MemorySocket::new_pair(); + + // Set up the dialer Rpc protocol actor + let (mut dialer_rpc_tx, dialer_rpc_rx) = channel::new_test(8); + let (_, dialer_peer_mgr_notifs_rx) = channel::new_test(8); + let (dialer_peer_mgr_reqs_tx, mut dialer_peer_mgr_reqs_rx) = channel::new_test(8); + let dialer_peer_mgr_reqs_tx = PeerManagerRequestSender::new(dialer_peer_mgr_reqs_tx); + let (rpc_handler_tx, _) = channel::new_test(8); + let dialer_rpc = Rpc::new( + dialer_rpc_rx, + dialer_peer_mgr_notifs_rx, + dialer_peer_mgr_reqs_tx, + rpc_handler_tx, + Duration::from_millis(500), + 10, + 10, + ); + + // Fake the dialer NetworkProvider + let f_dialer_network = async move { + let (res_tx, res_rx) = oneshot::channel(); + + let req = OutboundRpcRequest { + protocol: ProtocolId::from_static(protocol_id), + data: Bytes::from_static(req_data), + res_tx, + timeout: Duration::from_secs(1), + }; + + // Tell Rpc to send an rpc request + dialer_rpc_tx + .send(RpcRequest::SendRpc(listener_peer_id, req)) + .await + .unwrap(); + + // Fulfill the open substream request + match dialer_peer_mgr_reqs_rx.next().await.unwrap() { + PeerManagerRequest::OpenSubstream(peer_id, protocol, substream_tx) => { + assert_eq!(peer_id, listener_peer_id); + assert_eq!(protocol.as_ref(), protocol_id); + substream_tx.send(Ok(dialer_substream)).unwrap(); + } + _ => { + unreachable!(); + } + } + + // Check the rpc response data + let data = res_rx.await.unwrap().unwrap(); + assert_eq!(data.as_ref(), res_data); + }; + + // Set up the listener Rpc protocol actor + let (_, listener_rpc_reqs_rx) = channel::new_test(8); + let (mut listener_peer_mgr_notifs_tx, listener_peer_mgr_notifs_rx) = channel::new_test(8); + let (listener_peer_mgr_reqs_tx, _) = channel::new_test(8); + let listener_peer_mgr_reqs_tx = PeerManagerRequestSender::new(listener_peer_mgr_reqs_tx); + let (listener_rpc_notifs_tx, mut listener_rpc_notifs_rx) = channel::new_test(8); + let listener_rpc = Rpc::new( + listener_rpc_reqs_rx, + listener_peer_mgr_notifs_rx, + listener_peer_mgr_reqs_tx, + listener_rpc_notifs_tx, + Duration::from_millis(500), + 10, + 10, + ); + + // Fake the listener NetworkProvider + let f_listener_network = async move { + // Notify Rpc of a new inbound substream + + listener_peer_mgr_notifs_tx + .send(PeerManagerNotification::NewInboundSubstream( + dialer_peer_id, + NegotiatedSubstream { + protocol: ProtocolId::from_static(protocol_id), + substream: listener_substream, + }, + )) + .await + .unwrap(); + + // Handle the inbound rpc request + match listener_rpc_notifs_rx.next().await.unwrap() { + RpcNotification::RecvRpc(peer_id, req) => { + assert_eq!(peer_id, dialer_peer_id); + assert_eq!(req.protocol.as_ref(), protocol_id); + assert_eq!(req.data.as_ref(), req_data); + req.res_tx.send(Ok(Bytes::from_static(res_data))).unwrap(); + } + } + }; + + let f = join4( + f_listener_network, + listener_rpc.start(), + f_dialer_network, + dialer_rpc.start(), + ); + rt.block_on(f.boxed().unit_error().compat()).unwrap(); +} diff --git a/network/src/sink/buffered_send.rs b/network/src/sink/buffered_send.rs new file mode 100644 index 0000000000000..8070370e6ccfe --- /dev/null +++ b/network/src/sink/buffered_send.rs @@ -0,0 +1,202 @@ +// Copyright (c) The Libra Core Contributors +// SPDX-License-Identifier: Apache-2.0 + +use futures::{ + future::Future, + sink::Sink, + task::{Context, Poll}, + try_ready, +}; +use std::pin::Pin; + +/// Future for the [`buffered_send`](super::NetworkSinkExt::buffered_send) method. +#[derive(Debug)] +#[must_use = "futures do nothing unless polled"] +// TODO(philiphayes): remove +#[allow(dead_code)] +pub struct BufferedSend<'a, S: Sink + Unpin + ?Sized, Item> { + sink: &'a mut S, + item: Option, +} + +// Pinning is never projected to children. +impl + Unpin + ?Sized, Item> Unpin for BufferedSend<'_, S, Item> {} + +impl<'a, S: Sink + Unpin + ?Sized, Item> BufferedSend<'a, S, Item> { + // TODO(philiphayes): remove + #[allow(dead_code)] + pub fn new(sink: &'a mut S, item: Item) -> Self { + Self { + sink, + item: Some(item), + } + } +} + +impl + Unpin + ?Sized, Item> Future for BufferedSend<'_, S, Item> { + type Output = Result<(), S::SinkError>; + + fn poll(mut self: Pin<&mut Self>, context: &mut Context<'_>) -> Poll { + // Get a &mut Self from the Pin<&mut Self>. + let this = &mut *self; + + // If someone polls us after we've already sent the item, successfully or + // not, then we just return Ok(()). + if this.item.is_none() { + return Poll::Ready(Ok(())); + } + + // Poll the underlying sink until it's ready to send an item (or errors). + let mut sink = Pin::new(&mut this.sink); + try_ready!(sink.as_mut().poll_ready(context)); + + // We need ownership of the pending item to send it on the sink. We take + // it _after_ the sink.poll_ready to avoid awkward control flow with + // placing the item back in self if the underlying sink isn't ready. + let item = this + .item + .take() + .expect("We have already checked that item.is_none(), so this will never panic"); + + // Actually send the item + Poll::Ready(sink.as_mut().start_send(item)) + } +} + +#[cfg(test)] +mod test { + use crate::sink::NetworkSinkExt; + use futures::{ + channel::mpsc, executor::block_on, future::join, sink::SinkExt, stream::StreamExt, + }; + + // It should work. + #[test] + fn buffered_send() { + let (mut tx, mut rx) = mpsc::channel::(0); + + block_on(tx.send(123)).unwrap(); + assert_eq!(Some(123), block_on(rx.next())); + } + + // It should not flush where `.send` otherwise would. + #[test] + fn doesnt_flush() { + ::logger::try_init_for_testing(); + + // A 0-capacity channel + one sender gives the channel 1 available buffer + // slot. + let (tx, mut rx) = mpsc::channel::(0); + let mut tx = tx.buffer(2); + + // Initial state + // + // +-----------------+ + // | _ | _ | _ | + // +-----------------+ + // .buffer \ channel + + // `.buffer` only buffers items if the underlying sink is busy. So the + // first send should write-through to the channel, since the channel has + // 1 available buffer slot. + block_on(tx.buffered_send(1)).unwrap(); + + // +-----------------+ + // | _ | _ | 1 | + // +-----------------+ + // .buffer \ channel + + // If we used `tx.send(2)` here, it would block since `tx.send` requires + // a flush after enqueueing. However, the channel is already full, so the + // flush would never complete. Instead, we can use our new `.buffered_send` + // which doesn't mandate a flush. + + // Next two should buffer in `.buffer` since the underlying channel is full. + block_on(tx.buffered_send(2)).unwrap(); + block_on(tx.buffered_send(3)).unwrap(); + + // +-----------------+ + // | 3 | 2 | 1 | + // +-----------------+ + // .buffer \ channel + + // If we used `tx.buffered_send(4)` here, it would block since both the + // channel and `.buffer` are full. + + // This call should succeed and return the item buffered in the channel. + assert_eq!(Some(1), block_on(rx.next())); + + // +-----------------+ + // | 3 | 2 | _ | => 1 + // +-----------------+ + // .buffer \ channel + + // The following calls would block, since 2 & 3 are stuck in `.buffer` + // even though the channel buffer is empty. + // assert_eq!(Some(2), block_on(rx.next())); + // assert_eq!(Some(3), block_on(rx.next())); + + // Instead, we have to manually flush `tx` while dequeueing the remaining + // items from the channel. + + // `f_flush` will complete when all items in `.buffer` are flushed down + // to the underlying channel + let f_flush = async move { + tx.flush().await.unwrap(); + }; + + let f_recv = async move { + assert_eq!(Some(2), rx.next().await); + assert_eq!(Some(3), rx.next().await); + }; + + // flush 2 + // + // +-----------------+ + // | 3 | _ | 2 | + // +-----------------+ + // .buffer \ channel + + // dequeue 2 + // + // +-----------------+ + // | 3 | _ | _ | => 2 + // +-----------------+ + // .buffer \ channel + + // flush 3 + f_flush done + // + // +-----------------+ + // | _ | _ | 3 | + // +-----------------+ + // .buffer \ channel + + // dequeue 3 + f_recv done + // + // +-----------------+ + // | _ | _ | _ | => 3 + // +-----------------+ + // .buffer \ channel + + block_on(join(f_flush, f_recv)); + } + + // Polling after the future has completed should not panic. + #[test] + fn poll_after_ready() { + let (mut tx, mut rx) = mpsc::channel::(0); + + let mut f_send = tx.send(123); + + // Poll the first time like normal. + block_on(&mut f_send).unwrap(); + // Polling after it's already complete should just resolve immediately. + block_on(&mut f_send).unwrap(); + + block_on(tx.close()).unwrap(); + + // There should only be one item in the channel. + assert_eq!(Some(123), block_on(rx.next())); + assert_eq!(None, block_on(rx.next())); + } +} diff --git a/network/src/sink/mod.rs b/network/src/sink/mod.rs new file mode 100644 index 0000000000000..65aa0a7003574 --- /dev/null +++ b/network/src/sink/mod.rs @@ -0,0 +1,26 @@ +// Copyright (c) The Libra Core Contributors +// SPDX-License-Identifier: Apache-2.0 + +use crate::sink::buffered_send::BufferedSend; +use futures::sink::Sink; + +mod buffered_send; + +// blanket trait impl for NetworkSinkExt +impl NetworkSinkExt for T where T: Sink {} + +/// Extension trait for [`Sink`] that provides network crate specific combinator +/// functions. +pub trait NetworkSinkExt: Sink { + /// Like `sink.send()` but without the mandatory flush. + /// + /// Specifically, `sink.send()` will do `sink.poll_ready()` then + /// `sink.start_send(item)` and finally a mandatory `sink.poll_flush()`. + /// This will only do `sink.poll_ready()` and `sink.start_send(item)`. + fn buffered_send(&mut self, item: Item) -> BufferedSend<'_, Self, Item> + where + Self: Unpin, + { + BufferedSend::new(self, item) + } +} diff --git a/network/src/transport.rs b/network/src/transport.rs new file mode 100644 index 0000000000000..3acdae12e7f83 --- /dev/null +++ b/network/src/transport.rs @@ -0,0 +1,193 @@ +// Copyright (c) The Libra Core Contributors +// SPDX-License-Identifier: Apache-2.0 + +use crate::{ + common::NetworkPublicKeys, + protocols::identity::{exchange_identity, Identity}, +}; +use crypto::x25519::{X25519PrivateKey, X25519PublicKey}; +use logger::prelude::*; +use netcore::{ + multiplexing::{yamux::Yamux, StreamMultiplexer}, + transport::{boxed, memory, tcp, TransportExt}, +}; +use noise::NoiseConfig; +use std::{ + collections::HashMap, + io, + sync::{Arc, RwLock}, + time::Duration, +}; +use types::PeerId; + +/// A timeout for the connection to open and complete all of the upgrade steps. +const TRANSPORT_TIMEOUT: Duration = Duration::from_secs(30); + +fn identity_key_to_peer_id( + trusted_peers: &RwLock>, + remote_static_key: &[u8], +) -> Option { + for (peer_id, public_keys) in trusted_peers.read().unwrap().iter() { + if public_keys.identity_public_key.as_bytes() == remote_static_key { + return Some(*peer_id); + } + } + + None +} + +pub fn build_memory_noise_transport( + own_identity: Identity, + identity_keypair: (X25519PrivateKey, X25519PublicKey), + trusted_peers: Arc>>, +) -> boxed::BoxedTransport<(Identity, impl StreamMultiplexer), impl ::std::error::Error> { + let memory_transport = memory::MemoryTransport::default(); + let noise_config = Arc::new(NoiseConfig::new(identity_keypair)); + + memory_transport + .and_then(move |socket, origin| { + async move { + let (remote_static_key, socket) = + noise_config.upgrade_connection(socket, origin).await?; + + if let Some(peer_id) = identity_key_to_peer_id(&trusted_peers, &remote_static_key) { + Ok((peer_id, socket)) + } else { + Err(io::Error::new(io::ErrorKind::Other, "Not a trusted peer")) + } + } + }) + .and_then(|(peer_id, socket), origin| { + async move { + let muxer = Yamux::upgrade_connection(socket, origin).await?; + Ok((peer_id, muxer)) + } + }) + .and_then(move |(peer_id, muxer), origin| { + async move { + let (identity, muxer) = exchange_identity(&own_identity, muxer, origin).await?; + + if identity.peer_id() == peer_id { + Ok((identity, muxer)) + } else { + Err(io::Error::new( + io::ErrorKind::Other, + format!( + "PeerId received from Noise Handshake ({}) doesn't match one received from Identity Exchange ({})", + peer_id.short_str(), + identity.peer_id().short_str() + ) + )) + } + } + }) + .with_timeout(TRANSPORT_TIMEOUT) + .boxed() +} + +pub fn build_memory_transport( + own_identity: Identity, +) -> boxed::BoxedTransport<(Identity, impl StreamMultiplexer), impl ::std::error::Error> { + let memory_transport = memory::MemoryTransport::default(); + + memory_transport + .and_then(|socket, origin| { + async move { + let muxer = Yamux::upgrade_connection(socket, origin).await?; + Ok(muxer) + } + }) + .and_then(move |muxer, origin| { + async move { + let (identity, muxer) = exchange_identity(&own_identity, muxer, origin).await?; + + Ok((identity, muxer)) + } + }) + .with_timeout(TRANSPORT_TIMEOUT) + .boxed() +} + +//TODO(bmwill) Maybe create an Either Transport so we can merge the building of Memory + Tcp +pub fn build_tcp_noise_transport( + own_identity: Identity, + identity_keypair: (X25519PrivateKey, X25519PublicKey), + trusted_peers: Arc>>, +) -> boxed::BoxedTransport<(Identity, impl StreamMultiplexer), impl ::std::error::Error> { + let tcp_transport = tcp::TcpTransport::default(); + let noise_config = Arc::new(NoiseConfig::new(identity_keypair)); + + tcp_transport + .and_then(move |socket, origin| { + async move { + let (remote_static_key, socket) = + noise_config.upgrade_connection(socket, origin).await?; + + if let Some(peer_id) = identity_key_to_peer_id(&trusted_peers, &remote_static_key) { + Ok((peer_id, socket)) + } else { + security_log(SecurityEvent::InvalidNetworkPeer) + .error("UntrustedPeer") + .data(&trusted_peers) + .data(&remote_static_key) + .log(); + Err(io::Error::new(io::ErrorKind::Other, "Not a trusted peer")) + } + } + }) + .and_then(|(peer_id, socket), origin| { + async move { + let muxer = Yamux::upgrade_connection(socket, origin).await?; + Ok((peer_id, muxer)) + } + }) + .and_then(move |(peer_id, muxer), origin| { + async move { + let (identity, muxer) = exchange_identity(&own_identity, muxer, origin).await?; + + if identity.peer_id() == peer_id { + Ok((identity, muxer)) + } else { + security_log(SecurityEvent::InvalidNetworkPeer) + .error("InvalidIdentity") + .data(&identity) + .data(&peer_id) + .data(&origin) + .log(); + Err(io::Error::new( + io::ErrorKind::Other, + format!( + "PeerId received from Noise Handshake ({}) doesn't match one received from Identity Exchange ({})", + peer_id.short_str(), + identity.peer_id().short_str() + ) + )) + } + } + }) + .with_timeout(TRANSPORT_TIMEOUT) + .boxed() +} + +pub fn build_tcp_transport( + own_identity: Identity, +) -> boxed::BoxedTransport<(Identity, impl StreamMultiplexer), impl ::std::error::Error> { + let tcp_transport = tcp::TcpTransport::default(); + + tcp_transport + .and_then(|socket, origin| { + async move { + let muxer = Yamux::upgrade_connection(socket, origin).await?; + Ok(muxer) + } + }) + .and_then(move |muxer, origin| { + async move { + let (identity, muxer) = exchange_identity(&own_identity, muxer, origin).await?; + + Ok((identity, muxer)) + } + }) + .with_timeout(TRANSPORT_TIMEOUT) + .boxed() +} diff --git a/network/src/utils.rs b/network/src/utils.rs new file mode 100644 index 0000000000000..416737407efc4 --- /dev/null +++ b/network/src/utils.rs @@ -0,0 +1,31 @@ +// Copyright (c) The Libra Core Contributors +// SPDX-License-Identifier: Apache-2.0 + +use crate::error::NetworkError; +use bytes::Bytes; +use futures::{ + compat::{Compat, Compat01As03Sink}, + io::AsyncRead, + stream::StreamExt, +}; +use protobuf::Message; +use std::io; +use tokio::codec::Framed; +use unsigned_varint::codec::UviBytes; + +pub async fn read_proto( + substream: &mut Compat01As03Sink, UviBytes>, Bytes>, +) -> Result +where + T: Message, + TSubstream: AsyncRead + Unpin, +{ + // Read from stream. + let data: Bytes = substream.next().await.map_or_else( + || Err(io::Error::from(io::ErrorKind::UnexpectedEof)), + |data| Ok(data?.freeze()), + )?; + // Parse to message. + let msg = protobuf::parse_from_bytes(data.as_ref())?; + Ok(msg) +} diff --git a/network/src/validator_network/consensus.rs b/network/src/validator_network/consensus.rs new file mode 100644 index 0000000000000..dd52ca7215d98 --- /dev/null +++ b/network/src/validator_network/consensus.rs @@ -0,0 +1,386 @@ +// Copyright (c) The Libra Core Contributors +// SPDX-License-Identifier: Apache-2.0 + +//! Interface between Consensus and Network layers. + +use crate::{ + error::NetworkError, + interface::{NetworkNotification, NetworkRequest}, + proto::{ConsensusMsg, RequestBlock, RequestChunk, RespondBlock, RespondChunk}, + protocols::{ + direct_send::Message, + rpc::{error::RpcError, OutboundRpcRequest}, + }, + validator_network::Event, + NetworkPublicKeys, ProtocolId, +}; +use bytes::Bytes; +use channel; +use futures::{ + channel::oneshot, + stream::Map, + task::{Context, Poll}, + SinkExt, Stream, StreamExt, +}; +use pin_utils::unsafe_pinned; +use protobuf::Message as proto_msg; +use std::{pin::Pin, time::Duration}; +use types::{validator_public_keys::ValidatorPublicKeys, PeerId}; + +/// Protocol id for consensus RPC calls +pub const CONSENSUS_RPC_PROTOCOL: &[u8] = b"/libra/consensus/rpc/0.1.0"; +/// Protocol id for consensus direct-send calls +pub const CONSENSUS_DIRECT_SEND_PROTOCOL: &[u8] = b"/libra/consensus/direct-send/0.1.0"; + +/// The interface from Network to Consensus layer. +/// +/// `ConsensusNetworkEvents` is a `Stream` of `NetworkNotification` where the +/// raw `Bytes` direct-send and rpc messages are deserialized into +/// `ConsensusMessage` types. `ConsensusNetworkEvents` is a thin wrapper around +/// an `channel::Receiver`. +pub struct ConsensusNetworkEvents { + inner: Map< + channel::Receiver, + fn(NetworkNotification) -> Result, NetworkError>, + >, +} + +impl ConsensusNetworkEvents { + // This use of `unsafe_pinned` is safe because: + // 1. This struct does not implement [`Drop`] + // 2. This struct does not implement [`Unpin`] + // 3. This struct is not `#[repr(packed)]` + unsafe_pinned!( + inner: + Map< + channel::Receiver, + fn(NetworkNotification) -> Result, NetworkError>, + > + ); + + pub fn new(receiver: channel::Receiver) -> Self { + let inner = receiver.map::<_, fn(_) -> _>(|notification| match notification { + NetworkNotification::NewPeer(peer_id) => Ok(Event::NewPeer(peer_id)), + NetworkNotification::LostPeer(peer_id) => Ok(Event::LostPeer(peer_id)), + NetworkNotification::RecvRpc(peer_id, rpc_req) => { + let req_msg = ::protobuf::parse_from_bytes(rpc_req.data.as_ref())?; + Ok(Event::RpcRequest((peer_id, req_msg, rpc_req.res_tx))) + } + NetworkNotification::RecvMessage(peer_id, msg) => { + let msg = ::protobuf::parse_from_bytes(msg.mdata.as_ref())?; + Ok(Event::Message((peer_id, msg))) + } + }); + + Self { inner } + } +} + +impl Stream for ConsensusNetworkEvents { + type Item = Result, NetworkError>; + + fn poll_next(self: Pin<&mut Self>, context: &mut Context) -> Poll> { + self.inner().poll_next(context) + } +} + +/// The interface from Consensus to Networking layer. +/// +/// This is a thin wrapper around an `channel::Sender`, so it is +/// easy to clone and send off to a separate task. For example, the rpc requests +/// return Futures that encapsulate the whole flow, from sending the request to +/// remote, to finally receiving the response and deserializing. It therefore +/// makes the most sense to make the rpc call on a separate async task, which +/// requires the `ConsensusNetworkSender` to be `Clone` and `Send`. +#[derive(Clone)] +pub struct ConsensusNetworkSender { + inner: channel::Sender, +} + +impl ConsensusNetworkSender { + pub fn new(inner: channel::Sender) -> Self { + Self { inner } + } + + /// Send a fire-and-forget "direct-send" message to remote peer `recipient`. + /// + /// Currently, the returned Future simply resolves when the message has been + /// enqueued on the network actor's event queue. It therefore makes no + /// reliable delivery guarantees. + pub async fn send_to( + &mut self, + recipient: PeerId, + message: ConsensusMsg, + ) -> Result<(), NetworkError> { + self.inner + .send(NetworkRequest::SendMessage( + recipient, + Message { + protocol: ProtocolId::from_static(CONSENSUS_DIRECT_SEND_PROTOCOL), + mdata: Bytes::from(message.write_to_bytes().unwrap()), + }, + )) + .await?; + Ok(()) + } + + /// Send a RequestBlock RPC request to remote peer `recipient`. Returns the + /// future `RespondBlock` returned by the remote peer. + /// + /// The rpc request can be canceled at any point by dropping the returned + /// future. + pub async fn request_block( + &mut self, + recipient: PeerId, + req_msg: RequestBlock, + timeout: Duration, + ) -> Result { + let protocol = ProtocolId::from_static(CONSENSUS_RPC_PROTOCOL); + let mut req_msg_enum = ConsensusMsg::new(); + req_msg_enum.set_request_block(req_msg); + + let mut res_msg_enum = self + .unary_rpc(recipient, protocol, req_msg_enum, timeout) + .await?; + + if res_msg_enum.has_respond_block() { + Ok(res_msg_enum.take_respond_block()) + } else { + // TODO: context + Err(RpcError::InvalidRpcResponse) + } + } + + /// Send a RequestChunk RPC request to remote peer `recipient`. Returns the + /// future `RespondChunk` returned by the remote peer. + /// + /// The rpc request can be canceled at any point by dropping the returned + /// future. + pub async fn request_chunk( + &mut self, + recipient: PeerId, + req_msg: RequestChunk, + timeout: Duration, + ) -> Result { + let protocol = ProtocolId::from_static(CONSENSUS_RPC_PROTOCOL); + let mut req_msg_enum = ConsensusMsg::new(); + req_msg_enum.set_request_chunk(req_msg); + + let mut res_msg_enum = self + .unary_rpc(recipient, protocol, req_msg_enum, timeout) + .await?; + + if res_msg_enum.has_respond_chunk() { + Ok(res_msg_enum.take_respond_chunk()) + } else { + // TODO: context + Err(RpcError::InvalidRpcResponse) + } + } + + pub async fn update_eligible_nodes( + &mut self, + validators: Vec, + ) -> Result<(), NetworkError> { + self.inner + .send(NetworkRequest::UpdateEligibleNodes( + validators + .into_iter() + .map(|keys| { + ( + *keys.account_address(), + NetworkPublicKeys { + identity_public_key: *keys.network_identity_public_key(), + signing_public_key: *keys.network_signing_public_key(), + }, + ) + }) + .collect(), + )) + .await?; + Ok(()) + } + + /// Send a unary rpc request to remote peer `recipient`. Handles + /// serialization and deserialization of the `ConsensusMsg` message enum. + /// + /// TODO(philiphayes): specify error cases + async fn unary_rpc( + &mut self, + recipient: PeerId, + protocol: ProtocolId, + req_msg: ConsensusMsg, + timeout: Duration, + ) -> Result { + // serialize request + let req_data = req_msg.write_to_bytes()?.into(); + + // ask network to fulfill rpc request + let (res_tx, res_rx) = oneshot::channel(); + let req = OutboundRpcRequest { + protocol, + data: req_data, + res_tx, + timeout, + }; + self.inner + .send(NetworkRequest::SendRpc(recipient, req)) + .await?; + + // wait for response and deserialize + let res_data = res_rx.await??; + let res_msg = ::protobuf::parse_from_bytes(res_data.as_ref())?; + Ok(res_msg) + } +} + +#[cfg(test)] +mod tests { + use super::*; + use crate::{proto::Vote, protocols::rpc::InboundRpcRequest}; + use futures::{executor::block_on, future::try_join}; + + fn new_test_vote() -> ConsensusMsg { + let mut vote = Vote::new(); + vote.set_proposed_block_id(Bytes::new()); + vote.set_executed_state_id(Bytes::new()); + vote.set_author(Bytes::new()); + vote.set_signature(Bytes::new()); + + let mut consensus_msg = ConsensusMsg::new(); + consensus_msg.set_vote(vote); + consensus_msg + } + + // Direct send messages should get deserialized through the + // `ConsensusNetworkEvents` stream. + #[test] + fn test_consensus_network_events() { + let (mut consensus_tx, consensus_rx) = channel::new_test(8); + let mut stream = ConsensusNetworkEvents::new(consensus_rx); + + let peer_id = PeerId::random(); + let consensus_msg = new_test_vote(); + let network_msg = Message { + protocol: ProtocolId::from_static(CONSENSUS_DIRECT_SEND_PROTOCOL), + mdata: consensus_msg.clone().write_to_bytes().unwrap().into(), + }; + + // Network sends inbound message to consensus + block_on(consensus_tx.send(NetworkNotification::RecvMessage(peer_id, network_msg))) + .unwrap(); + + // Consensus should receive deserialized message event + let event = block_on(stream.next()).unwrap().unwrap(); + assert_eq!(event, Event::Message((peer_id.into(), consensus_msg))); + + // Network notifies consensus about new peer + block_on(consensus_tx.send(NetworkNotification::NewPeer(peer_id))).unwrap(); + + // Consensus should receive notification + let event = block_on(stream.next()).unwrap().unwrap(); + assert_eq!(event, Event::NewPeer(peer_id.into())); + } + + // `ConsensusNetworkSender` should serialize outbound messages + #[test] + fn test_consensus_network_sender() { + let (network_reqs_tx, mut network_reqs_rx) = channel::new_test(8); + let mut sender = ConsensusNetworkSender::new(network_reqs_tx); + + let peer_id = PeerId::random(); + let consensus_msg = new_test_vote(); + let expected_network_msg = Message { + protocol: ProtocolId::from_static(CONSENSUS_DIRECT_SEND_PROTOCOL), + mdata: consensus_msg.clone().write_to_bytes().unwrap().into(), + }; + + // Send the message to network layer + block_on(sender.send_to(peer_id.into(), consensus_msg)).unwrap(); + + // Network layer should receive serialized message to send out + let event = block_on(network_reqs_rx.next()).unwrap(); + match event { + NetworkRequest::SendMessage(recv_peer_id, network_msg) => { + assert_eq!(recv_peer_id, peer_id); + assert_eq!(network_msg, expected_network_msg); + } + event => panic!("Unexpected event: {:?}", event), + } + } + + // `ConsensusNetworkEvents` should deserialize inbound RPC requests + #[test] + fn test_consensus_inbound_rpc() { + let (mut consensus_tx, consensus_rx) = channel::new_test(8); + let mut stream = ConsensusNetworkEvents::new(consensus_rx); + + // build rpc request + let req_msg = RequestBlock::new(); + let mut req_msg_enum = ConsensusMsg::new(); + req_msg_enum.set_request_block(req_msg); + let req_data = req_msg_enum.clone().write_to_bytes().unwrap().into(); + + let (res_tx, _) = oneshot::channel(); + let rpc_req = InboundRpcRequest { + protocol: ProtocolId::from_static(CONSENSUS_RPC_PROTOCOL), + data: req_data, + res_tx, + }; + + // mock receiving rpc request + let peer_id = PeerId::random(); + let event = NetworkNotification::RecvRpc(peer_id, rpc_req); + block_on(consensus_tx.send(event)).unwrap(); + + // request should be properly deserialized + let (res_tx, _) = oneshot::channel(); + let expected_event = Event::RpcRequest((peer_id.into(), req_msg_enum.clone(), res_tx)); + let event = block_on(stream.next()).unwrap().unwrap(); + assert_eq!(event, expected_event); + } + + // When consensus sends an rpc request, network should get a `NetworkRequest::SendRpc` + // with the serialized request. + #[test] + fn test_consensus_outbound_rpc() { + let (network_reqs_tx, mut network_reqs_rx) = channel::new_test(8); + let mut sender = ConsensusNetworkSender::new(network_reqs_tx); + + // send get_block rpc request + let peer_id = PeerId::random(); + let req_msg = RequestBlock::new(); + let f_res_msg = + sender.request_block(peer_id.into(), req_msg.clone(), Duration::from_secs(5)); + + // build rpc response + let res_msg = RespondBlock::new(); + let mut res_msg_enum = ConsensusMsg::new(); + res_msg_enum.set_respond_block(res_msg.clone()); + let res_data = res_msg_enum.write_to_bytes().unwrap().into(); + + // the future response + let f_recv = async move { + match network_reqs_rx.next().await.unwrap() { + NetworkRequest::SendRpc(recv_peer_id, req) => { + assert_eq!(recv_peer_id, peer_id); + assert_eq!(req.protocol.as_ref(), CONSENSUS_RPC_PROTOCOL); + + // check request deserializes + let mut req_msg_enum: ConsensusMsg = + ::protobuf::parse_from_bytes(req.data.as_ref()).unwrap(); + let recv_req_msg = req_msg_enum.take_request_block(); + assert_eq!(recv_req_msg, req_msg); + + // remote replies with some response message + req.res_tx.send(Ok(res_data)).unwrap(); + Ok(()) + } + event => panic!("Unexpected event: {:?}", event), + } + }; + + let (recv_res_msg, _) = block_on(try_join(f_res_msg, f_recv)).unwrap(); + assert_eq!(recv_res_msg, res_msg); + } +} diff --git a/network/src/validator_network/mempool.rs b/network/src/validator_network/mempool.rs new file mode 100644 index 0000000000000..ef421c4909cf8 --- /dev/null +++ b/network/src/validator_network/mempool.rs @@ -0,0 +1,191 @@ +// Copyright (c) The Libra Core Contributors +// SPDX-License-Identifier: Apache-2.0 + +//! Interface between Mempool and Network layers. + +use crate::{ + error::NetworkError, + interface::{NetworkNotification, NetworkRequest}, + proto::MempoolSyncMsg, + protocols::direct_send::Message, + validator_network::Event, + ProtocolId, +}; +use bytes::Bytes; +use channel; +use futures::{ + stream::Map, + task::{Context, Poll}, + SinkExt, Stream, StreamExt, +}; +use pin_utils::unsafe_pinned; +use protobuf::Message as proto_msg; +use std::pin::Pin; +use types::PeerId; + +/// Protocol id for mempool direct-send calls +pub const MEMPOOL_DIRECT_SEND_PROTOCOL: &[u8] = b"/libra/mempool/direct-send/0.1.0"; + +/// The interface from Network to Mempool layer. +/// +/// `MempoolNetworkEvents` is a `Stream` of `NetworkNotification` where the +/// raw `Bytes` direct-send and rpc messages are deserialized into +/// `MempoolMessage` types. `MempoolNetworkEvents` is a thin wrapper around an +/// `channel::Receiver`. +pub struct MempoolNetworkEvents { + // TODO(philiphayes): remove pub + pub inner: Map< + channel::Receiver, + fn(NetworkNotification) -> Result, NetworkError>, + >, +} + +impl MempoolNetworkEvents { + // This use of `unsafe_pinned` is safe because: + // 1. This struct does not implement [`Drop`] + // 2. This struct does not implement [`Unpin`] + // 3. This struct is not `#[repr(packed)]` + unsafe_pinned!( + inner: + Map< + channel::Receiver, + fn(NetworkNotification) -> Result, NetworkError>, + > + ); + + pub fn new(receiver: channel::Receiver) -> Self { + let inner = receiver + // TODO(philiphayes): filter_map might be better, so we can drop + // messages that don't deserialize. + .map::<_, fn(_) -> _>(|notification| match notification { + NetworkNotification::NewPeer(peer_id) => Ok(Event::NewPeer(peer_id)), + NetworkNotification::LostPeer(peer_id) => Ok(Event::LostPeer(peer_id)), + NetworkNotification::RecvRpc(_, _) => { + unimplemented!("Mempool does not currently use RPC"); + } + NetworkNotification::RecvMessage(peer_id, msg) => { + let msg = ::protobuf::parse_from_bytes(msg.mdata.as_ref())?; + Ok(Event::Message((peer_id, msg))) + } + }); + + Self { inner } + } +} + +impl Stream for MempoolNetworkEvents { + type Item = Result, NetworkError>; + + fn poll_next(self: Pin<&mut Self>, context: &mut Context) -> Poll> { + self.inner().poll_next(context) + } +} + +/// The interface from Mempool to Networking layer. +/// +/// This is a thin wrapper around an `channel::Sender`, so it is +/// easy to clone and send off to a separate task. For example, the rpc requests +/// return Futures that encapsulate the whole flow, from sending the request to +/// remote, to finally receiving the response and deserializing. It therefore +/// makes the most sense to make the rpc call on a separate async task, which +/// requires the `MempoolNetworkSender` to be `Clone` and `Send`. +#[derive(Clone)] +pub struct MempoolNetworkSender { + // TODO(philiphayes): remove pub + pub inner: channel::Sender, +} + +impl MempoolNetworkSender { + pub fn new(inner: channel::Sender) -> Self { + Self { inner } + } + + /// Send a fire-and-forget "direct-send" message to remote peer `recipient`. + /// + /// Currently, the returned Future simply resolves when the message has been + /// enqueued on the network actor's event queue. It therefore makes no + /// reliable delivery guarantees. + pub async fn send_to( + &mut self, + recipient: PeerId, + message: MempoolSyncMsg, + ) -> Result<(), NetworkError> { + self.inner + .send(NetworkRequest::SendMessage( + recipient, + Message { + protocol: ProtocolId::from_static(MEMPOOL_DIRECT_SEND_PROTOCOL), + mdata: Bytes::from(message.write_to_bytes().unwrap()), + }, + )) + .await?; + Ok(()) + } +} + +#[cfg(test)] +mod tests { + use super::*; + use futures::executor::block_on; + + fn new_test_sync_msg(peer_id: PeerId) -> MempoolSyncMsg { + let mut mempool_msg = MempoolSyncMsg::new(); + mempool_msg.set_peer_id(peer_id.into()); + mempool_msg.set_transactions(::protobuf::RepeatedField::from_vec(vec![])); + mempool_msg + } + + // Direct send messages should get deserialized through the + // `MempoolNetworkEvents` stream. + #[test] + fn test_mempool_network_events() { + let (mut mempool_tx, mempool_rx) = channel::new_test(8); + let mut stream = MempoolNetworkEvents::new(mempool_rx); + + let peer_id = PeerId::random(); + let mempool_msg = new_test_sync_msg(peer_id); + let network_msg = Message { + protocol: ProtocolId::from_static(MEMPOOL_DIRECT_SEND_PROTOCOL), + mdata: mempool_msg.write_to_bytes().unwrap().into(), + }; + + block_on(mempool_tx.send(NetworkNotification::RecvMessage( + PeerId::from(peer_id), + network_msg, + ))) + .unwrap(); + let event = block_on(stream.next()).unwrap().unwrap(); + assert_eq!(event, Event::Message((peer_id.into(), mempool_msg))); + + block_on(mempool_tx.send(NetworkNotification::NewPeer(PeerId::from(peer_id)))).unwrap(); + let event = block_on(stream.next()).unwrap().unwrap(); + assert_eq!(event, Event::NewPeer(peer_id.into())); + } + + // `MempoolNetworkSender` should serialize outbound messages + #[test] + fn test_mempool_network_sender() { + let (network_reqs_tx, mut network_reqs_rx) = channel::new_test(8); + let mut sender = MempoolNetworkSender::new(network_reqs_tx); + + let peer_id = PeerId::random(); + let mempool_msg = new_test_sync_msg(peer_id); + let expected_network_msg = Message { + protocol: ProtocolId::from_static(MEMPOOL_DIRECT_SEND_PROTOCOL), + mdata: mempool_msg.clone().write_to_bytes().unwrap().into(), + }; + + // Send the message to network layer + block_on(sender.send_to(peer_id, mempool_msg)).unwrap(); + + // Network layer should receive serialized message to send out + let event = block_on(network_reqs_rx.next()).unwrap(); + match event { + NetworkRequest::SendMessage(recv_peer_id, network_msg) => { + assert_eq!(recv_peer_id, PeerId::from(peer_id)); + assert_eq!(network_msg, expected_network_msg); + } + event => panic!("Unexpected event: {:?}", event), + } + } +} diff --git a/network/src/validator_network/mod.rs b/network/src/validator_network/mod.rs new file mode 100644 index 0000000000000..32764e6b48a69 --- /dev/null +++ b/network/src/validator_network/mod.rs @@ -0,0 +1,63 @@ +// Copyright (c) The Libra Core Contributors +// SPDX-License-Identifier: Apache-2.0 + +//! Network API for [`Consensus`](/consensus/index.html) and [`Mempool`](/mempool/index.html) + +pub use crate::protocols::rpc::error::RpcError; +use bytes::Bytes; +use futures::channel::oneshot; + +pub mod network_builder; + +mod consensus; +mod mempool; +#[cfg(test)] +mod test; + +// Public re-exports +pub use consensus::{ + ConsensusNetworkEvents, ConsensusNetworkSender, CONSENSUS_DIRECT_SEND_PROTOCOL, + CONSENSUS_RPC_PROTOCOL, +}; +pub use mempool::{MempoolNetworkEvents, MempoolNetworkSender, MEMPOOL_DIRECT_SEND_PROTOCOL}; +use types::PeerId; + +/// Events received by network clients in a validator +/// +/// An enumeration of the various types of messages that the network will be sending +/// to its clients. This differs from [`NetworkNotification`] since the contents are deserialized +/// into the type `TMessage` over which `Event` is generic. Note that we assume here that for every +/// consumer of this API there's a singleton message type, `TMessage`, which encapsulates all the +/// messages and RPCs that are received by that consumer. +/// +/// [`NetworkNotification`]: crate::interface::NetworkNotification +#[derive(Debug)] +pub enum Event { + /// New inbound direct-send message from peer. + Message((PeerId, TMessage)), + /// New inbound rpc request. The request is fulfilled by sending the + /// serialized response `Bytes` over the `onshot::Sender`, where the network + /// layer will handle sending the response over-the-wire. + RpcRequest((PeerId, TMessage, oneshot::Sender>)), + /// Peer which we have a newly established connection with. + NewPeer(PeerId), + /// Peer with which we've lost our connection. + LostPeer(PeerId), +} + +/// impl PartialEq for simpler testing +impl PartialEq for Event { + fn eq(&self, other: &Event) -> bool { + use Event::*; + match (self, other) { + (Message((pid1, msg1)), Message((pid2, msg2))) => pid1 == pid2 && msg1 == msg2, + // ignore oneshot::Sender in comparison + (RpcRequest((pid1, msg1, _)), RpcRequest((pid2, msg2, _))) => { + pid1 == pid2 && msg1 == msg2 + } + (NewPeer(pid1), NewPeer(pid2)) => pid1 == pid2, + (LostPeer(pid1), LostPeer(pid2)) => pid1 == pid2, + _ => false, + } + } +} diff --git a/network/src/validator_network/network_builder.rs b/network/src/validator_network/network_builder.rs new file mode 100644 index 0000000000000..b51f2e8c7710d --- /dev/null +++ b/network/src/validator_network/network_builder.rs @@ -0,0 +1,545 @@ +// Copyright (c) The Libra Core Contributors +// SPDX-License-Identifier: Apache-2.0 + +use crate::{ + common::NetworkPublicKeys, + connectivity_manager::ConnectivityManager, + counters, + interface::NetworkProvider, + peer_manager::{PeerManager, PeerManagerRequestSender}, + proto::PeerInfo, + protocols::{ + direct_send::DirectSend, + discovery::{Discovery, DISCOVERY_PROTOCOL_NAME}, + health_checker::{HealthChecker, PING_PROTOCOL_NAME}, + identity::Identity, + rpc::Rpc, + }, + transport::{ + build_memory_noise_transport, build_memory_transport, build_tcp_noise_transport, + build_tcp_transport, + }, + validator_network::{ + ConsensusNetworkEvents, ConsensusNetworkSender, MempoolNetworkEvents, MempoolNetworkSender, + }, + ProtocolId, +}; +use channel; +use crypto::{ + x25519::{X25519PrivateKey, X25519PublicKey}, + PrivateKey, PublicKey, +}; +use futures::{compat::Compat01As03, FutureExt, StreamExt, TryFutureExt}; +use netcore::{multiplexing::StreamMultiplexer, transport::boxed::BoxedTransport}; +use parity_multiaddr::Multiaddr; +use std::{ + collections::HashMap, + sync::{Arc, RwLock}, + time::Duration, +}; +use tokio::runtime::TaskExecutor; +use tokio_timer::Interval; +use types::{validator_signer::ValidatorSigner, PeerId}; + +pub const NETWORK_CHANNEL_SIZE: usize = 1024; +pub const DISCOVERY_INTERVAL_MS: u64 = 1000; +pub const PING_INTERVAL_MS: u64 = 1000; +pub const PING_TIMEOUT_MS: u64 = 10_000; +pub const DISOVERY_MSG_TIMEOUT_MS: u64 = 10_000; +pub const CONNECTIVITY_CHECK_INTERNAL_MS: u64 = 5000; +pub const INBOUND_RPC_TIMEOUT_MS: u64 = 10_000; +pub const MAX_CONCURRENT_OUTBOUND_RPCS: u32 = 100; +pub const MAX_CONCURRENT_INBOUND_RPCS: u32 = 100; +pub const PING_FAILURES_TOLERATED: u64 = 10; +pub const MAX_CONCURRENT_NETWORK_REQS: u32 = 100; +pub const MAX_CONCURRENT_NETWORK_NOTIFS: u32 = 100; + +/// The type of the transport layer, i.e., running on memory or TCP stream, +/// with or without Noise encryption +pub enum TransportType { + Memory, + MemoryNoise, + Tcp, + TcpNoise, +} + +/// Build Network module with custom configuration values. +/// Methods can be chained in order to set the configuration values. +/// MempoolNetworkHandler and ConsensusNetworkHandler are constructed by calling +/// [`NetworkBuilder::build`]. New instances of `NetworkBuilder` are obtained +/// via [`NetworkBuilder::new`]. +pub struct NetworkBuilder { + executor: TaskExecutor, + peer_id: PeerId, + addr: Multiaddr, + advertised_address: Option, + seed_peers: HashMap, + trusted_peers: Arc>>, + transport: TransportType, + channel_size: usize, + mempool_protocols: Vec, + consensus_protocols: Vec, + direct_send_protocols: Vec, + rpc_protocols: Vec, + discovery_interval_ms: u64, + discovery_msg_timeout_ms: u64, + ping_interval_ms: u64, + ping_timeout_ms: u64, + ping_failures_tolerated: u64, + connectivity_check_interval_ms: u64, + inbound_rpc_timeout_ms: u64, + max_concurrent_outbound_rpcs: u32, + max_concurrent_inbound_rpcs: u32, + max_concurrent_network_reqs: u32, + max_concurrent_network_notifs: u32, + signing_keys: Option<(PrivateKey, PublicKey)>, + identity_keys: Option<(X25519PrivateKey, X25519PublicKey)>, +} + +impl NetworkBuilder { + /// Return a new NetworkBuilder initialized with default configuration values. + pub fn new(executor: TaskExecutor, peer_id: PeerId, addr: Multiaddr) -> NetworkBuilder { + NetworkBuilder { + executor, + peer_id, + addr, + advertised_address: None, + seed_peers: HashMap::new(), + trusted_peers: Arc::new(RwLock::new(HashMap::new())), + channel_size: NETWORK_CHANNEL_SIZE, + mempool_protocols: vec![], + consensus_protocols: vec![], + direct_send_protocols: vec![], + rpc_protocols: vec![], + transport: TransportType::Memory, + discovery_interval_ms: DISCOVERY_INTERVAL_MS, + discovery_msg_timeout_ms: DISOVERY_MSG_TIMEOUT_MS, + ping_interval_ms: PING_INTERVAL_MS, + ping_timeout_ms: PING_TIMEOUT_MS, + ping_failures_tolerated: PING_FAILURES_TOLERATED, + connectivity_check_interval_ms: CONNECTIVITY_CHECK_INTERNAL_MS, + inbound_rpc_timeout_ms: INBOUND_RPC_TIMEOUT_MS, + max_concurrent_outbound_rpcs: MAX_CONCURRENT_OUTBOUND_RPCS, + max_concurrent_inbound_rpcs: MAX_CONCURRENT_INBOUND_RPCS, + max_concurrent_network_reqs: MAX_CONCURRENT_NETWORK_REQS, + max_concurrent_network_notifs: MAX_CONCURRENT_NETWORK_NOTIFS, + signing_keys: None, + identity_keys: None, + } + } + + /// Set transport type, i.e., Memory or Tcp transports. + pub fn transport(&mut self, transport: TransportType) -> &mut Self { + self.transport = transport; + self + } + + /// Set and address to advertise, if different from the listen address + pub fn advertised_address(&mut self, advertised_address: Multiaddr) -> &mut Self { + self.advertised_address = Some(advertised_address); + self + } + + /// Set trusted peers. + pub fn trusted_peers( + &mut self, + trusted_peers: HashMap, + ) -> &mut Self { + *self.trusted_peers.write().unwrap() = trusted_peers; + self + } + + /// Set signing keys of local node. + pub fn signing_keys(&mut self, keys: (PrivateKey, PublicKey)) -> &mut Self { + self.signing_keys = Some(keys); + self + } + + /// Set identity keys of local node. + pub fn identity_keys(&mut self, keys: (X25519PrivateKey, X25519PublicKey)) -> &mut Self { + self.identity_keys = Some(keys); + self + } + + /// Set seed peers to bootstrap discovery + pub fn seed_peers(&mut self, seed_peers: HashMap>) -> &mut Self { + self.seed_peers = seed_peers + .into_iter() + .map(|(peer_id, seed_addrs)| { + let mut peer_info = PeerInfo::new(); + peer_info.set_epoch(0); + peer_info.set_addrs( + seed_addrs + .into_iter() + .map(|addr| addr.as_ref().into()) + .collect(), + ); + (peer_id, peer_info) + }) + .collect(); + self + } + + /// Set discovery ticker interval + pub fn discovery_interval_ms(&mut self, discovery_interval_ms: u64) -> &mut Self { + self.discovery_interval_ms = discovery_interval_ms; + self + } + + /// Set ping interval. + pub fn ping_interval_ms(&mut self, ping_interval_ms: u64) -> &mut Self { + self.ping_interval_ms = ping_interval_ms; + self + } + + /// Set number of ping failures tolerated. + pub fn ping_failures_tolerated(&mut self, ping_failures_tolerated: u64) -> &mut Self { + self.ping_failures_tolerated = ping_failures_tolerated; + self + } + + /// Set ping timeout. + pub fn ping_timeout_ms(&mut self, ping_timeout_ms: u64) -> &mut Self { + self.ping_timeout_ms = ping_timeout_ms; + self + } + + /// Set discovery message timeout. + pub fn discovery_msg_timeout_ms(&mut self, discovery_msg_timeout_ms: u64) -> &mut Self { + self.discovery_msg_timeout_ms = discovery_msg_timeout_ms; + self + } + + /// Set connectivity check ticker interval + pub fn connectivity_check_interval_ms( + &mut self, + connectivity_check_interval_ms: u64, + ) -> &mut Self { + self.connectivity_check_interval_ms = connectivity_check_interval_ms; + self + } + + /// Set inbound rpc timeout. + pub fn inbound_rpc_timeout_ms(&mut self, inbound_rpc_timeout_ms: u64) -> &mut Self { + self.inbound_rpc_timeout_ms = inbound_rpc_timeout_ms; + self + } + + /// The maximum number of concurrent outbound rpc requests we will service. + pub fn max_concurrent_outbound_rpcs(&mut self, max_concurrent_outbound_rpcs: u32) -> &mut Self { + self.max_concurrent_outbound_rpcs = max_concurrent_outbound_rpcs; + self + } + + /// The maximum number of concurrent inbound rpc requests we will service. + pub fn max_concurrent_inbound_rpcs(&mut self, max_concurrent_inbound_rpcs: u32) -> &mut Self { + self.max_concurrent_inbound_rpcs = max_concurrent_inbound_rpcs; + self + } + + /// The maximum number of concurrent NetworkRequests we will service in NetworkProvider. + pub fn max_concurrent_network_reqs(&mut self, max_concurrent_network_reqs: u32) -> &mut Self { + self.max_concurrent_network_reqs = max_concurrent_network_reqs; + self + } + + /// The maximum number of concurrent Notifications from each actor we will service in + /// NetworkProvider. + pub fn max_concurrent_network_notifs( + &mut self, + max_concurrent_network_notifs: u32, + ) -> &mut Self { + self.max_concurrent_network_notifs = max_concurrent_network_notifs; + self + } + + /// Set the size of the channels between different network actors. + pub fn channel_size(&mut self, channel_size: usize) -> &mut Self { + self.channel_size = channel_size; + self + } + + /// Set the protocol IDs that Mempool subscribes. + pub fn mempool_protocols(&mut self, protocols: Vec) -> &mut Self { + self.mempool_protocols = protocols; + self + } + + /// Set the protocol IDs that Consensus subscribes. + pub fn consensus_protocols(&mut self, protocols: Vec) -> &mut Self { + self.consensus_protocols = protocols; + self + } + + /// Set the protocol IDs that DirectSend actor subscribes. + pub fn direct_send_protocols(&mut self, protocols: Vec) -> &mut Self { + self.direct_send_protocols = protocols; + self + } + + /// Set the protocol IDs that RPC actor subscribes. + pub fn rpc_protocols(&mut self, protocols: Vec) -> &mut Self { + self.rpc_protocols = protocols; + self + } + + fn supported_protocols(&self) -> Vec { + self.direct_send_protocols + .iter() + .chain(&self.rpc_protocols) + .chain(&vec![ + ProtocolId::from_static(DISCOVERY_PROTOCOL_NAME), + ProtocolId::from_static(PING_PROTOCOL_NAME), + ]) + .cloned() + .collect() + } + + /// Create the configured `NetworkBuilder` + /// Return the constructed Mempool and Consensus Sender+Events + pub fn build( + &mut self, + ) -> ( + (MempoolNetworkSender, MempoolNetworkEvents), + (ConsensusNetworkSender, ConsensusNetworkEvents), + Multiaddr, + ) { + let identity = Identity::new(self.peer_id, self.supported_protocols()); + // Build network based on the transport type + let own_identity_keys = self.identity_keys.take().unwrap(); + let trusted_peers = self.trusted_peers.clone(); + match self.transport { + TransportType::Memory => self.build_with_transport(build_memory_transport(identity)), + TransportType::MemoryNoise => self.build_with_transport(build_memory_noise_transport( + identity, + own_identity_keys, + trusted_peers, + )), + TransportType::Tcp => self.build_with_transport(build_tcp_transport(identity)), + TransportType::TcpNoise => self.build_with_transport(build_tcp_noise_transport( + identity, + own_identity_keys, + trusted_peers, + )), + } + } + + /// Given a transport build and launch the NetworkProvider and all subcomponents + /// Return the constructed Mempool and Consensus Sender+Events + fn build_with_transport( + &mut self, + transport: BoxedTransport< + (Identity, impl StreamMultiplexer + 'static), + impl ::std::error::Error + Send + Sync + 'static, + >, + ) -> ( + (MempoolNetworkSender, MempoolNetworkEvents), + (ConsensusNetworkSender, ConsensusNetworkEvents), + Multiaddr, + ) { + // Construct Mempool and Consensus network interfaces + let (network_reqs_tx, network_reqs_rx) = + channel::new(self.channel_size, &counters::PENDING_NETWORK_REQUESTS); + let (mempool_tx, mempool_rx) = + channel::new(self.channel_size, &counters::PENDING_MEMPOOL_NETWORK_EVENTS); + let (consensus_tx, consensus_rx) = channel::new( + self.channel_size, + &counters::PENDING_CONSENSUS_NETWORK_EVENTS, + ); + + let mempool_network_sender = MempoolNetworkSender::new(network_reqs_tx.clone()); + let mempool_network_events = MempoolNetworkEvents::new(mempool_rx); + let consensus_network_sender = ConsensusNetworkSender::new(network_reqs_tx.clone()); + let consensus_network_events = ConsensusNetworkEvents::new(consensus_rx); + // Initialize and start NetworkProvider. + let (pm_reqs_tx, pm_reqs_rx) = + channel::new(self.channel_size, &counters::PENDING_PEER_MANAGER_REQUESTS); + let (pm_net_notifs_tx, pm_net_notifs_rx) = channel::new( + self.channel_size, + &counters::PENDING_PEER_MANAGER_NET_NOTIFICATIONS, + ); + let (ds_reqs_tx, ds_reqs_rx) = + channel::new(self.channel_size, &counters::PENDING_DIRECT_SEND_REQUESTS); + let (ds_net_notifs_tx, ds_net_notifs_rx) = channel::new( + self.channel_size, + &counters::PENDING_DIRECT_SEND_NOTIFICATIONS, + ); + let (conn_mgr_reqs_tx, conn_mgr_reqs_rx) = channel::new( + self.channel_size, + &counters::PENDING_CONNECTIVITY_MANAGER_REQUESTS, + ); + let (rpc_reqs_tx, rpc_reqs_rx) = + channel::new(self.channel_size, &counters::PENDING_RPC_REQUESTS); + let (rpc_net_notifs_tx, rpc_net_notifs_rx) = + channel::new(self.channel_size, &counters::PENDING_RPC_NOTIFICATIONS); + + let mempool_handlers = self + .mempool_protocols + .iter() + .map(|p| (p.clone(), mempool_tx.clone())); + let consensus_handlers = self + .consensus_protocols + .iter() + .map(|p| (p.clone(), consensus_tx.clone())); + let upstream_handlers = mempool_handlers.chain(consensus_handlers).collect(); + + let validator_network = NetworkProvider::new( + pm_net_notifs_rx, + rpc_reqs_tx, + rpc_net_notifs_rx, + ds_reqs_tx, + ds_net_notifs_rx, + conn_mgr_reqs_tx.clone(), + network_reqs_rx, + upstream_handlers, + self.max_concurrent_network_reqs, + self.max_concurrent_network_notifs, + ); + self.executor + .spawn(validator_network.start().boxed().unit_error().compat()); + + // Initialize and start PeerManager. + let (pm_ds_notifs_tx, pm_ds_notifs_rx) = channel::new( + self.channel_size, + &counters::PENDING_PEER_MANAGER_DIRECT_SEND_NOTIFICATIONS, + ); + let (pm_rpc_notifs_tx, pm_rpc_notifs_rx) = channel::new( + self.channel_size, + &counters::PENDING_PEER_MANAGER_RPC_NOTIFICATIONS, + ); + let (pm_discovery_notifs_tx, pm_discovery_notifs_rx) = channel::new( + self.channel_size, + &counters::PENDING_PEER_MANAGER_DISCOVERY_NOTIFICATIONS, + ); + let (pm_ping_notifs_tx, pm_ping_notifs_rx) = channel::new( + self.channel_size, + &counters::PENDING_PEER_MANAGER_PING_NOTIFICATIONS, + ); + let (pm_conn_mgr_notifs_tx, pm_conn_mgr_notifs_rx) = channel::new( + self.channel_size, + &counters::PENDING_PEER_MANAGER_CONNECTIVITY_MANAGER_NOTIFICATIONS, + ); + + let direct_send_handlers = self + .direct_send_protocols + .iter() + .map(|p| (p.clone(), pm_ds_notifs_tx.clone())); + let rpc_handlers = self + .rpc_protocols + .iter() + .map(|p| (p.clone(), pm_rpc_notifs_tx.clone())); + let discovery_handler = vec![( + ProtocolId::from_static(DISCOVERY_PROTOCOL_NAME), + pm_discovery_notifs_tx.clone(), + )]; + let ping_handler = vec![( + ProtocolId::from_static(PING_PROTOCOL_NAME), + pm_ping_notifs_tx.clone(), + )]; + let protocol_handlers = direct_send_handlers + .chain(rpc_handlers) + .chain(discovery_handler) + .chain(ping_handler) + .collect(); + + let peer_mgr = PeerManager::new( + transport, + self.executor.clone(), + self.peer_id, + self.addr.clone(), + pm_reqs_rx, + protocol_handlers, + vec![ + pm_net_notifs_tx, + pm_conn_mgr_notifs_tx, + pm_ping_notifs_tx, + pm_discovery_notifs_tx, + ], + ); + let listen_addr = peer_mgr.listen_addr().clone(); + self.executor + .spawn(peer_mgr.start().boxed().unit_error().compat()); + + // Initialize and start DirectSend actor. + let ds = DirectSend::new( + self.executor.clone(), + ds_reqs_rx, + ds_net_notifs_tx, + pm_ds_notifs_rx, + PeerManagerRequestSender::new(pm_reqs_tx.clone()), + ); + self.executor + .spawn(ds.start().boxed().unit_error().compat()); + + // Initialize and start RPC actor. + let rpc = Rpc::new( + rpc_reqs_rx, + pm_rpc_notifs_rx, + PeerManagerRequestSender::new(pm_reqs_tx.clone()), + rpc_net_notifs_tx, + Duration::from_millis(self.inbound_rpc_timeout_ms), + self.max_concurrent_outbound_rpcs, + self.max_concurrent_inbound_rpcs, + ); + self.executor + .spawn(rpc.start().boxed().unit_error().compat()); + + let conn_mgr = ConnectivityManager::new( + self.trusted_peers.clone(), + Compat01As03::new(Interval::new_interval(Duration::from_millis( + self.connectivity_check_interval_ms, + ))) + .fuse(), + PeerManagerRequestSender::new(pm_reqs_tx.clone()), + pm_conn_mgr_notifs_rx, + conn_mgr_reqs_rx, + ); + self.executor + .spawn(conn_mgr.start().boxed().unit_error().compat()); + + // Setup signer from keys. + let (signing_private_key, signing_public_key) = self.signing_keys.take().unwrap(); + let signer = ValidatorSigner::new(self.peer_id, signing_public_key, signing_private_key); + // Initialize and start Discovery actor. + let discovery = Discovery::new( + self.peer_id, + vec![self + .advertised_address + .clone() + .unwrap_or_else(|| listen_addr.clone())], + signer, + self.seed_peers.clone(), + self.trusted_peers.clone(), + Compat01As03::new(Interval::new_interval(Duration::from_millis( + self.discovery_interval_ms, + ))) + .fuse(), + PeerManagerRequestSender::new(pm_reqs_tx.clone()), + pm_discovery_notifs_rx, + conn_mgr_reqs_tx.clone(), + Duration::from_millis(self.discovery_msg_timeout_ms), + ); + self.executor + .spawn(discovery.start().boxed().unit_error().compat()); + + // Initialize and start HealthChecker. + let health_checker = HealthChecker::new( + Compat01As03::new(Interval::new_interval(Duration::from_millis( + self.ping_interval_ms, + ))) + .fuse(), + PeerManagerRequestSender::new(pm_reqs_tx.clone()), + pm_ping_notifs_rx, + Duration::from_millis(self.ping_timeout_ms), + self.ping_failures_tolerated, + ); + self.executor + .spawn(health_checker.start().boxed().unit_error().compat()); + + ( + (mempool_network_sender, mempool_network_events), + (consensus_network_sender, consensus_network_events), + listen_addr, + ) + } +} diff --git a/network/src/validator_network/test.rs b/network/src/validator_network/test.rs new file mode 100644 index 0000000000000..4bf5e25760936 --- /dev/null +++ b/network/src/validator_network/test.rs @@ -0,0 +1,323 @@ +// Copyright (c) The Libra Core Contributors +// SPDX-License-Identifier: Apache-2.0 + +//! Integration tests for validator_network. + +use crate::{ + common::NetworkPublicKeys, + proto::{ConsensusMsg, MempoolSyncMsg, RequestBlock, RespondBlock, SignedTransaction}, + validator_network::{ + network_builder::{NetworkBuilder, TransportType}, + Event, CONSENSUS_RPC_PROTOCOL, MEMPOOL_DIRECT_SEND_PROTOCOL, + }, + ProtocolId, +}; +use crypto::{ + signing::{self, generate_keypair}, + x25519, +}; +use futures::{executor::block_on, future::join, StreamExt}; +use parity_multiaddr::Multiaddr; +use protobuf::Message as proto_msg; +use std::{collections::HashMap, time::Duration}; +use tokio::runtime::Runtime; +use types::{ + account_address::{AccountAddress, ADDRESS_LENGTH}, + test_helpers::transaction_test_helpers::get_test_signed_txn, + PeerId, +}; + +#[test] +fn test_network_builder() { + let runtime = Runtime::new().unwrap(); + let peer_id = PeerId::random(); + let addr: Multiaddr = "/memory/0".parse().unwrap(); + let mempool_sync_protocol = ProtocolId::from_static(MEMPOOL_DIRECT_SEND_PROTOCOL); + let consensus_get_blocks_protocol = ProtocolId::from_static(b"get_blocks"); + let synchronizer_get_chunks_protocol = ProtocolId::from_static(b"get_chunks"); + let (signing_private_key, signing_public_key) = signing::generate_keypair(); + let (identity_private_key, identity_public_key) = x25519::generate_keypair(); + + let ( + (_mempool_network_sender, _mempool_network_events), + (_consensus_network_sender, _consensus_network_events), + _listen_addr, + ) = NetworkBuilder::new(runtime.executor(), peer_id, addr) + .transport(TransportType::Memory) + .signing_keys((signing_private_key, signing_public_key)) + .identity_keys((identity_private_key, identity_public_key)) + .trusted_peers( + vec![( + peer_id, + NetworkPublicKeys { + signing_public_key, + identity_public_key, + }, + )] + .into_iter() + .collect(), + ) + .channel_size(8) + .consensus_protocols(vec![consensus_get_blocks_protocol.clone()]) + .mempool_protocols(vec![mempool_sync_protocol.clone()]) + .direct_send_protocols(vec![mempool_sync_protocol]) + .rpc_protocols(vec![ + consensus_get_blocks_protocol, + synchronizer_get_chunks_protocol, + ]) + .build(); +} + +#[test] +fn test_mempool_sync() { + ::logger::try_init_for_testing(); + let runtime = Runtime::new().unwrap(); + let mempool_sync_protocol = ProtocolId::from_static(MEMPOOL_DIRECT_SEND_PROTOCOL); + + // Setup peer ids. + let listener_peer_id = PeerId::random(); + let dialer_peer_id = PeerId::random(); + // Setup signing public keys. + let (listener_signing_private_key, listener_signing_public_key) = signing::generate_keypair(); + let (dialer_signing_private_key, dialer_signing_public_key) = signing::generate_keypair(); + // Setup identity public keys. + let (listener_identity_private_key, listener_identity_public_key) = x25519::generate_keypair(); + let (dialer_identity_private_key, dialer_identity_public_key) = x25519::generate_keypair(); + + // Set up the listener network + let listener_addr: Multiaddr = "/memory/0".parse().unwrap(); + + let trusted_peers: HashMap<_, _> = vec![ + ( + listener_peer_id, + NetworkPublicKeys { + signing_public_key: listener_signing_public_key, + identity_public_key: listener_identity_public_key, + }, + ), + ( + dialer_peer_id, + NetworkPublicKeys { + signing_public_key: dialer_signing_public_key, + identity_public_key: dialer_identity_public_key, + }, + ), + ] + .into_iter() + .collect(); + + let ((_, mut listener_mp_net_events), _, listener_addr) = + NetworkBuilder::new(runtime.executor(), listener_peer_id, listener_addr) + .signing_keys((listener_signing_private_key, listener_signing_public_key)) + .identity_keys((listener_identity_private_key, listener_identity_public_key)) + .trusted_peers(trusted_peers.clone()) + .transport(TransportType::Memory) + .channel_size(8) + .mempool_protocols(vec![mempool_sync_protocol.clone()]) + .direct_send_protocols(vec![mempool_sync_protocol.clone()]) + .build(); + + // Set up the dialer network + let dialer_addr: Multiaddr = "/memory/0".parse().unwrap(); + + let ((mut dialer_mp_net_sender, mut dialer_mp_net_events), _, _dialer_addr) = + NetworkBuilder::new(runtime.executor(), dialer_peer_id, dialer_addr) + .transport(TransportType::Memory) + .signing_keys((dialer_signing_private_key, dialer_signing_public_key)) + .identity_keys((dialer_identity_private_key, dialer_identity_public_key)) + .trusted_peers(trusted_peers.clone()) + .seed_peers( + [(listener_peer_id, vec![listener_addr])] + .iter() + .cloned() + .collect(), + ) + .channel_size(8) + .mempool_protocols(vec![mempool_sync_protocol.clone()]) + .direct_send_protocols(vec![mempool_sync_protocol.clone()]) + .build(); + + // The dialer dials the listener and sends a mempool sync message + let mut mempool_msg = MempoolSyncMsg::new(); + mempool_msg.set_peer_id(dialer_peer_id.into()); + let sender = AccountAddress::new([0; ADDRESS_LENGTH]); + let keypair = generate_keypair(); + let txn = get_test_signed_txn(sender, 0, keypair.0, keypair.1, None); + mempool_msg.set_transactions(::protobuf::RepeatedField::from_vec(vec![txn.clone()])); + + let f_dialer = async move { + // Wait until dialing finished and NewPeer event received + match dialer_mp_net_events.next().await.unwrap().unwrap() { + Event::NewPeer(peer_id) => { + assert_eq!(peer_id, listener_peer_id.into()); + } + event => panic!("Unexpected event {:?}", event), + } + + // Dialer sends a mempool sync message + dialer_mp_net_sender + .send_to(listener_peer_id.into(), mempool_msg) + .await + .unwrap(); + }; + + // The listener receives a mempool sync message + let f_listener = async move { + // The listener receives a NewPeer event first + match listener_mp_net_events.next().await.unwrap().unwrap() { + Event::NewPeer(peer_id) => { + assert_eq!(peer_id, dialer_peer_id.into()); + } + event => panic!("Unexpected event {:?}", event), + } + + // The listener then receives the mempool sync message + match listener_mp_net_events.next().await.unwrap().unwrap() { + Event::Message((peer_id, msg)) => { + assert_eq!(peer_id, dialer_peer_id.into()); + let dialer_peer_id_bytes = Vec::from(&dialer_peer_id); + assert_eq!(msg.peer_id, dialer_peer_id_bytes); + let transactions: Vec = msg.transactions.into(); + assert_eq!(transactions, vec![txn]); + } + event => panic!("Unexpected event {:?}", event), + } + }; + + block_on(join(f_dialer, f_listener)); +} + +#[test] +fn test_consensus_rpc() { + ::logger::try_init_for_testing(); + let runtime = Runtime::new().unwrap(); + let rpc_protocol = ProtocolId::from_static(CONSENSUS_RPC_PROTOCOL); + + // Setup peer ids. + let listener_peer_id = PeerId::random(); + let dialer_peer_id = PeerId::random(); + // Setup signing public keys. + let (listener_signing_private_key, listener_signing_public_key) = signing::generate_keypair(); + let (dialer_signing_private_key, dialer_signing_public_key) = signing::generate_keypair(); + // Setup identity public keys. + let (listener_identity_private_key, listener_identity_public_key) = x25519::generate_keypair(); + let (dialer_identity_private_key, dialer_identity_public_key) = x25519::generate_keypair(); + + // Set up the listener network + let listener_addr: Multiaddr = "/memory/0".parse().unwrap(); + + let trusted_peers: HashMap<_, _> = vec![ + ( + listener_peer_id, + NetworkPublicKeys { + signing_public_key: listener_signing_public_key, + identity_public_key: listener_identity_public_key, + }, + ), + ( + dialer_peer_id, + NetworkPublicKeys { + signing_public_key: dialer_signing_public_key, + identity_public_key: dialer_identity_public_key, + }, + ), + ] + .into_iter() + .collect(); + + let (_, (_, mut listener_con_net_events), listener_addr) = + NetworkBuilder::new(runtime.executor(), listener_peer_id, listener_addr) + .signing_keys((listener_signing_private_key, listener_signing_public_key)) + .identity_keys((listener_identity_private_key, listener_identity_public_key)) + .trusted_peers(trusted_peers.clone()) + .transport(TransportType::Memory) + .channel_size(8) + .consensus_protocols(vec![rpc_protocol.clone()]) + .rpc_protocols(vec![rpc_protocol.clone()]) + .build(); + + // Set up the dialer network + let dialer_addr: Multiaddr = "/memory/0".parse().unwrap(); + + let (_, (mut dialer_con_net_sender, mut dialer_con_net_events), _dialer_addr) = + NetworkBuilder::new(runtime.executor(), dialer_peer_id, dialer_addr) + .transport(TransportType::Memory) + .signing_keys((dialer_signing_private_key, dialer_signing_public_key)) + .identity_keys((dialer_identity_private_key, dialer_identity_public_key)) + .trusted_peers(trusted_peers.clone()) + .seed_peers( + [(listener_peer_id, vec![listener_addr])] + .iter() + .cloned() + .collect(), + ) + .channel_size(8) + .consensus_protocols(vec![rpc_protocol.clone()]) + .rpc_protocols(vec![rpc_protocol.clone()]) + .build(); + + let block_id = vec![0_u8; 32]; + let mut req_block_msg = RequestBlock::new(); + req_block_msg.set_block_id(block_id.into()); + + let res_block_msg = RespondBlock::new(); + + // The dialer dials the listener and sends a RequestBlock rpc request + let req_block_msg_clone = req_block_msg.clone(); + let res_block_msg_clone = res_block_msg.clone(); + let f_dialer = async move { + // Wait until dialing finished and NewPeer event received + match dialer_con_net_events.next().await.unwrap().unwrap() { + Event::NewPeer(peer_id) => { + assert_eq!(peer_id, listener_peer_id.into()); + } + event => panic!("Unexpected event {:?}", event), + } + + // Dialer sends a RequestBlock rpc request. + let res_block_msg = dialer_con_net_sender + .request_block( + listener_peer_id.into(), + req_block_msg_clone, + Duration::from_secs(10), + ) + .await + .unwrap(); + assert_eq!(res_block_msg, res_block_msg_clone); + }; + + // The listener receives a RequestBlock rpc request and sends a RespondBlock + // rpc response. + let req_block_msg_clone = req_block_msg.clone(); + let res_block_msg_clone = res_block_msg.clone(); + let f_listener = async move { + // The listener receives a NewPeer event first + match listener_con_net_events.next().await.unwrap().unwrap() { + Event::NewPeer(peer_id) => { + assert_eq!(peer_id, dialer_peer_id.into()); + } + event => panic!("Unexpected event {:?}", event), + } + + // The listener then handles the RequestBlock rpc request. + match listener_con_net_events.next().await.unwrap().unwrap() { + Event::RpcRequest((peer_id, mut req_msg, res_tx)) => { + assert_eq!(peer_id, dialer_peer_id.into()); + + // Check the request + assert!(req_msg.has_request_block()); + let req_block_msg = req_msg.take_request_block(); + assert_eq!(req_block_msg, req_block_msg_clone); + + // Send the serialized response back. + let mut res_msg = ConsensusMsg::new(); + res_msg.set_respond_block(res_block_msg_clone); + let res_data = res_msg.write_to_bytes().unwrap().into(); + res_tx.send(Ok(res_data)).unwrap(); + } + event => panic!("Unexpected event {:?}", event), + } + }; + + block_on(join(f_dialer, f_listener)); +} diff --git a/rust-toolchain b/rust-toolchain new file mode 100644 index 0000000000000..f677ace0c5d00 --- /dev/null +++ b/rust-toolchain @@ -0,0 +1 @@ +nightly-2019-05-22 diff --git a/rustfmt.toml b/rustfmt.toml new file mode 100644 index 0000000000000..5a458a086ee9d --- /dev/null +++ b/rustfmt.toml @@ -0,0 +1,10 @@ +comment_width = 100 +edition = "2018" +merge_imports = true +use_field_init_shorthand = true +wrap_comments = true + +# Ignore generated files. +ignore = [ + "language/compiler/src/parser/syntax.rs", +] diff --git a/scripts/cargo_check_dependencies.sh b/scripts/cargo_check_dependencies.sh new file mode 100755 index 0000000000000..5fe6394791fd7 --- /dev/null +++ b/scripts/cargo_check_dependencies.sh @@ -0,0 +1,37 @@ +#!/bin/bash + +# Copyright (c) The Libra Core Contributors +# SPDX-License-Identifier: Apache-2.0 + +# This script assumes it runs in the same directory as a Cargo.toml file +# and sees if this Cargo.toml file can operate without some of its +# dependencies using repeated `cargo check --all-targets` attempts. +# +# In order to run this in a directoy containing multiple Cargo.toml files, +# we could suggest: +# find ./ -name Cargo.toml -execdir /path/to/$(basename $0) \; + +# Requirements: +# https://github.com/killercup/cargo-edit << `cargo install cargo-edit` +# awk, bash, git +# This will make one local commit per removable dependency. It is advised to +# squash those dependency-removing commits into a single one. + +if [ ! -f Cargo.toml ]; then + echo "Cargo.toml not found! Are you running this script in the right directory?" +fi + +dependencies=$(awk 'x==1 {print $1} /\[dependencies\]/ {x=1}' Cargo.toml| grep -v '^$'|grep -v '^\[.+\]$') +echo "$dependencies" +for i in $dependencies; do + echo "testing removal of $i" + cargo rm "$i"; + cargo check --all-targets + if (( $? == 0 )); then + echo "removal succeeded, committing" + git commit -m "Removing $i" --all; + else + echo "removal failed, rolling back" + git reset --hard + fi +done diff --git a/scripts/cargo_update_outdated.sh b/scripts/cargo_update_outdated.sh new file mode 100755 index 0000000000000..a2532c2b934ed --- /dev/null +++ b/scripts/cargo_update_outdated.sh @@ -0,0 +1,31 @@ +#!/bin/bash + +# Copyright (c) The Libra Core Contributors +# SPDX-License-Identifier: Apache-2.0 + +# This script modifies local cargo files to reflect compatibility (semver) +# upgrades to direct dependencies. It does not allow non-compatible +# updates as those should be done manually (reading the output of cargo outdated) + +# It requires cargo-edit and cargo-outdated +# Example usage: +# libra$ ./scripts/cargo_update_outdated.sh +# libra$ git commit --all -m "Update dependencies" +set -e + +# check install for outdated & edit +if ! $(cargo install --list | grep -qe 'cargo-outdated') +then + cargo install cargo-outdated +fi + +if ! $(cargo install --list | grep -qe 'cargo-edit') +then + cargo install cargo-edit +fi + +for upgrade in $(cargo outdated | awk 'NF >2 && $2 ~ /[0-9\.]+/ && $3 ~ /[0-9\.]+/ {print $1"@"$3}' | uniq |tr '\n' " ") +do + echo $upgrade + cargo -q upgrade $upgrade --all > /dev/null +done diff --git a/scripts/cli/start_cli_testnet.sh b/scripts/cli/start_cli_testnet.sh new file mode 100755 index 0000000000000..a45b3dca44724 --- /dev/null +++ b/scripts/cli/start_cli_testnet.sh @@ -0,0 +1,28 @@ +#!/bin/bash + +print_help() +{ + echo "Build client binary and connect to testnet." + echo "\`$0 -r|--release\` to use release build or" + echo "\`$0\` to use debug build." +} + +source "$HOME/.cargo/env" + +SCRIPT_PATH="$(dirname $0)" + +RUN_PARAMS="--host ac.testnet.libra.org --port 80 -s $SCRIPT_PATH/trusted_peers.config.toml" + +case $1 in + -h | --help) + print_help;exit 0;; + -r | --release) + echo "Building and running client in release mode." + cargo run -p client --release -- $RUN_PARAMS + ;; + '') + echo "Building and running client in debug mode." + cargo run -p client -- $RUN_PARAMS + ;; + *) echo "Invalid option"; print_help; exit 0; +esac diff --git a/scripts/cli/trusted_peers.config.toml b/scripts/cli/trusted_peers.config.toml new file mode 100644 index 0000000000000..e4ae247088b28 --- /dev/null +++ b/scripts/cli/trusted_peers.config.toml @@ -0,0 +1,49 @@ +[peers.9102bd7b1ad7e8f31023c500371cc7d2971758b450cfa89c003efb3ab192a4b8] +network_signing_pubkey = "20000000000000005f5ecda9576edd942ed22aa4735939092161445177cd456fd087c7bc1d6de403" +network_identity_pubkey = "2000000000000000b5eb9a2e5814c66df6c01a1dc94252a4ae6733e93a58187c5eb48d1f53be0b28" +consensus_pubkey = "2000000000000000576e91b04632683a11c3be3dc47a19f9f0a31ae947211f59c5fe02dfa2d07d68" + +[peers.dfb9c683d1788857e961160f28d4c9c79b23f042c80f770f37f0f93ee5fa6a96] +network_signing_pubkey = "2000000000000000246ca919a3b39c95110e3bee891136ab087a9b3b9e84fa90cbf8f19c8abe62e3" +network_identity_pubkey = "20000000000000008aa297d686dd2444de86ea3a68353d74af74b9659990d06ccaf4344e2b629b33" +consensus_pubkey = "20000000000000003ca1400fb865befa8a21c58e90fc636ef2f84993a8396cb0e10008f876a00afd" + +[peers.26873decd9330065988b0acf5027662b5097fb50913ae2a2652b50a9869df4fb] +network_signing_pubkey = "20000000000000001512ea6e18f7c3069372d6883220aadb8fa525f9b6bb4d0eb473fc94a19ca87b" +network_identity_pubkey = "2000000000000000d19114154b3cbb76434d5fb5f50d215eab6ddf2b67907850118c8b3525896b11" +consensus_pubkey = "2000000000000000ada40b3c2039b0fe06e31f88c13505acc3d0731d9e6a09410f463a5c02ce5ef7" + +[peers.19f93314cbe8c0925a4492eb2f2f197ce6e11717449c218f50e043e37fa7a5f3] +network_signing_pubkey = "2000000000000000f6c7c31b68157b839cfc662cd9330ea5d27ed03d0b1ec9b2970c05fc66cd80d0" +network_identity_pubkey = "2000000000000000f0799c6e2b843066d5a23f42f775dce008fe92b9a0e8e9bf1208e83f5cec883a" +consensus_pubkey = "20000000000000000397615aa5cc4cedadaa870511f381423e914012f5e341f474f3f608e0224e1d" + +[peers.f9770caa0be0c0ad427f204c22a2c2d7b22ee373a1b9cf6fd768fbf48a079554] +network_signing_pubkey = "2000000000000000d844e3eb78ae751157af04c844c80bf55bc2bd7d2d7feea9b22b6ca833e6ac30" +network_identity_pubkey = "20000000000000003f64b2233f8eb629f4b00724099895b819dd23e8986fdb5774bf9e176c9c8f26" +consensus_pubkey = "20000000000000006d9e7d0d6c1c0acf7884ab8f8258f9f9dca74ca20653422c09652e348cf66b78" + +[peers.3b7c756cce9ad7d801b078a08ee91df5f8122e44011b4fdf6d6c20c016823b8f] +network_signing_pubkey = "2000000000000000d00667dcedc2c8359a261df5844408fe17fb768c33e211835f98893cfa304c91" +network_identity_pubkey = "2000000000000000d7eae84059a32ed1e2723909f8e05a3fe4ff96e3e496c8cb64917cd14e73ee3d" +consensus_pubkey = "2000000000000000ef84908705e82f835d665a94aa39990b101de671935a108c7d979fec91cbac10" + +[peers.6687e9a6e6c3de0dc4f7e91eacbc676a228a9c0c46450bbccbd1072780bfcb30] +network_signing_pubkey = "2000000000000000dbe17f5fce01ad52dc78574d63124c108102291a042e332b4c82cf30c62ddca6" +network_identity_pubkey = "20000000000000005c538501145dbc5e24b822edb1d0543ee63f00a022d2bd7b3e72f5468b12e82a" +consensus_pubkey = "200000000000000062767f364921a4446a8d141298362497ab74b6cf697f2aa49eb2916b41a1b1cc" + +[peers.c28b953590c58117ae8431456ea28f67c2f9e1733078b208e1a7bd5bd4081e9e] +network_signing_pubkey = "2000000000000000865a2c72cee4ede366bfdc391420885e4dc49f924506778d214dc3cf6d09c8c4" +network_identity_pubkey = "2000000000000000727bef5a30048db7000bbc54e6eb8dfe74b1fbac4fc627c8b3c78153e2e19805" +consensus_pubkey = "2000000000000000b62aa0cde6584521a9c5ba81fa0fe659eba8987356db81fc6f54e812ed0c2437" + +[peers.4d78ab90b759ecacafe4e687c5db9cc2936a7a29c84aa8be777f54db519d756b] +network_signing_pubkey = "2000000000000000b0c0773494d1e87dd15e34a2636dfcaed771b01353227c69ab71c175a5ed437a" +network_identity_pubkey = "20000000000000008383c5dbe5bc2888534e063c626084e58ba57661c62aa557022b807ee9838c02" +consensus_pubkey = "2000000000000000cec7be0f4808b68823ffe3db481564568a37f91ccdaa3bfc5d0d31b664e53695" + +[peers.4995559c4844b7e4101c486035329402a8ba4976c1be23080bac5bddf6a60f0d] +network_signing_pubkey = "2000000000000000e1a61336ee0a122e1c86e228dfc5067f24941b76a34da400af4e3a4da8da04a6" +network_identity_pubkey = "200000000000000018ce05c58c4109cb123a9bfdf57b541b84814e7504f1aaeb09d23e05d20fb155" +consensus_pubkey = "2000000000000000d29b1294c223ec10749df7fee80c2ddf39c1fb32f1a89ae028e835b0502c19a0" diff --git a/scripts/clippy.args b/scripts/clippy.args new file mode 100644 index 0000000000000..08404192c63f5 --- /dev/null +++ b/scripts/clippy.args @@ -0,0 +1,17 @@ +# Copyright (c) The Libra Core Contributors +# SPDX-License-Identifier: Apache-2.0 + +# Allowed lints +allowed_lints=( + "-A" "renamed_and_removed_lints" + "-A" "clippy::match_bool" + "-A" "clippy::get_unwrap" + "-A" "clippy::new_without_default" + + # Added with compiler bump nightly-2019-04-13 + "-A" "clippy::identity-conversion" + # Clippy seems to complain about async functions + "-A" "clippy::needless_lifetimes" +) + +echo "${allowed_lints[@]}" diff --git a/scripts/clippy.sh b/scripts/clippy.sh new file mode 100755 index 0000000000000..b030667a6afd6 --- /dev/null +++ b/scripts/clippy.sh @@ -0,0 +1,17 @@ +#!/bin/bash + +# Copyright (c) The Libra Core Contributors +# SPDX-License-Identifier: Apache-2.0 + +set -e + +SCRIPT_PATH="$( cd "$( dirname "${BASH_SOURCE[0]}" )" >/dev/null 2>&1 && pwd )" + +# Workaround https://github.com/rust-lang/rust-clippy/issues/2604 +if [[ "$1" == *"c"* ]]; then + cargo clean +fi + +# Run 'clippy' on all targets, ensuring that all warnings trigger a +# failure. +cargo clippy --all --all-targets -- -D warnings $(source "$SCRIPT_PATH/clippy.args") diff --git a/scripts/dev_setup.sh b/scripts/dev_setup.sh new file mode 100755 index 0000000000000..c369ad4460723 --- /dev/null +++ b/scripts/dev_setup.sh @@ -0,0 +1,138 @@ +#!/bin/bash +# This script sets up the environment for the Libra build by installing necessary dependencies. +# +# Usage ./dev_setup.sh +# v - verbose, print all statements + +SCRIPT_PATH="$( cd "$( dirname "${BASH_SOURCE[0]}" )" >/dev/null 2>&1 && pwd )" +cd "$SCRIPT_PATH/.." + +set -e +OPTIONS="$1" + +if [[ $OPTIONS == *"v"* ]]; then + set -x +fi + +if [ ! -f Cargo.toml ]; then + echo "Unknown location. Please run this from the libra repository. Abort." + exit 1 +fi + +PACKAGE_MANAGER= +if [[ "$OSTYPE" == "linux-gnu" ]]; then + if which yum &>/dev/null; then + PACKAGE_MANAGER="yum" + elif which apt-get &>/dev/null; then + PACKAGE_MANAGER="apt-get" + else + echo "Unable to find supported package manager (yum or apt-get). Abort" + exit 1 + fi +elif [[ "$OSTYPE" == "darwin"* ]]; then + if which brew &>/dev/null; then + PACKAGE_MANAGER="brew" + else + echo "Missing package manager Homebrew (https://brew.sh/). Abort" + exit 1 + fi +else + echo "Unknown OS. Abort." + exit 1 +fi + +cat < " +read -e input +if [[ "$input" != "y"* ]]; then + echo "Exiting..." + exit 0 +fi + +# Install Rust +echo "Installing Rust......" +if rustup --version &>/dev/null; then + echo "Rust is already installed" +else + curl https://sh.rustup.rs -sSf | sh -s -- -y --default-toolchain stable + CARGO_ENV="$HOME/.cargo/env" + source "$CARGO_ENV" +fi + +# Run update in order to download and install the checked in toolchain +rustup update + +# Add all the components that we need +rustup component add rustfmt +rustup component add clippy + +echo "Installing CMake......" +if which cmake &>/dev/null; then + echo "CMake is already installed" +else + if [[ "$PACKAGE_MANAGER" == "yum" ]]; then + sudo yum install cmake + elif [[ "$PACKAGE_MANAGER" == "apt-get" ]]; then + sudo apt-get install cmake + elif [[ "$PACKAGE_MANAGER" == "brew" ]]; then + brew install cmake + fi +fi + +echo "Installing Go......" +if which go &>/dev/null; then + echo "Go is already installed" +else + if [[ "$PACKAGE_MANAGER" == "yum" ]]; then + sudo yum install golang + elif [[ "$PACKAGE_MANAGER" == "apt-get" ]]; then + sudo apt-get install golang + elif [[ "$PACKAGE_MANAGER" == "brew" ]]; then + brew install go + fi +fi + +echo "Installing Protobuf......" +if which protoc &>/dev/null; then + echo "Protobuf is already installed" +else + if [[ "$PACKAGE_MANAGER" == "yum" ]]; then + sudo yum install protobuf + elif [[ "$PACKAGE_MANAGER" == "apt-get" ]]; then + sudo apt-get install protobuf-compiler + elif [[ "$PACKAGE_MANAGER" == "brew" ]]; then + brew install protobuf + fi +fi + +if [[ -f /etc/debian_version ]]; then + PROTOC_VERSION=`protoc --version | cut -d" " -f2` + REQUIRED_PROTOC_VERSION="3.6.0" + PROTOC_VERSION_CHECK=`dpkg --compare-versions $REQUIRED_PROTOC_VERSION "gt" $PROTOC_VERSION` + + if [ $? -eq "0" ]; then + echo "protoc version is too old. Update protoc to 3.6.0 or above. Abort" + exit 1 + fi +fi + +cat </dev/null 2>&1; then + echo "Disallowed Nightly Features Found:" + git grep -e"#\!\[feature(.*)\]" "${allowed_features[@]}" -- "*.rs" + exit 1 +else + exit 0 +fi diff --git a/storage/README.md b/storage/README.md new file mode 100644 index 0000000000000..11650c2d1f5e0 --- /dev/null +++ b/storage/README.md @@ -0,0 +1,99 @@ +--- +id: storage +title: Storage +custom_edit_url: https://github.com/libra/libra/edit/master/storage/README.md +--- + +# Storage + +The storage module provides reliable and efficient persistent storage for the +entire set of data on the Libra Blockchain, as well as the necessary data used +internally by Libra Core. + +## Overview + +The storage module is designed to serve two primary purposes: + +1. Persist the blockchain data, specifically the transactions and their outputs + that have been agreed by validators via consensus protocol. +2. Provide a response with Merkle proofs to any query that asks for a part of the + blockchain data. A client can easily verify the integrity of the response if + they have obtained the correct root hash. + +The Libra Blockchain can be viewed as a Merkle tree consisting of the following +components: + +![data](assets/data.png) + +### Ledger History + +Ledger history is represented by a Merkle accumulator. Each time a transaction +`T` is added to the blockchain, a *TransactionInfo* structure containing the +transaction `T`, the root hash for the state Merkle tree after the execution of +`T` and the root hash for the event Merkle tree generated by `T` is appended to +the accumulator. + +### Ledger State + +The ledger state at each version is represented by a sparse Merkle tree that has the +state of all accounts. The keys are the 256-bit hash of the addresses, and their +corresponding value is the state of the entire account serialized as a binary +blob. While a tree of size `2^256` is an intractable representation, subtrees +consisting entirely of empty nodes are replaced with a placeholder value, and +subtrees consisting of exactly one leaf are replaced with a single node. + +While each *TransactionInfo* structure points to a different state tree, the new +tree can reuse unchanged portion of the previous tree, forming a persistent data +structure. + +### Events + +Each transaction emits a list of events and those events form a Merkle accumulator. +Similar to the state Merkle tree, the root hash of the event accumulator of a +transaction is recorded in the corresponding *TransactionInfo* structure. + +### Ledger Info and Signatures + +A *LedgerInfo* structure that has the root hash of the ledger history +accumulator at some version and other metadata is a binding commitment to +the ledger history up to this version. Validators sign the corresponding +*LedgerInfo* structure every time they agree on a set of transactions and their +execution outcome. For each *LedgerInfo* structure that is stored, a set of +signatures on this structure from validators are also stored, so +clients can verify the structure if they have obtained the public key of each +validator. + +## Implementation Details + +The storage module uses [RocksDB](https://rocksdb.org/) as its physical storage +engine. Since the storage module needs to store multiple types of data, and +key-value pairs in RocksDB are byte arrays, there is a wrapper on top of RocksDB +to deal with the serialization of keys and values. This wrapper enforces that all data in and +out of the DB is structured according to predefined schemas. + +The core module that implements the main functionalities is called *LibraDB*. +While we use a single RocksDB instance to store the entire set of data, related +data are grouped into logical stores — for example, ledger store, state store, +and transaction store, etc. + +For the sparse Merkle tree that represents ledger state, we optimize the disk +layout by using branch nodes with 16 children that represents 4-level subtrees +and extension nodes that represents a path without branches. However, we still +simulate a binary tree when computing the root hash and proofs. This modification +results in proofs that are shorter than the ones generated by Ethereum's Merkle +Patricia tree. + +## How is this module organized? +``` + storage + └── accumulator # Implementation of Merkle accumulator. + └── libradb # Implementation of LibraDB. + └── schemadb # Schematized wrapper on top of RocksDB. + └── scratchpad # In-memory representation of Libra core data structures used by execution. + └── sparse_merkle # Implementation of sparse Merkle tree. + └── state_view # An abstraction layer representing a snapshot of state where the Move VM reads data. + └── storage_client # A Rust wrapper on top of GRPC clients. + └── storage_proto # All interfaces provided by the storage module. + └── storage_service # Storage module as a GRPC service. +``` + diff --git a/storage/accumulator/Cargo.toml b/storage/accumulator/Cargo.toml new file mode 100644 index 0000000000000..3ba0c072fdd3c --- /dev/null +++ b/storage/accumulator/Cargo.toml @@ -0,0 +1,17 @@ +[package] +name = "accumulator" +version = "0.1.0" +authors = ["Libra Association "] +license = "Apache-2.0" +publish = false +edition = "2018" + +[dependencies] +proptest = "0.9.1" + +crypto = { path = "../../crypto/legacy_crypto" } +failure = { path = "../../common/failure_ext", package = "failure_ext" } +types = { path = "../../types" } + +[dev-dependencies] +rand = "0.6.5" diff --git a/storage/accumulator/src/lib.rs b/storage/accumulator/src/lib.rs new file mode 100644 index 0000000000000..30bba5222ab98 --- /dev/null +++ b/storage/accumulator/src/lib.rs @@ -0,0 +1,321 @@ +// Copyright (c) The Libra Core Contributors +// SPDX-License-Identifier: Apache-2.0 + +//! This module provides algorithms for accessing and updating a Merkle Accumulator structure +//! persisted in a key-value store. Note that this doesn't write to the storage directly, rather, +//! it reads from it via the `HashReader` trait and yields writes via an in memory `HashMap`. +//! +//! # Merkle Accumulator +//! Given an ever growing (append only) series of "leaf" hashes, we construct an evolving Merkle +//! Tree for which proofs of inclusion/exclusion of a leaf hash at a leaf index in a snapshot +//! of the tree (represented by root hash) can be given. +//! +//! # Leaf Nodes +//! Leaf nodes carry hash values to be stored and proved. They are only appended to the tree but +//! never deleted or updated. +//! +//! # Internal Nodes +//! A non-leaf node carries the hash value derived from both its left and right children. +//! +//! # Placeholder Nodes +//! To make sure each Leaf node has a Merkle Proof towards the root, placeholder nodes are added so +//! that along the route from a leaf to the root, each node has a sibling. Placeholder nodes have +//! the hash value `ACCUMULATOR_PLACEHOLDER_HASH` +//! +//! A placeholder node can appear as either a Leaf node or a non-Leaf node, but there is at most one +//! placeholder leaf at any time. +//! +//! # Frozen Nodes & Non-frozen Nodes +//! As leaves are added to the tree, placeholder nodes get replaced by non-placeholder nodes, and +//! when a node has all its descendants being non-placeholder, it becomes "Frozen" -- its hash value +//! won't change again in the event of new leaves being added. All leaves appended (not counting the +//! one possible placeholder leaf) are by definition Frozen. +//! +//! Other nodes, which have one or more placeholder descendants are Non-Frozen. As new elements are +//! appended to the accumulator the hash value of these nodes will change. +//! +//! # Leaf Count +//! Given a count of the number of leaves in a Merkle Accumulator it is possible to determine the +//! shape of the accumulator -- which nodes are filled and which nodes are placeholder nodes. +//! +//! Example: +//! Logical view of a Merkle Accumulator with 5 leaves: +//! ```text +//! Non-fzn +//! / \ +//! / \ +//! / \ +//! Fzn2 Non-fzn +//! / \ / \ +//! / \ / \ +//! Fzn1 Fzn3 Non-fzn [Placeholder] +//! / \ / \ / \ +//! L0 L1 L2 L3 L4 [Placeholder] +//! ``` +//! +//! # Position and Physical Representation +//! As a Merkle Accumulator tree expands to the right and upwards, we number newly frozen nodes +//! monotonically. (One way to do it is simply to use in-order index of nodes.) We call the +//! stated numbers identifying nodes below simply "Position". +//! +//! And with that we can map a Merkle Accumulator into a key-value storage: key is the position of a +//! node, value is hash value it carries. +//! +//! We store only Frozen nodes, and generate non-Frozen nodes on the fly when accessing the tree. +//! This way, the physical representation of the tree is append-only, i.e. once written to physical +//! storage, nodes won't be either modified or deleted. +//! +//! Here is what we persist for the logical tree in the above example: +//! +//! ```text +//! Fzn2(3) +//! / \ +//! / \ +//! Fzn1(1) Fzn3(5) +//! / \ / \ +//! L0(0) L1(2) L2(4) L3(6) L4(8) +//! ``` +//! +//! When the next leaf node is persisted, the physical representation will be: +//! +//! ```text +//! Fzn2(3) +//! / \ +//! / \ +//! Fzn1(1) Fzn3(5) Fzn4(9) +//! / \ / \ / \ +//! L0(0) L1(2) L2(4) L3(6) L4(8) L5(10) +//! ``` +//! +//! The numbering corresponds to the in-order traversal of the tree. +//! +//! To think in key-value pairs: +//! ```text +//! |<-key->|<--value-->| +//! | 0 | hash_L0 | +//! | 1 | hash_Fzn1 | +//! | 2 | hash_L1 | +//! | ... | ... | +//! ``` + +use crypto::hash::{CryptoHash, CryptoHasher, HashValue, ACCUMULATOR_PLACEHOLDER_HASH}; +use failure::prelude::*; +use std::marker::PhantomData; +use types::proof::{ + position::Position, treebits::NodeDirection, AccumulatorProof, MerkleTreeInternalNode, +}; + +/// Defines the interface between `MerkleAccumulator` and underlying storage. +pub trait HashReader { + /// Return `HashValue` carried by the node at `Position`. + fn get(&self, position: Position) -> Result; +} + +/// A `Node` in a `MerkleAccumulator` tree is a `HashValue` at a `Position` +type Node = (Position, HashValue); + +/// In this live Merkle Accumulator algorithms. +pub struct MerkleAccumulator { + reader: PhantomData, + hasher: PhantomData, +} + +impl MerkleAccumulator +where + R: HashReader, + H: CryptoHasher, +{ + /// Given an existing Merkle Accumulator (represented by `num_existing_leaves` and a `reader` + /// that is able to fetch all existing frozen nodes), and a list of leaves to be appended, + /// returns the result root hash and new nodes to be frozen. + pub fn append( + reader: &R, + num_existing_leaves: u64, + new_leaves: &[HashValue], + ) -> Result<(HashValue, Vec)> { + MerkleAccumulatorView::::new(reader, num_existing_leaves).append(new_leaves) + } + + /// Get proof of inclusion of the leaf at `leaf_index` in this Merkle Accumulator of + /// `num_leaves` leaves in total. Siblings are read via `reader` (or generated dynamically + /// if they are non-frozen). + /// + /// See [`types::proof::AccumulatorProof`] for proof format. + pub fn get_proof(reader: &R, num_leaves: u64, leaf_index: u64) -> Result { + MerkleAccumulatorView::::new(reader, num_leaves).get_proof(leaf_index) + } +} + +/// Actual implementation of Merkle Accumulator algorithms, which carries the `reader` and +/// `num_leaves` on an instance for convenience +struct MerkleAccumulatorView<'a, R, H> { + reader: &'a R, + num_leaves: u64, + hasher: PhantomData, +} + +impl<'a, R, H> MerkleAccumulatorView<'a, R, H> +where + R: HashReader, + H: CryptoHasher, +{ + fn new(reader: &'a R, num_leaves: u64) -> Self { + Self { + reader, + num_leaves, + hasher: PhantomData, + } + } + + /// implementation for pub interface `MerkleAccumulator::append` + fn append(&self, new_leaves: &[HashValue]) -> Result<(HashValue, Vec)> { + // Deal with the case where new_leaves is empty + if new_leaves.is_empty() { + if self.num_leaves == 0 { + return Ok((*ACCUMULATOR_PLACEHOLDER_HASH, Vec::new())); + } else { + let root_hash = self.get_hash(Position::get_root_position(self.num_leaves - 1))?; + return Ok((root_hash, Vec::new())); + } + } + + let num_new_leaves = new_leaves.len(); + let last_new_leaf_idx = self.num_leaves + num_new_leaves as u64 - 1; + let root_level = Position::get_root_position(last_new_leaf_idx).get_level() as usize; + let mut to_freeze = Vec::with_capacity(Self::max_to_freeze(num_new_leaves, root_level)); + + // create one new node for each new leaf hash + let mut current_level = self.gen_leaf_level(new_leaves); + Self::record_to_freeze( + &mut to_freeze, + ¤t_level, + false, /* has_non_frozen */ + ); + + // loop starting from leaf level, upwards till root_level - 1, + // making new nodes of parent level and recording frozen ones. + let mut has_non_frozen = false; + for _ in 0..root_level { + let (parent_level, placeholder_used) = self.gen_parent_level(¤t_level)?; + + // If a placeholder node is used to generate the right most node of a certain level, + // such level and all its parent levels have a non-frozen right most node. + has_non_frozen |= placeholder_used; + Self::record_to_freeze(&mut to_freeze, &parent_level, has_non_frozen); + + current_level = parent_level; + } + + assert_eq!(current_level.len(), 1, "must conclude in single root node"); + Ok((current_level.first().expect("unexpected None").1, to_freeze)) + } + + /// upper bound of num of frozen nodes: + /// new leaves and resulting frozen internal nodes forming a complete binary subtree + /// num_new_leaves * 2 - 1 < num_new_leaves * 2 + /// and the full route from root of that subtree to the accumulator root turns frozen + /// height - (log2(num_new_leaves) + 1) < height - 1 = root_level + fn max_to_freeze(num_new_leaves: usize, root_level: usize) -> usize { + num_new_leaves * 2 + root_level + } + + fn hash_internal_node(left: HashValue, right: HashValue) -> HashValue { + MerkleTreeInternalNode::::new(left, right).hash() + } + + /// Given leaf level hashes, create leaf level nodes + fn gen_leaf_level(&self, new_leaves: &[HashValue]) -> Vec { + new_leaves + .iter() + .enumerate() + .map(|(i, hash)| (Position::from_leaf_index(self.num_leaves + i as u64), *hash)) + .collect() + } + + /// Given a level of new nodes (frozen or not), return new nodes on its parent level, and + /// a boolean value indicating whether a placeholder node is used to construct the last node + fn gen_parent_level(&self, current_level: &[Node]) -> Result<((Vec, bool))> { + let mut parent_level: Vec = Vec::with_capacity(current_level.len() / 2 + 1); + let mut iter = current_level.iter().peekable(); + + // first node may be a right child, in that case pair it with its existing sibling + let (first_pos, first_hash) = iter.peek().expect("Current level is empty"); + if first_pos.get_direction_for_self() == NodeDirection::Right { + parent_level.push(( + first_pos.get_parent(), + Self::hash_internal_node(self.reader.get(first_pos.get_sibling())?, *first_hash), + )); + iter.next(); + } + + // walk through in pairs of siblings, use placeholder as last right sibling if necessary + let mut placeholder_used = false; + while let Some((left_pos, left_hash)) = iter.next() { + let right_hash = match iter.next() { + Some((_, h)) => h, + None => { + placeholder_used = true; + &ACCUMULATOR_PLACEHOLDER_HASH + } + }; + + parent_level.push(( + left_pos.get_parent(), + Self::hash_internal_node(*left_hash, *right_hash), + )); + } + + Ok((parent_level, placeholder_used)) + } + + /// append a level of new nodes into output vector, skip the last one if it's a non-frozen node + fn record_to_freeze(to_freeze: &mut Vec, level: &[Node], has_non_frozen: bool) { + to_freeze.extend( + level + .iter() + .take(level.len() - has_non_frozen as usize) + .cloned(), + ) + } + + fn get_hash(&self, position: Position) -> Result { + if position.is_placeholder(self.num_leaves - 1) { + Ok(*ACCUMULATOR_PLACEHOLDER_HASH) + } else if position.is_freezable(self.num_leaves - 1) { + self.reader.get(position) + } else { + // non-frozen non-placeholder node + Ok(Self::hash_internal_node( + self.get_hash(position.get_left_child())?, + self.get_hash(position.get_right_child())?, + )) + } + } + + /// implementation for pub interface `MerkleAccumulator::get_proof` + fn get_proof(&self, leaf_index: u64) -> Result { + ensure!( + leaf_index < self.num_leaves, + "invalid leaf_index {}, num_leaves {}", + leaf_index, + self.num_leaves + ); + + let leaf_pos = Position::from_leaf_index(leaf_index); + let root_pos = Position::get_root_position(self.num_leaves - 1); + + let siblings: Vec = leaf_pos + .iter_ancestor_sibling() + .take(root_pos.get_level() as usize) + .map(|p| self.get_hash(p)) + .collect::>>()? + .into_iter() + .rev() + .collect(); + + Ok(AccumulatorProof::new(siblings)) + } +} + +#[cfg(test)] +mod tests; diff --git a/storage/accumulator/src/tests/mod.rs b/storage/accumulator/src/tests/mod.rs new file mode 100644 index 0000000000000..e709b43538a80 --- /dev/null +++ b/storage/accumulator/src/tests/mod.rs @@ -0,0 +1,154 @@ +// Copyright (c) The Libra Core Contributors +// SPDX-License-Identifier: Apache-2.0 + +use super::*; +use crypto::hash::TestOnlyHasher; +use std::collections::HashMap; + +type TestAccumulator = MerkleAccumulator; + +struct MockHashStore { + store: HashMap, +} + +impl HashReader for MockHashStore { + fn get(&self, position: Position) -> Result { + self.store + .get(&position) + .cloned() + .ok_or_else(|| format_err!("Position {:?} absent.", position)) + } +} + +struct Subtree { + root_position: u64, + max_position: u64, + hash: HashValue, + frozen: bool, + num_frozen_nodes: usize, +} + +impl MockHashStore { + pub fn new() -> Self { + MockHashStore { + store: HashMap::new(), + } + } + + pub fn put_many(&mut self, writes: &[(Position, HashValue)]) { + self.store.extend(writes.iter().cloned()) + } + + fn verify_in_store_if_frozen( + &self, + position: u64, + hash: HashValue, + frozen: bool, + ) -> Result<()> { + if frozen { + ensure!( + hash == *self + .store + .get(&Position::from_inorder_index(position)) + .ok_or_else(|| format_err!("unexpected None at position {}", position))?, + "Hash mismatch for node at position {}", + position + ); + } + Ok(()) + } + + /// Return the placeholder hash if left and right hashes both are the placeholder hash, + /// otherwise delegate to `TestAccumulatorView::hash_internal_node`. + /// + /// This is needed because a real Accumulator subtree with only placeholder nodes are trimmed + /// (never generated nor accessed). + fn hash_internal(left: HashValue, right: HashValue) -> HashValue { + if left == *ACCUMULATOR_PLACEHOLDER_HASH && right == *ACCUMULATOR_PLACEHOLDER_HASH { + *ACCUMULATOR_PLACEHOLDER_HASH + } else { + MerkleAccumulatorView::::hash_internal_node(left, right) + } + } + + fn verify_subtree(&self, leaves: &[HashValue], min_position: u64) -> Result { + assert!(leaves.len().is_power_of_two()); + + let me = if leaves.len() == 1 { + // leaf + let root_position = min_position; + let max_position = min_position; + let hash = leaves[0]; + let frozen = hash != *ACCUMULATOR_PLACEHOLDER_HASH; + let num_frozen_nodes = frozen as usize; + + Subtree { + root_position, + max_position, + hash, + frozen, + num_frozen_nodes, + } + } else { + let subtree_width = leaves.len() / 2; + + let left = self.verify_subtree(&leaves[..subtree_width], min_position)?; + let root_position = left.max_position + 1; + let right = self.verify_subtree(&leaves[subtree_width..], root_position + 1)?; + + let max_position = right.max_position; + let hash = Self::hash_internal(left.hash, right.hash); + let frozen = left.frozen && right.frozen; + let num_frozen_nodes = left.num_frozen_nodes + right.num_frozen_nodes + frozen as usize; + + Subtree { + root_position, + max_position, + hash, + frozen, + num_frozen_nodes, + } + }; + + self.verify_in_store_if_frozen(me.root_position, me.hash, me.frozen)?; + + Ok(me) + } + + /// (Naively) Verify `self.store` has in it nodes that represent an accumulator of `leaves` and + /// only those nodes. + /// + /// 1. expand the accumulator tree to a virtual full binary tree by adding placeholder nodes + /// 2. recursively: + /// a. in-order number each node, call it "position" + /// b. calculate internal node hash out of its children + /// c. sum up frozen nodes + /// d. verify frozen nodes are in store at the above mentioned "position" + /// 4. verify number of nodes in store matches exactly number of frozen nodes. + fn verify(&self, leaves: &[HashValue]) -> Result { + if leaves.is_empty() { + ensure!(self.store.is_empty(), "non-empty store for empty tree."); + Ok(*ACCUMULATOR_PLACEHOLDER_HASH) + } else { + // pad `leaves` with dummies, to form a full tree + let mut full_tree_leaves = leaves.to_vec(); + full_tree_leaves.resize( + leaves.len().next_power_of_two(), + *ACCUMULATOR_PLACEHOLDER_HASH, + ); + + let tree = self.verify_subtree(&full_tree_leaves, 0)?; + + ensure!( + self.store.len() == tree.num_frozen_nodes, + "mismatch: items in store - {} vs expect num of frozen nodes - {}", + self.store.len(), + tree.num_frozen_nodes, + ); + Ok(tree.hash) + } + } +} + +mod proof_test; +mod write_test; diff --git a/storage/accumulator/src/tests/proof_test.rs b/storage/accumulator/src/tests/proof_test.rs new file mode 100644 index 0000000000000..5a80609812836 --- /dev/null +++ b/storage/accumulator/src/tests/proof_test.rs @@ -0,0 +1,64 @@ +// Copyright (c) The Libra Core Contributors +// SPDX-License-Identifier: Apache-2.0 + +use super::*; +use proptest::{collection::vec, prelude::*}; +use types::proof::verify_test_accumulator_element; + +#[test] +fn test_error_on_bad_parameters() { + let store = MockHashStore::new(); + assert!(TestAccumulator::get_proof(&store, 0, 0).is_err()); + assert!(TestAccumulator::get_proof(&store, 100, 101).is_err()); +} + +#[test] +fn test_one_leaf() { + let hash = HashValue::random(); + let mut store = MockHashStore::new(); + let (root_hash, writes) = TestAccumulator::append(&store, 0, &[hash]).unwrap(); + store.put_many(&writes); + + verify(&store, 1, root_hash, &[hash], 0) +} + +proptest! { + #![proptest_config(ProptestConfig::with_cases(10))] + + #[test] + fn test_proof( + batch1 in vec(any::(), 1..100), + batch2 in vec(any::(), 1..100), + ) { + let total_leaves = batch1.len() + batch2.len(); + let batch1_size = batch1.len() as u64; + let mut store = MockHashStore::new(); + + // insert all leaves in two batches + let (root_hash1, writes1) = TestAccumulator::append(&store, 0, &batch1).unwrap(); + store.put_many(&writes1); + let (root_hash2, writes2) = TestAccumulator::append(&store, batch1_size, &batch2).unwrap(); + store.put_many(&writes2); + + // verify proofs for all leaves towards current root + verify(&store, total_leaves, root_hash2, &batch1, 0); + verify(&store, total_leaves, root_hash2, &batch2, batch1_size); + + // verify proofs for all leaves of a subtree towards subtree root + verify(&store, batch1.len(), root_hash1, &batch1, 0); + } +} + +fn verify( + store: &MockHashStore, + num_leaves: usize, + root_hash: HashValue, + leaves: &[HashValue], + first_leaf_idx: u64, +) { + leaves.iter().enumerate().for_each(|(i, hash)| { + let leaf_index = first_leaf_idx + i as u64; + let proof = TestAccumulator::get_proof(store, num_leaves as u64, leaf_index).unwrap(); + verify_test_accumulator_element(root_hash, *hash, leaf_index, &proof).unwrap(); + }); +} diff --git a/storage/accumulator/src/tests/write_test.rs b/storage/accumulator/src/tests/write_test.rs new file mode 100644 index 0000000000000..c1ccc62fd9e16 --- /dev/null +++ b/storage/accumulator/src/tests/write_test.rs @@ -0,0 +1,70 @@ +// Copyright (c) The Libra Core Contributors +// SPDX-License-Identifier: Apache-2.0 + +use super::*; +use crypto::hash::ACCUMULATOR_PLACEHOLDER_HASH; +use proptest::{collection::vec, prelude::*}; + +#[test] +fn test_append_empty_on_empty() { + let store = MockHashStore::new(); + assert_eq!( + TestAccumulator::append(&store, 0, &[]).unwrap(), + (*ACCUMULATOR_PLACEHOLDER_HASH, Vec::new()) + ); +} + +#[test] +fn test_append_one() { + let mut store = MockHashStore::new(); + store.verify(&[]).unwrap(); + + let mut leaves = Vec::new(); + for _ in 0..100 { + let hash = HashValue::random(); + let (root_hash, writes) = + TestAccumulator::append(&store, leaves.len() as u64, &[hash]).unwrap(); + store.put_many(&writes); + + leaves.push(hash); + let expected_root_hash = store.verify(&leaves).unwrap(); + + assert_eq!(root_hash, expected_root_hash) + } +} + +proptest! { + #![proptest_config(ProptestConfig::with_cases(10))] + + #[test] + fn test_append_many(batches in vec(vec(any::(), 10), 10)) { + let mut store = MockHashStore::new(); + + let mut leaves: Vec = Vec::new(); + let mut num_leaves = 0; + for hashes in batches.iter() { + let (root_hash, writes) = + TestAccumulator::append(&store, num_leaves, &hashes).unwrap(); + store.put_many(&writes); + + num_leaves += hashes.len() as u64; + leaves.extend(hashes.iter()); + let expected_root_hash = store.verify(&leaves).unwrap(); + assert_eq!(root_hash, expected_root_hash) + } + } + + #[test] + fn test_append_empty(leaves in vec(any::(), 100)) { + let mut store = MockHashStore::new(); + + let (root_hash, writes) = TestAccumulator::append(&store, 0, &leaves).unwrap(); + store.put_many(&writes); + + let (root_hash2, writes2) = + TestAccumulator::append(&store, leaves.len() as u64, &[]).unwrap(); + + assert_eq!(root_hash, root_hash2); + assert!(writes2.is_empty()); + } +} diff --git a/storage/data.png b/storage/data.png new file mode 100644 index 0000000000000..7aef8c74d60b7 Binary files /dev/null and b/storage/data.png differ diff --git a/storage/libradb/Cargo.toml b/storage/libradb/Cargo.toml new file mode 100644 index 0000000000000..46c110f42b892 --- /dev/null +++ b/storage/libradb/Cargo.toml @@ -0,0 +1,30 @@ +[package] +name = "libradb" +version = "0.1.0" +authors = ["Libra Association "] +license = "Apache-2.0" +publish = false +edition = "2018" + +[dependencies] +byteorder = "1.3.1" +itertools = "0.7.3" +lazy_static = "1.2.0" +proptest = "0.9.2" +rand = "0.4.2" +tempfile = "3.0.6" + +accumulator = { path = "../accumulator" } +canonical_serialization = { path = "../../common/canonical_serialization" } +crypto = { path = "../../crypto/legacy_crypto" } +failure = { path = "../../common/failure_ext", package = "failure_ext" } +logger = { path = "../../common/logger" } +metrics = { path = "../../common/metrics" } +proto_conv = { path = "../../common/proto_conv" } +schemadb = { path = "../schemadb" } +sparse_merkle = { path = "../sparse_merkle" } +storage_proto = { path = "../storage_proto" } +types = { path = "../../types" } + +[dev-dependencies] +rusty-fork = "0.2.1" diff --git a/storage/libradb/src/errors.rs b/storage/libradb/src/errors.rs new file mode 100644 index 0000000000000..354f35af86979 --- /dev/null +++ b/storage/libradb/src/errors.rs @@ -0,0 +1,14 @@ +// Copyright (c) The Libra Core Contributors +// SPDX-License-Identifier: Apache-2.0 + +//! This module defines error types used by [`LibraDB`](crate::LibraDB). + +use failure::Fail; + +/// This enum defines errors commonly used among [`LibraDB`](crate::LibraDB) APIs. +#[derive(Debug, Fail)] +pub enum LibraDbError { + /// A requested item is not found. + #[fail(display = "{} not found.", _0)] + NotFound(String), +} diff --git a/storage/libradb/src/event_store/mod.rs b/storage/libradb/src/event_store/mod.rs new file mode 100644 index 0000000000000..7ff29b8c752e5 --- /dev/null +++ b/storage/libradb/src/event_store/mod.rs @@ -0,0 +1,253 @@ +// Copyright (c) The Libra Core Contributors +// SPDX-License-Identifier: Apache-2.0 + +//! This file defines event store APIs that are related to the event accumulator and events +//! themselves. +#![allow(unused)] + +use super::LibraDB; +use crate::{ + errors::LibraDbError, + schema::{ + event::EventSchema, event_accumulator::EventAccumulatorSchema, + event_by_access_path::EventByAccessPathSchema, + }, +}; +use accumulator::{HashReader, MerkleAccumulator}; +use crypto::{ + hash::{CryptoHash, EventAccumulatorHasher}, + HashValue, +}; +use failure::prelude::*; +use schemadb::{schema::ValueCodec, ReadOptions, SchemaBatch, DB}; +use std::sync::Arc; +use types::{ + access_path::AccessPath, + contract_event::ContractEvent, + proof::{position::Position, AccumulatorProof, EventProof}, + transaction::Version, +}; + +pub(crate) struct EventStore { + db: Arc, +} + +impl EventStore { + pub fn new(db: Arc) -> Self { + Self { db } + } + + /// Get all of the events given a transaction version. + /// We don't need a proof for this because it's only used to get all events + /// for a version which can be proved from the root hash of the event tree. + pub fn get_events_by_version(&self, version: Version) -> Result> { + let mut events = vec![]; + + let mut iter = self.db.iter::(ReadOptions::default())?; + // Grab the first event and then iterate until we get all events for this version. + iter.seek(&version)?; + while let Some(((ver, index), event)) = iter.next().transpose()? { + if ver != version { + break; + } + events.push(event); + } + + Ok(events) + } + + /// Get the event raw data given transaction version and the index of the event queried. + pub fn get_event_with_proof_by_version_and_index( + &self, + version: Version, + index: u64, + ) -> Result<(ContractEvent, AccumulatorProof)> { + // Get event content. + let event = self + .db + .get::(&(version, index))? + .ok_or_else(|| LibraDbError::NotFound(format!("Event {} of Txn {}", index, version)))?; + + // Get the number of events in total for the transaction at `version`. + let mut iter = self.db.iter::(ReadOptions::default())?; + iter.seek_for_prev(&(version + 1))?; + let num_events = match iter.next().transpose()? { + Some(((ver, index), _)) if ver == version => index + 1, + _ => unreachable!(), // since we've already got at least one event above + }; + + // Get proof. + let proof = + Accumulator::get_proof(&EventHashReader::new(self, version), num_events, index)?; + + Ok((event, proof)) + } + + fn get_txn_ver_by_seq_num(&self, access_path: &AccessPath, seq_num: u64) -> Result { + let (ver, _) = self + .db + .get::(&(access_path.clone(), seq_num))? + .ok_or_else(|| format_err!("Index entry should exist for seq_num {}", seq_num))?; + Ok(ver) + } + + /// Get the latest sequence number on `access_path` considering all transactions with versions + /// no greater than `ledger_version`. + pub fn get_latest_sequence_number( + &self, + ledger_version: Version, + access_path: &AccessPath, + ) -> Result> { + let mut iter = self + .db + .iter::(ReadOptions::default())?; + iter.seek_for_prev(&(access_path.clone(), u64::max_value())); + if let Some(res) = iter.next() { + let ((path, mut seq), (ver, _idx)) = res?; + if path == *access_path { + if ver <= ledger_version { + return Ok(Some(seq)); + } + + // Queries tend to base on very recent ledger infos, so first try to linear search + // from the most recent end, for limited tries. + // TODO: Optimize: Physical store use reverse order. + let mut n_try_recent = 10; + #[cfg(test)] + let mut n_try_recent = 1; + while seq > 0 && n_try_recent > 0 { + seq -= 1; + n_try_recent -= 1; + let ver = self.get_txn_ver_by_seq_num(access_path, seq)?; + if ver <= ledger_version { + return Ok(Some(seq)); + } + } + + // Fall back to binary search if the above short linear search didn't work out. + let (mut begin, mut end) = (0, seq); + while begin < end { + let mid = end - (end - begin) / 2; + let ver = self.get_txn_ver_by_seq_num(access_path, mid)?; + if ver <= ledger_version { + begin = mid; + } else { + end = mid - 1; + } + } + return Ok(Some(begin)); + } + } + Ok(None) + } + + /// Given access path and start sequence number, return events identified by transaction index + /// and index among all events yielded by the same transaction. Result won't contain records + /// with a txn_version > `ledger_version` and is in ascending order. + pub fn lookup_events_by_access_path( + &self, + access_path: &AccessPath, + start_seq_num: u64, + limit: u64, + ledger_version: u64, + ) -> Result< + Vec<( + u64, // sequence number + Version, // transaction version it belongs to + u64, // index among events for the same transaction + )>, + > { + let mut iter = self + .db + .iter::(ReadOptions::default())?; + iter.seek(&(access_path.clone(), start_seq_num))?; + + let mut result = Vec::new(); + let mut cur_seq = start_seq_num; + for res in iter.take(limit as usize) { + let ((path, seq), (ver, idx)) = res?; + if path != *access_path || ver > ledger_version { + break; + } + ensure!( + seq == cur_seq, + "DB corrupt: Sequence number not continuous, expected: {}, actual: {}.", + cur_seq, + seq + ); + result.push((seq, ver, idx)); + cur_seq += 1; + } + + Ok(result) + } + + /// Save contract events yielded by the transaction at `version` and return root hash of the + /// event accumulator formed by these events. + pub fn put_events( + &self, + version: u64, + events: &[ContractEvent], + batch: &mut SchemaBatch, + ) -> Result { + // EventSchema and EventByAccessPathSchema updates + events + .iter() + .enumerate() + .map(|(idx, event)| { + batch.put::(&(version, idx as u64), event)?; + batch.put::( + &(event.access_path().clone(), event.sequence_number()), + &(version, idx as u64), + )?; + Ok(()) + }) + .collect::>()?; + + // EventAccumulatorSchema updates + let event_hashes: Vec = events.iter().map(ContractEvent::hash).collect(); + let (root_hash, writes) = EmptyAccumulator::append(&EmptyReader, 0, &event_hashes)?; + writes + .into_iter() + .map(|(pos, hash)| batch.put::(&(version, pos), &hash)) + .collect::>()?; + + Ok(root_hash) + } +} + +type Accumulator<'a> = MerkleAccumulator, EventAccumulatorHasher>; + +struct EventHashReader<'a> { + store: &'a EventStore, + version: Version, +} + +impl<'a> EventHashReader<'a> { + fn new(store: &'a EventStore, version: Version) -> Self { + Self { store, version } + } +} + +impl<'a> HashReader for EventHashReader<'a> { + fn get(&self, position: Position) -> Result { + self.store + .db + .get::(&(self.version, position))? + .ok_or_else(|| format_err!("Hash at position {:?} not found.", position)) + } +} + +type EmptyAccumulator = MerkleAccumulator; + +struct EmptyReader; + +// Asserts `get()` is never called. +impl HashReader for EmptyReader { + fn get(&self, _position: Position) -> Result { + unreachable!() + } +} + +#[cfg(test)] +mod test; diff --git a/storage/libradb/src/event_store/test.rs b/storage/libradb/src/event_store/test.rs new file mode 100644 index 0000000000000..ecaec1dd48d80 --- /dev/null +++ b/storage/libradb/src/event_store/test.rs @@ -0,0 +1,236 @@ +// Copyright (c) The Libra Core Contributors +// SPDX-License-Identifier: Apache-2.0 + +use super::*; +use crate::LibraDB; +use crypto::{hash::ACCUMULATOR_PLACEHOLDER_HASH, utils::keypair_strategy}; +use itertools::Itertools; +use proptest::{ + collection::{hash_set, vec}, + prelude::*, + strategy::Union, +}; +use rand::{Rng, StdRng}; +use std::collections::HashMap; +use tempfile::tempdir; +use types::{ + account_address::AccountAddress, contract_event::ContractEvent, + proof::verify_event_accumulator_element, proptest_types::renumber_events, +}; + +fn save(store: &EventStore, version: Version, events: &[ContractEvent]) -> HashValue { + let mut batch = SchemaBatch::new(); + let root_hash = store.put_events(version, events, &mut batch).unwrap(); + store.db.write_schemas(batch).unwrap(); + + root_hash +} + +#[test] +fn test_put_empty() { + let tmp_dir = tempdir().unwrap(); + let db = LibraDB::new(&tmp_dir); + let store = &db.event_store; + let mut batch = SchemaBatch::new(); + assert_eq!( + store.put_events(0, &[], &mut batch).unwrap(), + *ACCUMULATOR_PLACEHOLDER_HASH + ); +} + +#[test] +fn test_error_on_get_from_empty() { + let tmp_dir = tempdir().unwrap(); + let db = LibraDB::new(&tmp_dir); + let store = &db.event_store; + + assert!(store + .get_event_with_proof_by_version_and_index(100, 0) + .is_err()); +} + +proptest! { + #![proptest_config(ProptestConfig::with_cases(10))] + + #[test] + fn test_put_get_verify(events in vec(any::().no_shrink(), 1..100)) { + let tmp_dir = tempdir().unwrap(); + let db = LibraDB::new(&tmp_dir); + let store = &db.event_store; + + let root_hash = save(store, 100, &events); + + // get and verify each and every event with proof + for (idx, expected_event) in events.iter().enumerate() { + let (event, proof) = store + .get_event_with_proof_by_version_and_index(100, idx as u64) + .unwrap(); + prop_assert_eq!(&event, expected_event); + verify_event_accumulator_element(root_hash, event.hash(), idx as u64, &proof).unwrap(); + } + // error on index >= num_events + prop_assert!(store + .get_event_with_proof_by_version_and_index(100, events.len() as u64) + .is_err()); + } + +} + +proptest! { + #![proptest_config(ProptestConfig::with_cases(1))] + + #[test] + fn test_get_all_events_by_version( + events1 in vec(any::().no_shrink(), 1..100), + events2 in vec(any::().no_shrink(), 1..100), + events3 in vec(any::().no_shrink(), 1..100), + ) { + + let tmp_dir = tempdir().unwrap(); + let db = LibraDB::new(&tmp_dir); + let store = &db.event_store; + // Save 3 chunks at different versions + save(store, 99 /*version*/, &events1); + save(store, 100 /*version*/, &events2); + save(store, 101 /*version*/, &events3); + + // Now get all events at each version and verify that it matches what is expected. + let events_99 = store.get_events_by_version(99 /*version*/).unwrap(); + prop_assert_eq!(events_99, events1); + + let events_100 = store.get_events_by_version(100 /*version*/).unwrap(); + prop_assert_eq!(events_100, events2); + + let events_101 = store.get_events_by_version(101 /*version*/).unwrap(); + prop_assert_eq!(events_101, events3); + + // Now query a version that doesn't exist and verify that no results come back + let events_102 = store.get_events_by_version(102 /*version*/).unwrap(); + prop_assert_eq!(events_102.len(), 0); + } +} + +fn traverse_events_by_access_path( + store: &EventStore, + access_path: &AccessPath, + ledger_version: Version, +) -> Vec { + const LIMIT: u64 = 3; + + let mut seq_num = 0; + + let mut event_keys = Vec::new(); + let mut last_batch_len = LIMIT; + loop { + let mut batch = store + .lookup_events_by_access_path(access_path, seq_num, LIMIT, ledger_version) + .unwrap(); + if last_batch_len < LIMIT { + assert!(batch.is_empty()); + } + if batch.is_empty() { + break; + } + + last_batch_len = batch.len() as u64; + let first_seq = batch.first().unwrap().0; + let last_seq = batch.last().unwrap().0; + + assert!(last_batch_len <= LIMIT); + assert_eq!(seq_num, first_seq); + assert_eq!(seq_num + last_batch_len - 1, last_seq); + + event_keys.extend(batch.iter()); + seq_num = last_seq + 1; + } + + event_keys + .into_iter() + .map(|(_seq, ver, idx)| { + store + .get_event_with_proof_by_version_and_index(ver, idx) + .unwrap() + .0 + }) + .collect() +} + +fn arb_event_batches() -> impl Strategy, Vec>)> { + ( + vec(any::(), 3), + hash_set(any::>(), 3), + (0..100usize), + ) + .prop_flat_map(|(addresses, event_paths, num_batches)| { + let all_possible_access_paths = addresses + .iter() + .cartesian_product(event_paths.iter()) + .map(|(address, event_path)| AccessPath::new(*address, event_path.clone())) + .collect::>(); + let access_path_strategy = + Union::new(all_possible_access_paths.clone().into_iter().map(Just)); + + ( + Just(all_possible_access_paths), + vec( + vec(ContractEvent::strategy_impl(access_path_strategy), 0..10), + num_batches, + ), + ) + }) + .prop_map(|(all_possible_access_paths, event_batches)| { + let mut seq_num_by_access_path = HashMap::new(); + let numbered_event_batches = event_batches + .into_iter() + .map(|events| renumber_events(&events, &mut seq_num_by_access_path)) + .collect::>(); + + (all_possible_access_paths, numbered_event_batches) + }) +} + +proptest! { + #![proptest_config(ProptestConfig::with_cases(10))] + + #[test] + fn test_get_events_by_access_path((addresses, event_batches) in arb_event_batches().no_shrink()) { + test_get_events_by_access_path_impl(addresses, event_batches); + } +} + +fn test_get_events_by_access_path_impl( + access_paths: Vec, + event_batches: Vec>, +) { + // Put into db. + let tmp_dir = tempdir().unwrap(); + let db = LibraDB::new(&tmp_dir); + let store = &db.event_store; + + let mut batch = SchemaBatch::new(); + event_batches.iter().enumerate().for_each(|(ver, events)| { + store.put_events(ver as u64, events, &mut batch).unwrap(); + }); + db.commit(batch); + let ledger_version_plus_one = event_batches.len() as u64; + + // Calculate expected event sequence per access_path. + let mut events_by_access_path = HashMap::new(); + event_batches.into_iter().for_each(|batch| { + batch.into_iter().for_each(|e| { + let mut events = events_by_access_path + .entry(e.access_path().clone()) + .or_insert_with(Vec::new); + assert_eq!(events.len() as u64, e.sequence_number()); + events.push(e.clone()); + }) + }); + + // Fetch and check. + events_by_access_path + .into_iter() + .for_each(|(path, events)| { + let traversed = traverse_events_by_access_path(&store, &path, ledger_version_plus_one); + assert_eq!(events, traversed); + }); +} diff --git a/storage/libradb/src/ledger_store/ledger_info_test.rs b/storage/libradb/src/ledger_store/ledger_info_test.rs new file mode 100644 index 0000000000000..d32ca065fe3f1 --- /dev/null +++ b/storage/libradb/src/ledger_store/ledger_info_test.rs @@ -0,0 +1,67 @@ +// Copyright (c) The Libra Core Contributors +// SPDX-License-Identifier: Apache-2.0 + +use super::*; +use crate::LibraDB; +use proptest::{collection::vec, prelude::*}; +use tempfile::tempdir; +use types::ledger_info::LedgerInfo; + +prop_compose! { + fn arb_partial_ledger_info()(accu_hash in any::(), + consensus_hash in any::(), + timestamp in any::()) -> (HashValue, HashValue, u64) { + (accu_hash, consensus_hash, timestamp) + } +} + +prop_compose! { + fn arb_ledger_infos_with_sigs()( + partial_ledger_infos_with_sigs in vec( + any_with::((1..3).into()).no_shrink(), 1..100 + ), + start_version in 0..10000u64, + ) -> Vec { + partial_ledger_infos_with_sigs + .iter() + .enumerate() + .map(|(i, p)| { + let ledger_info = p.ledger_info(); + LedgerInfoWithSignatures::new( + LedgerInfo::new( + start_version + i as u64, + ledger_info.transaction_accumulator_hash(), + ledger_info.consensus_data_hash(), + HashValue::zero(), + ledger_info.epoch_num(), + ledger_info.timestamp_usecs(), + ), + p.signatures().clone(), + ) + }) + .collect() + } +} + +proptest! { + #![proptest_config(ProptestConfig::with_cases(10))] + + #[test] + fn test_ledger_info_put_get_verify( + ledger_infos_with_sigs in arb_ledger_infos_with_sigs() + ) { + let tmp_dir = tempdir().unwrap(); + let db = LibraDB::new(&tmp_dir); + let store = &db.ledger_store; + let start_version = ledger_infos_with_sigs.first().unwrap().ledger_info().version(); + + let mut batch = SchemaBatch::new(); + ledger_infos_with_sigs + .iter() + .map(|info| store.put_ledger_info(info, &mut batch)) + .collect::>>() + .unwrap(); + db.commit(batch).unwrap(); + prop_assert_eq!(db.ledger_store.get_ledger_infos(start_version).unwrap(), ledger_infos_with_sigs); + } +} diff --git a/storage/libradb/src/ledger_store/mod.rs b/storage/libradb/src/ledger_store/mod.rs new file mode 100644 index 0000000000000..90f24cc433319 --- /dev/null +++ b/storage/libradb/src/ledger_store/mod.rs @@ -0,0 +1,177 @@ +// Copyright (c) The Libra Core Contributors +// SPDX-License-Identifier: Apache-2.0 + +//! This file defines ledger store APIs that are related to the main ledger accumulator, from the +//! root(LedgerInfo) to leaf(TransactionInfo). + +use crate::{ + errors::LibraDbError, + schema::{ + ledger_info::LedgerInfoSchema, transaction_accumulator::TransactionAccumulatorSchema, + transaction_info::TransactionInfoSchema, + }, +}; +use accumulator::{HashReader, MerkleAccumulator}; +use crypto::{ + hash::{CryptoHash, TransactionAccumulatorHasher}, + HashValue, +}; +use failure::prelude::*; +use itertools::Itertools; +use schemadb::{ReadOptions, SchemaBatch, DB}; +use std::sync::Arc; +use types::{ + ledger_info::LedgerInfoWithSignatures, + proof::{ + position::{FrozenSubTreeIterator, Position}, + AccumulatorProof, + }, + transaction::{TransactionInfo, Version}, +}; + +pub(crate) struct LedgerStore { + db: Arc, +} + +impl LedgerStore { + pub fn new(db: Arc) -> Self { + Self { db } + } + + /// Return the ledger infos with their least 2f+1 signatures starting from `start_version` to + /// the most recent one. + /// Note: ledger infos and signatures are only available at the last version of each earlier + /// epoch and at the latest version of current epoch. + #[cfg(test)] + fn get_ledger_infos(&self, start_version: Version) -> Result> { + let mut iter = self.db.iter::(ReadOptions::default())?; + iter.seek(&start_version)?; + Ok(iter.map(|kv| Ok(kv?.1)).collect::>>()?) + } + + pub fn get_latest_ledger_info_option(&self) -> Result> { + let mut iter = self.db.iter::(ReadOptions::default())?; + iter.seek_to_last(); + Ok(iter.next().transpose()?.map(|kv| kv.1)) + } + + pub fn get_latest_ledger_info(&self) -> Result { + self.get_latest_ledger_info_option()? + .ok_or_else(|| LibraDbError::NotFound(String::from("Genesis LedgerInfo")).into()) + } + + /// Get transaction info given `version` + pub fn get_transaction_info(&self, version: Version) -> Result { + self.db + .get::(&version)? + .ok_or_else(|| format_err!("No TransactionInfo at version {}", version)) + } + + pub fn get_latest_transaction_info_option(&self) -> Result> { + let mut iter = self + .db + .iter::(ReadOptions::default())?; + iter.seek_to_last(); + iter.next().transpose() + } + + /// Get latest transaction info together with its version. Note that during node syncing, this + /// version can be greater than what's in the latest LedgerInfo. + pub fn get_latest_transaction_info(&self) -> Result<(Version, TransactionInfo)> { + self.get_latest_transaction_info_option()? + .ok_or_else(|| LibraDbError::NotFound(String::from("Genesis TransactionInfo.")).into()) + } + + /// Get transaction info at `version` with proof towards root of ledger at `ledger_version`. + pub fn get_transaction_info_with_proof( + &self, + version: Version, + ledger_version: Version, + ) -> Result<(TransactionInfo, AccumulatorProof)> { + Ok(( + self.get_transaction_info(version)?, + self.get_transaction_proof(version, ledger_version)?, + )) + } + + /// Get proof for transaction at `version` towards root of ledger at `ledger_version`. + pub fn get_transaction_proof( + &self, + version: Version, + ledger_version: Version, + ) -> Result { + Accumulator::get_proof(self, ledger_version + 1 /* num_leaves */, version) + } + + /// Write `txn_infos` to `batch`. Assigned `first_version` to the the version number of the + /// first transaction, and so on. + pub fn put_transaction_infos( + &self, + first_version: u64, + txn_infos: &[TransactionInfo], + batch: &mut SchemaBatch, + ) -> Result { + // write txn_info + (first_version..first_version + txn_infos.len() as u64) + .zip_eq(txn_infos.iter()) + .map(|(version, txn_info)| batch.put::(&version, txn_info)) + .collect::>()?; + + // write hash of txn_info into the accumulator + let txn_hashes: Vec = txn_infos.iter().map(TransactionInfo::hash).collect(); + let (root_hash, writes) = Accumulator::append( + self, + first_version, /* num_existing_leaves */ + &txn_hashes, + )?; + writes + .iter() + .map(|(pos, hash)| batch.put::(pos, hash)) + .collect::>()?; + Ok(root_hash) + } + + /// Write `ledger_info` to `batch`. + pub fn put_ledger_info( + &self, + ledger_info_with_sigs: &LedgerInfoWithSignatures, + batch: &mut SchemaBatch, + ) -> Result<()> { + batch.put::( + &ledger_info_with_sigs.ledger_info().version(), + ledger_info_with_sigs, + ) + } + + /// From left to right, get frozen subtree root hashes of the transaction accumulator. + pub fn get_ledger_frozen_subtree_hashes(&self, version: Version) -> Result> { + FrozenSubTreeIterator::new(version + 1) + .map(|pos| { + self.db + .get::(&pos)? + .ok_or_else(|| { + LibraDbError::NotFound(format!( + "Txn Accumulator node at pos {}", + pos.to_inorder_index() + )) + .into() + }) + }) + .collect::>>() + } +} + +type Accumulator = MerkleAccumulator; + +impl HashReader for LedgerStore { + fn get(&self, position: Position) -> Result { + self.db + .get::(&position)? + .ok_or_else(|| format_err!("Does not exist.")) + } +} + +#[cfg(test)] +mod ledger_info_test; +#[cfg(test)] +mod transaction_info_test; diff --git a/storage/libradb/src/ledger_store/transaction_info_test.rs b/storage/libradb/src/ledger_store/transaction_info_test.rs new file mode 100644 index 0000000000000..b803560d4a6c3 --- /dev/null +++ b/storage/libradb/src/ledger_store/transaction_info_test.rs @@ -0,0 +1,69 @@ +// Copyright (c) The Libra Core Contributors +// SPDX-License-Identifier: Apache-2.0 + +use super::*; +use crate::LibraDB; +use proptest::{collection::vec, prelude::*}; +use tempfile::tempdir; +use types::proof::verify_transaction_accumulator_element; + +fn verify( + db: &LibraDB, + txn_infos: &[TransactionInfo], + first_version: Version, + ledger_version: Version, + root_hash: HashValue, +) { + txn_infos + .iter() + .enumerate() + .for_each(|(idx, expected_txn_info)| { + let version = first_version + idx as u64; + + let (txn_info, proof) = db + .ledger_store + .get_transaction_info_with_proof(version, ledger_version) + .unwrap(); + + assert_eq!(&txn_info, expected_txn_info); + verify_transaction_accumulator_element(root_hash, txn_info.hash(), version, &proof) + .unwrap(); + }) +} + +fn save(db: &LibraDB, first_version: Version, txn_infos: &[TransactionInfo]) -> HashValue { + let mut batch = SchemaBatch::new(); + let root_hash = db + .ledger_store + .put_transaction_infos(first_version, &txn_infos, &mut batch) + .unwrap(); + db.commit(batch).unwrap(); + root_hash +} + +proptest! { + #![proptest_config(ProptestConfig::with_cases(10))] + + #[test] + fn test_transaction_info_put_get_verify( + batch1 in vec(any::(), 1..100), + batch2 in vec(any::(), 1..100), + ) { + + let tmp_dir = tempdir().unwrap(); + let db = LibraDB::new(&tmp_dir); + + // insert two batches of transaction infos + let root_hash1 = save(&db, 0, &batch1); + let ledger_version1 = batch1.len() as u64 - 1; + let root_hash2 = save(&db, batch1.len() as u64, &batch2); + let ledger_version2 = batch1.len() as u64 + batch2.len() as u64 - 1; + + // retrieve all leaves and verify against latest root hash + verify(&db, &batch1, 0, ledger_version2, root_hash2); + verify(&db, &batch2, batch1.len() as u64, ledger_version2, root_hash2); + + // retrieve batch1 and verify against root_hash after batch1 was interted + verify(&db, &batch1, 0, ledger_version1, root_hash1); + } +} diff --git a/storage/libradb/src/lib.rs b/storage/libradb/src/lib.rs new file mode 100644 index 0000000000000..9e87f251c2e05 --- /dev/null +++ b/storage/libradb/src/lib.rs @@ -0,0 +1,690 @@ +// Copyright (c) The Libra Core Contributors +// SPDX-License-Identifier: Apache-2.0 + +//! This crate provides [`LibraDB`] which represents physical storage of the core Libra data +//! structures. +//! +//! It relays read/write operations on the physical storage via [`schemadb`] to the underlying +//! Key-Value storage system, and implements libra data structures on top of it. + +// Used in other crates for testing. +pub mod mock_genesis; +// Used in this and other crates for testing. +pub mod test_helper; + +pub mod errors; + +mod event_store; +mod ledger_store; +pub mod schema; +mod state_store; +mod transaction_store; + +#[cfg(test)] +mod libradb_test; + +use crate::{ + event_store::EventStore, ledger_store::LedgerStore, schema::*, state_store::StateStore, + transaction_store::TransactionStore, +}; +use crypto::{ + hash::{CryptoHash, SPARSE_MERKLE_PLACEHOLDER_HASH}, + HashValue, +}; +use failure::prelude::*; +use itertools::{izip, zip_eq}; +use lazy_static::lazy_static; +use logger::prelude::*; +use metrics::OpMetrics; +use schemadb::{ColumnFamilyOptions, ColumnFamilyOptionsMap, SchemaBatch, DB, DEFAULT_CF_NAME}; +use std::{iter::Iterator, path::Path, sync::Arc, time::Instant}; +use storage_proto::ExecutorStartupInfo; +use types::{ + access_path::AccessPath, + account_address::AccountAddress, + account_config::get_account_resource_or_default, + account_state_blob::{AccountStateBlob, AccountStateWithProof}, + contract_event::EventWithProof, + get_with_proof::{RequestItem, ResponseItem}, + ledger_info::LedgerInfoWithSignatures, + proof::{AccountStateProof, EventProof, SignedTransactionProof, SparseMerkleProof}, + transaction::{ + SignedTransactionWithProof, TransactionInfo, TransactionListWithProof, TransactionToCommit, + Version, + }, + validator_change::ValidatorChangeEventWithProof, +}; + +lazy_static! { + static ref OP_COUNTER: OpMetrics = OpMetrics::new_and_registered("storage"); +} + +/// This holds a handle to the underlying DB responsible for physical storage and provides APIs for +/// access to the core Libra data structures. +pub struct LibraDB { + db: Arc, + ledger_store: LedgerStore, + transaction_store: TransactionStore, + state_store: StateStore, + event_store: EventStore, +} + +impl LibraDB { + /// This creates an empty LibraDB instance on disk or opens one if it already exists. + pub fn new + Clone>(db_root_path: P) -> Self { + let cf_opts_map: ColumnFamilyOptionsMap = [ + ( + /* LedgerInfo CF = */ DEFAULT_CF_NAME, + ColumnFamilyOptions::default(), + ), + (ACCOUNT_STATE_CF_NAME, ColumnFamilyOptions::default()), + (EVENT_ACCUMULATOR_CF_NAME, ColumnFamilyOptions::default()), + (EVENT_BY_ACCESS_PATH_CF_NAME, ColumnFamilyOptions::default()), + (EVENT_CF_NAME, ColumnFamilyOptions::default()), + (SIGNATURE_CF_NAME, ColumnFamilyOptions::default()), + (SIGNED_TRANSACTION_CF_NAME, ColumnFamilyOptions::default()), + (STATE_MERKLE_NODE_CF_NAME, ColumnFamilyOptions::default()), + ( + TRANSACTION_ACCUMULATOR_CF_NAME, + ColumnFamilyOptions::default(), + ), + (TRANSACTION_INFO_CF_NAME, ColumnFamilyOptions::default()), + (VALIDATOR_CF_NAME, ColumnFamilyOptions::default()), + ] + .iter() + .cloned() + .collect(); + + let path = db_root_path.as_ref().join("libradb"); + let instant = Instant::now(); + let db = Arc::new( + DB::open(path.clone(), cf_opts_map) + .unwrap_or_else(|e| panic!("LibraDB open failed: {:?}", e)), + ); + + info!( + "Opened LibraDB at {:?} in {} ms", + path, + instant.elapsed().as_millis() + ); + + LibraDB { + db: Arc::clone(&db), + event_store: EventStore::new(Arc::clone(&db)), + ledger_store: LedgerStore::new(Arc::clone(&db)), + state_store: StateStore::new(Arc::clone(&db)), + transaction_store: TransactionStore::new(Arc::clone(&db)), + } + } + + // ================================== Public API ================================== + /// Returns the account state corresponding to the given version and account address with proof + /// based on `ledger_version` + fn get_account_state_with_proof( + &self, + address: AccountAddress, + version: Version, + ledger_version: Version, + ) -> Result { + ensure!( + version <= ledger_version, + "The queried version {} should be equal to or older than ledger version {}.", + version, + ledger_version + ); + let latest_version = self.get_latest_version()?; + ensure!( + ledger_version <= latest_version, + "The ledger version {} is greather than the latest version currently in ledger: {}", + ledger_version, + latest_version + ); + + let (txn_info, txn_info_accumulator_proof) = self + .ledger_store + .get_transaction_info_with_proof(version, ledger_version)?; + let (account_state_blob, sparse_merkle_proof) = self + .state_store + .get_account_state_with_proof_by_state_root(address, txn_info.state_root_hash())?; + Ok(AccountStateWithProof::new( + version, + account_state_blob, + AccountStateProof::new(txn_info_accumulator_proof, txn_info, sparse_merkle_proof), + )) + } + + /// Returns events specified by `access_path` with sequence number in range designated by + /// `start_seq_num`, `ascending` and `limit`. If ascending is true this query will return up to + /// `limit` events that were emitted after `start_event_seq_num`. Otherwise it will return up to + /// `limit` events in the reverse order. Both cases are inclusive. + fn get_events_by_event_access_path( + &self, + access_path: &AccessPath, + start_seq_num: u64, + ascending: bool, + limit: u64, + ledger_version: Version, + ) -> Result<(Vec, Option)> { + let get_latest = !ascending && start_seq_num == u64::max_value(); + let cursor = if get_latest { + // Caller wants the latest, figure out the latest seq_num. + // In the case of no events on that path, use 0 and expect empty result below. + self.event_store + .get_latest_sequence_number(ledger_version, access_path)? + .unwrap_or(0) + } else { + start_seq_num + }; + + // Convert requested range and order to a range in ascending order. + let (first_seq, real_limit) = get_first_seq_num_and_limit(ascending, cursor, limit)?; + + // Query the index. + let mut event_keys = self.event_store.lookup_events_by_access_path( + access_path, + first_seq, + real_limit, + ledger_version, + )?; + + // When descending, it's possible that user is asking for something beyond the latest + // sequence number, in which case we will consider it a bad request and return an empty + // list. + // For example, if the latest sequence number is 100, and the caller is asking for 110 to + // 90, we will get 90 to 100 from the index lookup above. Seeing that the last item + // is 100 instead of 110 tells us 110 is out of bound. + if !ascending { + if let Some((seq_num, _, _)) = event_keys.last() { + if *seq_num < cursor { + event_keys = Vec::new(); + } + } + } + + let mut events_with_proof = event_keys + .into_iter() + .map(|(seq, ver, idx)| { + let (event, event_proof) = self + .event_store + .get_event_with_proof_by_version_and_index(ver, idx)?; + ensure!( + seq == event.sequence_number(), + "Index broken, expected seq:{}, actual:{}", + seq, + event.sequence_number() + ); + let (txn_info, txn_info_proof) = self + .ledger_store + .get_transaction_info_with_proof(ver, ledger_version)?; + let proof = EventProof::new(txn_info_proof, txn_info, event_proof); + Ok(EventWithProof::new(ver, idx, event, proof)) + }) + .collect::>>()?; + if !ascending { + events_with_proof.reverse(); + } + + // There are two cases where we need to return proof_of_latest_event to let the caller know + // the latest sequence number: + // 1. The user asks for the latest event by using u64::max() as the cursor, apparently + // he doesn't know the latest sequence number. + // 2. We are going to return less than `real_limit` items. (Two cases can lead to that: + // a. the cursor is beyond the latest sequence number; b. in ascending order we don't have + // enough items to return because the latest sequence number is hit). In this case we + // need to return the proof to convince the caller we didn't hide any item from him. Note + // that we use `real_limit` instead of `limit` here because it takes into account the case + // of hitting 0 in descending order, which is valid and doesn't require the proof. + let proof_of_latest_event = if get_latest || events_with_proof.len() < real_limit as usize { + Some(self.get_account_state_with_proof( + access_path.address, + ledger_version, + ledger_version, + )?) + } else { + None + }; + + Ok((events_with_proof, proof_of_latest_event)) + } + + /// Returns a signed transaction that is the `seq_num`-th one associated with the given account. + /// If the signed transaction with given `seq_num` doesn't exist, returns `None`. + // TODO(gzh): Use binary search for now. We may create seq_num index in the future. + fn get_txn_by_account_and_seq( + &self, + address: AccountAddress, + seq_num: u64, + ledger_version: Version, + fetch_events: bool, + ) -> Result> { + // If txn with seq_num n is at some version, the corresponding account state at the + // same version will be the first account state that has seq_num n + 1. + let seq_num = seq_num + 1; + let (mut start_version, mut end_version) = (0, ledger_version); + while start_version < end_version { + let mid_version = start_version + (end_version - start_version) / 2; + let account_seq_num = self.get_account_seq_num_by_version(address, mid_version)?; + if account_seq_num >= seq_num { + end_version = mid_version; + } else { + start_version = mid_version + 1; + } + } + assert_eq!(start_version, end_version); + + let seq_num_found = self.get_account_seq_num_by_version(address, start_version)?; + if seq_num_found < seq_num { + return Ok(None); + } else if seq_num_found > seq_num { + // log error + bail!("internal error: seq_num is not continuous.") + } + // start_version cannot be 0 (genesis version). + assert_eq!( + self.get_account_seq_num_by_version(address, start_version - 1)?, + seq_num_found - 1 + ); + self.get_transaction_with_proof(start_version, ledger_version, fetch_events) + .map(Some) + } + + /// Gets the latest version number available in the ledger. + fn get_latest_version(&self) -> Result { + Ok(self + .ledger_store + .get_latest_ledger_info()? + .ledger_info() + .version()) + } + + /// Persist transactions. Called by the executor module when either syncing nodes or committing + /// blocks during normal operation. + /// + /// When `ledger_info_with_sigs` is provided, verify that the transaction accumulator root hash + /// it carries is generated after the `txns_to_commit` are applied. + pub fn save_transactions( + &self, + txns_to_commit: &[TransactionToCommit], + first_version: Version, + ledger_info_with_sigs: &Option, + ) -> Result<()> { + let num_txns = txns_to_commit.len() as u64; + // ledger_info_with_sigs could be None if we are doing state synchronization. In this case + // txns_to_commit should not be empty. Otherwise it is okay to commit empty blocks. + ensure!( + ledger_info_with_sigs.is_some() || num_txns > 0, + "txns_to_commit is empty while ledger_info_with_sigs is None.", + ); + + let cur_state_root_hash = if first_version == 0 { + *SPARSE_MERKLE_PLACEHOLDER_HASH + } else { + self.ledger_store + .get_transaction_info(first_version - 1)? + .state_root_hash() + }; + + if let Some(x) = ledger_info_with_sigs { + let last_version = x.ledger_info().version(); + ensure!( + first_version + num_txns - 1 == last_version, + "Transaction batch not applicable: first_version {}, num_txns {}, last_version {}", + first_version, + num_txns, + last_version + ); + } + + // Gather db mutations to `batch`. + let mut batch = SchemaBatch::new(); + + let new_root_hash = self.save_transactions_impl( + txns_to_commit, + first_version, + cur_state_root_hash, + &mut batch, + )?; + + // If expected ledger info is provided, verify result root hash and save the ledger info. + if let Some(x) = ledger_info_with_sigs { + let expected_root_hash = x.ledger_info().transaction_accumulator_hash(); + ensure!( + new_root_hash == expected_root_hash, + "Root hash calculated doesn't match expected. {:?} vs {:?}", + new_root_hash, + expected_root_hash, + ); + + self.ledger_store.put_ledger_info(x, &mut batch)?; + } + + // Persist. + self.commit(batch)?; + // Only increment counter if commit(batch) succeeds. + OP_COUNTER.inc_by("committed_txns", txns_to_commit.len()); + Ok(()) + } + + fn save_transactions_impl( + &self, + txns_to_commit: &[TransactionToCommit], + first_version: u64, + cur_state_root_hash: HashValue, + mut batch: &mut SchemaBatch, + ) -> Result { + let last_version = first_version + txns_to_commit.len() as u64 - 1; + + // Account state updates. Gather account state root hashes + let account_state_sets = txns_to_commit + .iter() + .map(|txn_to_commit| txn_to_commit.account_states().clone()) + .collect::>(); + let state_root_hashes = self.state_store.put_account_state_sets( + account_state_sets, + cur_state_root_hash, + &mut batch, + )?; + + // Event updates. Gather event accumulator root hashes. + let event_root_hashes = zip_eq(first_version..=last_version, txns_to_commit) + .map(|(ver, txn_to_commit)| { + self.event_store + .put_events(ver, txn_to_commit.events(), &mut batch) + }) + .collect::>>()?; + + // Transaction updates. Gather transaction hashes. + zip_eq(first_version..=last_version, txns_to_commit) + .map(|(ver, txn_to_commit)| { + self.transaction_store + .put_transaction(ver, txn_to_commit.signed_txn(), &mut batch) + }) + .collect::>()?; + let txn_hashes = txns_to_commit + .iter() + .map(|txn_to_commit| txn_to_commit.signed_txn().hash()) + .collect::>(); + let gas_amounts = txns_to_commit + .iter() + .map(TransactionToCommit::gas_used) + .collect::>(); + + // Transaction accumulator updates. Get result root hash. + let txn_infos = izip!( + txn_hashes, + state_root_hashes, + event_root_hashes, + gas_amounts + ) + .map(|(t, s, e, g)| TransactionInfo::new(t, s, e, g)) + .collect::>(); + assert_eq!(txn_infos.len(), txns_to_commit.len()); + + let new_root_hash = + self.ledger_store + .put_transaction_infos(first_version, &txn_infos, &mut batch)?; + + Ok(new_root_hash) + } + + /// This backs the `UpdateToLatestLedger` public read API which returns the latest + /// [`LedgerInfoWithSignatures`] together with items requested and proofs relative to the same + /// ledger info. + pub fn update_to_latest_ledger( + &self, + _client_known_version: u64, + request_items: Vec, + ) -> Result<( + Vec, + LedgerInfoWithSignatures, + Vec, + )> { + // Get the latest ledger info and signatures + let ledger_info_with_sigs = self.ledger_store.get_latest_ledger_info()?; + let ledger_version = ledger_info_with_sigs.ledger_info().version(); + + // Fulfill all request items + let response_items = request_items + .into_iter() + .map(|request_item| match request_item { + RequestItem::GetAccountState { address } => Ok(ResponseItem::GetAccountState { + account_state_with_proof: self.get_account_state_with_proof( + address, + ledger_version, + ledger_version, + )?, + }), + RequestItem::GetAccountTransactionBySequenceNumber { + account, + sequence_number, + fetch_events, + } => { + let signed_transaction_with_proof = self.get_txn_by_account_and_seq( + account, + sequence_number, + ledger_version, + fetch_events, + )?; + + let proof_of_current_sequence_number = match signed_transaction_with_proof { + Some(_) => None, + None => Some(self.get_account_state_with_proof( + account, + ledger_version, + ledger_version, + )?), + }; + + Ok(ResponseItem::GetAccountTransactionBySequenceNumber { + signed_transaction_with_proof, + proof_of_current_sequence_number, + }) + } + + RequestItem::GetEventsByEventAccessPath { + access_path, + start_event_seq_num, + ascending, + limit, + } => { + let (events_with_proof, proof_of_latest_event) = self + .get_events_by_event_access_path( + &access_path, + start_event_seq_num, + ascending, + limit, + ledger_version, + )?; + Ok(ResponseItem::GetEventsByEventAccessPath { + events_with_proof, + proof_of_latest_event, + }) + } + RequestItem::GetTransactions { + start_version, + limit, + fetch_events, + } => { + let txn_list_with_proof = + self.get_transactions(start_version, limit, ledger_version, fetch_events)?; + + Ok(ResponseItem::GetTransactions { + txn_list_with_proof, + }) + } + }) + .collect::>>()?; + + Ok(( + response_items, + ledger_info_with_sigs, + vec![], /* TODO: validator_change_events */ + )) + } + + // =========================== Execution Internal APIs ======================================== + + /// Gets an account state by account address, out of the ledger state indicated by the state + /// Merkle tree root hash. + /// + /// This is used by the executor module internally. + pub fn get_account_state_with_proof_by_state_root( + &self, + address: AccountAddress, + state_root: HashValue, + ) -> Result<(Option, SparseMerkleProof)> { + self.state_store + .get_account_state_with_proof_by_state_root(address, state_root) + } + + /// Gets information needed from storage during the startup of the executor module. + /// + /// This is used by the executor module internally. + pub fn get_executor_startup_info(&self) -> Result> { + // Get the latest ledger info. Return None if not bootstrapped. + let ledger_info_with_sigs = match self.ledger_store.get_latest_ledger_info_option()? { + Some(x) => x, + None => return Ok(None), + }; + let ledger_info = ledger_info_with_sigs.ledger_info().clone(); + + let (latest_version, txn_info) = self.ledger_store.get_latest_transaction_info()?; + + let account_state_root_hash = txn_info.state_root_hash(); + + let ledger_frozen_subtree_hashes = self + .ledger_store + .get_ledger_frozen_subtree_hashes(latest_version)?; + + Ok(Some(ExecutorStartupInfo { + ledger_info, + latest_version, + account_state_root_hash, + ledger_frozen_subtree_hashes, + })) + } + + // ======================= State Synchronizer Internal APIs =================================== + /// Gets a batch of transactions for the purpose of synchronizing state to another node. + /// + /// This is used by the State Synchronizer module internally. + pub fn get_transactions( + &self, + start_version: Version, + limit: u64, + ledger_version: Version, + fetch_events: bool, + ) -> Result { + if start_version > ledger_version || limit == 0 { + return Ok(TransactionListWithProof::new_empty()); + } + + let limit = std::cmp::min(limit, ledger_version - start_version + 1); + let txn_and_txn_info_list = (start_version..start_version + limit) + .into_iter() + .map(|version| { + Ok(( + self.transaction_store.get_transaction(version)?, + self.ledger_store.get_transaction_info(version)?, + )) + }) + .collect::>>()?; + let proof_of_first_transaction = Some( + self.ledger_store + .get_transaction_proof(start_version, ledger_version)?, + ); + let proof_of_last_transaction = if limit == 1 { + None + } else { + Some( + self.ledger_store + .get_transaction_proof(start_version + limit - 1, ledger_version)?, + ) + }; + let events = if fetch_events { + Some( + (start_version..start_version + limit) + .into_iter() + .map(|version| Ok(self.event_store.get_events_by_version(version)?)) + .collect::>>()?, + ) + } else { + None + }; + + Ok(TransactionListWithProof::new( + txn_and_txn_info_list, + events, + Some(start_version), + proof_of_first_transaction, + proof_of_last_transaction, + )) + } + + // ================================== Private APIs ================================== + /// Write the whole schema batch including all data necessary to mutate the ledge + /// state of some transaction by leveraging rocksdb atomicity support. + fn commit(&self, batch: SchemaBatch) -> Result<()> { + self.db.write_schemas(batch) + } + + fn get_account_seq_num_by_version( + &self, + address: AccountAddress, + version: Version, + ) -> Result { + let (account_state_blob, _proof) = self + .state_store + .get_account_state_with_proof_by_state_root( + address, + self.ledger_store + .get_transaction_info(version)? + .state_root_hash(), + )?; + + // If an account does not exist, we treat it as if it has sequence number 0. + Ok(get_account_resource_or_default(&account_state_blob)?.sequence_number()) + } + + fn get_transaction_with_proof( + &self, + version: Version, + ledger_version: Version, + fetch_events: bool, + ) -> Result { + let proof = { + let (txn_info, txn_info_accumulator_proof) = self + .ledger_store + .get_transaction_info_with_proof(version, ledger_version)?; + SignedTransactionProof::new(txn_info_accumulator_proof, txn_info) + }; + let signed_transaction = self.transaction_store.get_transaction(version)?; + + // If events were requested, also fetch those. + let events = if fetch_events { + Some(self.event_store.get_events_by_version(version)?) + } else { + None + }; + + Ok(SignedTransactionWithProof { + version, + signed_transaction, + events, + proof, + }) + } +} + +// Convert requested range and order to a range in ascending order. +fn get_first_seq_num_and_limit(ascending: bool, cursor: u64, limit: u64) -> Result<(u64, u64)> { + ensure!(limit > 0, "limit should > 0, got {}", limit); + + Ok(if ascending { + (cursor, limit) + } else if limit <= cursor { + (cursor - limit + 1, limit) + } else { + (0, cursor + 1) + }) +} diff --git a/storage/libradb/src/libradb_test.rs b/storage/libradb/src/libradb_test.rs new file mode 100644 index 0000000000000..63deedd3b40e8 --- /dev/null +++ b/storage/libradb/src/libradb_test.rs @@ -0,0 +1,400 @@ +// Copyright (c) The Libra Core Contributors +// SPDX-License-Identifier: Apache-2.0 + +use super::*; +use crate::{ + mock_genesis::{db_with_mock_genesis, GENESIS_INFO}, + test_helper::arb_blocks_to_commit, +}; +use crypto::hash::CryptoHash; +use proptest::prelude::*; +use rusty_fork::{rusty_fork_id, rusty_fork_test, rusty_fork_test_name}; +use std::collections::HashMap; +use types::{contract_event::ContractEvent, ledger_info::LedgerInfo}; + +fn test_save_blocks_impl( + input: Vec<(Vec, LedgerInfoWithSignatures)>, +) -> Result<()> { + let tmp_dir = tempfile::tempdir()?; + let db = db_with_mock_genesis(&tmp_dir)?; + + let num_batches = input.len(); + let mut cur_ver = 0; + for (batch_idx, (txns_to_commit, ledger_info_with_sigs)) in input.iter().enumerate() { + db.save_transactions( + &txns_to_commit, + cur_ver + 1, /* first_version */ + &Some(ledger_info_with_sigs.clone()), + )?; + + assert_eq!( + db.ledger_store.get_latest_ledger_info()?, + *ledger_info_with_sigs + ); + verify_committed_transactions( + &db, + &txns_to_commit, + cur_ver, + ledger_info_with_sigs, + batch_idx + 1 == num_batches, /* is_latest */ + )?; + + cur_ver += txns_to_commit.len() as u64; + } + + let first_batch = input.first().unwrap().0.clone(); + let first_batch_ledger_info = input.first().unwrap().1.clone(); + let latest_ledger_info = input.last().unwrap().1.clone(); + // Verify an old batch with the latest LedgerInfo. + verify_committed_transactions( + &db, + &first_batch, + 0, + &latest_ledger_info, + false, /* is_latest */ + )?; + // Verify an old batch with an old LedgerInfo. + verify_committed_transactions( + &db, + &first_batch, + 0, + &first_batch_ledger_info, + true, /* is_latest */ + )?; + + Ok(()) +} + +fn test_sync_transactions_impl( + input: Vec<(Vec, LedgerInfoWithSignatures)>, +) -> Result<()> { + let tmp_dir = tempfile::tempdir()?; + let db = db_with_mock_genesis(&tmp_dir)?; + + let num_batches = input.len(); + let mut cur_ver = 0; + for (batch_idx, (txns_to_commit, ledger_info_with_sigs)) in input.into_iter().enumerate() { + // if batch has more than 2 transactions, save them in two batches + let batch1_len = txns_to_commit.len() / 2; + if batch1_len > 0 { + db.save_transactions( + &txns_to_commit[0..batch1_len], + cur_ver + 1, /* first_version */ + &None, + )?; + } + db.save_transactions( + &txns_to_commit[batch1_len..], + cur_ver + batch1_len as u64 + 1, /* first_version */ + &Some(ledger_info_with_sigs.clone()), + )?; + + verify_committed_transactions( + &db, + &txns_to_commit, + cur_ver, + &ledger_info_with_sigs, + batch_idx + 1 == num_batches, /* is_latest */ + )?; + cur_ver += txns_to_commit.len() as u64; + } + + Ok(()) +} + +fn get_events_by_access_path( + db: &LibraDB, + ledger_info: &LedgerInfo, + access_path: &AccessPath, + first_seq_num: u64, + last_seq_num: u64, + ascending: bool, + is_latest: bool, +) -> Result> { + const LIMIT: u64 = 3; + + let mut cursor = if ascending { + first_seq_num + } else if is_latest { + // Test the ability to get the latest. + u64::max_value() + } else { + last_seq_num + }; + + let mut ret = Vec::new(); + loop { + let (events_with_proof, proof_of_latest_event) = db.get_events_by_event_access_path( + access_path, + cursor, + ascending, + LIMIT, + ledger_info.version(), + )?; + + let num_events = events_with_proof.len() as u64; + if ascending && num_events < LIMIT || !ascending && cursor == u64::max_value() { + let proof_of_latest_event = proof_of_latest_event.unwrap(); + proof_of_latest_event.verify( + ledger_info, + ledger_info.version(), + access_path.address, + )?; + // TODO: decode and see event seq_num once things get more real. + } else { + assert!(proof_of_latest_event.is_none()); + } + + if cursor == u64::max_value() { + cursor = last_seq_num; + } + let expected_seq_nums: Vec<_> = if ascending { + (cursor..cursor + num_events).collect() + } else { + (cursor + 1 - num_events..=cursor).rev().collect() + }; + + let events: Vec<_> = itertools::zip_eq(events_with_proof, expected_seq_nums) + .map(|(e, seq_num)| { + e.verify( + ledger_info, + access_path, + seq_num, + e.transaction_version, + e.event_index, + ) + .unwrap(); + e.event + }) + .collect(); + + let num_results = events.len() as u64; + if num_results == 0 { + break; + } + assert_eq!(events.first().unwrap().sequence_number(), cursor); + + if ascending { + if cursor + num_results > last_seq_num { + ret.extend( + events + .into_iter() + .take((last_seq_num - cursor + 1) as usize), + ); + break; + } else { + ret.extend(events.into_iter()); + cursor += num_results; + } + } else { + // descending + if first_seq_num + num_results > cursor { + ret.extend( + events + .into_iter() + .take((cursor - first_seq_num + 1) as usize), + ); + break; + } else { + ret.extend(events.into_iter()); + cursor -= num_results; + } + } + } + + if !ascending { + ret.reverse(); + } + + Ok(ret) +} + +fn verify_events_by_access_path( + db: &LibraDB, + events: &[ContractEvent], + ledger_info: &LedgerInfo, + is_latest: bool, +) -> Result<()> { + let mut events_by_access_path = HashMap::new(); + events.iter().for_each(|e| { + let list = events_by_access_path + .entry(e.access_path().clone()) + .or_insert_with(Vec::new); + list.push(e.clone()) + }); + + events_by_access_path + .into_iter() + .map(|(access_path, events)| { + let first_seq = events + .first() + .expect("Shouldn't be empty") + .sequence_number(); + let last_seq = events.last().expect("Shouldn't be empty").sequence_number(); + + let traversed = get_events_by_access_path( + db, + ledger_info, + &access_path, + first_seq, + last_seq, + /* ascending = */ true, + is_latest, + )?; + assert_eq!(events, traversed); + + let rev_traversed = get_events_by_access_path( + db, + ledger_info, + &access_path, + first_seq, + last_seq, + /* ascending = */ false, + is_latest, + )?; + assert_eq!(events, rev_traversed); + Ok(()) + }) + .collect::>>()?; + + Ok(()) +} + +fn verify_committed_transactions( + db: &LibraDB, + txns_to_commit: &[TransactionToCommit], + first_version: Version, + ledger_info_with_sigs: &LedgerInfoWithSignatures, + is_latest: bool, +) -> Result<()> { + let ledger_info = ledger_info_with_sigs.ledger_info(); + let ledger_version = ledger_info.version(); + + let mut cur_ver = first_version; + for txn_to_commit in txns_to_commit { + cur_ver += 1; + + let txn_info = db.ledger_store.get_transaction_info(cur_ver)?; + + // Verify transaction hash. + assert_eq!( + txn_info.signed_transaction_hash(), + txn_to_commit.signed_txn().hash() + ); + let txn_list_with_proof = + db.get_transactions(cur_ver, 1, ledger_version, true /* fetch_events */)?; + txn_list_with_proof.verify(ledger_info, Some(cur_ver))?; + + // Fetch and verify account states. + for (addr, expected_blob) in txn_to_commit.account_states() { + let account_state_with_proof = + db.get_account_state_with_proof(*addr, cur_ver, ledger_version)?; + assert_eq!(account_state_with_proof.blob, Some(expected_blob.clone())); + account_state_with_proof.verify(ledger_info, cur_ver, *addr)?; + } + } + + // Fetch and verify events. + // TODO: verify events are saved to correct transaction version. + verify_events_by_access_path( + db, + &txns_to_commit + .iter() + .flat_map(|t| t.events().to_vec()) + .collect::>(), + ledger_info, + is_latest, + )?; + + Ok(()) +} + +proptest! { + #![proptest_config(ProptestConfig::with_cases(10))] + + #[test] + fn test_save_blocks(input in arb_blocks_to_commit()) { + test_save_blocks_impl(input).unwrap(); + } + + #[test] + fn test_sync_transactions(input in arb_blocks_to_commit()) { + test_sync_transactions_impl(input).unwrap(); + } +} + +#[test] +fn test_bootstrap() { + let tmp_dir = tempfile::tempdir().unwrap(); + let db = LibraDB::new(&tmp_dir); + + let genesis_txn_info = GENESIS_INFO.0.clone(); + let genesis_ledger_info_with_sigs = GENESIS_INFO.1.clone(); + let genesis_txn = GENESIS_INFO.2.clone(); + + db.save_transactions( + &[genesis_txn], + 0, /* first_version */ + &Some(genesis_ledger_info_with_sigs.clone()), + ) + .unwrap(); + + assert_eq!(db.get_latest_version().unwrap(), 0); + assert_eq!( + db.ledger_store.get_latest_ledger_info().unwrap(), + genesis_ledger_info_with_sigs + ); + assert_eq!( + db.ledger_store.get_transaction_info(0).unwrap(), + genesis_txn_info + ); +} + +rusty_fork_test! { +#[test] +fn test_committed_txns_counter() { + let tmp_dir = tempfile::tempdir().unwrap(); + let db = LibraDB::new(&tmp_dir); + + let genesis_ledger_info_with_sigs = GENESIS_INFO.1.clone(); + let genesis_txn = GENESIS_INFO.2.clone(); + + db.save_transactions(&[genesis_txn], + 0 /* first_version */, + &Some(genesis_ledger_info_with_sigs.clone())) + .unwrap(); + assert_eq!(OP_COUNTER.counter("committed_txns").get(), 1); +} +} + +#[test] +fn test_bootstrapping_already_bootstrapped_db() { + let tmp_dir = tempfile::tempdir().unwrap(); + let db = db_with_mock_genesis(&tmp_dir).unwrap(); + let ledger_info = db.ledger_store.get_latest_ledger_info().unwrap(); + + let genesis_ledger_info_with_sigs = GENESIS_INFO.1.clone(); + let genesis_txn = GENESIS_INFO.2.clone(); + assert!(db + .save_transactions(&[genesis_txn], 0, &Some(genesis_ledger_info_with_sigs)) + .is_ok()); + assert_eq!( + ledger_info, + db.ledger_store.get_latest_ledger_info().unwrap() + ); +} + +#[test] +fn test_get_first_seq_num_and_limit() { + assert!(get_first_seq_num_and_limit(true, 0, 0).is_err()); + + // ascending + assert_eq!(get_first_seq_num_and_limit(true, 0, 4).unwrap(), (0, 4)); + assert_eq!(get_first_seq_num_and_limit(true, 0, 1).unwrap(), (0, 1)); + + // descending + assert_eq!(get_first_seq_num_and_limit(false, 2, 1).unwrap(), (2, 1)); + assert_eq!(get_first_seq_num_and_limit(false, 2, 2).unwrap(), (1, 2)); + assert_eq!(get_first_seq_num_and_limit(false, 2, 3).unwrap(), (0, 3)); + assert_eq!(get_first_seq_num_and_limit(false, 2, 4).unwrap(), (0, 3)); +} diff --git a/storage/libradb/src/mock_genesis/mod.rs b/storage/libradb/src/mock_genesis/mod.rs new file mode 100644 index 0000000000000..f1af0c6eac90d --- /dev/null +++ b/storage/libradb/src/mock_genesis/mod.rs @@ -0,0 +1,109 @@ +// Copyright (c) The Libra Core Contributors +// SPDX-License-Identifier: Apache-2.0 + +//! This module provides helpers to initialize [`LibraDB`] with fake generic state in tests. + +use crate::LibraDB; +use crypto::{ + hash::{CryptoHash, ACCUMULATOR_PLACEHOLDER_HASH, GENESIS_BLOCK_ID}, + signing::generate_keypair, + HashValue, +}; +use failure::Result; +use lazy_static::lazy_static; +use std::collections::HashMap; +use types::{ + account_address::AccountAddress, + account_state_blob::AccountStateBlob, + ledger_info::{LedgerInfo, LedgerInfoWithSignatures}, + proof::SparseMerkleLeafNode, + transaction::{Program, RawTransaction, TransactionInfo, TransactionToCommit}, +}; + +fn gen_mock_genesis() -> ( + TransactionInfo, + LedgerInfoWithSignatures, + TransactionToCommit, +) { + let (privkey, pubkey) = generate_keypair(); + let some_addr = AccountAddress::from(pubkey); + let raw_txn = RawTransaction::new( + some_addr, + /* sequence_number = */ 0, + Program::new(vec![], vec![], vec![]), + /* max_gas_amount = */ 0, + /* gas_unit_price = */ 0, + /* expiration_time = */ std::time::Duration::new(0, 0), + ); + let signed_txn = raw_txn.sign(&privkey, pubkey).expect("Signing failed."); + let signed_txn_hash = signed_txn.hash(); + + let some_blob = AccountStateBlob::from(vec![1u8]); + let account_states = vec![(some_addr, some_blob.clone())] + .into_iter() + .collect::>(); + + let txn_to_commit = TransactionToCommit::new( + signed_txn, + account_states.clone(), + vec![], /* events */ + 0, /* gas_used */ + ); + + // The genesis state tree has a single leaf node, so the root hash is the hash of that node. + let state_root_hash = SparseMerkleLeafNode::new(some_addr.hash(), some_blob.hash()).hash(); + let txn_info = TransactionInfo::new( + signed_txn_hash, + state_root_hash, + *ACCUMULATOR_PLACEHOLDER_HASH, + 0, + ); + + let ledger_info = LedgerInfo::new( + 0, + txn_info.hash(), + HashValue::random(), + *GENESIS_BLOCK_ID, + 0, + 0, + ); + let ledger_info_with_sigs = + LedgerInfoWithSignatures::new(ledger_info, HashMap::new() /* signatures */); + + (txn_info, ledger_info_with_sigs, txn_to_commit) +} + +lazy_static! { + /// Tuple containing information about the mock genesis state. + /// + /// Tests can use this as input to generate the mock genesis state and verify against it. It is + /// defined as ([`TransactionInfo`], [`LedgerInfoWithSignatures`], + /// [`TransactionToCommit`]): + /// + /// - [`TransactionToCommit`] is the mock genesis transaction. + /// - [`TransactionInfo`] is calculated out of the mock genesis transaction. + /// - [`LedgerInfoWithSignatures`] contains the hash of the above mock transaction info and + /// other mocked information including validator signatures. + pub static ref GENESIS_INFO: ( + TransactionInfo, + LedgerInfoWithSignatures, + TransactionToCommit + ) = gen_mock_genesis(); +} + +/// This creates an empty db at input `dir` and initializes it with mock genesis info. +/// +/// The resulting db will have only one transaction at version 0 (the mock genesis transaction) and +/// related outputs (the mock genesis state) in it. +pub fn db_with_mock_genesis>(dir: &P) -> Result { + let genesis_ledger_info_with_sigs = GENESIS_INFO.1.clone(); + let genesis_txn = GENESIS_INFO.2.clone(); + + let db = LibraDB::new(dir); + db.save_transactions( + &[genesis_txn], + 0, /* first_version */ + &Some(genesis_ledger_info_with_sigs), + )?; + Ok(db) +} diff --git a/storage/libradb/src/schema/account_state/mod.rs b/storage/libradb/src/schema/account_state/mod.rs new file mode 100644 index 0000000000000..f8509f0fc1e17 --- /dev/null +++ b/storage/libradb/src/schema/account_state/mod.rs @@ -0,0 +1,49 @@ +// Copyright (c) The Libra Core Contributors +// SPDX-License-Identifier: Apache-2.0 + +//! This module defines physical storage schema for account data blob pointed by leaves of state +//! Merkle tree. +//! Account state blob is identified by its hash. +//! ```text +//! |<----key--->|<-------value------->| +//! | hash | account state blob | +//! ``` + +use crate::schema::ACCOUNT_STATE_CF_NAME; +use crypto::HashValue; +use failure::prelude::*; +use schemadb::{ + define_schema, + schema::{KeyCodec, ValueCodec}, +}; +use types::account_state_blob::AccountStateBlob; + +define_schema!( + AccountStateSchema, + HashValue, + AccountStateBlob, + ACCOUNT_STATE_CF_NAME +); + +impl KeyCodec for HashValue { + fn encode_key(&self) -> Result> { + Ok(self.to_vec()) + } + + fn decode_key(data: &[u8]) -> Result { + Ok(HashValue::from_slice(data)?) + } +} + +impl ValueCodec for AccountStateBlob { + fn encode_value(&self) -> Result> { + Ok(self.clone().into()) + } + + fn decode_value(data: &[u8]) -> Result { + Ok(data.to_vec().into()) + } +} + +#[cfg(test)] +mod test; diff --git a/storage/libradb/src/schema/account_state/test.rs b/storage/libradb/src/schema/account_state/test.rs new file mode 100644 index 0000000000000..18f7e33f5c554 --- /dev/null +++ b/storage/libradb/src/schema/account_state/test.rs @@ -0,0 +1,14 @@ +// Copyright (c) The Libra Core Contributors +// SPDX-License-Identifier: Apache-2.0 + +use super::*; +use crypto::HashValue; +use schemadb::schema::assert_encode_decode; + +#[test] +fn test_account_state_row() { + assert_encode_decode::( + &HashValue::random(), + &vec![0x01, 0x02, 0x03].into(), + ); +} diff --git a/storage/libradb/src/schema/event/mod.rs b/storage/libradb/src/schema/event/mod.rs new file mode 100644 index 0000000000000..54f720516d1f4 --- /dev/null +++ b/storage/libradb/src/schema/event/mod.rs @@ -0,0 +1,67 @@ +// Copyright (c) The Libra Core Contributors +// SPDX-License-Identifier: Apache-2.0 + +//! This module defines physical storage schema for the contract events. +//! +//! An event is keyed by the version of the transaction it belongs to and the index of it among all +//! events yielded by the same transaction. +//! ```text +//! |<-------key----->|<---value--->| +//! | version | index | event bytes | +//! ``` + +use crate::schema::{ensure_slice_len_eq, EVENT_CF_NAME}; +use byteorder::{BigEndian, ReadBytesExt, WriteBytesExt}; +use failure::prelude::*; +use proto_conv::{FromProtoBytes, IntoProtoBytes}; +use schemadb::{ + define_schema, + schema::{KeyCodec, SeekKeyCodec, ValueCodec}, +}; +use std::mem::size_of; +use types::{contract_event::ContractEvent, transaction::Version}; + +define_schema!(EventSchema, Key, ContractEvent, EVENT_CF_NAME); + +type Index = u64; +type Key = (Version, Index); + +impl KeyCodec for Key { + fn encode_key(&self) -> Result> { + let (version, index) = *self; + + let mut encoded_key = Vec::with_capacity(size_of::() + size_of::()); + encoded_key.write_u64::(version)?; + encoded_key.write_u64::(index)?; + Ok(encoded_key) + } + + fn decode_key(data: &[u8]) -> Result { + ensure_slice_len_eq(data, size_of::())?; + + let version_size = size_of::(); + + let version = (&data[..version_size]).read_u64::()?; + let index = (&data[version_size..]).read_u64::()?; + Ok((version, index)) + } +} + +impl ValueCodec for ContractEvent { + fn encode_value(&self) -> Result> { + self.clone().into_proto_bytes() + } + + fn decode_value(data: &[u8]) -> Result { + Self::from_proto_bytes(data) + } +} + +impl SeekKeyCodec for Version { + fn encode_seek_key(&self) -> Result> { + Ok(self.to_be_bytes().to_vec()) + } +} + +#[cfg(test)] +mod test; diff --git a/storage/libradb/src/schema/event/test.rs b/storage/libradb/src/schema/event/test.rs new file mode 100644 index 0000000000000..c0c1caf7fe71b --- /dev/null +++ b/storage/libradb/src/schema/event/test.rs @@ -0,0 +1,17 @@ +// Copyright (c) The Libra Core Contributors +// SPDX-License-Identifier: Apache-2.0 + +use super::*; +use proptest::prelude::*; +use schemadb::schema::assert_encode_decode; + +proptest! { + #[test] + fn test_encode_decode( + version in any::(), + index in any::(), + event in any::(), + ) { + assert_encode_decode::(&(version, index), &event); + } +} diff --git a/storage/libradb/src/schema/event_accumulator/mod.rs b/storage/libradb/src/schema/event_accumulator/mod.rs new file mode 100644 index 0000000000000..9b97172734213 --- /dev/null +++ b/storage/libradb/src/schema/event_accumulator/mod.rs @@ -0,0 +1,65 @@ +// Copyright (c) The Libra Core Contributors +// SPDX-License-Identifier: Apache-2.0 + +//! This module defines physical storage schema for the event accumulator. +//! +//! Each version has its own event accumulator and a hash value is stored on each position within an +//! accumulator. See `storage/accumulator/lib.rs` for details. +//! ```text +//! |<--------key------->|<-value->| +//! | version | position | hash | +//! ``` + +use crate::schema::{ensure_slice_len_eq, EVENT_ACCUMULATOR_CF_NAME}; +use byteorder::{BigEndian, ReadBytesExt, WriteBytesExt}; +use crypto::hash::HashValue; +use failure::prelude::*; +use schemadb::{ + define_schema, + schema::{KeyCodec, ValueCodec}, +}; +use std::mem::size_of; +use types::{proof::position::Position, transaction::Version}; + +define_schema!( + EventAccumulatorSchema, + Key, + HashValue, + EVENT_ACCUMULATOR_CF_NAME +); + +type Key = (Version, Position); + +impl KeyCodec for Key { + fn encode_key(&self) -> Result> { + let (version, position) = self; + + let mut encoded_key = Vec::with_capacity(size_of::() + size_of::()); + encoded_key.write_u64::(*version)?; + encoded_key.write_u64::(position.to_inorder_index())?; + Ok(encoded_key) + } + + fn decode_key(data: &[u8]) -> Result { + ensure_slice_len_eq(data, size_of::())?; + + let version_size = size_of::(); + + let version = (&data[..version_size]).read_u64::()?; + let position = (&data[version_size..]).read_u64::()?; + Ok((version, Position::from_inorder_index(position))) + } +} + +impl ValueCodec for HashValue { + fn encode_value(&self) -> Result> { + Ok(self.to_vec()) + } + + fn decode_value(data: &[u8]) -> Result { + Self::from_slice(data) + } +} + +#[cfg(test)] +mod test; diff --git a/storage/libradb/src/schema/event_accumulator/test.rs b/storage/libradb/src/schema/event_accumulator/test.rs new file mode 100644 index 0000000000000..18be337269dee --- /dev/null +++ b/storage/libradb/src/schema/event_accumulator/test.rs @@ -0,0 +1,13 @@ +// Copyright (c) The Libra Core Contributors +// SPDX-License-Identifier: Apache-2.0 + +use super::*; +use schemadb::schema::assert_encode_decode; + +#[test] +fn test_encode_decode() { + assert_encode_decode::( + &(100, Position::from_inorder_index(100)), + &HashValue::random(), + ); +} diff --git a/storage/libradb/src/schema/event_by_access_path/mod.rs b/storage/libradb/src/schema/event_by_access_path/mod.rs new file mode 100644 index 0000000000000..35b26e406b130 --- /dev/null +++ b/storage/libradb/src/schema/event_by_access_path/mod.rs @@ -0,0 +1,82 @@ +// Copyright (c) The Libra Core Contributors +// SPDX-License-Identifier: Apache-2.0 + +//! This module defines physical storage schema for an event index via which a ContractEvent ( +//! represented by a tuple so that it can be fetched from `EventSchema`) +//! can be found by tuple. +//! +//! ```text +//! |<----------key-------->|<----value---->| +//! | access_path | seq_num | txn_ver | idx | +//! ``` + +use crate::schema::{ensure_slice_len_eq, ensure_slice_len_gt, EVENT_BY_ACCESS_PATH_CF_NAME}; +use byteorder::{BigEndian, ReadBytesExt, WriteBytesExt}; +use canonical_serialization::{SimpleDeserializer, SimpleSerializer}; +use failure::prelude::*; +use schemadb::{ + define_schema, + schema::{KeyCodec, ValueCodec}, +}; +use std::mem::size_of; +use types::{access_path::AccessPath, transaction::Version}; + +define_schema!( + EventByAccessPathSchema, + Key, + Value, + EVENT_BY_ACCESS_PATH_CF_NAME +); + +type SeqNum = u64; +type Key = (AccessPath, SeqNum); + +type Index = u64; +type Value = (Version, Index); + +impl KeyCodec for Key { + fn encode_key(&self) -> Result> { + let (ref access_path, seq_num) = *self; + + let mut encoded = SimpleSerializer::>::serialize(access_path)?; + encoded.write_u64::(seq_num)?; + + Ok(encoded) + } + + fn decode_key(data: &[u8]) -> Result { + let version_size = size_of::(); + ensure_slice_len_gt(data, version_size)?; + let access_path_len = data.len() - version_size; + + let access_path = SimpleDeserializer::deserialize(&data[..access_path_len])?; + let seq_num = (&data[access_path_len..]).read_u64::()?; + + Ok((access_path, seq_num)) + } +} + +impl ValueCodec for Value { + fn encode_value(&self) -> Result> { + let (version, index) = *self; + + let mut encoded = Vec::with_capacity(size_of::() + size_of::()); + encoded.write_u64::(version)?; + encoded.write_u64::(index)?; + + Ok(encoded) + } + + fn decode_value(data: &[u8]) -> Result { + ensure_slice_len_eq(data, size_of::())?; + + let version_size = size_of::(); + + let version = (&data[..version_size]).read_u64::()?; + let index = (&data[version_size..]).read_u64::()?; + Ok((version, index)) + } +} + +#[cfg(test)] +mod test; diff --git a/storage/libradb/src/schema/event_by_access_path/test.rs b/storage/libradb/src/schema/event_by_access_path/test.rs new file mode 100644 index 0000000000000..64ddf5881b09a --- /dev/null +++ b/storage/libradb/src/schema/event_by_access_path/test.rs @@ -0,0 +1,18 @@ +// Copyright (c) The Libra Core Contributors +// SPDX-License-Identifier: Apache-2.0 + +use super::*; +use proptest::prelude::*; +use schemadb::schema::assert_encode_decode; + +proptest! { + #[test] + fn test_encode_decode( + access_path in any::(), + seq_num in any::(), + version in any::(), + index in any::(), + ) { + assert_encode_decode::(&(access_path, seq_num), &(version, index)); + } +} diff --git a/storage/libradb/src/schema/ledger_info/mod.rs b/storage/libradb/src/schema/ledger_info/mod.rs new file mode 100644 index 0000000000000..ed09a2aa113d0 --- /dev/null +++ b/storage/libradb/src/schema/ledger_info/mod.rs @@ -0,0 +1,56 @@ +// Copyright (c) The Libra Core Contributors +// SPDX-License-Identifier: Apache-2.0 + +//! This module defines physical storage schema for LedgerInfoWithSignatures structure. +//! +//! Serialized LedgerInfoWithSignatures identified by version. +//! ```text +//! |<--key-->|<---------------value------------->| +//! | version | ledger_info_with_signatures bytes | +//! ``` +//! +//! `Version` is serialized in big endian so that records in RocksDB will be in order of it's +//! numeric value. + +use crate::schema::ensure_slice_len_eq; +use byteorder::{BigEndian, ReadBytesExt}; +use failure::prelude::*; +use proto_conv::{FromProtoBytes, IntoProtoBytes}; +use schemadb::{ + define_schema, + schema::{KeyCodec, ValueCodec}, + DEFAULT_CF_NAME, +}; +use std::mem::size_of; +use types::{ledger_info::LedgerInfoWithSignatures, transaction::Version}; + +define_schema!( + LedgerInfoSchema, + Version, + LedgerInfoWithSignatures, + DEFAULT_CF_NAME +); + +impl KeyCodec for Version { + fn encode_key(&self) -> Result> { + Ok(self.to_be_bytes().to_vec()) + } + + fn decode_key(data: &[u8]) -> Result { + ensure_slice_len_eq(data, size_of::())?; + Ok((&data[..]).read_u64::()?) + } +} + +impl ValueCodec for LedgerInfoWithSignatures { + fn encode_value(&self) -> Result> { + self.clone().into_proto_bytes() + } + + fn decode_value(data: &[u8]) -> Result { + Self::from_proto_bytes(data) + } +} + +#[cfg(test)] +mod test; diff --git a/storage/libradb/src/schema/ledger_info/test.rs b/storage/libradb/src/schema/ledger_info/test.rs new file mode 100644 index 0000000000000..b06ac2b19802e --- /dev/null +++ b/storage/libradb/src/schema/ledger_info/test.rs @@ -0,0 +1,16 @@ +// Copyright (c) The Libra Core Contributors +// SPDX-License-Identifier: Apache-2.0 + +use super::*; +use proptest::prelude::*; +use schemadb::schema::assert_encode_decode; +use types::ledger_info::LedgerInfoWithSignatures; + +proptest! { + #[test] + fn test_encode_decode( + ledger_info_with_sigs in any_with::((1..10).into()) + ) { + assert_encode_decode::(&0, &ledger_info_with_sigs); + } +} diff --git a/storage/libradb/src/schema/mod.rs b/storage/libradb/src/schema/mod.rs new file mode 100644 index 0000000000000..dc4a12af35ab5 --- /dev/null +++ b/storage/libradb/src/schema/mod.rs @@ -0,0 +1,52 @@ +// Copyright (c) The Libra Core Contributors +// SPDX-License-Identifier: Apache-2.0 + +//! This module defines representation of Libra core data structures at physical level via schemas +//! that implement [`schemadb::schema::Schema`]. +//! +//! All schemas are `pub(crate)` so not shown in rustdoc, refer to the source code to see details. + +pub(crate) mod account_state; +pub(crate) mod event; +pub(crate) mod event_accumulator; +pub(crate) mod event_by_access_path; +pub(crate) mod ledger_info; +pub(crate) mod signed_transaction; +pub(crate) mod state_merkle_node; +pub(crate) mod transaction_accumulator; +pub(crate) mod transaction_info; +pub(crate) mod validator; + +use failure::prelude::*; +use schemadb::ColumnFamilyName; + +pub(super) const ACCOUNT_STATE_CF_NAME: ColumnFamilyName = "account_state"; +pub(super) const EVENT_ACCUMULATOR_CF_NAME: ColumnFamilyName = "event_accumulator"; +pub(super) const EVENT_BY_ACCESS_PATH_CF_NAME: ColumnFamilyName = "event_by_access_path"; +pub(super) const EVENT_CF_NAME: ColumnFamilyName = "event"; +pub(super) const SIGNATURE_CF_NAME: ColumnFamilyName = "signature"; +pub(super) const SIGNED_TRANSACTION_CF_NAME: ColumnFamilyName = "signed_transaction"; +pub(super) const STATE_MERKLE_NODE_CF_NAME: ColumnFamilyName = "state_merkle_node"; +pub(super) const TRANSACTION_ACCUMULATOR_CF_NAME: ColumnFamilyName = "transaction_accumulator"; +pub(super) const TRANSACTION_INFO_CF_NAME: ColumnFamilyName = "transaction_info"; +pub(super) const VALIDATOR_CF_NAME: ColumnFamilyName = "validator"; + +fn ensure_slice_len_eq(data: &[u8], len: usize) -> Result<()> { + ensure!( + data.len() == len, + "Unexpected data len {}, expected {}.", + data.len(), + len, + ); + Ok(()) +} + +fn ensure_slice_len_gt(data: &[u8], len: usize) -> Result<()> { + ensure!( + data.len() > len, + "Unexpected data len {}, expected greater than {}.", + data.len(), + len, + ); + Ok(()) +} diff --git a/storage/libradb/src/schema/signed_transaction/mod.rs b/storage/libradb/src/schema/signed_transaction/mod.rs new file mode 100644 index 0000000000000..0b4a51f1b539d --- /dev/null +++ b/storage/libradb/src/schema/signed_transaction/mod.rs @@ -0,0 +1,55 @@ +// Copyright (c) The Libra Core Contributors +// SPDX-License-Identifier: Apache-2.0 + +//! This module defines physical storage schema for signed transactions. +//! +//! Serialized signed transaction bytes identified by version. +//! ```text +//! |<--key-->|<--value-->| +//! | version | txn bytes | +//! ``` +//! +//! `Version` is serialized in big endian so that records in RocksDB will be in order of it's +//! numeric value. + +use crate::schema::{ensure_slice_len_eq, SIGNED_TRANSACTION_CF_NAME}; +use byteorder::{BigEndian, ReadBytesExt}; +use failure::prelude::*; +use proto_conv::{FromProtoBytes, IntoProtoBytes}; +use schemadb::{ + define_schema, + schema::{KeyCodec, ValueCodec}, +}; +use std::mem::size_of; +use types::transaction::{SignedTransaction, Version}; + +define_schema!( + SignedTransactionSchema, + Version, + SignedTransaction, + SIGNED_TRANSACTION_CF_NAME +); + +impl KeyCodec for Version { + fn encode_key(&self) -> Result> { + Ok(self.to_be_bytes().to_vec()) + } + + fn decode_key(data: &[u8]) -> Result { + ensure_slice_len_eq(data, size_of::())?; + Ok((&data[..]).read_u64::()?) + } +} + +impl ValueCodec for SignedTransaction { + fn encode_value(&self) -> Result> { + self.clone().into_proto_bytes() + } + + fn decode_value(data: &[u8]) -> Result { + Self::from_proto_bytes(data) + } +} + +#[cfg(test)] +mod test; diff --git a/storage/libradb/src/schema/signed_transaction/test.rs b/storage/libradb/src/schema/signed_transaction/test.rs new file mode 100644 index 0000000000000..1aa541addca1f --- /dev/null +++ b/storage/libradb/src/schema/signed_transaction/test.rs @@ -0,0 +1,14 @@ +// Copyright (c) The Libra Core Contributors +// SPDX-License-Identifier: Apache-2.0 + +use super::*; +use proptest::prelude::*; +use schemadb::schema::assert_encode_decode; +use types::transaction::SignedTransaction; + +proptest! { + #[test] + fn test_encode_decode(txn in any::()) { + assert_encode_decode::(&0u64, &txn); + } +} diff --git a/storage/libradb/src/schema/state_merkle_node/mod.rs b/storage/libradb/src/schema/state_merkle_node/mod.rs new file mode 100644 index 0000000000000..a199aa4cbbab1 --- /dev/null +++ b/storage/libradb/src/schema/state_merkle_node/mod.rs @@ -0,0 +1,48 @@ +// Copyright (c) The Libra Core Contributors +// SPDX-License-Identifier: Apache-2.0 + +//! This module defines physical storage schema for nodes in state Merkle tree. +//! Node is identified by node_hash = Hash(serialized_node). +//! ```text +//! |<----key--->|<-----value----->| +//! | node_hash | serialized_node | +//! ``` + +use crate::schema::STATE_MERKLE_NODE_CF_NAME; +use crypto::HashValue; +use failure::prelude::*; +use schemadb::{ + define_schema, + schema::{KeyCodec, ValueCodec}, +}; +use sparse_merkle::node_type::Node; + +define_schema!( + StateMerkleNodeSchema, + HashValue, + Node, + STATE_MERKLE_NODE_CF_NAME +); + +impl KeyCodec for HashValue { + fn encode_key(&self) -> Result> { + Ok(self.to_vec()) + } + + fn decode_key(data: &[u8]) -> Result { + Ok(HashValue::from_slice(&data[..])?) + } +} + +impl ValueCodec for Node { + fn encode_value(&self) -> Result> { + Ok(self.encode()?) + } + + fn decode_value(data: &[u8]) -> Result { + Ok(Node::decode(&data[..])?) + } +} + +#[cfg(test)] +mod test; diff --git a/storage/libradb/src/schema/state_merkle_node/test.rs b/storage/libradb/src/schema/state_merkle_node/test.rs new file mode 100644 index 0000000000000..c00843269a329 --- /dev/null +++ b/storage/libradb/src/schema/state_merkle_node/test.rs @@ -0,0 +1,15 @@ +// Copyright (c) The Libra Core Contributors +// SPDX-License-Identifier: Apache-2.0 + +use super::*; +use crypto::HashValue; +use schemadb::schema::assert_encode_decode; +use sparse_merkle::node_type::Node; + +#[test] +fn test_state_merkle_node_schema() { + assert_encode_decode::( + &HashValue::random(), + &Node::new_leaf(HashValue::random(), HashValue::random()), + ); +} diff --git a/storage/libradb/src/schema/transaction_accumulator/mod.rs b/storage/libradb/src/schema/transaction_accumulator/mod.rs new file mode 100644 index 0000000000000..c621d77cf11e6 --- /dev/null +++ b/storage/libradb/src/schema/transaction_accumulator/mod.rs @@ -0,0 +1,55 @@ +// Copyright (c) The Libra Core Contributors +// SPDX-License-Identifier: Apache-2.0 + +//! This module defines physical storage schema for the transaction accumulator. +//! +//! A hash value is stored on each position. +//! See `storage/accumulator/lib.rs` for details. +//! ```text +//! |<---key--->|<-value->| +//! | position | hash | +//! ``` + +use crate::schema::{ensure_slice_len_eq, TRANSACTION_ACCUMULATOR_CF_NAME}; +use byteorder::{BigEndian, ReadBytesExt}; +use crypto::HashValue; +use failure::prelude::*; +use schemadb::{ + define_schema, + schema::{KeyCodec, ValueCodec}, +}; +use std::mem::size_of; +use types::proof::position::Position; + +define_schema!( + TransactionAccumulatorSchema, + Position, + HashValue, + TRANSACTION_ACCUMULATOR_CF_NAME +); + +impl KeyCodec for Position { + fn encode_key(&self) -> Result> { + Ok(self.to_inorder_index().to_be_bytes().to_vec()) + } + + fn decode_key(data: &[u8]) -> Result { + ensure_slice_len_eq(data, size_of::())?; + Ok(Position::from_inorder_index( + (&data[..]).read_u64::()?, + )) + } +} + +impl ValueCodec for HashValue { + fn encode_value(&self) -> Result> { + Ok(self.to_vec()) + } + + fn decode_value(data: &[u8]) -> Result { + Self::from_slice(data) + } +} + +#[cfg(test)] +mod test; diff --git a/storage/libradb/src/schema/transaction_accumulator/test.rs b/storage/libradb/src/schema/transaction_accumulator/test.rs new file mode 100644 index 0000000000000..371568a365d40 --- /dev/null +++ b/storage/libradb/src/schema/transaction_accumulator/test.rs @@ -0,0 +1,13 @@ +// Copyright (c) The Libra Core Contributors +// SPDX-License-Identifier: Apache-2.0 + +use super::*; +use schemadb::schema::assert_encode_decode; + +#[test] +fn test_encode_decode() { + assert_encode_decode::( + &Position::from_inorder_index(100), + &HashValue::random(), + ); +} diff --git a/storage/libradb/src/schema/transaction_info/mod.rs b/storage/libradb/src/schema/transaction_info/mod.rs new file mode 100644 index 0000000000000..a9940fb8e8be0 --- /dev/null +++ b/storage/libradb/src/schema/transaction_info/mod.rs @@ -0,0 +1,59 @@ +// Copyright (c) The Libra Core Contributors +// SPDX-License-Identifier: Apache-2.0 + +//! This module defines physical storage schema for TransactionInfo structure. +//! +//! Serialized signed transaction bytes identified by version. +//! ```text +//! |<--key-->|<-----value---->| +//! | version | txn_info bytes | +//! ``` +//! +//! `Version` is serialized in big endian so that records in RocksDB will be in order of it's +//! numeric value. + +use crate::schema::TRANSACTION_INFO_CF_NAME; +use byteorder::{BigEndian, ReadBytesExt}; +use failure::prelude::*; +use proto_conv::{FromProtoBytes, IntoProtoBytes}; +use schemadb::{ + define_schema, + schema::{KeyCodec, ValueCodec}, +}; +use std::mem::size_of; +use types::transaction::{TransactionInfo, Version}; + +define_schema!( + TransactionInfoSchema, + Version, + TransactionInfo, + TRANSACTION_INFO_CF_NAME +); + +impl KeyCodec for Version { + fn encode_key(&self) -> Result> { + Ok(self.to_be_bytes().to_vec()) + } + + fn decode_key(data: &[u8]) -> Result { + ensure!( + data.len() == size_of::(), + "Bad num of bytes: {}", + data.len() + ); + Ok((&data[..]).read_u64::()?) + } +} + +impl ValueCodec for TransactionInfo { + fn encode_value(&self) -> Result> { + self.clone().into_proto_bytes() + } + + fn decode_value(data: &[u8]) -> Result { + Self::from_proto_bytes(data) + } +} + +#[cfg(test)] +mod test; diff --git a/storage/libradb/src/schema/transaction_info/test.rs b/storage/libradb/src/schema/transaction_info/test.rs new file mode 100644 index 0000000000000..335b7d1e1ae13 --- /dev/null +++ b/storage/libradb/src/schema/transaction_info/test.rs @@ -0,0 +1,18 @@ +// Copyright (c) The Libra Core Contributors +// SPDX-License-Identifier: Apache-2.0 + +use super::*; +use crypto::HashValue; +use schemadb::schema::assert_encode_decode; +use types::transaction::TransactionInfo; + +#[test] +fn test_encode_decode() { + let txn_info = TransactionInfo::new( + HashValue::random(), + HashValue::random(), + HashValue::random(), + 7, + ); + assert_encode_decode::(&0u64, &txn_info); +} diff --git a/storage/libradb/src/schema/validator/mod.rs b/storage/libradb/src/schema/validator/mod.rs new file mode 100644 index 0000000000000..5f8c3c161a588 --- /dev/null +++ b/storage/libradb/src/schema/validator/mod.rs @@ -0,0 +1,69 @@ +// Copyright (c) The Libra Core Contributors +// SPDX-License-Identifier: Apache-2.0 + +//! This module defines physical storage schema for validator sets +//! Among all versions at some of them we change the validator set, at which +//! point we call it a new epoch. Each validator identified by their `public_key` +//! that's part of the epoch starting at `version` is stored in a row. +//! ```text +//! |<---------key-------->|<-value->| +//! | version | public_key | null | +//! ``` + +use crate::schema::{ensure_slice_len_eq, VALIDATOR_CF_NAME}; +use byteorder::{BigEndian, ReadBytesExt, WriteBytesExt}; +use crypto::PublicKey; +use failure::prelude::*; +use schemadb::{ + define_schema, + schema::{KeyCodec, ValueCodec}, +}; +use std::{io::Write, mem::size_of}; +use types::transaction::Version; + +define_schema!(ValidatorSchema, Key, Value, VALIDATOR_CF_NAME); + +#[derive(Clone, Eq, PartialEq, Debug)] +pub struct Key { + /// version at which epoch starts + pub(crate) version: Version, + /// public_key of validator + pub(crate) public_key: PublicKey, +} + +#[derive(Clone, Eq, PartialEq, Debug)] +pub struct Value; + +impl KeyCodec for Key { + fn encode_key(&self) -> Result> { + let public_key_serialized = self.public_key.to_slice(); + let mut encoded_key = + Vec::with_capacity(size_of::() + PublicKey::LENGTH * size_of::()); + encoded_key.write_u64::(self.version)?; + encoded_key.write_all(&public_key_serialized)?; + Ok(encoded_key) + } + + fn decode_key(data: &[u8]) -> Result { + let version = (&data[..size_of::()]).read_u64::()?; + let public_key = PublicKey::from_slice(&data[size_of::()..])?; + Ok(Key { + version, + public_key, + }) + } +} + +impl ValueCodec for Value { + fn encode_value(&self) -> Result> { + Ok(vec![]) + } + + fn decode_value(data: &[u8]) -> Result { + ensure_slice_len_eq(data, 0)?; + Ok(Value) + } +} + +#[cfg(test)] +mod test; diff --git a/storage/libradb/src/schema/validator/test.rs b/storage/libradb/src/schema/validator/test.rs new file mode 100644 index 0000000000000..6d082af134b57 --- /dev/null +++ b/storage/libradb/src/schema/validator/test.rs @@ -0,0 +1,46 @@ +// Copyright (c) The Libra Core Contributors +// SPDX-License-Identifier: Apache-2.0 + +use super::*; +use crypto::signing::generate_keypair; +use itertools::Itertools; +use rand::{thread_rng, Rng}; +use schemadb::schema::assert_encode_decode; +use types::transaction::Version; + +fn row_with_arbitrary_validator(version: Version) -> (Key, Value) { + let (_private_key, public_key) = generate_keypair(); + ( + Key { + version, + public_key, + }, + Value, + ) +} + +#[test] +fn test_encode_decode() { + let (k, v) = row_with_arbitrary_validator(1); + assert_encode_decode::(&k, &v); +} + +#[test] +fn test_order() { + let mut versions: Vec = (0..1024).collect(); + thread_rng().shuffle(&mut versions); + + let encoded_sorted: Vec> = versions + .into_iter() + .map(|v| row_with_arbitrary_validator(v).0.encode_key().unwrap()) + .sorted(); + + let decoded_versions: Vec = encoded_sorted + .iter() + .map(|k| Key::decode_key(k).unwrap().version) + .collect(); + + let ordered_versions: Vec = (0..1024).collect(); + + assert_eq!(decoded_versions, ordered_versions) +} diff --git a/storage/libradb/src/state_store/mod.rs b/storage/libradb/src/state_store/mod.rs new file mode 100644 index 0000000000000..bd596927652b9 --- /dev/null +++ b/storage/libradb/src/state_store/mod.rs @@ -0,0 +1,97 @@ +// Copyright (c) The Libra Core Contributors +// SPDX-License-Identifier: Apache-2.0 + +//! This file defines state store APIs that are related account state Merkle tree. + +#[cfg(test)] +mod state_store_test; + +use crate::schema::{account_state::AccountStateSchema, state_merkle_node::StateMerkleNodeSchema}; +use crypto::{hash::CryptoHash, HashValue}; +use failure::prelude::*; +use schemadb::{SchemaBatch, DB}; +use sparse_merkle::{node_type::Node, SparseMerkleTree, TreeReader}; +use std::{collections::HashMap, sync::Arc}; +use types::{ + account_address::AccountAddress, + account_state_blob::AccountStateBlob, + proof::{verify_sparse_merkle_element, SparseMerkleProof}, +}; + +pub(crate) struct StateStore { + db: Arc, +} + +impl StateStore { + pub fn new(db: Arc) -> Self { + Self { db } + } + + /// Get the account state blob given account address and root hash of state Merkle tree + pub fn get_account_state_with_proof_by_state_root( + &self, + address: AccountAddress, + root_hash: HashValue, + ) -> Result<(Option, SparseMerkleProof)> { + let (blob, proof) = + SparseMerkleTree::new(self).get_with_proof(address.hash(), root_hash)?; + debug_assert!( + verify_sparse_merkle_element(root_hash, address.hash(), &blob, &proof).is_ok(), + "Invalid proof." + ); + Ok((blob, proof)) + } + + /// Put the results generated by `keyed_blob_sets` to `batch` and return the result root hashes + /// for each write set. + pub fn put_account_state_sets( + &self, + account_state_sets: Vec>, + root_hash: HashValue, + batch: &mut SchemaBatch, + ) -> Result> { + let keyed_blob_sets = account_state_sets + .into_iter() + .map(|account_states| { + account_states + .into_iter() + .map(|(addr, blob)| (addr.hash(), blob)) + .collect::>() + }) + .collect::>(); + + let (new_root_hash_vec, tree_update_batch) = + SparseMerkleTree::new(self).put_keyed_blob_sets(keyed_blob_sets, root_hash)?; + let (node_batch, blob_batch) = tree_update_batch.into(); + node_batch + .iter() + .map(|(node_hash, node)| batch.put::(node_hash, node)) + .collect::>>()?; + blob_batch + .iter() + .map(|(blob_hash, blob)| batch.put::(blob_hash, blob)) + .collect::>>()?; + Ok(new_root_hash_vec) + } +} + +impl TreeReader for StateStore { + fn get_node(&self, node_hash: HashValue) -> Result { + Ok(self + .db + .get::(&node_hash)? + .ok_or_else(|| format_err!("Failed to find node with hash {:?}", node_hash))?) + } + + fn get_blob(&self, blob_hash: HashValue) -> Result { + Ok(self + .db + .get::(&blob_hash)? + .ok_or_else(|| { + format_err!( + "Failed to find account state blob with hash {:?}", + blob_hash + ) + })?) + } +} diff --git a/storage/libradb/src/state_store/state_store_test.rs b/storage/libradb/src/state_store/state_store_test.rs new file mode 100644 index 0000000000000..4174dd21e1f54 --- /dev/null +++ b/storage/libradb/src/state_store/state_store_test.rs @@ -0,0 +1,141 @@ +// Copyright (c) The Libra Core Contributors +// SPDX-License-Identifier: Apache-2.0 + +use super::*; +use crate::LibraDB; +use crypto::hash::{CryptoHash, SPARSE_MERKLE_PLACEHOLDER_HASH}; +use tempfile::tempdir; +use types::{ + account_address::{AccountAddress, ADDRESS_LENGTH}, + account_state_blob::AccountStateBlob, + proof::verify_sparse_merkle_element, +}; + +fn put_account_state_set( + store: &StateStore, + account_state_set: Vec<(AccountAddress, AccountStateBlob)>, + root_hash: HashValue, + batch: &mut SchemaBatch, +) -> HashValue { + store + .put_account_state_sets( + vec![account_state_set.into_iter().collect::>()], + root_hash, + batch, + ) + .unwrap()[0] +} + +#[test] +fn test_empty_store() { + let tmp_dir = tempdir().unwrap(); + let db = LibraDB::new(&tmp_dir); + let store = &db.state_store; + let address = AccountAddress::new([1u8; ADDRESS_LENGTH]); + let root = *SPARSE_MERKLE_PLACEHOLDER_HASH; + let (value, proof) = store + .get_account_state_with_proof_by_state_root(address, root) + .unwrap(); + assert!(value.is_none()); + assert!(verify_sparse_merkle_element(root, address.hash(), &None, &proof).is_ok()); +} + +#[test] +fn test_state_store_reader_writer() { + let tmp_dir = tempdir().unwrap(); + let db = LibraDB::new(&tmp_dir); + let store = &db.state_store; + let address1 = AccountAddress::new([1u8; ADDRESS_LENGTH]); + let address2 = AccountAddress::new([2u8; ADDRESS_LENGTH]); + let address3 = AccountAddress::new([3u8; ADDRESS_LENGTH]); + let value1 = AccountStateBlob::from(vec![0x01]); + let value1_update = AccountStateBlob::from(vec![0x00]); + let value2 = AccountStateBlob::from(vec![0x02]); + let value3 = AccountStateBlob::from(vec![0x03]); + let mut root = *SPARSE_MERKLE_PLACEHOLDER_HASH; + + // Verify initial states. + { + let (value, proof) = store + .get_account_state_with_proof_by_state_root(address1, root) + .unwrap(); + assert!(value.is_none()); + assert!(verify_sparse_merkle_element(root, address1.hash(), &value, &proof).is_ok()); + } + { + let (value, proof) = store + .get_account_state_with_proof_by_state_root(address2, root) + .unwrap(); + assert!(value.is_none()); + assert!(verify_sparse_merkle_element(root, address2.hash(), &value, &proof).is_ok()); + } + { + let (value, proof) = store + .get_account_state_with_proof_by_state_root(address3, root) + .unwrap(); + assert!(value.is_none()); + assert!(verify_sparse_merkle_element(root, address3.hash(), &value, &proof).is_ok()); + } + + // Insert address1 with value 1 and verify new states. + let mut batch1 = SchemaBatch::new(); + root = put_account_state_set(&store, vec![(address1, value1.clone())], root, &mut batch1); + db.commit(batch1).unwrap(); + { + let (value, proof) = store + .get_account_state_with_proof_by_state_root(address1, root) + .unwrap(); + assert_eq!(value, Some(value1)); + assert!(verify_sparse_merkle_element(root, address1.hash(), &value, &proof).is_ok()); + } + { + let (value, proof) = store + .get_account_state_with_proof_by_state_root(address2, root) + .unwrap(); + assert!(value.is_none()); + assert!(verify_sparse_merkle_element(root, address2.hash(), &value, &proof).is_ok()); + } + { + let (value, proof) = store + .get_account_state_with_proof_by_state_root(address3, root) + .unwrap(); + assert!(value.is_none()); + assert!(verify_sparse_merkle_element(root, address3.hash(), &value, &proof).is_ok()); + } + + // Insert address 1 with updated value1, address2 with value 2 and address3 with value3 and + // verify new states. + let mut batch2 = SchemaBatch::new(); + root = put_account_state_set( + &store, + vec![ + (address1, value1_update.clone()), + (address2, value2.clone()), + (address3, value3.clone()), + ], + root, + &mut batch2, + ); + db.commit(batch2).unwrap(); + { + let (value, proof) = store + .get_account_state_with_proof_by_state_root(address1, root) + .unwrap(); + assert_eq!(value, Some(value1_update)); + assert!(verify_sparse_merkle_element(root, address1.hash(), &value, &proof).is_ok()); + } + { + let (value, proof) = store + .get_account_state_with_proof_by_state_root(address2, root) + .unwrap(); + assert_eq!(value, Some(value2)); + assert!(verify_sparse_merkle_element(root, address2.hash(), &value, &proof).is_ok()); + } + { + let (value, proof) = store + .get_account_state_with_proof_by_state_root(address3, root) + .unwrap(); + assert_eq!(value, Some(value3)); + assert!(verify_sparse_merkle_element(root, address3.hash(), &value, &proof).is_ok()); + } +} diff --git a/storage/libradb/src/test_helper.rs b/storage/libradb/src/test_helper.rs new file mode 100644 index 0000000000000..025142ed16bdf --- /dev/null +++ b/storage/libradb/src/test_helper.rs @@ -0,0 +1,119 @@ +// Copyright (c) The Libra Core Contributors +// SPDX-License-Identifier: Apache-2.0 + +//! This module provides reusable helpers in tests. + +use super::*; +use crate::mock_genesis::{db_with_mock_genesis, GENESIS_INFO}; +use crypto::hash::CryptoHash; +use itertools::zip_eq; +use proptest::{collection::vec, prelude::*}; +use types::{ledger_info::LedgerInfo, proptest_types::arb_txn_to_commit_batch}; + +fn to_blocks_to_commit( + txns_to_commit_vec: Vec>, + partial_ledger_info_with_sigs_vec: Vec, +) -> Result, LedgerInfoWithSignatures)>> { + // Use temporary LibraDB and STORE LEVEL APIs to calculate hashes on a per transaction basis. + // Result is used to test the batch PUBLIC API for saving everything, i.e. `save_transactions()` + let tmp_dir = tempfile::tempdir()?; + let db = db_with_mock_genesis(&tmp_dir)?; + + let genesis_txn_info = GENESIS_INFO.0.clone(); + let genesis_ledger_info_with_sigs = GENESIS_INFO.1.clone(); + let genesis_ledger_info = genesis_ledger_info_with_sigs.ledger_info(); + let mut cur_state_root_hash = genesis_txn_info.state_root_hash(); + let mut cur_ver = 0; + let mut cur_txn_accu_hash = genesis_ledger_info.transaction_accumulator_hash(); + let blocks_to_commit = zip_eq(txns_to_commit_vec, partial_ledger_info_with_sigs_vec) + .map(|(txns_to_commit, partial_ledger_info_with_sigs)| { + for txn_to_commit in txns_to_commit.iter() { + cur_ver += 1; + let mut batch = SchemaBatch::new(); + + let txn_hash = txn_to_commit.signed_txn().hash(); + let state_root_hash = db.state_store.put_account_state_sets( + vec![txn_to_commit.account_states().clone()], + cur_state_root_hash, + &mut batch, + )?[0]; + let event_root_hash = + db.event_store + .put_events(cur_ver, txn_to_commit.events(), &mut batch)?; + + let txn_info = TransactionInfo::new( + txn_hash, + state_root_hash, + event_root_hash, + txn_to_commit.gas_used(), + ); + let txn_accu_hash = + db.ledger_store + .put_transaction_infos(cur_ver, &[txn_info], &mut batch)?; + db.commit(batch)?; + + cur_state_root_hash = state_root_hash; + cur_txn_accu_hash = txn_accu_hash; + } + + let ledger_info = LedgerInfo::new( + cur_ver, + cur_txn_accu_hash, + partial_ledger_info_with_sigs + .ledger_info() + .consensus_data_hash(), + partial_ledger_info_with_sigs + .ledger_info() + .consensus_block_id(), + partial_ledger_info_with_sigs.ledger_info().epoch_num(), + partial_ledger_info_with_sigs + .ledger_info() + .timestamp_usecs(), + ); + let ledger_info_with_sigs = LedgerInfoWithSignatures::new( + ledger_info, + partial_ledger_info_with_sigs.signatures().clone(), + ); + Ok((txns_to_commit, ledger_info_with_sigs)) + }) + .collect::>>()?; + + Ok(blocks_to_commit) +} + +/// This returns a [`proptest`](https://altsysrq.github.io/proptest-book/intro.html) +/// [`Strategy`](https://docs.rs/proptest/0/proptest/strategy/trait.Strategy.html) that yields an +/// arbitrary number of arbitrary batches of transactions to commit. +/// +/// It is used in tests for both transaction block committing during normal running and +/// transaction syncing during start up. +pub fn arb_blocks_to_commit( +) -> impl Strategy, LedgerInfoWithSignatures)>> { + vec(0..3usize, 1..10usize) + .prop_flat_map(|batch_sizes| { + let total_txns = batch_sizes.iter().sum(); + let total_batches = batch_sizes.len(); + ( + Just(batch_sizes), + arb_txn_to_commit_batch(3, 3, total_txns), + vec( + any_with::((1..3).into()), + total_batches, + ), + ) + }) + .prop_map( + |(batch_sizes, all_txns_to_commit, partial_ledger_info_with_sigs_vec)| { + // split txns_to_commit to batches + let txns_to_commit_batches = batch_sizes + .iter() + .scan(0, |end, batch_size| { + *end += batch_size; + Some(all_txns_to_commit[*end - batch_size..*end].to_vec()) + }) + .collect::>(); + to_blocks_to_commit(txns_to_commit_batches, partial_ledger_info_with_sigs_vec) + .unwrap() + }, + ) +} diff --git a/storage/libradb/src/transaction_store/mod.rs b/storage/libradb/src/transaction_store/mod.rs new file mode 100644 index 0000000000000..fd820ad6b0751 --- /dev/null +++ b/storage/libradb/src/transaction_store/mod.rs @@ -0,0 +1,41 @@ +// Copyright (c) The Libra Core Contributors +// SPDX-License-Identifier: Apache-2.0 + +//! This file defines transaction store APIs that are related to committed signed transactions. + +use super::schema::signed_transaction::*; +use crate::errors::LibraDbError; +use failure::prelude::*; +use schemadb::{SchemaBatch, DB}; +use std::sync::Arc; +use types::transaction::{SignedTransaction, Version}; + +pub(crate) struct TransactionStore { + db: Arc, +} + +impl TransactionStore { + pub fn new(db: Arc) -> Self { + Self { db } + } + + /// Get signed transaction given `version` + pub fn get_transaction(&self, version: Version) -> Result { + self.db + .get::(&version)? + .ok_or_else(|| LibraDbError::NotFound(format!("Txn {}", version)).into()) + } + + /// Save signed transaction at `version` + pub fn put_transaction( + &self, + version: Version, + signed_transaction: &SignedTransaction, + batch: &mut SchemaBatch, + ) -> Result<()> { + batch.put::(&version, signed_transaction) + } +} + +#[cfg(test)] +mod test; diff --git a/storage/libradb/src/transaction_store/test.rs b/storage/libradb/src/transaction_store/test.rs new file mode 100644 index 0000000000000..e47dbb677d214 --- /dev/null +++ b/storage/libradb/src/transaction_store/test.rs @@ -0,0 +1,32 @@ +// Copyright (c) The Libra Core Contributors +// SPDX-License-Identifier: Apache-2.0 + +use super::*; +use crate::LibraDB; +use proptest::{collection::vec, prelude::*}; +use tempfile::tempdir; + +proptest! { + #![proptest_config(ProptestConfig::with_cases(10))] + + #[test] + fn test_put_get(txns in vec(any::(), 1..10)) { + let tmp_dir = tempdir().unwrap(); + let db = LibraDB::new(&tmp_dir); + let store = &db.transaction_store; + + prop_assert!(store.get_transaction(0).is_err()); + + let mut batch = SchemaBatch::new(); + for (i, txn) in txns.iter().enumerate() { + store.put_transaction(i as u64, &txn, &mut batch).unwrap(); + } + db.commit(batch).unwrap(); + + for (i, txn) in txns.iter().enumerate() { + prop_assert_eq!(store.get_transaction(i as u64).unwrap(), txn.clone()); + } + + prop_assert!(store.get_transaction(txns.len() as u64).is_err()); + } +} diff --git a/storage/schemadb/Cargo.toml b/storage/schemadb/Cargo.toml new file mode 100644 index 0000000000000..5896924fe2911 --- /dev/null +++ b/storage/schemadb/Cargo.toml @@ -0,0 +1,18 @@ +[package] +name = "schemadb" +version = "0.1.0" +authors = ["Libra Association "] +license = "Apache-2.0" +publish = false +edition = "2018" + +[dependencies] +failure = { path = "../../common/failure_ext", package = "failure_ext" } + +[dependencies.rocksdb] +git = "https://github.com/pingcap/rust-rocksdb.git" +rev = "4f8c6b3a48c7acb0c81f9bb91e7d8493b1c5a73e" + +[dev-dependencies] +byteorder = "1.3.1" +tempfile = "3.0.6" diff --git a/storage/schemadb/src/lib.rs b/storage/schemadb/src/lib.rs new file mode 100644 index 0000000000000..e2084d55dd68f --- /dev/null +++ b/storage/schemadb/src/lib.rs @@ -0,0 +1,283 @@ +// Copyright (c) The Libra Core Contributors +// SPDX-License-Identifier: Apache-2.0 + +//! This library implements a schematized DB on top of [RocksDB](https://rocksdb.org/). It makes +//! sure all data passed in and out are structured according to predefined schemas and prevents +//! access to raw keys and values. This library also enforces a set of Libra specific DB options, +//! like custom comparators and schema-to-column-family mapping. +//! +//! It requires that different kinds of key-value pairs be stored in separate column +//! families. To use this library to store a kind of key-value pairs, the user needs to use the +//! [`define_schema!`] macro to define the schema name, the types of key and value, and name of the +//! column family. + +#[macro_use] +pub mod schema; + +use crate::schema::{KeyCodec, Schema, SeekKeyCodec, ValueCodec}; +use failure::prelude::*; +use rocksdb::{ + rocksdb_options::ColumnFamilyDescriptor, CFHandle, DBOptions, Writable, WriteOptions, +}; +use std::{collections::HashMap, iter::Iterator, marker::PhantomData, path::Path}; + +/// Type alias to `rocksdb::ColumnFamilyOptions`. See [`rocksdb doc`](https://github.com/pingcap/rust-rocksdb/blob/master/src/rocksdb_options.rs) +pub type ColumnFamilyOptions = rocksdb::ColumnFamilyOptions; +/// Type alias to `rocksdb::ReadOptions`. See [`rocksdb doc`](https://github.com/pingcap/rust-rocksdb/blob/master/src/rocksdb_options.rs) +pub type ReadOptions = rocksdb::ReadOptions; + +/// Type alias to improve readability. +pub type ColumnFamilyName = &'static str; +/// Type alias to improve readability. +pub type ColumnFamilyOptionsMap = HashMap; + +/// Name for the `default` column family that's always open by RocksDB. We use it to store +/// [`LedgerInfo`](../types/ledger_info/struct.LedgerInfo.html). +pub const DEFAULT_CF_NAME: ColumnFamilyName = "default"; + +#[derive(Debug)] +enum WriteOp { + Value(Vec), + Deletion, +} + +/// `SchemaBatch` holds a collection of updates that can be applied to a DB atomically. The updates +/// will be applied in the order in which they are added to the `SchemaBatch`. +#[derive(Debug, Default)] +pub struct SchemaBatch { + rows: Vec<(ColumnFamilyName, Vec /* key */, WriteOp)>, +} + +impl SchemaBatch { + /// Creates an empty batch. + pub fn new() -> Self { + Self::default() + } + + /// Adds an insert/update operation to the batch. + pub fn put(&mut self, key: &S::Key, value: &S::Value) -> Result<()> { + let key = >::encode_key(key)?; + let value = >::encode_value(value)?; + self.rows + .push((S::COLUMN_FAMILY_NAME, key, WriteOp::Value(value))); + Ok(()) + } + + /// Adds a delete operation to the batch. + pub fn delete(&mut self, key: &S::Key) -> Result<()> { + let key = >::encode_key(key)?; + self.rows + .push((S::COLUMN_FAMILY_NAME, key, WriteOp::Deletion)); + Ok(()) + } +} + +/// DB Iterator parameterized on [`Schema`] that seeks with [`Schema::Key`] and yields +/// [`Schema::Key`] and [`Schema::Value`] +pub struct SchemaIterator<'a, S> { + db_iter: rocksdb::DBIterator<&'a rocksdb::DB>, + phantom: PhantomData, +} + +impl<'a, S> SchemaIterator<'a, S> +where + S: Schema, +{ + fn new(db_iter: rocksdb::DBIterator<&'a rocksdb::DB>) -> Self { + SchemaIterator { + db_iter, + phantom: PhantomData, + } + } + + /// Seeks to the first key. + pub fn seek_to_first(&mut self) -> bool { + self.db_iter.seek(rocksdb::SeekKey::Start) + } + + /// Seeks to the last key. + pub fn seek_to_last(&mut self) -> bool { + self.db_iter.seek(rocksdb::SeekKey::End) + } + + /// Seeks to the first key whose binary representation is equal to or greater than that of the + /// `seek_key`. + pub fn seek(&mut self, seek_key: &SK) -> Result + where + SK: SeekKeyCodec, + { + let key = >::encode_seek_key(seek_key)?; + Ok(self.db_iter.seek(rocksdb::SeekKey::Key(&key))) + } + + /// Seeks to the last key whose binary representation is less than or equal to that of the + /// `seek_key`. + /// + /// See example in [`RocksDB doc`](https://github.com/facebook/rocksdb/wiki/SeekForPrev). + pub fn seek_for_prev(&mut self, seek_key: &SK) -> Result + where + SK: SeekKeyCodec, + { + let key = >::encode_seek_key(seek_key)?; + Ok(self.db_iter.seek_for_prev(rocksdb::SeekKey::Key(&key))) + } +} + +impl<'a, S> Iterator for SchemaIterator<'a, S> +where + S: Schema, +{ + type Item = Result<(S::Key, S::Value)>; + + fn next(&mut self) -> Option { + self.db_iter.kv().map(|(raw_key, raw_value)| { + self.db_iter.next(); + Ok(( + >::decode_key(&raw_key)?, + >::decode_value(&raw_value)?, + )) + }) + } +} + +/// Checks underlying Rocksdb instance existence by checking `CURRENT` file existence, the same way +/// Rocksdb adopts to detect db existence. +fn db_exists(path: &Path) -> bool { + let rocksdb_current_file = path.join("CURRENT"); + rocksdb_current_file.is_file() +} + +/// All the RocksDB methods return `std::result::Result`. Since our methods return +/// `failure::Result`, manual conversion is needed. +fn convert_rocksdb_err(msg: String) -> failure::Error { + format_err!("RocksDB internal error: {}.", msg) +} + +/// This DB is a schematized RocksDB wrapper where all data passed in and out are typed according to +/// [`Schema`]s. +#[derive(Debug)] +pub struct DB { + inner: rocksdb::DB, +} + +impl DB { + /// Create db with all the column families provided if it doesn't exist at `path`; Otherwise, + /// try to open it with all the column families. + pub fn open>(path: P, mut cf_opts_map: ColumnFamilyOptionsMap) -> Result { + let mut db_opts = DBOptions::new(); + + // If db exists, just open it with all cfs. + if db_exists(path.as_ref()) { + return DB::open_cf(db_opts, &path, cf_opts_map.into_iter().collect()); + } + + // If db doesn't exist, create a db first with all column families. + db_opts.create_if_missing(true); + + // For now we set the max total WAL size to be 1G. This config can be useful when column + // families are updated at non-uniform frequencies. + db_opts.set_max_total_wal_size(1 << 30); + + let mut db = DB::open_cf( + db_opts, + path, + vec![cf_opts_map + .remove_entry(&DEFAULT_CF_NAME) + .ok_or_else(|| format_err!("No \"default\" column family name found"))?], + )?; + cf_opts_map + .into_iter() + .map(|(cf_name, cf_opts)| db.create_cf((cf_name, cf_opts))) + .collect::>>()?; + Ok(db) + } + + fn open_cf<'a, P, T>(opts: DBOptions, path: P, cfds: Vec) -> Result + where + P: AsRef, + T: Into>, + { + let inner = rocksdb::DB::open_cf( + opts, + path.as_ref().to_str().ok_or_else(|| { + format_err!("Path {:?} can not be converted to string.", path.as_ref()) + })?, + cfds, + ) + .map_err(convert_rocksdb_err)?; + + Ok(DB { inner }) + } + + fn create_cf<'a, T>(&mut self, cfd: T) -> Result<()> + where + T: Into>, + { + let _cf_handle = self.inner.create_cf(cfd).map_err(convert_rocksdb_err)?; + Ok(()) + } + + /// Reads single record by key. + pub fn get(&self, schema_key: &S::Key) -> Result> { + let k = >::encode_key(&schema_key)?; + let cf_handle = self.get_cf_handle(S::COLUMN_FAMILY_NAME)?; + + self.inner + .get_cf(cf_handle, &k) + .map_err(convert_rocksdb_err)? + .map(|raw_value| >::decode_value(&raw_value)) + .transpose() + } + + /// Writes single record. + pub fn put(&self, key: &S::Key, value: &S::Value) -> Result<()> { + let k = >::encode_key(&key)?; + let v = >::encode_value(&value)?; + let cf_handle = self.get_cf_handle(S::COLUMN_FAMILY_NAME)?; + + self.inner + .put_cf_opt(cf_handle, &k, &v, &default_write_options()) + .map_err(convert_rocksdb_err) + } + + /// Returns a [`SchemaIterator`] on a certain schema. + pub fn iter(&self, opts: ReadOptions) -> Result> { + let cf_handle = self.get_cf_handle(S::COLUMN_FAMILY_NAME)?; + Ok(SchemaIterator::new(self.inner.iter_cf_opt(cf_handle, opts))) + } + + /// Writes a group of records wrapped in a [`SchemaBatch`]. + pub fn write_schemas(&self, batch: SchemaBatch) -> Result<()> { + let db_batch = rocksdb::WriteBatch::new(); + for (cf_name, key, write_op) in &batch.rows { + let cf_handle = self.get_cf_handle(cf_name)?; + match write_op { + WriteOp::Value(value) => db_batch.put_cf(cf_handle, &key, &value), + WriteOp::Deletion => db_batch.delete_cf(cf_handle, &key), + } + .map_err(convert_rocksdb_err)?; + } + + self.inner + .write_opt(&db_batch, &default_write_options()) + .map_err(convert_rocksdb_err) + } + + fn get_cf_handle(&self, cf_name: ColumnFamilyName) -> Result<&CFHandle> { + self.inner.cf_handle(cf_name).ok_or_else(|| { + format_err!( + "DB::cf_handle not found for column family name: {}", + cf_name + ) + }) + } +} + +/// For now we always use synchronous writes. This makes sure that once the operation returns +/// `Ok(())` the data is persisted even if the machine crashes. In the future we might consider +/// selectively turning this off for some non-critical writes to improve performance. +fn default_write_options() -> WriteOptions { + let mut opts = WriteOptions::new(); + opts.set_sync(true); + opts +} diff --git a/storage/schemadb/src/schema.rs b/storage/schemadb/src/schema.rs new file mode 100644 index 0000000000000..bac720bfccaef --- /dev/null +++ b/storage/schemadb/src/schema.rs @@ -0,0 +1,141 @@ +// Copyright (c) The Libra Core Contributors +// SPDX-License-Identifier: Apache-2.0 + +//! This module provides traits that define the behavior of a schema and its associated key and +//! value types, along with helpers to define a new schema with ease. +use crate::ColumnFamilyName; +use failure::Result; +use std::fmt::Debug; + +/// Macro for defining a SchemaDB schema. +/// +/// `define_schema!` allows a schema to be defined in the following syntax: +/// ``` +/// use failure::Result; +/// use schemadb::{ +/// define_schema, +/// schema::{KeyCodec, SeekKeyCodec, ValueCodec}, +/// }; +/// +/// // Define key type and value type for a schema with derived traits (Clone, Debug, Eq, PartialEq) +/// #[derive(Clone, Debug, Eq, PartialEq)] +/// pub struct Key; +/// #[derive(Clone, Debug, Eq, PartialEq)] +/// pub struct Value; +/// +/// // Implement KeyCodec/ValueCodec traits for key and value types +/// impl KeyCodec for Key { +/// fn encode_key(&self) -> Result> { +/// Ok(vec![]) +/// } +/// +/// fn decode_key(data: &[u8]) -> Result { +/// Ok(Key) +/// } +/// } +/// +/// impl ValueCodec for Value { +/// fn encode_value(&self) -> Result> { +/// Ok(vec![]) +/// } +/// +/// fn decode_value(data: &[u8]) -> Result { +/// Ok(Value) +/// } +/// } +/// +/// // And finally define a schema type and associate it with key and value types, as well as the +/// // column family name, by generating code that implements the `Schema` trait for the type. +/// define_schema!(ExampleSchema, Key, Value, "exmaple_cf_name"); +/// +/// // SeekKeyCodec is automatically implemented for KeyCodec, +/// // so you can seek an iterator with the Key type: +/// // iter.seek(&Key); +/// +/// // Or if seek-by-prefix is desired, you can implement your own SeekKey +/// #[derive(Clone, Eq, PartialEq, Debug)] +/// pub struct PrefixSeekKey; +/// +/// impl SeekKeyCodec for PrefixSeekKey { +/// fn encode_seek_key(&self) -> Result> { +/// Ok(vec![]) +/// } +/// } +/// // and seek like this: +/// // iter.seek(&PrefixSeekKey); +/// ``` +#[macro_export] +macro_rules! define_schema { + ($schema_type: ident, $key_type: ty, $value_type: ty, $cf_name: expr) => { + pub(crate) struct $schema_type; + + impl $crate::schema::Schema for $schema_type { + const COLUMN_FAMILY_NAME: $crate::ColumnFamilyName = $cf_name; + type Key = $key_type; + type Value = $value_type; + } + }; +} + +/// This trait defines a type that can serve as a [`Schema::Key`]. +pub trait KeyCodec: Sized + PartialEq + Debug { + /// Converts `self` to bytes to be stored in DB. + fn encode_key(&self) -> Result>; + /// Converts bytes fetched from DB to `Self`. + fn decode_key(data: &[u8]) -> Result; +} + +/// This trait defines a type that can serve as a [`Schema::Value`]. +pub trait ValueCodec: Sized + PartialEq + Debug { + /// Converts `self` to bytes to be stored in DB. + fn encode_value(&self) -> Result>; + /// Converts bytes fetched from DB to `Self`. + fn decode_value(data: &[u8]) -> Result; +} + +/// This defines a type that can be used to seek a [`SchemaIterator`](crate::SchemaIterator), via +/// interfaces like [`seek`](crate::SchemaIterator::seek). +pub trait SeekKeyCodec: Sized { + /// Converts `self` to bytes which is used to seek the underlying raw iterator. + fn encode_seek_key(&self) -> Result>; +} + +/// All keys can automatically be used as seek keys. +impl SeekKeyCodec for K +where + S: Schema, + K: KeyCodec, +{ + /// Delegates to [`KeyCodec::encode_key`]. + fn encode_seek_key(&self) -> Result> { + >::encode_key(&self) + } +} + +/// This trait defines a schema: an association of a column family name, the key type and the value +/// type. +pub trait Schema { + /// The column family name associated with this struct. + /// Note: all schemas within the same SchemaDB must have distinct column family names. + const COLUMN_FAMILY_NAME: ColumnFamilyName; + + /// Type of the key. + type Key: KeyCodec; + /// Type of the value. + type Value: ValueCodec; +} + +/// Helper used in tests to assert a (key, value) pair for a certain [`Schema`] is able to convert +/// to bytes and convert back. +pub fn assert_encode_decode(key: &S::Key, value: &S::Value) { + { + let encoded = key.encode_key().expect("Encoding key should work."); + let decoded = S::Key::decode_key(&encoded).expect("Decoding key should work."); + assert_eq!(*key, decoded); + } + { + let encoded = value.encode_value().expect("Encoding value should work."); + let decoded = S::Value::decode_value(&encoded).expect("Decoding value should work."); + assert_eq!(*value, decoded); + } +} diff --git a/storage/schemadb/tests/db.rs b/storage/schemadb/tests/db.rs new file mode 100644 index 0000000000000..463168933d41e --- /dev/null +++ b/storage/schemadb/tests/db.rs @@ -0,0 +1,272 @@ +// Copyright (c) The Libra Core Contributors +// SPDX-License-Identifier: Apache-2.0 + +use byteorder::{LittleEndian, ReadBytesExt}; +use failure::Result; +use schemadb::{ + define_schema, + schema::{KeyCodec, Schema, ValueCodec}, + ColumnFamilyOptions, ColumnFamilyOptionsMap, SchemaBatch, DB, DEFAULT_CF_NAME, +}; + +// Creating two schemas that share exactly the same structure but are stored in different column +// families. Also note that the key and value are of the same type `TestField`. By implementing +// both the `KeyCodec<>` and `ValueCodec<>` traits for both schemas, we are able to use it +// everywhere. +define_schema!(TestSchema1, TestField, TestField, "TestCF1"); +define_schema!(TestSchema2, TestField, TestField, "TestCF2"); + +#[derive(Debug, Eq, PartialEq)] +struct TestField(u32); + +impl TestField { + fn to_bytes(&self) -> Result> { + Ok(self.0.to_le_bytes().to_vec()) + } + + fn from_bytes(data: &[u8]) -> Result { + let mut reader = std::io::Cursor::new(data); + Ok(TestField(reader.read_u32::()?)) + } +} + +impl KeyCodec for TestField { + fn encode_key(&self) -> Result> { + self.to_bytes() + } + + fn decode_key(data: &[u8]) -> Result { + Self::from_bytes(data) + } +} + +impl ValueCodec for TestField { + fn encode_value(&self) -> Result> { + self.to_bytes() + } + + fn decode_value(data: &[u8]) -> Result { + Self::from_bytes(data) + } +} + +impl KeyCodec for TestField { + fn encode_key(&self) -> Result> { + self.to_bytes() + } + + fn decode_key(data: &[u8]) -> Result { + Self::from_bytes(data) + } +} + +impl ValueCodec for TestField { + fn encode_value(&self) -> Result> { + self.to_bytes() + } + + fn decode_value(data: &[u8]) -> Result { + Self::from_bytes(data) + } +} + +fn open_db(dir: &tempfile::TempDir) -> DB { + let cf_opts_map: ColumnFamilyOptionsMap = [ + (DEFAULT_CF_NAME, ColumnFamilyOptions::default()), + ( + TestSchema1::COLUMN_FAMILY_NAME, + ColumnFamilyOptions::default(), + ), + ( + TestSchema2::COLUMN_FAMILY_NAME, + ColumnFamilyOptions::default(), + ), + ] + .iter() + .cloned() + .collect(); + DB::open(&dir, cf_opts_map).expect("Failed to open DB.") +} + +struct TestDB { + _tmpdir: tempfile::TempDir, + db: DB, +} + +impl TestDB { + fn new() -> Self { + let tmpdir = tempfile::tempdir().expect("Failed to create temporary directory."); + let db = open_db(&tmpdir); + + TestDB { + _tmpdir: tmpdir, + db, + } + } +} + +impl std::ops::Deref for TestDB { + type Target = DB; + + fn deref(&self) -> &Self::Target { + &self.db + } +} + +#[test] +fn test_schema_put_get() { + let db = TestDB::new(); + + db.put::(&TestField(0), &TestField(0)).unwrap(); + db.put::(&TestField(1), &TestField(1)).unwrap(); + db.put::(&TestField(2), &TestField(2)).unwrap(); + db.put::(&TestField(2), &TestField(3)).unwrap(); + db.put::(&TestField(3), &TestField(4)).unwrap(); + db.put::(&TestField(4), &TestField(5)).unwrap(); + + assert_eq!( + db.get::(&TestField(0)).unwrap(), + Some(TestField(0)), + ); + assert_eq!( + db.get::(&TestField(1)).unwrap(), + Some(TestField(1)), + ); + assert_eq!( + db.get::(&TestField(2)).unwrap(), + Some(TestField(2)), + ); + assert_eq!(db.get::(&TestField(3)).unwrap(), None); + + assert_eq!(db.get::(&TestField(1)).unwrap(), None); + assert_eq!( + db.get::(&TestField(2)).unwrap(), + Some(TestField(3)), + ); + assert_eq!( + db.get::(&TestField(3)).unwrap(), + Some(TestField(4)), + ); + assert_eq!( + db.get::(&TestField(4)).unwrap(), + Some(TestField(5)), + ); +} + +fn collect_values(db: &TestDB) -> Vec<(S::Key, S::Value)> { + let mut iter = db + .iter::(Default::default()) + .expect("Failed to create iterator."); + iter.seek_to_first(); + iter.collect::>>().unwrap() +} + +fn gen_expected_values(values: &[(u32, u32)]) -> Vec<(TestField, TestField)> { + values + .iter() + .cloned() + .map(|(x, y)| (TestField(x), TestField(y))) + .collect() +} + +#[test] +fn test_single_schema_batch() { + let db = TestDB::new(); + + let mut db_batch = SchemaBatch::new(); + db_batch + .put::(&TestField(0), &TestField(0)) + .unwrap(); + db_batch + .put::(&TestField(1), &TestField(1)) + .unwrap(); + db_batch + .put::(&TestField(2), &TestField(2)) + .unwrap(); + db_batch + .put::(&TestField(3), &TestField(3)) + .unwrap(); + db_batch.delete::(&TestField(4)).unwrap(); + db_batch.delete::(&TestField(3)).unwrap(); + db_batch + .put::(&TestField(4), &TestField(4)) + .unwrap(); + db_batch + .put::(&TestField(5), &TestField(5)) + .unwrap(); + db.write_schemas(db_batch).unwrap(); + + assert_eq!( + collect_values::(&db), + gen_expected_values(&[(0, 0), (1, 1), (2, 2)]), + ); + assert_eq!( + collect_values::(&db), + gen_expected_values(&[(4, 4), (5, 5)]), + ); +} + +#[test] +fn test_two_schema_batches() { + let db = TestDB::new(); + + let mut db_batch1 = SchemaBatch::new(); + db_batch1 + .put::(&TestField(0), &TestField(0)) + .unwrap(); + db_batch1 + .put::(&TestField(1), &TestField(1)) + .unwrap(); + db_batch1 + .put::(&TestField(2), &TestField(2)) + .unwrap(); + db_batch1.delete::(&TestField(2)).unwrap(); + db.write_schemas(db_batch1).unwrap(); + + assert_eq!( + collect_values::(&db), + gen_expected_values(&[(0, 0), (1, 1)]), + ); + + let mut db_batch2 = SchemaBatch::new(); + db_batch2.delete::(&TestField(3)).unwrap(); + db_batch2 + .put::(&TestField(3), &TestField(3)) + .unwrap(); + db_batch2 + .put::(&TestField(4), &TestField(4)) + .unwrap(); + db_batch2 + .put::(&TestField(5), &TestField(5)) + .unwrap(); + db.write_schemas(db_batch2).unwrap(); + + assert_eq!( + collect_values::(&db), + gen_expected_values(&[(0, 0), (1, 1)]), + ); + assert_eq!( + collect_values::(&db), + gen_expected_values(&[(3, 3), (4, 4), (5, 5)]), + ); +} + +#[test] +fn test_reopen() { + let tmpdir = tempfile::tempdir().expect("Failed to create temporary directory."); + { + let db = open_db(&tmpdir); + db.put::(&TestField(0), &TestField(0)).unwrap(); + assert_eq!( + db.get::(&TestField(0)).unwrap(), + Some(TestField(0)), + ); + } + { + let db = open_db(&tmpdir); + assert_eq!( + db.get::(&TestField(0)).unwrap(), + Some(TestField(0)), + ); + } +} diff --git a/storage/schemadb/tests/iterator.rs b/storage/schemadb/tests/iterator.rs new file mode 100644 index 0000000000000..6abb2a3cf2f03 --- /dev/null +++ b/storage/schemadb/tests/iterator.rs @@ -0,0 +1,214 @@ +// Copyright (c) The Libra Core Contributors +// SPDX-License-Identifier: Apache-2.0 + +use byteorder::{BigEndian, ReadBytesExt, WriteBytesExt}; +use failure::Result; +use schemadb::{ + define_schema, + schema::{KeyCodec, Schema, SeekKeyCodec, ValueCodec}, + ColumnFamilyOptions, ColumnFamilyOptionsMap, SchemaIterator, DB, DEFAULT_CF_NAME, +}; + +define_schema!(TestSchema, TestKey, TestValue, "TestCF"); + +#[derive(Debug, Eq, PartialEq)] +struct TestKey(u32, u32, u32); + +#[derive(Debug, Eq, PartialEq)] +struct TestValue(u32); + +impl KeyCodec for TestKey { + fn encode_key(&self) -> Result> { + let mut bytes = vec![]; + bytes.write_u32::(self.0)?; + bytes.write_u32::(self.1)?; + bytes.write_u32::(self.2)?; + Ok(bytes) + } + + fn decode_key(data: &[u8]) -> Result { + let mut reader = std::io::Cursor::new(data); + Ok(TestKey( + reader.read_u32::()?, + reader.read_u32::()?, + reader.read_u32::()?, + )) + } +} + +impl ValueCodec for TestValue { + fn encode_value(&self) -> Result> { + Ok(self.0.to_be_bytes().to_vec()) + } + + fn decode_value(data: &[u8]) -> Result { + let mut reader = std::io::Cursor::new(data); + Ok(TestValue(reader.read_u32::()?)) + } +} + +pub struct KeyPrefix1(u32); + +impl SeekKeyCodec for KeyPrefix1 { + fn encode_seek_key(&self) -> Result> { + Ok(self.0.to_be_bytes().to_vec()) + } +} + +pub struct KeyPrefix2(u32, u32); + +impl SeekKeyCodec for KeyPrefix2 { + fn encode_seek_key(&self) -> Result> { + let mut bytes = vec![]; + bytes.write_u32::(self.0)?; + bytes.write_u32::(self.1)?; + Ok(bytes) + } +} + +fn collect_values(iter: SchemaIterator) -> Vec { + iter.map(|row| (row.unwrap().1).0).collect() +} + +struct TestDB { + _tmpdir: tempfile::TempDir, + db: DB, +} + +impl TestDB { + fn new() -> Self { + let tmpdir = tempfile::tempdir().expect("Failed to create temporary directory."); + let cf_opts_map: ColumnFamilyOptionsMap = [ + (DEFAULT_CF_NAME, ColumnFamilyOptions::default()), + ( + TestSchema::COLUMN_FAMILY_NAME, + ColumnFamilyOptions::default(), + ), + ] + .iter() + .cloned() + .collect(); + let db = DB::open(&tmpdir, cf_opts_map).unwrap(); + + db.put::(&TestKey(1, 0, 0), &TestValue(100)) + .unwrap(); + db.put::(&TestKey(1, 0, 2), &TestValue(102)) + .unwrap(); + db.put::(&TestKey(1, 0, 4), &TestValue(104)) + .unwrap(); + db.put::(&TestKey(1, 1, 0), &TestValue(110)) + .unwrap(); + db.put::(&TestKey(1, 1, 2), &TestValue(112)) + .unwrap(); + db.put::(&TestKey(1, 1, 4), &TestValue(114)) + .unwrap(); + db.put::(&TestKey(2, 0, 0), &TestValue(200)) + .unwrap(); + db.put::(&TestKey(2, 0, 2), &TestValue(202)) + .unwrap(); + + TestDB { + _tmpdir: tmpdir, + db, + } + } +} + +impl TestDB { + fn iter(&self) -> SchemaIterator { + self.db + .iter(Default::default()) + .expect("Failed to create iterator.") + } +} + +impl std::ops::Deref for TestDB { + type Target = DB; + + fn deref(&self) -> &Self::Target { + &self.db + } +} + +#[test] +fn test_seek_to_first() { + let db = TestDB::new(); + let mut iter = db.iter(); + iter.seek_to_first(); + assert_eq!( + collect_values(iter), + [100, 102, 104, 110, 112, 114, 200, 202] + ); +} + +#[test] +fn test_seek_to_last() { + let db = TestDB::new(); + let mut iter = db.iter(); + iter.seek_to_last(); + assert_eq!(collect_values(iter), [202]); +} + +#[test] +fn test_seek_by_existing_key() { + let db = TestDB::new(); + let mut iter = db.iter(); + iter.seek(&TestKey(1, 1, 0)).unwrap(); + assert_eq!(collect_values(iter), [110, 112, 114, 200, 202]); +} + +#[test] +fn test_seek_by_nonexistent_key() { + let db = TestDB::new(); + let mut iter = db.iter(); + iter.seek(&TestKey(1, 1, 1)).unwrap(); + assert_eq!(collect_values(iter), [112, 114, 200, 202]); +} + +#[test] +fn test_seek_for_prev_by_existing_key() { + let db = TestDB::new(); + let mut iter = db.iter(); + iter.seek_for_prev(&TestKey(1, 1, 0)).unwrap(); + assert_eq!(collect_values(iter), [110, 112, 114, 200, 202]); +} + +#[test] +fn test_seek_for_prev_by_nonexistent_key() { + let db = TestDB::new(); + let mut iter = db.iter(); + iter.seek_for_prev(&TestKey(1, 1, 1)).unwrap(); + assert_eq!(collect_values(iter), [110, 112, 114, 200, 202]); +} + +#[test] +fn test_seek_by_1prefix() { + let db = TestDB::new(); + let mut iter = db.iter(); + iter.seek(&KeyPrefix1(2)).unwrap(); + assert_eq!(collect_values(iter), [200, 202]); +} + +#[test] +fn test_seek_for_prev_by_1prefix() { + let db = TestDB::new(); + let mut iter = db.iter(); + iter.seek_for_prev(&KeyPrefix1(2)).unwrap(); + assert_eq!(collect_values(iter), [114, 200, 202]); +} + +#[test] +fn test_seek_by_2prefix() { + let db = TestDB::new(); + let mut iter = db.iter(); + iter.seek(&KeyPrefix2(2, 0)).unwrap(); + assert_eq!(collect_values(iter), [200, 202]); +} + +#[test] +fn test_seek_for_prev_by_2prefix() { + let db = TestDB::new(); + let mut iter = db.iter(); + iter.seek_for_prev(&KeyPrefix2(2, 0)).unwrap(); + assert_eq!(collect_values(iter), [114, 200, 202]); +} diff --git a/storage/scratchpad/Cargo.toml b/storage/scratchpad/Cargo.toml new file mode 100644 index 0000000000000..34e849a068288 --- /dev/null +++ b/storage/scratchpad/Cargo.toml @@ -0,0 +1,16 @@ +[package] +name = "scratchpad" +version = "0.1.0" +authors = ["Libra Association "] +license = "Apache-2.0" +publish = false +edition = "2018" + +[dependencies] +itertools = "0.8.0" + +crypto = { path = "../../crypto/legacy_crypto" } +types = { path = "../../types" } + +[dev-dependencies] +failure = { path = "../../common/failure_ext", package = "failure_ext" } diff --git a/storage/scratchpad/src/accumulator/accumulator_test.rs b/storage/scratchpad/src/accumulator/accumulator_test.rs new file mode 100644 index 0000000000000..4a6f846130706 --- /dev/null +++ b/storage/scratchpad/src/accumulator/accumulator_test.rs @@ -0,0 +1,72 @@ +// Copyright (c) The Libra Core Contributors +// SPDX-License-Identifier: Apache-2.0 + +use super::Accumulator; +use crypto::{ + hash::{CryptoHash, TestOnlyHash, TestOnlyHasher, ACCUMULATOR_PLACEHOLDER_HASH}, + HashValue, +}; +use types::proof::TestAccumulatorInternalNode; + +// Computes the root hash of an accumulator with given elements. +fn compute_root_hash_naive(elements: &[HashValue]) -> HashValue { + if elements.is_empty() { + return *ACCUMULATOR_PLACEHOLDER_HASH; + } + + let mut current_level = elements.to_vec(); + current_level.resize( + elements.len().next_power_of_two(), + *ACCUMULATOR_PLACEHOLDER_HASH, + ); + + while current_level.len() > 1 { + assert!(current_level.len().is_power_of_two()); + + let mut parent_level = vec![]; + for (index, hash) in current_level.iter().enumerate().step_by(2) { + let left_hash = hash; + let right_hash = ¤t_level[index + 1]; + let parent_hash = if *left_hash == *ACCUMULATOR_PLACEHOLDER_HASH + && *right_hash == *ACCUMULATOR_PLACEHOLDER_HASH + { + *ACCUMULATOR_PLACEHOLDER_HASH + } else { + TestAccumulatorInternalNode::new(*left_hash, *right_hash).hash() + }; + parent_level.push(parent_hash); + } + + current_level = parent_level; + } + + assert_eq!(current_level.len(), 1); + current_level.remove(0) +} + +// Helper function to create a list of elements. +fn create_elements(nums: std::ops::Range) -> Vec { + nums.map(|x| x.to_be_bytes().test_only_hash()).collect() +} + +#[test] +fn test_accumulator_append() { + // expected_root_hashes[i] is the root hash of an accumulator that has the first i elements. + let expected_root_hashes: Vec = (0..100) + .map(|x| { + let elements = create_elements(0..x); + compute_root_hash_naive(&elements) + }) + .collect(); + + let elements = create_elements(0..100); + let mut accumulator = Accumulator::::default(); + // Append the elements one at a time and check the root hashes match. + for (i, (element, expected_root_hash)) in + itertools::zip_eq(elements.into_iter(), expected_root_hashes.into_iter()).enumerate() + { + assert_eq!(accumulator.root_hash(), expected_root_hash); + assert_eq!(accumulator.num_elements(), i as u64); + accumulator = accumulator.append(vec![element]); + } +} diff --git a/storage/scratchpad/src/accumulator/mod.rs b/storage/scratchpad/src/accumulator/mod.rs new file mode 100644 index 0000000000000..c1ebbdfd84262 --- /dev/null +++ b/storage/scratchpad/src/accumulator/mod.rs @@ -0,0 +1,180 @@ +// Copyright (c) The Libra Core Contributors +// SPDX-License-Identifier: Apache-2.0 + +//! This module implements an in-memory Merkle Accumulator that is similar to what we use in +//! storage. This accumulator will only store a small portion of the tree -- for any subtree that +//! is full, we store only the root. Also we only store the frozen nodes, therefore this structure +//! will always store up to `Log(n)` number of nodes, where `n` is the total number of elements in +//! the tree. +//! +//! This accumulator is immutable once constructed. If we append new elements to the tree we will +//! obtain a new accumulator instance and the old one remains unchanged. + +#[cfg(test)] +mod accumulator_test; + +use crypto::{ + hash::{CryptoHash, CryptoHasher, ACCUMULATOR_PLACEHOLDER_HASH}, + HashValue, +}; +use std::marker::PhantomData; +use types::proof::{position::Position, treebits::NodeDirection, MerkleTreeInternalNode}; + +/// The Accumulator implementation. +#[derive(Default)] +pub struct Accumulator { + /// Represents the roots of all the full subtrees from left to right in this accumulator. For + /// example, if we have the following accumulator, this vector will have two hashes that + /// correspond to `X` and `e`. + /// ```text + /// root + /// / \ + /// / \ + /// / \ + /// X o + /// / \ / \ + /// / \ / \ + /// o o o placeholder + /// / \ / \ / \ + /// a b c d e placeholder + /// ``` + frozen_subtree_roots: Vec, + + /// The total number of elements in this accumulator. + num_elements: u64, + + phantom: PhantomData, +} + +impl Accumulator +where + H: CryptoHasher, +{ + /// Constructs a new accumulator with roots of existing frozen subtrees. At the beginning this + /// will be an empty vector and `num_elements` will be zero. Later if we restart and the + /// storage have persisted some elements, we will load them from storage. + pub fn new(frozen_subtree_roots: Vec, num_elements: u64) -> Self { + assert_eq!( + frozen_subtree_roots.len(), + num_elements.count_ones() as usize, + "The number of frozen subtrees does not match the number of elements. \ + frozen_subtree_roots.len(): {}. num_elements: {}.", + frozen_subtree_roots.len(), + num_elements, + ); + + Accumulator { + frozen_subtree_roots, + num_elements, + phantom: PhantomData, + } + } + + /// Appends a list of new elements to an existing accumulator. Since the accumulator is + /// immutable, the existing one remains unchanged and a new one representing the result is + /// returned. + pub fn append(&self, elements: Vec) -> Self { + let mut frozen_subtree_roots = self.frozen_subtree_roots.clone(); + let mut num_elements = self.num_elements; + for element in elements { + Self::append_one(&mut frozen_subtree_roots, num_elements, element); + num_elements += 1; + } + + Self::new(frozen_subtree_roots, num_elements) + } + + /// Appends one element. This will update `frozen_subtree_roots` to store new frozen root nodes + /// and remove old nodes if they are now part of a larger frozen subtree. + fn append_one( + frozen_subtree_roots: &mut Vec, + num_existing_elements: u64, + element: HashValue, + ) { + // For example, this accumulator originally had N = 7 elements. Appending an element is + // like adding one to this number N: 0b0111 + 1 = 0b1000. Every time we carry a bit to the + // left we merge the rightmost two subtrees and compute their parent. + // ```text + // A + // / \ + // / \ + // o o B + // / \ / \ / \ + // o o o o o o o + // ``` + + // First just append the element. + frozen_subtree_roots.push(element); + + // Next, merge the last two subtrees into one. If `num_existing_elements` has N trailing + // ones, the carry will happen N times. + let num_trailing_ones = (!num_existing_elements).trailing_zeros(); + + for _i in 0..num_trailing_ones { + let right_hash = frozen_subtree_roots.pop().expect("Invalid accumulator."); + let left_hash = frozen_subtree_roots.pop().expect("Invalid accumulator."); + let parent_hash = MerkleTreeInternalNode::::new(left_hash, right_hash).hash(); + frozen_subtree_roots.push(parent_hash); + } + } + + /// Computes the root hash of an accumulator given the frozen subtree roots. + pub fn root_hash(&self) -> HashValue { + if self.frozen_subtree_roots.is_empty() { + return *ACCUMULATOR_PLACEHOLDER_HASH; + } + + // First, start from the rightmost leaf position and move it to the rightmost frozen root. + let max_leaf_index = self.num_elements - 1; + let mut current_position = Position::from_leaf_index(max_leaf_index); + // Move current position up until it reaches the corresponding frozen subtree root. + while current_position.get_parent().is_freezable(max_leaf_index) { + current_position = current_position.get_parent(); + } + + let mut roots = self.frozen_subtree_roots.iter().rev(); + let mut current_hash = *roots + .next() + .expect("We have checked frozen_subtree_roots is not empty."); + + // While current position is not root, find current sibling and compute parent hash. + let root_position = Position::get_root_position(max_leaf_index); + while current_position != root_position { + current_hash = match current_position.get_direction_for_self() { + NodeDirection::Left => { + // If a frozen node is the left child of its parent, its sibling must be + // a placeholder node. + MerkleTreeInternalNode::::new(current_hash, *ACCUMULATOR_PLACEHOLDER_HASH) + .hash() + } + NodeDirection::Right => { + // Otherwise the left sibling must have been frozen. + MerkleTreeInternalNode::::new( + *roots.next().expect("Ran out of subtree roots."), + current_hash, + ) + .hash() + } + }; + current_position = current_position.get_parent(); + } + + current_hash + } + + /// Returns the total number of elements in this accumulator. + pub fn num_elements(&self) -> u64 { + self.num_elements + } +} + +// We manually implement Debug because H (CryptoHasher) does not implement Debug. +impl std::fmt::Debug for Accumulator { + fn fmt(&self, f: &mut std::fmt::Formatter) -> std::fmt::Result { + write!( + f, + "Accumulator {{ frozen_subtree_roots: {:?}, num_elements: {:?} }}", + self.frozen_subtree_roots, self.num_elements + ) + } +} diff --git a/storage/scratchpad/src/lib.rs b/storage/scratchpad/src/lib.rs new file mode 100644 index 0000000000000..41de0fc696faf --- /dev/null +++ b/storage/scratchpad/src/lib.rs @@ -0,0 +1,12 @@ +// Copyright (c) The Libra Core Contributors +// SPDX-License-Identifier: Apache-2.0 + +//! This crate provides in-memory representation of Libra core data structures used by the executor. + +mod accumulator; +mod sparse_merkle; + +pub use crate::{ + accumulator::Accumulator, + sparse_merkle::{AccountState, ProofRead, SparseMerkleTree}, +}; diff --git a/storage/scratchpad/src/sparse_merkle/mod.rs b/storage/scratchpad/src/sparse_merkle/mod.rs new file mode 100644 index 0000000000000..c67de2d4f755a --- /dev/null +++ b/storage/scratchpad/src/sparse_merkle/mod.rs @@ -0,0 +1,459 @@ +// Copyright (c) The Libra Core Contributors +// SPDX-License-Identifier: Apache-2.0 + +//! This module implements an in-memory Sparse Merkle Tree that is similar to what we use in +//! storage to represent world state. This tree will store only a small portion of the state -- the +//! part of accounts that have been modified by uncommitted transactions. For example, if we +//! execute a transaction T_i on top of committed state and it modified account A, we will end up +//! having the following tree: +//! ```text +//! S_i +//! / \ +//! o y +//! / \ +//! x A +//! ``` +//! where A has the new state of the account, and y and x are the siblings on the path from root to +//! A in the tree. +//! +//! This Sparse Merkle Tree is immutable once constructed. If the next transaction T_{i+1} modified +//! another account B that lives in the subtree at y, a new tree will be constructed and the +//! structure will look like the following: +//! ```text +//! S_i S_{i+1} +//! / \ / \ +//! / y / \ +//! / _______/ \ +//! // \ +//! o y' +//! / \ / \ +//! x A z B +//! ``` +//! +//! Using this structure, we are able to query the global state, taking into account the output of +//! uncommitted transactions. For example, if we want to execute another transaction T_{i+1}', we +//! can use the tree S_i. If we look for account A, we can find its new value in the tree. +//! Otherwise we know the account does not exist in the tree and we can fall back to storage. As +//! another example, if we want to execute transaction T_{i+2}, we can use the tree S_{i+1} that +//! has updated values for both account A and B. +//! +//! When we commit a transaction, for example T_i, we will first send its write set to storage. +//! Once the writes to storage complete, any node reachable from S_i will be available in storage. +//! Therefore we start from S_i and recursively drop its descendant. For internal or leaf nodes +//! (for example node o in the above example), we do not know if there are other nodes (for example +//! S_{i+1} in the above example) pointing to it, so we replace the node with a subtree node with +//! the same hash. This allows us to clean up memory as transactions are committed. +//! +//! This Sparse Merkle Tree serves a dual purpose. First, to support a leader based consensus +//! algorithm, we need to build a tree of transactions like the following: +//! ```text +//! Committed -> T5 -> T6 -> T7 +//! β””---> T6' -> T7' +//! β””----> T7" +//! ``` +//! Once T5 is executed, we will have a tree that stores the modified portion of the state. Later +//! when we execute T6 on top of T5, the output of T5 can be visible to T6. +//! +//! Second, given this tree representation it is straightforward to compute the root hash of S_i +//! once T_i is executed. This allows us to verify the proofs we need when executing T_{i+1}. + +// See https://play.rust-lang.org/?version=stable&mode=debug&edition=2018&gist=e9c4c53eb80b30d09112fcfb07d481e7 +#![allow(clippy::let_and_return)] +// See https://play.rust-lang.org/?version=stable&mode=debug&edition=2018&gist=795cd4f459f1d4a0005a99650726834b +#![allow(clippy::while_let_loop)] + +mod node; + +#[cfg(test)] +mod sparse_merkle_test; + +use self::node::{LeafNode, LeafValue, Node, SparseMerkleNode}; +use crypto::{ + hash::{HashValueBitIterator, SPARSE_MERKLE_PLACEHOLDER_HASH}, + HashValue, +}; +use std::rc::Rc; +use types::{account_state_blob::AccountStateBlob, proof::SparseMerkleProof}; + +/// `AccountState` describes the result of querying an account from this SparseMerkleTree. +#[derive(Debug, Eq, PartialEq)] +pub enum AccountState { + /// The account exists in the tree, therefore we can give its value. + ExistsInScratchPad(AccountStateBlob), + + /// The account does not exist in the tree, but exists in DB. This happens when the search + /// reaches a leaf node that has the requested account, but the node has only the value hash + /// because it was loaded into memory as part of a non-inclusion proof. When we go to DB we + /// don't need to traverse the tree to find the same leaf, instead we can use the value hash to + /// look up the account blob directly. + ExistsInDB, + + /// The account does not exist in either the tree or DB. This happens when the search reaches + /// an empty node, or a leaf node that has a different account. + DoesNotExist, + + /// We do not know if this account exists or not and need to go to DB to find out. This happens + /// when the search reaches a subtree node. + Unknown, +} + +/// The Sparse Merkle Tree implementation. +#[derive(Debug)] +pub struct SparseMerkleTree { + root: Rc, +} + +impl SparseMerkleTree { + /// Constructs a Sparse Merkle Tree with a root hash. This is often used when we restart and + /// the scratch pad and the storage have identical state, so we use a single root hash to + /// represent the entire state. + pub fn new(root_hash: HashValue) -> Self { + SparseMerkleTree { + root: Rc::new(if root_hash != *SPARSE_MERKLE_PLACEHOLDER_HASH { + SparseMerkleNode::new_subtree(root_hash) + } else { + SparseMerkleNode::new_empty() + }), + } + } + + /// Constructs a new Sparse Merkle Tree as if we are updating the existing tree. Since the tree + /// is immutable, the existing tree will remain the same and may share part of the tree with + /// the new one. + pub fn update( + &self, + updates: Vec<(HashValue, AccountStateBlob)>, + proof_reader: &impl ProofRead, + ) -> Result { + let mut root = Rc::clone(&self.root); + for (key, new_blob) in updates { + root = Self::update_one(root, key, new_blob, proof_reader)?; + } + Ok(SparseMerkleTree { root }) + } + + fn update_one( + root: Rc, + key: HashValue, + new_blob: AccountStateBlob, + proof_reader: &impl ProofRead, + ) -> Result, UpdateError> { + let mut current_node = root; + let mut bits = key.iter_bits(); + + // Starting from root, traverse the tree according to key until we find a non-internal + // node. Record all the bits and sibling nodes on the path. + let mut bits_on_path = vec![]; + let mut siblings_on_path = vec![]; + loop { + let next_node = if let Node::Internal(node) = &*current_node.borrow() { + let bit = bits.next().unwrap_or_else(|| { + panic!("Tree is deeper than {} levels.", HashValue::LENGTH_IN_BITS) + }); + bits_on_path.push(bit); + if bit { + siblings_on_path.push(node.clone_left_child()); + node.clone_right_child() + } else { + siblings_on_path.push(node.clone_right_child()); + node.clone_left_child() + } + } else { + break; + }; + current_node = next_node; + } + + // Now we are at the bottom of the tree and current_node can be either a leaf, a subtree or + // empty. We construct a new subtree like we are inserting the key here. + let new_node = + Self::construct_subtree_at_bottom(current_node, key, new_blob, bits, proof_reader)?; + + // Use the new node and all previous siblings on the path to construct the final tree. + Ok(Self::construct_subtree( + bits_on_path.into_iter().rev(), + siblings_on_path.into_iter().rev(), + new_node, + )) + } + + /// This function is called when we are trying to write (key, new_value) to the tree and have + /// traversed the existing tree using some prefix of the key. We should have reached the bottom + /// of the existing tree, so current_node cannot be an internal node. This function will + /// construct a subtree using current_node, the new key-value pair and potentially the + /// key-value pair in the proof. + fn construct_subtree_at_bottom( + current_node: Rc, + key: HashValue, + new_blob: AccountStateBlob, + remaining_bits: HashValueBitIterator, + proof_reader: &impl ProofRead, + ) -> Result, UpdateError> { + match &*current_node.borrow() { + Node::Internal(_) => { + unreachable!("Reached an internal node at the bottom of the tree.") + } + Node::Leaf(node) => Ok(Self::construct_subtree_with_new_leaf( + key, + new_blob, + node, + HashValue::LENGTH_IN_BITS - remaining_bits.len(), + )), + Node::Subtree(_) => { + // When the search reaches an Subtree node, we need proof to to give us more + // information about this part of the tree. + let proof = proof_reader + .get_proof(key) + .ok_or(UpdateError::MissingProof)?; + + // Here the in-memory tree is identical to the tree in storage (we have only the + // root hash of this subtree in memory). So we need to take into account the leaf + // in the proof. + let new_subtree = match proof.leaf() { + Some((existing_key, existing_value_hash)) => { + let existing_leaf = + LeafNode::new(existing_key, LeafValue::BlobHash(existing_value_hash)); + Self::construct_subtree_with_new_leaf( + key, + new_blob, + &existing_leaf, + proof.siblings().len(), + ) + } + None => Rc::new(SparseMerkleNode::new_leaf(key, LeafValue::Blob(new_blob))), + }; + + let num_remaining_bits = remaining_bits.len(); + Ok(Self::construct_subtree( + remaining_bits + .rev() + .skip(HashValue::LENGTH_IN_BITS - proof.siblings().len()), + proof + .siblings() + .iter() + .skip(HashValue::LENGTH_IN_BITS - num_remaining_bits) + .rev() + .map(|sibling_hash| { + Rc::new(if *sibling_hash != *SPARSE_MERKLE_PLACEHOLDER_HASH { + SparseMerkleNode::new_subtree(*sibling_hash) + } else { + SparseMerkleNode::new_empty() + }) + }), + new_subtree, + )) + } + Node::Empty => { + // When we reach an empty node, we just place the leaf node at the same position to + // replace the empty node. + Ok(Rc::new(SparseMerkleNode::new_leaf( + key, + LeafValue::Blob(new_blob), + ))) + } + } + } + + /// Given key, new value, existing leaf and the distance from root to the existing leaf, + /// constructs a new subtree that has either the new leaf or both nodes, depending on whether + /// the key equals the existing leaf's key. + /// + /// 1. If the key equals the existing leaf's key, we simply need to update the leaf to the new + /// value and return it. For example, in the following case this function will return + /// `new_leaf`. + /// ``` text + /// o o + /// / \ / \ + /// o o => o o + /// / \ / \ + /// o existing_leaf o new_leaf + /// ``` + /// + /// 2. Otherwise, we need to construct an "extension" for the common prefix, and at the end of + /// the extension a subtree for both keys. For example, in the following case we assume the + /// existing leaf's key starts with 010010 and key starts with 010011, and this function + /// will return `x`. + /// ```text + /// o o common_prefix_len = 5 + /// / \ / \ distance_from_root_to_existing_leaf = 2 + /// o o o o extension_len = common_prefix_len - distance_from_root_to_existing_leaf = 3 + /// / \ / \ + /// o existing_leaf => o x _ + /// / \ ^ + /// o placeholder | + /// / \ | + /// o placeholder extension + /// / \ | + /// placeholder o - + /// / \ + /// existing_leaf new_leaf + /// ``` + fn construct_subtree_with_new_leaf( + key: HashValue, + new_blob: AccountStateBlob, + existing_leaf: &LeafNode, + distance_from_root_to_existing_leaf: usize, + ) -> Rc { + let new_leaf = Rc::new(SparseMerkleNode::new_leaf(key, LeafValue::Blob(new_blob))); + + if key == existing_leaf.key() { + // This implies that `key` already existed and the proof is an inclusion proof. + return new_leaf; + } + + // This implies that `key` did not exist and was just created. The proof is a non-inclusion + // proof. See above example for how extension_len is computed. + let common_prefix_len = key.common_prefix_bits_len(existing_leaf.key()); + assert!( + common_prefix_len >= distance_from_root_to_existing_leaf, + "common_prefix_len: {}, distance_from_root_to_existing_leaf: {}", + common_prefix_len, + distance_from_root_to_existing_leaf, + ); + let extension_len = common_prefix_len - distance_from_root_to_existing_leaf; + Self::construct_subtree( + key.iter_bits() + .rev() + .skip(HashValue::LENGTH_IN_BITS - common_prefix_len - 1) + .take(extension_len + 1), + std::iter::once(Rc::new(SparseMerkleNode::new_leaf( + existing_leaf.key(), + existing_leaf.value().clone(), + ))) + .chain(std::iter::repeat(Rc::new(SparseMerkleNode::new_empty())).take(extension_len)), + new_leaf, + ) + } + + /// Constructs a subtree with a list of siblings and a leaf. For example, if `bits` are + /// [false, false, true] and `siblings` are [a, b, c], the resulting subtree will look like: + /// ```text + /// x + /// / \ + /// c o + /// / \ + /// o b + /// / \ + /// leaf a + /// ``` + /// and this function will return `x`. Both `bits` and `siblings` start from the bottom. + fn construct_subtree( + bits: impl Iterator, + siblings: impl Iterator>, + leaf: Rc, + ) -> Rc { + itertools::zip_eq(bits, siblings).fold(leaf, |previous_node, (bit, sibling)| { + Rc::new(if bit { + SparseMerkleNode::new_internal(sibling, previous_node) + } else { + SparseMerkleNode::new_internal(previous_node, sibling) + }) + }) + } + + /// Queries a `key` in this `SparseMerkleTree`. + pub fn get(&self, key: HashValue) -> AccountState { + let mut current_node = Rc::clone(&self.root); + let mut bits = key.iter_bits(); + + loop { + let next_node = if let Node::Internal(node) = &*current_node.borrow() { + match bits.next() { + Some(bit) => { + if bit { + node.clone_right_child() + } else { + node.clone_left_child() + } + } + None => panic!("Tree is deeper than {} levels.", HashValue::LENGTH_IN_BITS), + } + } else { + break; + }; + current_node = next_node; + } + + let ret = match &*current_node.borrow() { + Node::Leaf(node) => { + if key == node.key() { + match node.value() { + LeafValue::Blob(blob) => AccountState::ExistsInScratchPad(blob.clone()), + LeafValue::BlobHash(_) => AccountState::ExistsInDB, + } + } else { + AccountState::DoesNotExist + } + } + Node::Subtree(_) => AccountState::Unknown, + Node::Empty => AccountState::DoesNotExist, + Node::Internal(_) => { + unreachable!("There is an internal node at the bottom of the tree.") + } + }; + ret + } + + /// Returns the root hash of this tree. + pub fn root_hash(&self) -> HashValue { + self.root.borrow().hash() + } + + /// Prunes a tree by replacing every node reachable from root with a subtree node that has the + /// same hash. If a node is empty or a subtree, we don't need to do anything. For example in + /// the following case, if we drop `S_i`, we will replace o with a subtree node, then `o` no + /// longer has pointers to its children `x` and `A`, so they will be dropped automatically. + /// ```text + /// S_i S_{i+1} S_{i+1} + /// / \ / \ / \ + /// / y / \ drop(S_i) o y' + /// / _______/ \ ========> / \ + /// // \ z B + /// o y' + /// / \ / \ + /// x A z B + /// ``` + pub fn prune(&self) { + let root = Rc::clone(&self.root); + Self::prune_node(root); + } + + fn prune_node(node: Rc) { + let mut borrowed = node.borrow_mut(); + let node_hash = borrowed.hash(); + + match &*borrowed { + Node::Empty => return, + Node::Subtree(_) => return, + Node::Internal(node) => { + let left_child = node.clone_left_child(); + let right_child = node.clone_right_child(); + Self::prune_node(left_child); + Self::prune_node(right_child); + } + Node::Leaf(_) => (), + } + + *borrowed = Node::new_subtree(node_hash); + } +} + +impl Default for SparseMerkleTree { + fn default() -> Self { + SparseMerkleTree::new(*SPARSE_MERKLE_PLACEHOLDER_HASH) + } +} + +/// A type that implements `ProofRead` can provide proof for keys in persistent storage. +pub trait ProofRead { + /// Gets verified proof for this key in persistent storage. + fn get_proof(&self, key: HashValue) -> Option<&SparseMerkleProof>; +} + +/// All errors `update` can possibly return. +#[derive(Debug, Eq, PartialEq)] +pub enum UpdateError { + /// The update intends to insert a key that does not exist in the tree, so the operation needs + /// proof to get more information about the tree, but no proof is provided. + MissingProof, +} diff --git a/storage/scratchpad/src/sparse_merkle/node.rs b/storage/scratchpad/src/sparse_merkle/node.rs new file mode 100644 index 0000000000000..4699e021d4b05 --- /dev/null +++ b/storage/scratchpad/src/sparse_merkle/node.rs @@ -0,0 +1,258 @@ +// Copyright (c) The Libra Core Contributors +// SPDX-License-Identifier: Apache-2.0 + +//! This module defines all kinds of nodes in the Sparse Merkle Tree maintained in scratch pad. +//! There are four kinds of nodes: +//! +//! - An `InternalNode` is a node that has two children. It is same as the internal node in a +//! standard Merkle tree. +//! +//! - A `LeafNode` represents a single account. Similar to what is in storage, a leaf node has a +//! key which is the hash of the account address as well as a value hash which is the hash of the +//! corresponding account blob. The difference is that a `LeafNode` does not always have the value, +//! in the case when the leaf was loaded into memory as part of a non-inclusion proof. +//! +//! - A `SubtreeNode` represents a subtree with one or more leaves. `SubtreeNode`s are generated +//! when we get accounts from storage with proof. It stores the root hash of this subtree. +//! +//! - An `EmptyNode` represents an empty subtree with zero leaf. + +use crypto::{ + hash::{CryptoHash, SPARSE_MERKLE_PLACEHOLDER_HASH}, + HashValue, +}; +use std::{ + cell::{Ref, RefCell, RefMut}, + rc::Rc, +}; +use types::{ + account_state_blob::AccountStateBlob, + proof::{SparseMerkleInternalNode, SparseMerkleLeafNode}, +}; + +/// We wrap the node in `RefCell`. The only case when we will mutably borrow the node is when we +/// drop a subtree originated from this node and commit things to storage. In that case we will +/// replace the an `InternalNode` or a `LeafNode` with a `SubtreeNode`. +#[derive(Debug)] +pub struct SparseMerkleNode { + node: RefCell, +} + +impl SparseMerkleNode { + /// Constructs a new internal node given two children. + pub fn new_internal( + left_child: Rc, + right_child: Rc, + ) -> Self { + SparseMerkleNode { + node: RefCell::new(Node::new_internal(left_child, right_child)), + } + } + + /// Constructs a new leaf node using given key and value. + pub fn new_leaf(key: HashValue, value: LeafValue) -> Self { + SparseMerkleNode { + node: RefCell::new(Node::new_leaf(key, value)), + } + } + + /// Constructs a new subtree node with given root hash. + pub fn new_subtree(hash: HashValue) -> Self { + SparseMerkleNode { + node: RefCell::new(Node::new_subtree(hash)), + } + } + + /// Constructs a new empty node. + pub fn new_empty() -> Self { + SparseMerkleNode { + node: RefCell::new(Node::new_empty()), + } + } + + /// Immutably borrows the wrapped node. + pub fn borrow(&self) -> Ref { + self.node.borrow() + } + + /// Mutably borrows the wrapped node. + pub fn borrow_mut(&self) -> RefMut { + self.node.borrow_mut() + } +} + +/// The underlying node is either `InternalNode`, `LeafNode`, `SubtreeNode` or `EmptyNode`. +#[derive(Debug)] +pub enum Node { + Internal(InternalNode), + Leaf(LeafNode), + Subtree(SubtreeNode), + Empty, +} + +impl Node { + pub fn new_internal( + left_child: Rc, + right_child: Rc, + ) -> Self { + Node::Internal(InternalNode::new(left_child, right_child)) + } + + pub fn new_leaf(key: HashValue, value: LeafValue) -> Self { + Node::Leaf(LeafNode::new(key, value)) + } + + pub fn new_subtree(hash: HashValue) -> Self { + Node::Subtree(SubtreeNode::new(hash)) + } + + pub fn new_empty() -> Self { + Node::Empty + } + + #[cfg(test)] + pub fn is_subtree(&self) -> bool { + if let Node::Subtree(_) = self { + true + } else { + false + } + } + + #[cfg(test)] + pub fn is_empty(&self) -> bool { + if let Node::Empty = self { + true + } else { + false + } + } + + pub fn hash(&self) -> HashValue { + match self { + Node::Internal(node) => node.hash(), + Node::Leaf(node) => node.hash(), + Node::Subtree(node) => node.hash(), + Node::Empty => *SPARSE_MERKLE_PLACEHOLDER_HASH, + } + } +} + +/// An internal node. +#[derive(Debug)] +pub struct InternalNode { + /// The hash of this internal node which is the root hash of the subtree. + hash: HashValue, + + /// Pointer to left child. + left_child: Rc, + + /// Pointer to right child. + right_child: Rc, +} + +impl InternalNode { + fn new(left_child: Rc, right_child: Rc) -> Self { + match (&*left_child.node.borrow(), &*right_child.node.borrow()) { + (Node::Subtree(_), Node::Subtree(_)) => { + panic!("Two subtree children should have been merged into a single subtree node.") + } + (Node::Leaf(_), Node::Empty) => { + panic!("A leaf with an empty sibling should have been merged into a single leaf.") + } + (Node::Empty, Node::Leaf(_)) => { + panic!("A leaf with an empty sibling should have been merged into a single leaf.") + } + _ => (), + } + + let hash = + SparseMerkleInternalNode::new(left_child.borrow().hash(), right_child.borrow().hash()) + .hash(); + InternalNode { + hash, + left_child, + right_child, + } + } + + fn hash(&self) -> HashValue { + self.hash + } + + pub fn clone_left_child(&self) -> Rc { + Rc::clone(&self.left_child) + } + + pub fn clone_right_child(&self) -> Rc { + Rc::clone(&self.right_child) + } +} + +#[derive(Clone, Debug, Eq, PartialEq)] +pub enum LeafValue { + /// The account state blob. + Blob(AccountStateBlob), + + /// The hash of the blob. + BlobHash(HashValue), +} + +/// A `LeafNode` represents a single account in the Sparse Merkle Tree. +#[derive(Debug)] +pub struct LeafNode { + /// The key is the hash of the address. + key: HashValue, + + /// The account blob or its hash. It's possible that we don't know the value here. For example, + /// this leaf was loaded into memory as part of an non-inclusion proof. In that case we + /// only know the value's hash. + value: LeafValue, + + /// The hash of this leaf node which is Hash(key || Hash(value)). + hash: HashValue, +} + +impl LeafNode { + pub fn new(key: HashValue, value: LeafValue) -> Self { + let value_hash = match value { + LeafValue::Blob(ref val) => val.hash(), + LeafValue::BlobHash(ref val_hash) => *val_hash, + }; + let hash = SparseMerkleLeafNode::new(key, value_hash).hash(); + LeafNode { key, value, hash } + } + + pub fn key(&self) -> HashValue { + self.key + } + + pub fn value(&self) -> &LeafValue { + &self.value + } + + fn hash(&self) -> HashValue { + self.hash + } +} + +/// A subtree node. +#[derive(Debug)] +pub struct SubtreeNode { + /// The root hash of the subtree represented by this node. + hash: HashValue, +} + +impl SubtreeNode { + fn new(hash: HashValue) -> Self { + assert_ne!( + hash, *SPARSE_MERKLE_PLACEHOLDER_HASH, + "A subtree should never be empty." + ); + SubtreeNode { hash } + } + + pub fn hash(&self) -> HashValue { + self.hash + } +} diff --git a/storage/scratchpad/src/sparse_merkle/sparse_merkle_test.rs b/storage/scratchpad/src/sparse_merkle/sparse_merkle_test.rs new file mode 100644 index 0000000000000..0bd2bc978018b --- /dev/null +++ b/storage/scratchpad/src/sparse_merkle/sparse_merkle_test.rs @@ -0,0 +1,547 @@ +// Copyright (c) The Libra Core Contributors +// SPDX-License-Identifier: Apache-2.0 + +use super::{ + node::{LeafNode, LeafValue, SparseMerkleNode}, + AccountState, ProofRead, SparseMerkleTree, +}; +use crypto::{ + hash::{CryptoHash, TestOnlyHash, SPARSE_MERKLE_PLACEHOLDER_HASH}, + HashValue, +}; +use std::{collections::HashMap, rc::Rc}; +use types::{ + account_state_blob::AccountStateBlob, + proof::{verify_sparse_merkle_element, SparseMerkleProof}, +}; + +fn hash_internal(left_child: HashValue, right_child: HashValue) -> HashValue { + types::proof::SparseMerkleInternalNode::new(left_child, right_child).hash() +} + +fn hash_leaf(key: HashValue, value_hash: HashValue) -> HashValue { + types::proof::SparseMerkleLeafNode::new(key, value_hash).hash() +} + +#[derive(Default)] +struct ProofReader(HashMap); + +impl ProofReader { + fn new(key_with_proof: Vec<(HashValue, SparseMerkleProof)>) -> Self { + ProofReader(key_with_proof.into_iter().collect()) + } +} + +impl ProofRead for ProofReader { + fn get_proof(&self, key: HashValue) -> Option<&SparseMerkleProof> { + self.0.get(&key) + } +} + +#[test] +fn test_construct_subtree_zero_siblings() { + let node_hash = HashValue::new([1; HashValue::LENGTH]); + let node = SparseMerkleNode::new_subtree(node_hash); + let subtree_node = + SparseMerkleTree::construct_subtree(std::iter::empty(), std::iter::empty(), Rc::new(node)); + let smt = SparseMerkleTree { root: subtree_node }; + assert_eq!(smt.root_hash(), node_hash); +} + +#[test] +fn test_construct_subtree_three_siblings() { + // x + // / \ + // [4; 32] c y + // / \ + // z b [3; 32] + // / \ + // node a [2; 32] + let key = b"hello".test_only_hash(); + let blob = AccountStateBlob::from(b"world".to_vec()); + let leaf_hash = hash_leaf(key, blob.hash()); + let node = SparseMerkleNode::new_leaf(key, LeafValue::BlobHash(blob.hash())); + let bits = vec![false, false, true]; + let a_hash = HashValue::new([2; HashValue::LENGTH]); + let b_hash = HashValue::new([3; HashValue::LENGTH]); + let c_hash = HashValue::new([4; HashValue::LENGTH]); + let siblings = vec![a_hash, b_hash, c_hash] + .into_iter() + .map(|hash| Rc::new(SparseMerkleNode::new_subtree(hash))); + let subtree_node = + SparseMerkleTree::construct_subtree(bits.into_iter(), siblings, Rc::new(node)); + let smt = SparseMerkleTree { root: subtree_node }; + + let z_hash = hash_internal(leaf_hash, a_hash); + let y_hash = hash_internal(z_hash, b_hash); + let root_hash = hash_internal(c_hash, y_hash); + assert_eq!(smt.root_hash(), root_hash); +} + +#[test] +#[should_panic] +fn test_construct_subtree_panic() { + let node_hash = HashValue::new([1; HashValue::LENGTH]); + let node = SparseMerkleNode::new_subtree(node_hash); + let _subtree_node = SparseMerkleTree::construct_subtree( + std::iter::once(true), + std::iter::empty(), + Rc::new(node), + ); +} + +#[test] +fn test_construct_subtree_with_new_leaf_override_existing_leaf() { + let key = b"hello".test_only_hash(); + let old_blob = AccountStateBlob::from(b"old_old_old".to_vec()); + let new_blob = AccountStateBlob::from(b"new_new_new".to_vec()); + + let existing_leaf = LeafNode::new(key, LeafValue::BlobHash(old_blob.hash())); + + let subtree = SparseMerkleTree::construct_subtree_with_new_leaf( + key, + new_blob.clone(), + &existing_leaf, + /* distance_from_root_to_existing_leaf = */ 3, + ); + let smt = SparseMerkleTree { root: subtree }; + + let new_blob_hash = new_blob.hash(); + let root_hash = hash_leaf(key, new_blob_hash); + assert_eq!(smt.root_hash(), root_hash); +} + +#[test] +fn test_construct_subtree_with_new_leaf_create_extension() { + // root root + // / \ / \ + // o o o o + // / \ / \ + // o existing_key => o subtree + // / \ + // y placeholder + // / \ + // x placeholder + // / \ + // existing_key new_key + let existing_key = b"world".test_only_hash(); + let existing_blob = AccountStateBlob::from(b"world".to_vec()); + let existing_blob_hash = existing_blob.hash(); + let new_key = b"aaaaaaaaaaaaaaaaaaaaaaaaaaaaaaa".test_only_hash(); + let new_blob = AccountStateBlob::from(b"new_blob!!!!!".to_vec()); + assert_eq!(existing_key[0], 0b0100_0010); + assert_eq!(new_key[0], 0b0100_1011); + + let existing_leaf = LeafNode::new(existing_key, LeafValue::BlobHash(existing_blob.hash())); + + let subtree = SparseMerkleTree::construct_subtree_with_new_leaf( + new_key, + new_blob.clone(), + &existing_leaf, + /* distance_from_root_to_existing_leaf = */ 2, + ); + let smt = SparseMerkleTree { root: subtree }; + + let new_blob_hash = new_blob.hash(); + let existing_leaf_hash = hash_leaf(existing_key, existing_blob_hash); + let new_leaf_hash = hash_leaf(new_key, new_blob_hash); + let x_hash = hash_internal(existing_leaf_hash, new_leaf_hash); + let y_hash = hash_internal(x_hash, *SPARSE_MERKLE_PLACEHOLDER_HASH); + let root_hash = hash_internal(y_hash, *SPARSE_MERKLE_PLACEHOLDER_HASH); + assert_eq!(smt.root_hash(), root_hash); +} + +#[test] +#[should_panic(expected = "Reached an internal node at the bottom of the tree.")] +fn test_construct_subtree_at_bottom_found_internal_node() { + let left_child = Rc::new(SparseMerkleNode::new_subtree(HashValue::new( + [1; HashValue::LENGTH], + ))); + let right_child = Rc::new(SparseMerkleNode::new_empty()); + let current_node = Rc::new(SparseMerkleNode::new_internal(left_child, right_child)); + let key = b"hello".test_only_hash(); + let new_blob = AccountStateBlob::from(b"new_blob".to_vec()); + let remaining_bits = key.iter_bits(); + let proof_reader = ProofReader::default(); + let _subtree_node = SparseMerkleTree::construct_subtree_at_bottom( + current_node, + key, + new_blob, + remaining_bits, + &proof_reader, + ); +} + +#[test] +fn test_construct_subtree_at_bottom_found_leaf_node() { + // root root + // / \ / \ + // o o o o + // / \ / \ + // o existing_key => o subtree + // / \ + // y placeholder + // / \ + // x placeholder + // / \ + // existing_key new_key + let existing_key = b"world".test_only_hash(); + let existing_blob = AccountStateBlob::from(b"world".to_vec()); + let existing_blob_hash = existing_blob.hash(); + let new_key = b"aaaaaaaaaaaaaaaaaaaaaaaaaaaaaaa".test_only_hash(); + let new_blob = AccountStateBlob::from(b"new_blob!!!!!".to_vec()); + assert_eq!(existing_key[0], 0b0100_0010); + assert_eq!(new_key[0], 0b0100_1011); + + let current_node = Rc::new(SparseMerkleNode::new_leaf( + existing_key, + LeafValue::BlobHash(existing_blob_hash), + )); + let remaining_bits = { + let mut iter = new_key.iter_bits(); + iter.next(); + iter.next(); + iter + }; + let leaf = Some((existing_key, existing_blob_hash)); + let siblings: Vec<_> = (0..2) + .map(|x| HashValue::new([x; HashValue::LENGTH])) + .collect(); + let proof = SparseMerkleProof::new(leaf, siblings); + let proof_reader = ProofReader::new(vec![(new_key, proof)]); + + let subtree = SparseMerkleTree::construct_subtree_at_bottom( + current_node, + new_key, + new_blob.clone(), + remaining_bits, + &proof_reader, + ) + .unwrap(); + let smt = SparseMerkleTree { root: subtree }; + + let existing_leaf_hash = hash_leaf(existing_key, existing_blob_hash); + let new_blob_hash = new_blob.hash(); + let new_leaf_hash = hash_leaf(new_key, new_blob_hash); + let x_hash = hash_internal(existing_leaf_hash, new_leaf_hash); + let y_hash = hash_internal(x_hash, *SPARSE_MERKLE_PLACEHOLDER_HASH); + let root_hash = hash_internal(y_hash, *SPARSE_MERKLE_PLACEHOLDER_HASH); + assert_eq!(smt.root_hash(), root_hash); +} + +#[test] +fn test_construct_subtree_at_bottom_found_empty_node() { + // root root + // / \ / \ + // o o o o + // / \ / \ + // o placeholder => o new_key + let new_key = b"aaaaaaaaaaaaaaaaaaaaaaaaaaaaaaa".test_only_hash(); + let new_blob = AccountStateBlob::from(b"new_blob!!!!!".to_vec()); + assert_eq!(new_key[0], 0b0100_1011); + + let current_node = Rc::new(SparseMerkleNode::new_empty()); + let remaining_bits = { + let mut iter = new_key.iter_bits(); + // Skip first two. + iter.next(); + iter.next(); + iter + }; + let proof_reader = ProofReader::default(); + + let subtree = SparseMerkleTree::construct_subtree_at_bottom( + current_node, + new_key, + new_blob.clone(), + remaining_bits, + &proof_reader, + ) + .unwrap(); + let smt = SparseMerkleTree { root: subtree }; + + let new_blob_hash = new_blob.hash(); + let new_leaf_hash = hash_leaf(new_key, new_blob_hash); + assert_eq!(smt.root_hash(), new_leaf_hash); +} + +#[test] +fn test_construct_subtree_at_bottom_found_subtree_node() { + // root root + // / \ / \ + // o o o o + // / \ / \ + // o subtree => o new_subtree + // / \ + // x sibling [5; 32] (from proof) + // / \ + // sibling [6; 32] (from proof) new_leaf + let new_key = b"aaaaaaaa".test_only_hash(); + let new_blob = AccountStateBlob::from(b"new_blob!!!!!".to_vec()); + assert_eq!(new_key[0], 0b0101_1111); + + let current_node = Rc::new(SparseMerkleNode::new_subtree(HashValue::new( + [1; HashValue::LENGTH], + ))); + let remaining_bits = { + let mut iter = new_key.iter_bits(); + // Skip first two. + iter.next(); + iter.next(); + iter + }; + let leaf = None; + let siblings: Vec<_> = (3..7) + .map(|x| HashValue::new([x; HashValue::LENGTH])) + .collect(); + let proof = SparseMerkleProof::new(leaf, siblings); + let proof_reader = ProofReader::new(vec![(new_key, proof)]); + + let new_subtree = SparseMerkleTree::construct_subtree_at_bottom( + current_node, + new_key, + new_blob.clone(), + remaining_bits, + &proof_reader, + ) + .unwrap(); + let smt = SparseMerkleTree { root: new_subtree }; + + let new_blob_hash = new_blob.hash(); + let new_leaf_hash = hash_leaf(new_key, new_blob_hash); + let x_hash = hash_internal(HashValue::new([6; HashValue::LENGTH]), new_leaf_hash); + let new_subtree_hash = hash_internal(x_hash, HashValue::new([5; HashValue::LENGTH])); + assert_eq!(smt.root_hash(), new_subtree_hash); +} + +#[test] +fn test_update_256_siblings_in_proof() { + // root + // / \ + // o placeholder + // / \ + // o placeholder + // / \ + // . placeholder + // . + // . (256 levels) + // o + // / \ + // key1 key2 + let key1 = HashValue::new([0; HashValue::LENGTH]); + let key2 = { + let mut buf = key1.to_vec(); + *buf.last_mut().unwrap() |= 1; + HashValue::from_slice(&buf).unwrap() + }; + + let blob1 = AccountStateBlob::from(b"value1".to_vec()); + let blob2 = AccountStateBlob::from(b"value2".to_vec()); + let value1_hash = blob1.hash(); + let value2_hash = blob2.hash(); + let leaf1_hash = hash_leaf(key1, value1_hash); + let leaf2_hash = hash_leaf(key2, value2_hash); + + let mut siblings: Vec<_> = std::iter::repeat(*SPARSE_MERKLE_PLACEHOLDER_HASH) + .take(255) + .collect(); + siblings.push(leaf2_hash); + let proof_of_key1 = SparseMerkleProof::new(Some((key1, value1_hash)), siblings.clone()); + + let old_root_hash = siblings + .iter() + .rev() + .fold(leaf1_hash, |previous_hash, hash| { + hash_internal(previous_hash, *hash) + }); + assert!( + verify_sparse_merkle_element(old_root_hash, key1, &Some(blob1), &proof_of_key1).is_ok() + ); + + let new_blob1 = AccountStateBlob::from(b"value1111111111111".to_vec()); + let proof_reader = ProofReader::new(vec![(key1, proof_of_key1)]); + let smt = SparseMerkleTree::new(old_root_hash); + let new_smt = smt + .update(vec![(key1, new_blob1.clone())], &proof_reader) + .unwrap(); + + let new_blob1_hash = new_blob1.hash(); + let new_leaf1_hash = hash_leaf(key1, new_blob1_hash); + let new_root_hash = siblings + .iter() + .rev() + .fold(new_leaf1_hash, |previous_hash, hash| { + hash_internal(previous_hash, *hash) + }); + assert_eq!(new_smt.root_hash(), new_root_hash); + + assert_eq!( + new_smt.get(key1), + AccountState::ExistsInScratchPad(new_blob1) + ); + assert_eq!(new_smt.get(key2), AccountState::Unknown); +} + +#[test] +fn test_new_subtree() { + let root_hash = HashValue::new([1; HashValue::LENGTH]); + let smt = SparseMerkleTree::new(root_hash); + assert!(smt.root.borrow().is_subtree()); + assert_eq!(smt.root_hash(), root_hash); +} + +#[test] +fn test_new_empty() { + let root_hash = *SPARSE_MERKLE_PLACEHOLDER_HASH; + let smt = SparseMerkleTree::new(root_hash); + assert!(smt.root.borrow().is_empty()); + assert_eq!(smt.root_hash(), root_hash); +} + +#[test] +fn test_update() { + // Before the update, the tree was: + // root + // / \ + // y key3 + // / \ + // x placeholder + // / \ + // key1 key2 + let key1 = b"aaaaa".test_only_hash(); + let key2 = b"bb".test_only_hash(); + let key3 = b"cccc".test_only_hash(); + assert_eq!(key1[0], 0b0000_0100); + assert_eq!(key2[0], 0b0010_0100); + assert_eq!(key3[0], 0b1110_0111); + let value1 = AccountStateBlob::from(b"value1".to_vec()); + let value1_hash = value1.hash(); + let value2_hash = AccountStateBlob::from(b"value2".to_vec()).hash(); + let value3_hash = AccountStateBlob::from(b"value3".to_vec()).hash(); + + // A new key at the "placeholder" position. + let key4 = b"d".test_only_hash(); + assert_eq!(key4[0], 0b0100_1100); + let value4 = AccountStateBlob::from(b"value".to_vec()); + + // Create a proof for this new key. + let leaf1_hash = hash_leaf(key1, value1_hash); + let leaf2_hash = hash_leaf(key2, value2_hash); + let leaf3_hash = hash_leaf(key3, value3_hash); + let x_hash = hash_internal(leaf1_hash, leaf2_hash); + let y_hash = hash_internal(x_hash, *SPARSE_MERKLE_PLACEHOLDER_HASH); + let old_root_hash = hash_internal(y_hash, leaf3_hash); + let proof = SparseMerkleProof::new(None, vec![leaf3_hash, x_hash]); + assert!(verify_sparse_merkle_element(old_root_hash, key4, &None, &proof).is_ok()); + + // Create the old tree and update the tree with new value and proof. + let proof_reader = ProofReader::new(vec![(key4, proof)]); + let old_smt = SparseMerkleTree::new(old_root_hash); + let smt1 = old_smt + .update(vec![(key4, value4.clone())], &proof_reader) + .unwrap(); + + // Now smt1 should look like this: + // root + // / \ + // y key3 (subtree) + // / \ + // x key4 + assert_eq!(smt1.get(key1), AccountState::Unknown); + assert_eq!(smt1.get(key2), AccountState::Unknown); + assert_eq!(smt1.get(key3), AccountState::Unknown); + assert_eq!( + smt1.get(key4), + AccountState::ExistsInScratchPad(value4.clone()) + ); + + let non_existing_key = b"foo".test_only_hash(); + assert_eq!(non_existing_key[0], 0b0111_0110); + assert_eq!(smt1.get(non_existing_key), AccountState::DoesNotExist); + + // Verify root hash. + let value4_hash = value4.hash(); + let leaf4_hash = hash_leaf(key4, value4_hash); + let y_hash = hash_internal(x_hash, leaf4_hash); + let root_hash = hash_internal(y_hash, leaf3_hash); + assert_eq!(smt1.root_hash(), root_hash); + + // Next, we are going to modify key1. Create a proof for key1. + let proof = SparseMerkleProof::new( + Some((key1, value1_hash)), + vec![leaf3_hash, *SPARSE_MERKLE_PLACEHOLDER_HASH, leaf2_hash], + ); + assert!(verify_sparse_merkle_element(old_root_hash, key1, &Some(value1), &proof).is_ok()); + + let value1 = AccountStateBlob::from(b"value11111".to_vec()); + let proof_reader = ProofReader::new(vec![(key1, proof)]); + let smt2 = smt1 + .update(vec![(key1, value1.clone())], &proof_reader) + .unwrap(); + + // Now the tree looks like: + // root + // / \ + // y key3 (subtree) + // / \ + // x key4 + // / \ + // key1 key2 (subtree) + assert_eq!( + smt2.get(key1), + AccountState::ExistsInScratchPad(value1.clone()) + ); + assert_eq!(smt2.get(key2), AccountState::Unknown); + assert_eq!(smt2.get(key3), AccountState::Unknown); + assert_eq!( + smt2.get(key4), + AccountState::ExistsInScratchPad(value4.clone()) + ); + + // Verify root hash. + let value1_hash = value1.hash(); + let leaf1_hash = hash_leaf(key1, value1_hash); + let x_hash = hash_internal(leaf1_hash, leaf2_hash); + let y_hash = hash_internal(x_hash, leaf4_hash); + let root_hash = hash_internal(y_hash, leaf3_hash); + assert_eq!(smt2.root_hash(), root_hash); + + // We now try to create another branch on top of smt1. + let value4 = AccountStateBlob::from(b"new value 4444444444".to_vec()); + // key4 already exists in the tree. + let proof_reader = ProofReader::default(); + let smt22 = smt1 + .update(vec![(key4, value4.clone())], &proof_reader) + .unwrap(); + + assert_eq!(smt22.get(key1), AccountState::Unknown); + assert_eq!(smt22.get(key2), AccountState::Unknown); + assert_eq!(smt22.get(key3), AccountState::Unknown); + assert_eq!( + smt22.get(key4), + AccountState::ExistsInScratchPad(value4.clone()) + ); + + // Now prune smt1. + smt1.prune(); + + // For smt2, only key1 should be available since smt2 was constructed by updating smt1 with + // key1. + assert_eq!( + smt2.get(key1), + AccountState::ExistsInScratchPad(value1.clone()) + ); + assert_eq!(smt2.get(key2), AccountState::Unknown); + assert_eq!(smt2.get(key3), AccountState::Unknown); + assert_eq!(smt2.get(key4), AccountState::Unknown); + + // For smt22, only key4 should be available since smt22 was constructed by updating smt1 with + // key4. + assert_eq!(smt22.get(key1), AccountState::Unknown); + assert_eq!(smt22.get(key2), AccountState::Unknown); + assert_eq!(smt22.get(key3), AccountState::Unknown); + assert_eq!( + smt22.get(key4), + AccountState::ExistsInScratchPad(value4.clone()) + ); +} diff --git a/storage/sparse_merkle/Cargo.toml b/storage/sparse_merkle/Cargo.toml new file mode 100644 index 0000000000000..981fccb194cb0 --- /dev/null +++ b/storage/sparse_merkle/Cargo.toml @@ -0,0 +1,20 @@ +[package] +name = "sparse_merkle" +version = "0.1.0" +authors = ["Libra Association "] +license = "Apache-2.0" +publish = false +edition = "2018" + +[dependencies] +bincode = "1.1.1" + +crypto = { path = "../../crypto/legacy_crypto" } +failure = { path = "../../common/failure_ext", package = "failure_ext" } +serde = { version = "1.0.89", features = ["derive"] } +types = { path = "../../types" } + +[dev-dependencies] +proptest = "0.9" +rand = "0.6.5" +serde_test = "1.0.87" diff --git a/storage/sparse_merkle/src/lib.rs b/storage/sparse_merkle/src/lib.rs new file mode 100644 index 0000000000000..72dfb09c5626c --- /dev/null +++ b/storage/sparse_merkle/src/lib.rs @@ -0,0 +1,535 @@ +// Copyright (c) The Libra Core Contributors +// SPDX-License-Identifier: Apache-2.0 + +//! This module implements [`SparseMerkleTree`] backed by storage module. The tree itself doesn't +//! persist anything, but realizes the logic of R/W only. The write path will produce all the +//! intermediate results in a batch for storage layer to commit and the read path will return +//! results directly. The public APIs are only [`new`](SparseMerkleTree::new), +//! [`put_keyed_blob_sets`](SparseMerkleTree::put_keyed_blob_sets), +//! [`put_keyed_blob_set`](SparseMerkleTree::put_keyed_blob_set) and +//! [`get_with_proof`](SparseMerkleTree::get_with_proof). After each put with a `keyed_blob_set` +//! based on a known root, the tree will return a new root hash with a [`TreeUpdateBatch`] +//! containing all newly generated tree nodes and blobs. +//! +//! The sparse Merkle tree itself logically is a 256-bit Merkle tree with an optimization +//! that any subtree containing 0 or 1 leaf node will be replaced by that leaf node or a placeholder +//! node with default hash value. With this optimization we can save CPU by avoiding hashing on +//! many sparse levels in the tree. Physically, the tree is structurally similar to the modified +//! Pactricia Merkle tree of Ethereum, with some modifications. Please read the code for details. + +#[cfg(test)] +mod mock_tree_store; +mod nibble_path; +mod node_serde; +pub mod node_type; +#[cfg(test)] +mod sparse_merkle_test; +mod tree_cache; + +use crypto::{ + hash::{CryptoHash, SPARSE_MERKLE_PLACEHOLDER_HASH}, + HashValue, +}; +use failure::prelude::*; +use nibble_path::{skip_common_prefix, NibbleIterator, NibblePath}; +use node_type::{BranchNode, ExtensionNode, LeafNode, Node}; +use std::collections::HashMap; +use tree_cache::TreeCache; +use types::{account_state_blob::AccountStateBlob, proof::definition::SparseMerkleProof}; + +/// The hardcoded maximum height of a [`SparseMerkleTree`] in nibbles. +const ROOT_NIBBLE_HEIGHT: usize = HashValue::LENGTH * 2; + +/// TreeReader defines the interface between [`SparseMerkleTree`] and underlying storage holding +/// nodes and blobs. +pub trait TreeReader { + /// Get state Merkle tree node given node hash + fn get_node(&self, node_hash: HashValue) -> Result; + /// Get state Merkle tree blob given blob hash + fn get_blob(&self, blob_hash: HashValue) -> Result; +} + +/// Node batch that will be written into db atomically with other batches. +pub type NodeBatch = HashMap; +/// Blob batch that will be written into db atomically with other batches. +pub type BlobBatch = HashMap; + +/// This is a wrapper of [`NodeBatch`] and [`BlobBatch`] that represents the incremental +/// updates of tree after applying a write set, which is a vector of account_address and +/// new_account_state_blob pairs. +#[derive(Clone, Debug, Default, Eq, PartialEq)] +pub struct TreeUpdateBatch { + node_batch: NodeBatch, + blob_batch: BlobBatch, +} + +/// Conversion between tuple type and `TreeUpdateBatch`. +impl From for (NodeBatch, BlobBatch) { + fn from(batch: TreeUpdateBatch) -> Self { + (batch.node_batch, batch.blob_batch) + } +} + +/// The sparse Merkle tree data structure. See [`crate`] for description. +pub struct SparseMerkleTree<'a, R: 'a + TreeReader> { + reader: &'a R, +} + +impl<'a, R> SparseMerkleTree<'a, R> +where + R: 'a + TreeReader, +{ + /// Creates a `SparseMerkleTree` backed by the given [`TreeReader`]. + pub fn new(reader: &'a R) -> Self { + Self { reader } + } + + /// Returns new nodes and account state blobs in a batch after applying `keyed_blob_set`. For + /// example, if after transaction `T_i` the committed state of the tree in persistent storage + /// looks like the following structure: + /// + /// ```text + /// S_i + /// / \ + /// . . + /// . . + /// / \ + /// o x + /// / \ + /// A B + /// storage (disk) + /// ``` + /// + /// where `A` and `B` denote the states of two adjacent accounts, and `x` is a sibling on the + /// path from root to A and B in the tree. Then a `keyed_blob_set` produced by the next + /// transaction `T_{i+1}` modifies other accounts `C` and `D` that lives in + /// the subtree under `x`, a new partial tree will be constructed in memory and the + /// structure will be: + /// + /// ```text + /// S_i | S_{i+1} + /// / \ | / \ + /// . . | . . + /// . . | . . + /// / \ | / \ + /// / x | / x' + /// o<-------------+- / \ + /// / \ | C D + /// A B | + /// storage (disk) | cache (memory) + /// ``` + /// + /// With this design, we are able to query the global state in persistent storage and + /// generate the proposed tree delta based on a specific root hash and `keyed_blob_set`. For + /// example, if we want to execute another transaction `T_{i+1}'`, we can use the tree `S_i` in + /// storage and apply the `keyed_blob_set` of transaction `T_{i+1}`. Then if the storage commits + /// the returned batch, the state `S_{i+1}` is ready to be read from the tree by calling + /// [`get_with_proof`](SparseMerkleTree::get_with_proof). Anything inside the batch is not + /// reachable from public interfaces before being commited. + pub fn put_keyed_blob_set( + &self, + keyed_blob_set: Vec<(HashValue, AccountStateBlob)>, + root_hash: HashValue, + ) -> Result<(HashValue, TreeUpdateBatch)> { + let (mut root_hashes, tree_update_batch) = + self.put_keyed_blob_sets(vec![keyed_blob_set], root_hash)?; + let root_hash = root_hashes.pop().expect("root hash must exist"); + assert!( + root_hashes.is_empty(), + "root_hashes can only have 1 root_hash inside" + ); + Ok((root_hash, tree_update_batch)) + } + + /// This is a helper function that calls + /// [`put_keyed_blob_set`](SparseMerkleTree::put_keyed_blob_set) with a series of + /// `keyed_blob_set`. + #[inline] + pub fn put_keyed_blob_sets( + &self, + keyed_blob_sets: Vec>, + root_hash: HashValue, + ) -> Result<(Vec, TreeUpdateBatch)> { + let mut tree_cache = TreeCache::new(self.reader, root_hash); + for keyed_blob_set in keyed_blob_sets { + assert!( + !keyed_blob_set.is_empty(), + "Transactions that output empty write set should not be included.", + ); + keyed_blob_set + .into_iter() + .map(|(key, value)| Self::put(key, value, &mut tree_cache)) + .collect::>()?; + // Freeze the current cache to make all contents in current cache immutable + tree_cache.freeze(); + } + + Ok(tree_cache.into()) + } + + fn put(key: HashValue, blob: AccountStateBlob, tree_cache: &mut TreeCache) -> Result<()> { + let nibble_path = NibblePath::new(key.to_vec()); + + // Get the root node. If this is the first operation, it would get the root node from the + // underlying db. Otherwise it most likely would come from `cache`. + let root = tree_cache.get_root_node()?; + + // Get the blob hash and put the value blob into `tree_cache`. + let blob_hash = blob.hash(); + tree_cache.put_blob(blob_hash, blob)?; + + // Start insertion from the root node. + let new_root_hash = + Self::insert_at(root, &mut nibble_path.nibbles(), blob_hash, tree_cache)?.0; + tree_cache.set_root_hash(new_root_hash); + + Ok(()) + } + + /// Helper function for recursive insertion into the subtree that starts from the current + /// `node`. Returns the hash of the created node and a boolean indicating whether the created + /// node is a leaf node. + /// + /// It is safe to use recursion here because the max depth is limited by the key length which + /// for this tree is the length of the hash of an account address. + fn insert_at( + node: Option, + nibble_iter: &mut NibbleIterator, + value_hash: HashValue, + tree_cache: &mut TreeCache, + ) -> Result<(HashValue, bool)> { + match node { + Some(Node::Branch(branch_node)) => { + Self::insert_at_branch_node(branch_node, nibble_iter, value_hash, tree_cache) + } + Some(Node::Extension(extension_node)) => { + Self::insert_at_extension_node(extension_node, nibble_iter, value_hash, tree_cache) + } + Some(Node::Leaf(leaf_node)) => { + Self::insert_at_leaf_node(leaf_node, nibble_iter, value_hash, tree_cache) + } + None => Self::create_leaf_node(nibble_iter, value_hash, tree_cache), + } + } + + /// Helper function for recursive insertion into the subtree that starts from the current + /// `branch_node`. Returns the hash of the created node and a boolean indicating whether the + /// created node is a leaf node. + fn insert_at_branch_node( + mut branch_node: BranchNode, + nibble_iter: &mut NibbleIterator, + value_hash: HashValue, + tree_cache: &mut TreeCache, + ) -> Result<(HashValue, bool)> { + // Delete the current branch node from `tree_cache` if it exists; Otherwise it is a noop + // since we will update this branch node anyway. Even if the new branch node is exactly the + // same as the old one, it is okay to delete it and then put the same node back. + tree_cache.delete_node(branch_node.hash()); + + // Find the next node to visit following the next nibble as index. + let next_node_index = nibble_iter.next().expect("Ran out of nibbles"); + + // Get the next node from `tree_cache` if it exists; Otherwise it will be `None`. + let next_node = branch_node + .child(next_node_index) + .map(|hash| tree_cache.get_node(hash)) + .transpose()?; + + // Traverse downwards from this branch node recursively to get the hash of the child node + // at `next_node_index`. + let (new_child_hash, is_new_child_leaf) = match next_node { + Some(child) => Self::insert_at(Some(child), nibble_iter, value_hash, tree_cache)?, + None => Self::create_leaf_node(nibble_iter, value_hash, tree_cache)?, + }; + + // Reuse the current `BranchNode` in memory to create a new branch node. + branch_node.set_child(next_node_index, (new_child_hash, is_new_child_leaf)); + let new_node_hash = branch_node.hash(); + + // Cache this new branch node with `new_node_hash` as the key. + tree_cache.put_node(new_node_hash, branch_node.into())?; + Ok((new_node_hash, false /* is_leaf */)) + } + + /// Helper function for recursive insertion into the subtree that starts from the current + /// extension node `extension_node`. Returns the hash of the created node and a boolean + /// indicating whether the created node is leaf node (always `false`). + fn insert_at_extension_node( + mut extension_node: ExtensionNode, + nibble_iter: &mut NibbleIterator, + value_hash: HashValue, + tree_cache: &mut TreeCache, + ) -> Result<(HashValue, bool)> { + // We are on a extension node but are trying to insert another node, so we may need + // to add a new path. + + // Delete the current extension node from tree_cache if it exists; Otherwise it is a + // noop. + tree_cache.delete_node(extension_node.hash()); + + // Determine the common prefix between this extension node and the nibble iterator of the + // incoming key. + let mut extension_nibble_iter = extension_node.nibble_path().nibbles(); + skip_common_prefix(&mut extension_nibble_iter, nibble_iter); + + // There are two possible cases after matching prefix: + // 1. All the nibbles of the extension node matches the nibble path of the incoming node, + // so just visit the next node recursively. Note: the next node must be a branch node + // otherwise the tree is corrupted. + if extension_nibble_iter.is_finished() { + assert!( + !nibble_iter.is_finished(), + "We should never end the search on an extension node when key length is fixed." + ); + let (inserted_child_hash, _is_leaf) = match tree_cache + .get_node(extension_node.child())? + { + Node::Branch(branch_node) => { + Self::insert_at_branch_node(branch_node, nibble_iter, value_hash, tree_cache)? + } + _ => bail!("Extension node shouldn't have a non-branch node as child"), + }; + extension_node.set_child(inserted_child_hash); + let new_node_hash = extension_node.hash(); + tree_cache.put_node(new_node_hash, extension_node.into())?; + return Ok((new_node_hash, false /* is_leaf */)); + } + + // 2. Not all the nibbles of the extension node match the nibble iterator of the + // incoming key; there are several cases. Let us assume `O` denotes a nibble and `X` + // denotes the first mismatched nibble. The nibble path of the extension node can be + // illustrated as `(O...)X(O...)`: the extension node will be replaced by an optional + // extension node if needed before the fork, followed by a branch node at the fork and + // another possible extension node if needed after the fork. We create new nodes in a + // bottom-up order. + + // 1) Cache the visited nibbles. We will use it in step 4). + let extension_nibble_iter_before_fork = extension_nibble_iter.visited_nibbles(); + + // 2) Create the branch node at the fork, i.e., `X`, as described above. + let mut new_branch_node_at_fork = BranchNode::default(); + let extension_node_index = extension_nibble_iter.next().expect("Ran out of nibbles"); + let new_leaf_node_index = nibble_iter.next().expect("Ran out of nibbles"); + assert_ne!(extension_node_index, new_leaf_node_index); + + // 3) Connects the two children to the branch node at fork; create if necessary. + let extension_nibble_iter_after_fork = extension_nibble_iter.remaining_nibbles(); + // Check whether the extension node after fork is necessary. + if extension_nibble_iter_after_fork.num_nibbles() != 0 { + // `...XO...` case: some nibbles of extension node are left after fork. + let new_extension_node_after_fork = Node::new_extension( + extension_nibble_iter_after_fork.get_nibble_path(), + extension_node.child(), + ); + let new_extension_node_after_fork_hash = new_extension_node_after_fork.hash(); + tree_cache.put_node( + new_extension_node_after_fork_hash, + new_extension_node_after_fork, + )?; + new_branch_node_at_fork.set_child( + extension_node_index, + (new_extension_node_after_fork_hash, false /* is_leaf */), + ); + } else { + // `...X` case: the nibble at the fork is the last nibble of the extension node. + new_branch_node_at_fork.set_child( + extension_node_index, + ( + extension_node.child(), + false, /* is_leaf, extension node must have a branch node as child */ + ), + ); + } + + // Set another child of the new branch node to be the new inserted leaf node. + new_branch_node_at_fork.set_child( + new_leaf_node_index, + Self::create_leaf_node(nibble_iter, value_hash, tree_cache)?, + ); + let mut new_node_hash = new_branch_node_at_fork.hash(); + tree_cache.put_node(new_node_hash, new_branch_node_at_fork.into())?; + + // 4) Check whether a extension node before the fork is necessary. + if extension_nibble_iter_before_fork.num_nibbles() != 0 { + let new_extension_node_before_fork = Node::new_extension( + extension_nibble_iter_before_fork.get_nibble_path(), + new_node_hash, + ); + new_node_hash = new_extension_node_before_fork.hash(); + tree_cache.put_node(new_node_hash, new_extension_node_before_fork)?; + } + Ok((new_node_hash, false /* is_leaf */)) + } + + /// Helper function for recursive insertion into the subtree that starts from the current + /// `leaf_node`. Returns the hash of the created node and a boolean indicating whether + /// the created node is a leaf node. + fn insert_at_leaf_node( + existing_leaf_node: LeafNode, + nibble_iter: &mut NibbleIterator, + value_hash: HashValue, + tree_cache: &mut TreeCache, + ) -> Result<(HashValue, bool)> { + // We are on a leaf node but trying to insert another node, so we may diverge. Different + // from insertion at branch nodes or at extension nodes, we don't delete the existing + // leaf node here unless it has the same key as the incoming key. + + // 1. Make sure that the existing leaf nibble_path has the same prefix as the already + // visited part of the nibble iter of the incoming key and advances the existing leaf + // nibble iterator by the length of that prefix. + let mut visited_nibble_iter = nibble_iter.visited_nibbles(); + let existing_leaf_nibble_path = NibblePath::new(existing_leaf_node.key().to_vec()); + let mut existing_leaf_nibble_iter = existing_leaf_nibble_path.nibbles(); + skip_common_prefix(&mut visited_nibble_iter, &mut existing_leaf_nibble_iter); + assert!( + visited_nibble_iter.is_finished(), + "Leaf nodes failed to share the same visited nibbles before index {}", + existing_leaf_nibble_iter.visited_nibbles().num_nibbles() + ); + + // 2. Determine the secondary common prefix between this leaf node and the incoming + // key. + let mut pruned_existing_leaf_nibble_iter = existing_leaf_nibble_iter.remaining_nibbles(); + skip_common_prefix(nibble_iter, &mut pruned_existing_leaf_nibble_iter); + assert_eq!( + nibble_iter.is_finished(), + pruned_existing_leaf_nibble_iter.is_finished(), + "key lengths mismatch." + ); + // Both are finished. That means the incoming key already exists in the tree and we just + // need to update its value. + if nibble_iter.is_finished() { + // Imply `&& pruned_existing_leaf_nibble_iter.is_finished()`. + // Delete the old leaf and create the new one. + // Note: it is necessary to delete the corresponding blob too. + tree_cache.delete_blob(existing_leaf_node.value_hash()); + tree_cache.delete_node(existing_leaf_node.hash()); + return Ok(Self::create_leaf_node(nibble_iter, value_hash, tree_cache)?); + } + + // Not both are finished. This means the incoming key branches off at some point. + // Then create a branch node at the fork, a new leaf node for the incoming key, and + // create an extension node if the secondary common prefix length is not 0. + // We create new nodes in a bottom-up order. + + // 1) Keep the visited nibble iterator. We will use it in step 3). + let pruned_existing_leaf_nibble_iter_before_fork = + pruned_existing_leaf_nibble_iter.visited_nibbles(); + + // 2) Create the branch node at the fork. + let existing_leaf_index = pruned_existing_leaf_nibble_iter + .next() + .expect("Ran out of nibbles"); + let new_leaf_index = nibble_iter.next().expect("Ran out of nibbles"); + assert_ne!(existing_leaf_index, new_leaf_index); + + let mut branch_node = BranchNode::default(); + branch_node.set_child( + existing_leaf_index, + (existing_leaf_node.hash(), true /* is_leaf */), + ); + branch_node.set_child( + new_leaf_index, + Self::create_leaf_node(nibble_iter, value_hash, tree_cache)?, + ); + + let mut new_node_hash = branch_node.hash(); + tree_cache.put_node(new_node_hash, branch_node.into())?; + + // 3) Create an extension node before the fork if necessary. + if pruned_existing_leaf_nibble_iter_before_fork.num_nibbles() != 0 { + let new_extension_node = Node::new_extension( + pruned_existing_leaf_nibble_iter_before_fork.get_nibble_path(), + new_node_hash, + ); + new_node_hash = new_extension_node.hash(); + tree_cache.put_node(new_node_hash, new_extension_node)?; + } + Ok((new_node_hash, false /* is_leaf */)) + } + + /// Helper function for creating leaf nodes. Returns the hash of the newly created leaf node and + /// a boolean indicating whether the created node is leaf node (always `true`). + fn create_leaf_node( + nibble_iter: &NibbleIterator, + value_hash: HashValue, + tree_cache: &mut TreeCache, + ) -> Result<(HashValue, bool)> { + // Get the underlying bytes of nibble_iter which must be a key, i.e., hashed account address + // with `HashValue::LENGTH` bytes. + let nibble_path = nibble_iter.get_nibble_path(); + + // Now create the new leaf node. + let new_leaf = Node::new_leaf(HashValue::from_slice(nibble_path.bytes())?, value_hash); + + // Cache new leaf node and return its hash. + let new_node_hash = new_leaf.hash(); + tree_cache.put_node(new_node_hash, new_leaf)?; + Ok((new_node_hash, true /* is_leaf */)) + } + + /// Returns the account state blob (if applicable) and the corresponding merkle proof. + pub fn get_with_proof( + &self, + key: HashValue, + root_hash: HashValue, + ) -> Result<(Option, SparseMerkleProof)> { + // Empty tree just returns proof with no sibling hash. + if root_hash == *SPARSE_MERKLE_PLACEHOLDER_HASH { + return Ok((None, SparseMerkleProof::new(None, vec![]))); + } + let mut siblings = vec![]; + let nibble_path = NibblePath::new(key.to_vec()); + let mut nibble_iter = nibble_path.nibbles(); + let mut next_hash = root_hash; + + // We limit the number of loops here deliberately to avoid potential cyclic graph bugs + // in the tree structure. + for _i in 0..ROOT_NIBBLE_HEIGHT { + match self.reader.get_node(next_hash)? { + Node::Branch(branch_node) => { + let queried_child_index = match nibble_iter.next() { + Some(nibble) => nibble, + // Shouldn't happen + None => bail!("ran out of nibbles"), + }; + let (child_for_proof, mut siblings_in_branch) = + branch_node.get_child_for_proof_and_siblings(queried_child_index); + siblings.append(&mut siblings_in_branch); + next_hash = match child_for_proof { + Some(hash) => hash, + None => return Ok((None, SparseMerkleProof::new(None, siblings))), + }; + } + Node::Extension(extension_node) => { + let (mut siblings_in_extension, needs_early_return) = + extension_node.get_siblings(&mut nibble_iter); + siblings.append(&mut siblings_in_extension); + if needs_early_return { + return Ok((None, SparseMerkleProof::new(None, siblings))); + } + next_hash = extension_node.child(); + } + Node::Leaf(leaf_node) => { + return Ok(( + if leaf_node.key() == key { + Some(self.reader.get_blob(leaf_node.value_hash())?) + } else { + None + }, + SparseMerkleProof::new( + Some((leaf_node.key(), leaf_node.value_hash())), + siblings, + ), + )); + } + } + } + bail!("Sparse Merkle tree has cyclic graph inside."); + } + + #[cfg(test)] + pub fn get(&self, key: HashValue, root_hash: HashValue) -> Result> { + Ok(self.get_with_proof(key, root_hash)?.0) + } +} diff --git a/storage/sparse_merkle/src/mock_tree_store.rs b/storage/sparse_merkle/src/mock_tree_store.rs new file mode 100644 index 0000000000000..be29f7ae31dea --- /dev/null +++ b/storage/sparse_merkle/src/mock_tree_store.rs @@ -0,0 +1,81 @@ +// Copyright (c) The Libra Core Contributors +// SPDX-License-Identifier: Apache-2.0 + +use crate::{node_type::Node, TreeReader, TreeUpdateBatch}; +use crypto::HashValue; +use failure::prelude::*; +use std::{ + collections::{hash_map::Entry, HashMap}, + sync::RwLock, +}; +use types::account_state_blob::AccountStateBlob; + +#[derive(Default)] +pub(crate) struct MockTreeStore( + RwLock<( + HashMap, + HashMap, + )>, +); + +impl TreeReader for MockTreeStore { + fn get_node(&self, node_hash: HashValue) -> Result { + Ok(self + .0 + .read() + .unwrap() + .0 + .get(&node_hash) + .cloned() + .ok_or_else(|| format_err!("Failed to find node with hash {:?}", node_hash))?) + } + + fn get_blob(&self, blob_hash: HashValue) -> Result { + Ok(self + .0 + .read() + .unwrap() + .1 + .get(&blob_hash) + .cloned() + .ok_or_else(|| format_err!("Failed to find blob with hash {:?}", blob_hash))?) + } +} + +impl MockTreeStore { + pub fn put_node(&self, key: HashValue, node: Node) -> Result<()> { + match self.0.write().unwrap().0.entry(key) { + Entry::Occupied(_) => bail!("Key {:?} exists.", key), + Entry::Vacant(v) => { + v.insert(node); + } + } + Ok(()) + } + + pub fn put_blob(&self, key: HashValue, blob: AccountStateBlob) -> Result<()> { + self.0.write().unwrap().1.insert(key, blob); + Ok(()) + } + + pub fn write_tree_update_batch(&self, batch: TreeUpdateBatch) -> Result<()> { + let (node_batch, blob_batch) = batch.into(); + node_batch + .into_iter() + .map(|(k, v)| self.put_node(k, v)) + .collect::>>()?; + blob_batch + .into_iter() + .map(|(k, v)| self.put_blob(k, v)) + .collect::>>()?; + Ok(()) + } + + pub fn num_nodes(&self) -> usize { + self.0.read().unwrap().0.len() + } + + pub fn num_blobs(&self) -> usize { + self.0.read().unwrap().1.len() + } +} diff --git a/storage/sparse_merkle/src/nibble_path/mod.rs b/storage/sparse_merkle/src/nibble_path/mod.rs new file mode 100644 index 0000000000000..7f1ce07879c03 --- /dev/null +++ b/storage/sparse_merkle/src/nibble_path/mod.rs @@ -0,0 +1,258 @@ +// Copyright (c) The Libra Core Contributors +// SPDX-License-Identifier: Apache-2.0 + +//! NibblePath library simplify operations with nibbles in a compact format for modified sparse +//! Merkle tree by providing powerful iterators advancing by either bit or nibble. + +#[cfg(test)] +mod nibble_path_test; + +use crate::ROOT_NIBBLE_HEIGHT; +use std::{fmt, iter::FromIterator}; + +/// NibblePath defines a path in Merkle tree in the unit of nibble (4 bits) +#[derive(Clone, Eq, PartialEq)] +pub struct NibblePath { + /// the underlying bytes that stores the path, 2 nibbles per byte. If the number of nibbles is + /// odd, the second half of the last byte must be 0. + bytes: Vec, + /// Indicates the total number of nibbles in bytes. Either `bytes.len() * 2 - 1` or + /// `bytes.len() * 2`. + num_nibbles: usize, +} + +/// Support debug format by concatenating nibbles literally. For example, [0x12, 0xa0] with 3 +/// nibbles will be printed as "12a". +impl fmt::Debug for NibblePath { + fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result { + self.nibbles().map(|x| write!(f, "{:x}", x)).collect() + } +} + +/// Convert a vector of bytes into `NibblePath` using the lower 4 bits of each byte as nibble. +impl FromIterator for NibblePath { + fn from_iter>(iter: I) -> Self { + let mut bytes = vec![]; + let mut count = 0; + for nibble in iter { + assert!(nibble < 16); + if count % 2 == 0 { + bytes.push(nibble << 4); + } else { + *bytes.last_mut().expect("Cannot be None") |= nibble & 0x0f; + } + count += 1; + } + if count % 2 == 0 { + NibblePath::new(bytes) + } else { + NibblePath::new_odd(bytes) + } + } +} + +impl NibblePath { + /// Creates a new `NibblePath` from a vector of bytes assuming each byte has 2 nibbles. + pub fn new(bytes: Vec) -> Self { + let num_nibbles = bytes.len() * 2; + assert!(num_nibbles <= ROOT_NIBBLE_HEIGHT); + NibblePath { bytes, num_nibbles } + } + + /// Similar to `new()` but assumes that the bytes have one less nibble. + pub fn new_odd(bytes: Vec) -> Self { + assert_eq!( + bytes.last().expect("Should have odd number of nibbles.") & 0x0f, + 0, + "Last nibble must be 0." + ); + let num_nibbles = bytes.len() * 2 - 1; + assert!(num_nibbles <= ROOT_NIBBLE_HEIGHT); + NibblePath { bytes, num_nibbles } + } + + /// Get the i-th bit. + fn get_bit(&self, i: usize) -> bool { + assert!(i / 4 < self.num_nibbles); + let pos = i / 8; + let bit = 7 - i % 8; + ((self.bytes[pos] >> bit) & 1) != 0 + } + + /// Get the i-th nibble, stored at lower 4 bits + fn get_nibble(&self, i: usize) -> u8 { + assert!(i < self.num_nibbles); + (self.bytes[i / 2] >> (if i % 2 == 1 { 0 } else { 4 })) & 0xf + } + + /// Get a bit iterator iterates over the whole nibble path. + pub fn bits(&self) -> BitIterator { + BitIterator { + nibble_path: self, + pos: (0..self.num_nibbles * 4), + } + } + + /// Get a nibble iterator iterates over the whole nibble path. + pub fn nibbles(&self) -> NibbleIterator { + NibbleIterator::new(self, 0, self.num_nibbles) + } + + /// Get the total number of nibbles stored. + pub fn num_nibbles(&self) -> usize { + self.num_nibbles + } + + /// Get the underlying bytes storing nibbles. + pub fn bytes(&self) -> &[u8] { + &self.bytes[..] + } +} + +pub(crate) trait Peekable: Iterator { + /// Returns the `next()` value without advancing the iterator. + fn peek(&self) -> Option; +} + +/// BitIterator iterates a nibble path by bit. +pub struct BitIterator<'a> { + nibble_path: &'a NibblePath, + pos: std::ops::Range, +} + +impl<'a> Peekable for BitIterator<'a> { + /// Returns the `next()` value without advancing the iterator. + fn peek(&self) -> Option { + if self.pos.start < self.pos.end { + Some(self.nibble_path.get_bit(self.pos.start)) + } else { + None + } + } +} + +/// BitIterator spits out a boolean each time. True/false denotes 1/0. +impl<'a> Iterator for BitIterator<'a> { + type Item = bool; + fn next(&mut self) -> Option { + self.pos + .next() + .and_then(|i| Some(self.nibble_path.get_bit(i))) + } +} + +/// Support iterating bits in reversed order. +impl<'a> DoubleEndedIterator for BitIterator<'a> { + fn next_back(&mut self) -> Option { + self.pos + .next_back() + .and_then(|i| Some(self.nibble_path.get_bit(i))) + } +} + +/// NibbleIterator iterates a nibble path by nibble. +pub struct NibbleIterator<'a> { + /// The underlying nibble path that stores the nibbles + nibble_path: &'a NibblePath, + + /// The current index, `pos.start`, will bump by 1 after calling `next()` until `pos.start == + /// pos.end`. + pos: std::ops::Range, + + /// The start index of the iterator. At the beginning, `pos.start == start`. [start, pos.end) + /// defines the range of `nibble_path` this iterator iterates over. `nibble_path` refers to + /// the entire underlying buffer but the range may only be partial. + start: usize, +} + +/// NibbleIterator spits out a byte each time. Each byte must be in range [0, 16). +impl<'a> Iterator for NibbleIterator<'a> { + type Item = u8; + fn next(&mut self) -> Option { + self.pos + .next() + .and_then(|i| Some(self.nibble_path.get_nibble(i))) + } +} + +impl<'a> Peekable for NibbleIterator<'a> { + /// Returns the `next()` value without advancing the iterator. + fn peek(&self) -> Option { + if self.pos.start < self.pos.end { + Some(self.nibble_path.get_nibble(self.pos.start)) + } else { + None + } + } +} + +impl<'a> NibbleIterator<'a> { + fn new(nibble_path: &'a NibblePath, start: usize, end: usize) -> Self { + Self { + nibble_path, + pos: (start..end), + start, + } + } + + /// Returns a nibble iterator that iterates all visited nibbles. + pub fn visited_nibbles(&self) -> NibbleIterator<'a> { + Self::new(self.nibble_path, self.start, self.pos.start) + } + + /// Returns a nibble iterator that iterates all remaining nibbles. + pub fn remaining_nibbles(&self) -> NibbleIterator<'a> { + Self::new(self.nibble_path, self.pos.start, self.pos.end) + } + + /// Turn it into a `BitIterator`. + pub fn bits(&self) -> BitIterator<'a> { + BitIterator { + nibble_path: self.nibble_path, + pos: (self.pos.start * 4..self.pos.end * 4), + } + } + + /// Cut and return the range of the underlying `nibble_path` that this iterator is iterating + /// over as a new `NibblePath` + pub fn get_nibble_path(&self) -> NibblePath { + self.visited_nibbles() + .chain(self.remaining_nibbles()) + .collect() + } + + /// Get the number of nibbles that this iterator covers. + pub fn num_nibbles(&self) -> usize { + self.pos.end - self.start + } + + /// Return `true` if the iteration is over. + pub fn is_finished(&self) -> bool { + self.peek().is_none() + } +} + +/// Advance both iterators if their next nibbles are the same until either reaches the end or +/// the find a mismatch. Return the number of matched nibbles. +pub(crate) fn skip_common_prefix<'a, 'b, I1: 'a, I2: 'b>(x: &'a mut I1, y: &mut I2) -> usize +where + I1: Iterator + Peekable, + I2: Iterator + Peekable, + ::Item: std::cmp::PartialEq<::Item>, +{ + let mut count = 0; + loop { + let x_peek = x.peek(); + let y_peek = y.peek(); + if x_peek.is_none() + || y_peek.is_none() + || x_peek.expect("cannot be none") != y_peek.expect("cannot be none") + { + break; + } + count += 1; + x.next(); + y.next(); + } + count +} diff --git a/storage/sparse_merkle/src/nibble_path/nibble_path_test.rs b/storage/sparse_merkle/src/nibble_path/nibble_path_test.rs new file mode 100644 index 0000000000000..35ac61a8f250e --- /dev/null +++ b/storage/sparse_merkle/src/nibble_path/nibble_path_test.rs @@ -0,0 +1,231 @@ +// Copyright (c) The Libra Core Contributors +// SPDX-License-Identifier: Apache-2.0 + +use crate::{ + nibble_path::{skip_common_prefix, NibblePath}, + ROOT_NIBBLE_HEIGHT, +}; +use proptest::{collection::vec, prelude::*}; + +#[test] +fn test_nibble_path_fmt() { + let nibble_path = NibblePath::new(vec![0x12, 0x34, 0x56]); + assert_eq!(format!("{:?}", nibble_path), "123456"); + + let nibble_path = NibblePath::new_odd(vec![0x12, 0x34, 0x50]); + assert_eq!(format!("{:?}", nibble_path), "12345"); +} + +#[test] +fn test_create_nibble_path_success() { + let nibble_path = NibblePath::new(vec![0x12, 0x34, 0x56]); + assert_eq!(nibble_path.num_nibbles(), 6); + + let nibble_path = NibblePath::new_odd(vec![0x12, 0x34, 0x50]); + assert_eq!(nibble_path.num_nibbles(), 5); + + let nibble_path = NibblePath::new(vec![]); + assert_eq!(nibble_path.num_nibbles(), 0); +} + +#[test] +#[should_panic(expected = "Last nibble must be 0.")] +fn test_create_nibble_path_failure() { + let bytes: Vec = vec![0x12, 0x34, 0x56]; + let _nibble_path = NibblePath::new_odd(bytes); +} + +#[test] +#[should_panic(expected = "Should have odd number of nibbles.")] +fn test_empty_nibble_path() { + NibblePath::new_odd(vec![]); +} + +#[test] +fn test_get_nibble() { + let bytes: Vec = vec![0x12, 0x34]; + let nibble_path = NibblePath::new(bytes); + assert_eq!(nibble_path.get_nibble(0), 0x01); + assert_eq!(nibble_path.get_nibble(1), 0x02); + assert_eq!(nibble_path.get_nibble(2), 0x03); + assert_eq!(nibble_path.get_nibble(3), 0x04); +} + +#[test] +fn test_nibble_iterator() { + let bytes: Vec = vec![0x12, 0x30]; + let nibble_path = NibblePath::new_odd(bytes); + let mut iter = nibble_path.nibbles(); + assert_eq!(iter.next().unwrap(), 0x01); + assert_eq!(iter.next().unwrap(), 0x02); + assert_eq!(iter.next().unwrap(), 0x03); + assert_eq!(iter.next(), None); +} + +#[test] +fn test_get_bit() { + let bytes: Vec = vec![0x01, 0x02]; + let nibble_path = NibblePath::new(bytes); + assert_eq!(nibble_path.get_bit(0), false); + assert_eq!(nibble_path.get_bit(1), false); + assert_eq!(nibble_path.get_bit(2), false); + assert_eq!(nibble_path.get_bit(7), true); + assert_eq!(nibble_path.get_bit(8), false); + assert_eq!(nibble_path.get_bit(14), true); +} + +#[test] +fn test_bit_iter() { + let bytes: Vec = vec![0xc3, 0xa0]; + let nibble_path = NibblePath::new_odd(bytes.clone()); + let mut iter = nibble_path.bits(); + // c: 0b1100 + assert_eq!(iter.next(), Some(true)); + assert_eq!(iter.next(), Some(true)); + assert_eq!(iter.next(), Some(false)); + assert_eq!(iter.next(), Some(false)); + // 3: 0b0011 + assert_eq!(iter.next(), Some(false)); + assert_eq!(iter.next(), Some(false)); + assert_eq!(iter.next(), Some(true)); + assert_eq!(iter.next(), Some(true)); + // a: 0b1010 + assert_eq!(iter.next_back(), Some(false)); + assert_eq!(iter.next_back(), Some(true)); + assert_eq!(iter.next_back(), Some(false)); + assert_eq!(iter.next_back(), Some(true)); + + assert_eq!(iter.next(), None); +} + +#[test] +fn test_visited_nibble_iter() { + let bytes: Vec = vec![0x12, 0x34, 0x56]; + let nibble_path = NibblePath::new(bytes.clone()); + let mut iter = nibble_path.nibbles(); + assert_eq!(iter.next().unwrap(), 0x01); + assert_eq!(iter.next().unwrap(), 0x02); + assert_eq!(iter.next().unwrap(), 0x03); + let mut visited_nibble_iter = iter.visited_nibbles(); + assert_eq!(visited_nibble_iter.next().unwrap(), 0x01); + assert_eq!(visited_nibble_iter.next().unwrap(), 0x02); + assert_eq!(visited_nibble_iter.next().unwrap(), 0x03); +} + +#[test] +fn test_skip_common_prefix() { + { + let nibble_path1 = NibblePath::new(vec![0x12, 0x34, 0x56]); + let nibble_path2 = NibblePath::new(vec![0x12, 0x34, 0x56]); + let mut iter1 = nibble_path1.nibbles(); + let mut iter2 = nibble_path2.nibbles(); + assert_eq!(skip_common_prefix(&mut iter1, &mut iter2), 6); + assert!(iter1.is_finished()); + assert!(iter2.is_finished()); + } + { + let nibble_path1 = NibblePath::new(vec![0x12, 0x35]); + let nibble_path2 = NibblePath::new(vec![0x12, 0x34, 0x56]); + let mut iter1 = nibble_path1.nibbles(); + let mut iter2 = nibble_path2.nibbles(); + assert_eq!(skip_common_prefix(&mut iter1, &mut iter2), 3); + assert_eq!( + iter1.visited_nibbles().get_nibble_path(), + iter2.visited_nibbles().get_nibble_path() + ); + assert_eq!( + iter1.remaining_nibbles().get_nibble_path(), + NibblePath::new_odd(vec![0x50]) + ); + assert_eq!( + iter2.remaining_nibbles().get_nibble_path(), + NibblePath::new_odd(vec![0x45, 0x60]) + ); + } + { + let nibble_path1 = NibblePath::new(vec![0x12, 0x34, 0x56]); + let nibble_path2 = NibblePath::new_odd(vec![0x12, 0x30]); + let mut iter1 = nibble_path1.nibbles(); + let mut iter2 = nibble_path2.nibbles(); + assert_eq!(skip_common_prefix(&mut iter1, &mut iter2), 3); + assert_eq!( + iter1.visited_nibbles().get_nibble_path(), + iter2.visited_nibbles().get_nibble_path() + ); + assert_eq!( + iter1.remaining_nibbles().get_nibble_path(), + NibblePath::new_odd(vec![0x45, 0x60]) + ); + assert!(iter2.is_finished()); + } +} + +impl Arbitrary for NibblePath { + type Parameters = (); + type Strategy = BoxedStrategy; + fn arbitrary_with(_args: Self::Parameters) -> Self::Strategy { + arb_nibble_path().boxed() + } +} + +prop_compose! { + fn arb_nibble_path()( + mut bytes in vec(any::(), 0..ROOT_NIBBLE_HEIGHT/2), is_odd in any::() + ) -> NibblePath { + if let Some(last_byte) = bytes.last_mut() { + if is_odd { + *last_byte &= 0xf0; + return NibblePath::new_odd(bytes); + } + } + NibblePath::new(bytes) + } +} + +prop_compose! { + fn arb_nibble_path_and_current()(nibble_path in any::()) + (current in 0..=nibble_path.num_nibbles(), + nibble_path in Just(nibble_path)) -> (usize, NibblePath) { + (current, nibble_path) + } +} + +proptest! { + #[test] + fn test_nibble_iter_roundtrip(nibble_path in arb_nibble_path()) { + let nibbles = nibble_path.nibbles(); + let nibble_path2 = nibbles.collect(); + prop_assert_eq!(nibble_path, nibble_path2); + } + + #[test] + fn test_visited_and_remaining_nibbles((current, nibble_path) in arb_nibble_path_and_current()) { + let mut nibble_iter = nibble_path.nibbles(); + let mut visited_nibbles = vec![]; + for _ in 0..current { + visited_nibbles.push(nibble_iter.next().unwrap()); + } + let visited_nibble_path = nibble_iter.visited_nibbles().get_nibble_path(); + let remaining_nibble_path = nibble_iter.remaining_nibbles().get_nibble_path(); + let visited_iter = visited_nibble_path.nibbles(); + let remaining_iter = remaining_nibble_path.nibbles(); + prop_assert_eq!(visited_nibbles, visited_iter.collect::>()); + prop_assert_eq!(nibble_iter.collect::>(), remaining_iter.collect::>()); + } + + #[test] + fn test_nibble_iter_to_bit_iter((current, nibble_path) in arb_nibble_path_and_current()) { + let mut nibble_iter = nibble_path.nibbles(); + (0..current) + .into_iter() + .for_each(|_| { + nibble_iter.next().unwrap(); + } + ); + let remaining_nibble_path = nibble_iter.remaining_nibbles().get_nibble_path(); + let remaining_bit_iter = remaining_nibble_path.bits(); + let bit_iter = nibble_iter.bits(); + prop_assert_eq!(remaining_bit_iter.collect::>(), bit_iter.collect::>()); + } + +} diff --git a/storage/sparse_merkle/src/node_serde/mod.rs b/storage/sparse_merkle/src/node_serde/mod.rs new file mode 100644 index 0000000000000..20ee00a2b004f --- /dev/null +++ b/storage/sparse_merkle/src/node_serde/mod.rs @@ -0,0 +1,250 @@ +// Copyright (c) The Libra Core Contributors +// SPDX-License-Identifier: Apache-2.0 + +//! Handle the customized and optimized serialization and deserialization of node types. We +//! customize ser/de for explicit specification and space optimization. + +#[cfg(test)] +mod node_serde_test; + +use crate::{ + nibble_path::NibblePath, + node_type::{BranchNode, ExtensionNode, LeafNode}, +}; +use crypto::HashValue; +use serde::{ + de::{self, Deserialize, Deserializer, SeqAccess, Visitor}, + ser::{self, Serialize, SerializeTuple, Serializer}, +}; +use std::{collections::hash_map::HashMap, fmt, result::Result}; + +/// Customized BranchNode serialization/deserialization +/// +/// A branch node will be serialized to 2 u16 bitmaps and a vector of hashes. The first bitmap +/// indicates which children exist by setting the bit at its corresponding index; the second bitmap +/// indicates which children are leaf nodes in the same way; finally a vector of hashes of children +/// follows in order of index starting at the beginning. For example, if a branch node has 3 +/// children, a leaf node, an extension node, and a branch node, with indices 0, 7, 12, +/// respectively. The serialization structure will be: +/// 1st field: `0b0001000010000001` (LSB denotes the child at index 0) +/// 2nd field: `0b0000000000000001` (LSB denotes the child at index 0) +/// 3rd field: vec![hash1, hash2, hash3] +impl Serialize for BranchNode { + fn serialize(&self, serializer: S) -> Result + where + S: Serializer, + { + let (child_bitmap, leaf_bitmap) = self.generate_bitmaps(); + let mut hashes = Vec::with_capacity(self.num_children()); + for i in 0..16 { + // If a child exists, append it to the vector. Note: Don't have to fetch node + // type info (leaf or not) since leaf_bitmap is born for this. + if child_bitmap >> i & 1 != 0 { + hashes.push(self.child(i as u8).ok_or_else(|| { + ser::Error::custom(format!( + "Invalid branch node: \ + unable to get child {} for BranchNode serialization.", + i + )) + })?); + } + } + let mut tuple = serializer.serialize_tuple(3)?; + tuple.serialize_element(&child_bitmap)?; + tuple.serialize_element(&leaf_bitmap)?; + tuple.serialize_element(&hashes)?; + tuple.end() + } +} + +struct BranchNodeVisitor; + +impl<'de> Visitor<'de> for BranchNodeVisitor { + type Value = BranchNode; + + fn expecting(&self, formatter: &mut fmt::Formatter) -> fmt::Result { + write!( + formatter, + "This visitor expects to receive two u16 bitmaps, \ + followed by a vector of HashValues(up to 16)" + ) + } + + fn visit_seq(self, mut seq: A) -> Result + where + A: SeqAccess<'de>, + { + let child_bitmap: u16 = seq + .next_element()? + .ok_or_else(|| de::Error::invalid_length(0, &self))?; + let leaf_bitmap: u16 = seq + .next_element()? + .ok_or_else(|| de::Error::invalid_length(1, &self))?; + if child_bitmap | leaf_bitmap != child_bitmap { + Err(de::Error::custom( + "Invalid branch node in deserialization: \ + leaf_bitmap conflicts with child_bitmap", + ))?; + } + let hashes: Vec = seq + .next_element()? + .ok_or_else(|| de::Error::invalid_length(2, &self))?; + + if child_bitmap.count_ones() as usize != hashes.len() { + Err(de::Error::custom( + "Invalid branch node in deserialization: \ + children number doesn't match child_bitmap", + ))?; + } + let children = (0..16) + .filter(|i| child_bitmap >> i & 1 == 1) + .zip(hashes.into_iter()) + .map(|(i, d)| (i, (d, leaf_bitmap >> i & 1 == 1))) + .collect::>(); + Ok(BranchNode::new(children)) + } +} + +impl<'de> Deserialize<'de> for BranchNode { + fn deserialize(deserializer: D) -> Result + where + D: Deserializer<'de>, + { + deserializer.deserialize_tuple(3, BranchNodeVisitor) + } +} + +/// Customized ExtensionNode serialization/deserialization +/// +/// An extension node will be serialized into a vector of bytes and a hash value. The vector of +/// bytes represents the encoded_path of this extension node. If the encoded path has an even number +/// of nibbles, we append 0x01 at the end. Otherwise we set the last unused nibble to be 0x0, since +/// in the previous case the last nibble is always 0x1. After the byte vector, we store the +/// hash value of the only child node of the current extension node. +/// +/// For example, if a branch node has an encoded_path of 'a13' with hash1 as its child node hash, +/// the serialization structure will be: +/// 1st field: vec![0xa1, 0x30] +/// 2nd field: hash1 +impl Serialize for ExtensionNode { + fn serialize(&self, serializer: S) -> Result + where + S: Serializer, + { + let nibble_path = self.nibble_path(); + if nibble_path.num_nibbles() == 0 { + return Err(ser::Error::custom("Encoded nibble path bytes is empty")); + } + let mut nibble_bytes = Vec::with_capacity(nibble_path.num_nibbles() / 2 + 1); + nibble_bytes.extend(self.nibble_path().bytes()); + if nibble_path.num_nibbles() % 2 == 1 { + // Set the last nibble to 0x0 to be the flag when odd. + *nibble_bytes + .last_mut() + .expect("Have verified encoded path bytes is not empty.") &= 0xf0; + } else { + // Append a new 1 byte to be the flag when even. + nibble_bytes.push(1); + } + let mut tuple = serializer.serialize_tuple(2)?; + tuple.serialize_element(&nibble_bytes)?; + tuple.serialize_element(&self.child())?; + tuple.end() + } +} + +struct ExtensionNodeVisitor; + +impl<'de> Visitor<'de> for ExtensionNodeVisitor { + type Value = ExtensionNode; + + fn expecting(&self, formatter: &mut fmt::Formatter) -> fmt::Result { + write!(formatter, "This visitor expects to receive a vector of u8") + } + + fn visit_seq(self, mut seq: A) -> Result + where + A: SeqAccess<'de>, + { + let mut nibble_bytes: Vec = seq + .next_element()? + .ok_or_else(|| de::Error::invalid_length(0, &self))?; + let child: HashValue = seq + .next_element()? + .ok_or_else(|| de::Error::invalid_length(1, &self))?; + match nibble_bytes + .last() + .ok_or_else(|| de::Error::custom("Serialization of extension node is empty."))? + & 0x0f + { + 1u8 => { + nibble_bytes.pop(); + Ok(ExtensionNode::new(NibblePath::new(nibble_bytes), child)) + } + 0u8 => Ok(ExtensionNode::new(NibblePath::new_odd(nibble_bytes), child)), + _ => Err(de::Error::custom(format!( + "The last byte of the serialization of extension node is corrupt: {:?}", + nibble_bytes + ))), + } + } +} + +impl<'de> Deserialize<'de> for ExtensionNode { + fn deserialize(deserializer: D) -> Result + where + D: Deserializer<'de>, + { + deserializer.deserialize_tuple(2, ExtensionNodeVisitor) + } +} + +/// Customized LeafNode serialization/deserialization +/// +/// A leaf node will be serialized into two hash values. They are the key to this leaf node and the +/// hash of the value blob under this account. For example, if a leaf node has hash1 as key with +/// hash2 as its value hash, the serialization structure will be: +/// 1st field: hash1 +/// 2nd field: hash2 +impl Serialize for LeafNode { + fn serialize(&self, serializer: S) -> Result + where + S: Serializer, + { + let mut tuple = serializer.serialize_tuple(2)?; + tuple.serialize_element(&self.key())?; + tuple.serialize_element(&self.value_hash())?; + tuple.end() + } +} + +struct LeafNodeVisitor; + +impl<'de> Visitor<'de> for LeafNodeVisitor { + type Value = LeafNode; + fn expecting(&self, formatter: &mut fmt::Formatter) -> fmt::Result { + write!(formatter, "This visitor expects to receive a vector of u8") + } + + fn visit_seq(self, mut seq: A) -> Result + where + A: SeqAccess<'de>, + { + let key: HashValue = seq + .next_element()? + .ok_or_else(|| de::Error::invalid_length(0, &self))?; + let value_hash: HashValue = seq + .next_element()? + .ok_or_else(|| de::Error::invalid_length(1, &self))?; + Ok(LeafNode::new(key, value_hash)) + } +} + +impl<'de> Deserialize<'de> for LeafNode { + fn deserialize(deserializer: D) -> Result + where + D: Deserializer<'de>, + { + deserializer.deserialize_tuple(2, LeafNodeVisitor) + } +} diff --git a/storage/sparse_merkle/src/node_serde/node_serde_test.rs b/storage/sparse_merkle/src/node_serde/node_serde_test.rs new file mode 100644 index 0000000000000..969b7f6b460be --- /dev/null +++ b/storage/sparse_merkle/src/node_serde/node_serde_test.rs @@ -0,0 +1,109 @@ +// Copyright (c) The Libra Core Contributors +// SPDX-License-Identifier: Apache-2.0 + +use crate::{ + nibble_path::NibblePath, + node_type::{BranchNode, ExtensionNode, LeafNode}, +}; +use crypto::HashValue; +use serde_test::{assert_tokens, Token}; + +fn append_hashvalue_tokens(tokens: &mut Vec, hash: HashValue) { + for i in 0..HashValue::LENGTH { + tokens.push(Token::U8(hash[i])); + } +} + +#[test] +fn test_serde_branch_type() { + let child1 = HashValue::random(); + let child2 = HashValue::random(); + let mut branch_node = BranchNode::default(); + branch_node.set_child(1, (child1, true)); + branch_node.set_child(2, (child2, true)); + let mut tokens = vec![]; + tokens.extend_from_slice(&[ + Token::Tuple { len: 3 }, + Token::U16(0b0110), + Token::U16(0b0110), + Token::Seq { len: Some(2) }, + Token::Struct { + name: "HashValue", + len: 1, + }, + Token::Str("hash"), + Token::Tuple { len: 32 }, + ]); + append_hashvalue_tokens(&mut tokens, child1); + tokens.extend_from_slice(&[ + Token::TupleEnd, + Token::StructEnd, + Token::Struct { + name: "HashValue", + len: 1, + }, + Token::Str("hash"), + Token::Tuple { len: 32 }, + ]); + append_hashvalue_tokens(&mut tokens, child2); + tokens.extend_from_slice(&[ + Token::TupleEnd, + Token::StructEnd, + Token::SeqEnd, + Token::TupleEnd, + ]); + assert_tokens(&branch_node, &tokens) +} + +#[test] +fn test_serde_extension_type() { + let child = HashValue::random(); + let path = vec![0xff, 0x10]; + let extension_node = ExtensionNode::new(NibblePath::new_odd(path.clone()), child); + let mut tokens = vec![]; + tokens.extend_from_slice(&[ + Token::Tuple { len: 2 }, + Token::Seq { len: Some(2) }, + Token::U8(path[0]), + Token::U8(path[1]), + Token::SeqEnd, + Token::Struct { + name: "HashValue", + len: 1, + }, + Token::Str("hash"), + Token::Tuple { len: 32 }, + ]); + append_hashvalue_tokens(&mut tokens, child); + tokens.extend_from_slice(&[Token::TupleEnd, Token::StructEnd, Token::TupleEnd]); + assert_tokens(&extension_node, &tokens); +} + +#[test] +fn test_serde_leaf_type() { + let leaf_node = LeafNode::new(HashValue::random(), HashValue::random()); + let mut tokens = vec![]; + tokens.extend_from_slice(&[ + Token::Tuple { len: 2 }, + Token::Struct { + name: "HashValue", + len: 1, + }, + Token::Str("hash"), + Token::Tuple { len: 32 }, + ]); + append_hashvalue_tokens(&mut tokens, leaf_node.key()); + tokens.extend_from_slice(&[ + Token::TupleEnd, + Token::StructEnd, + Token::Struct { + name: "HashValue", + len: 1, + }, + Token::Str("hash"), + Token::Tuple { len: 32 }, + ]); + append_hashvalue_tokens(&mut tokens, leaf_node.value_hash()); + tokens.extend_from_slice(&[Token::TupleEnd, Token::StructEnd, Token::TupleEnd]); + assert_tokens(&leaf_node, &tokens); +} diff --git a/storage/sparse_merkle/src/node_type/mod.rs b/storage/sparse_merkle/src/node_type/mod.rs new file mode 100644 index 0000000000000..8acf9b635fa6c --- /dev/null +++ b/storage/sparse_merkle/src/node_type/mod.rs @@ -0,0 +1,574 @@ +// Copyright (c) The Libra Core Contributors +// SPDX-License-Identifier: Apache-2.0 + +//! Node types of [`SparseMerkleTree`](crate::SparseMerkleTree) +//! +//! This module defines three types of patricia Merkle tree nodes: [`BranchNode`], +//! [`ExtensionNode`] and [`LeafNode`] as building blocks of a 256-bit +//! [`SparseMerkleTree`](crate::SparseMerkleTree). [`BranchNode`] represents a 4-level binary tree +//! to optimize for IOPS: it compresses a tree with 31 nodes into one node with 16 chidren at the +//! lowest level. [`ExtensionNode`] compresses a partial path without any fork into a single node by +//! storing the partial path inside. [`LeafNode`] stores the full key and the value hash which is +//! used as the key to query binary account blob data from the storage. + +#[cfg(test)] +mod node_type_test; + +use crate::nibble_path::{skip_common_prefix, NibbleIterator, NibblePath}; +use bincode::{deserialize, serialize}; +use crypto::{ + hash::{ + CryptoHash, SparseMerkleInternalHasher, SparseMerkleLeafHasher, + SPARSE_MERKLE_PLACEHOLDER_HASH, + }, + HashValue, +}; +use failure::{Fail, Result}; +use serde::{Deserialize, Serialize}; +use std::collections::hash_map::HashMap; +use types::proof::{SparseMerkleInternalNode, SparseMerkleLeafNode}; + +pub(crate) type Children = HashMap; + +/// Represents a 4-level subtree with 16 children at the bottom level. Theoretically, this reduces +/// IOPS to query a tree by 4x since we compress 4 levels in a standard Merkle tree into 1 node. +/// Though we choose the same branch node structure as that of a patricia Merkle tree, the root hash +/// computation logic is similar to a 4-level sparse Merkle tree except for some customizations. See +/// the `CryptoHash` trait implementation below for details. +#[derive(Clone, Debug, Eq, PartialEq, Default)] +pub struct BranchNode { + // key: child index from 0 to 15, inclusive. + // value: Child node hash and a boolean whose true value indicates the child is a leaf node. + children: Children, +} + +/// Node in a patricia Merkle tree. It compresses a path without any fork with a single +/// node instead of multiple single-child branch nodes. +#[derive(Clone, Debug, Eq, PartialEq)] +pub struct ExtensionNode { + // The nibble path this extension node encapsulates. + nibble_path: NibblePath, + // Represents the next node down the path. + child: HashValue, +} + +/// Represents an account. It has two fields: `key` is the hash of the acccont adress and +/// `value_hash` is the hash of account state blob. +#[derive(Clone, Debug, Eq, PartialEq)] +pub struct LeafNode { + // the full key of this node + key: HashValue, + // the hash of the data blob identified by the key + value_hash: HashValue, +} + +/// The explicit tag is used as a prefix in the encoded format of nodes to distinguish different +/// node discrinminants. +trait Tag { + const TAG: u8; +} + +// We leave 0 reserved. +impl Tag for BranchNode { + const TAG: u8 = 1; +} + +impl Tag for ExtensionNode { + const TAG: u8 = 2; +} + +impl Tag for LeafNode { + const TAG: u8 = 3; +} + +/// Computes the hash of branch node according to [`SparseMerkleTree`](crate::SparseMerkleTree) +/// data structure in the logical view. `start` and `nibble_height` determine a subtree whose +/// root hash we want to get. For a branch node with 16 children at the bottom level, we compute +/// the root hash of it as if a full binary Merkle tree with 16 leaves as below: +/// +/// ```text +/// +/// 4 -> +------ root hash ------+ +/// | | +/// 3 -> +---- # ----+ +---- # ----+ +/// | | | | +/// 2 -> # # # # +/// / \ / \ / \ / \ +/// 1 -> # # # # # # # # +/// / \ / \ / \ / \ / \ / \ / \ / \ +/// 0 -> 0 1 2 3 4 5 6 7 8 9 A B C D E F +/// ^ +/// height +/// ``` +/// +/// As illustrated above, at nibble height 0, `0..F` in hex denote 16 chidren hashes. Each `#` +/// means the hash of its two direct children, which will be used to generate the hash of its +/// parent with the hash of its sibling. Finally, we can get the hash of this branch node. +/// +/// However, if a branch node doesn't have all 16 chidren exist at height 0 but just a few of +/// them, we have a modified hashing rule on top of what is stated above: +/// 1. From top to bottom, a node will be replaced by a leaf child if the subtree rooted at this +/// node has only one child at height 0 and it is a leaf child. +/// 2. From top to bottom, a node will be replaced by the placeholder node if the subtree rooted at +/// this node doesn't have any child at height 0. For example, if a branch node has 3 leaf nodes +/// at index 0, 3, 8, respectively, and 1 branch/extension node at index C, then the computation +/// graph will be like: +/// +/// ```text +/// +/// 4 -> +------ root hash ------+ +/// | | +/// 3 -> +---- # ----+ +---- # ----+ +/// | | | | +/// 2 -> # @ 8 # +/// / \ / \ +/// 1 -> 0 3 # @ +/// / \ +/// 0 -> C @ +/// ^ +/// height +/// Note: @ denotes placeholder hash. +/// ``` +impl CryptoHash for BranchNode { + // Unused hasher. + type Hasher = SparseMerkleInternalHasher; + + fn hash(&self) -> HashValue { + self.merkle_hash( + 0, /* start index */ + 16, /* the number of leaves in the subtree of which we want the hash of root */ + self.generate_bitmaps(), + ) + } +} + +/// Computes the hash of an [`ExtensionNode`]. Similar to [`BranchNode`], we generate +/// the hash by logically expanding it into a sparse Merkle tree. For an extension node with 2 +/// nibbles, compute the final hash as follows: +/// +/// ```text +/// +/// #(final hash) +/// / \ +/// # placeholder +/// / \ +/// # placeholder +/// / \ +/// placeholder # +/// / \ +/// # placeholder +/// / \ +/// placeholder \ +/// / \ +/// # placeholder +/// / \ +/// # placeholder +/// / \ +/// child placeholder +/// ``` +/// +/// The final hash is generated by iteratively hashing the concatenation of two children of each +/// node following a bottom-up order. It is worth nothing that by definition [`ExtensionNode`] is +/// just a path, so each intermediate node must only have one child. When being expanded to a +/// sparse Merkle tree logically, empty nodes should be replaced by the default digest. +impl CryptoHash for ExtensionNode { + // Unused hasher. + type Hasher = SparseMerkleInternalHasher; + + fn hash(&self) -> HashValue { + self.nibble_path.bits().rev().fold(self.child, |hash, bit| { + if bit { + SparseMerkleInternalNode::new(*SPARSE_MERKLE_PLACEHOLDER_HASH, hash).hash() + } else { + SparseMerkleInternalNode::new(hash, *SPARSE_MERKLE_PLACEHOLDER_HASH).hash() + } + }) + } +} + +/// Computes the hash of a [`LeafNode`]. +impl CryptoHash for LeafNode { + // Unused hasher. + type Hasher = SparseMerkleLeafHasher; + + fn hash(&self) -> HashValue { + SparseMerkleLeafNode::new(self.key, self.value_hash).hash() + } +} + +/// The concrete node type of [`SparseMerkleTree`](crate::SparseMerkleTree). +#[derive(Clone, Debug, Eq, PartialEq, Serialize, Deserialize)] +pub enum Node { + /// A wrapper of [`BranchNode`]. + Branch(BranchNode), + /// A wrapper of [`ExtensionNode`]. + Extension(ExtensionNode), + /// A wrapper of [`LeafNode`]. + Leaf(LeafNode), +} + +impl From for Node { + fn from(node: BranchNode) -> Self { + Node::Branch(node) + } +} + +impl From for Node { + fn from(node: ExtensionNode) -> Self { + Node::Extension(node) + } +} + +impl From for Node { + fn from(node: LeafNode) -> Self { + Node::Leaf(node) + } +} + +impl BranchNode { + /// Creates a new branch node. + pub fn new(children: HashMap) -> Self { + Self { children } + } + + /// Sets the `n`-th child to given hash and stores a `bool` indicating whether the child passed + /// in is a leaf node. + pub fn set_child(&mut self, n: u8, child: (HashValue, bool)) { + assert!(n < 16); + self.children.insert(n, child); + } + + /// Gets the hash of the `n`-th child. + pub fn child(&self, n: u8) -> Option { + assert!(n < 16); + self.children.get(&n).map(|p| p.0) + } + + /// Returns an `Option` indicating whether the `n`-th child is a leaf node. If the child + /// doesn't exist, returns `None`. + pub fn is_leaf(&self, n: u8) -> Option { + assert!(n < 16); + self.children.get(&n).map(|p| p.1) + } + + /// Return the total number of children + pub fn num_children(&self) -> usize { + self.children.len() + } + + /// Generates `child_bitmap` and `leaf_bitmap` as a pair of `u16`: child at index `i` exists if + /// `child_bitmap[i]` is set; child at index `i` is leaf node if `leaf_bitmap[i]` is set. + pub fn generate_bitmaps(&self) -> (u16, u16) { + let mut child_bitmap = 0_u16; + let mut leaf_bitmap = 0_u16; + for (k, v) in self.children.iter() { + child_bitmap |= 1u16 << k; + leaf_bitmap |= (v.1 as u16) << k; + } + // `leaf_bitmap` must be a subset of `child_bitmap`. + assert!(child_bitmap | leaf_bitmap == child_bitmap); + (child_bitmap, leaf_bitmap) + } + + /// Given a range [start, start + width), returns the sub-bitmap of that range. + fn range_bitmaps(start: u8, width: u8, bitmaps: (u16, u16)) -> (u16, u16) { + assert!(start < 16 && start % width == 0); + // A range with `start == 8` and `width == 4` will generate a mask 0b0000111100000000. + let mask = if width == 16 { + 0xffff + } else { + assert!(width <= 16 && (width & (width - 1)) == 0); + (1 << width) - 1 + } << start; + (bitmaps.0 & mask, bitmaps.1 & mask) + } + + fn merkle_hash( + &self, + start: u8, + width: u8, + (child_bitmap, leaf_bitmap): (u16, u16), + ) -> HashValue { + // Given a bit [start, 1 << nibble_height], return the value of that range. + let (range_child_bitmap, range_leaf_bitmap) = + BranchNode::range_bitmaps(start, width, (child_bitmap, leaf_bitmap)); + if range_child_bitmap == 0 { + // No child under this subtree + *SPARSE_MERKLE_PLACEHOLDER_HASH + } else if range_child_bitmap & (range_child_bitmap - 1) == 0 + && (range_leaf_bitmap != 0 || width == 1) + { + // Only 1 leaf child under this subtree or reach the lowest level + let only_child_index = range_child_bitmap.trailing_zeros() as u8; + self.children + .get(&only_child_index) + .unwrap_or_else(|| { + panic!( + "Corrupted branch node: child_bitmap indicates \ + the existense of a non-exist child at index {}", + only_child_index + ) + }) + .0 + } else { + let left_child = self.merkle_hash(start, width / 2, (child_bitmap, leaf_bitmap)); + let right_child = + self.merkle_hash(start + width / 2, width / 2, (child_bitmap, leaf_bitmap)); + SparseMerkleInternalNode::new(left_child, right_child).hash() + } + } + + /// Gets the child and its corresponding siblings that are necessary to generate the proof for + /// the `n`-th child. If it is an existence proof, the returned child must be the `n`-th + /// child; otherwise, the returned child may be another child. See inline explanation for + /// details. When calling this function with n = 11 (node `b` in the following graph), the + /// range at each level is illustrated as a pair of square brackets: + /// + /// ```text + /// 4 [f e d c b a 9 8 7 6 5 4 3 2 1 0] -> root level + /// --------------------------------------------------------------- + /// 3 [f e d c b a 9 8] [7 6 5 4 3 2 1 0] width = 8 + /// chs <--β”˜ shs <--β”˜ + /// 2 [f e d c] [b a 9 8] [7 6 5 4] [3 2 1 0] width = 4 + /// lhs <--β”˜ β””--> chs + /// 1 [f e] [d c] [b a] [9 8] [7 6] [5 4] [3 2] [1 0] width = 2 + /// chs <--β”˜ β””--> shs + /// 0 [f] [e] [d] [c] [b] [a] [9] [8] [7] [6] [5] [4] [3] [2] [1] [0] width = 1 + /// ^ chs <--β”˜ β””--> schs + /// | MSB|<---------------------- uint 16 ---------------------------->|LSB + /// height chs: `child_half_start` shs: `sibling_half_start` + /// ``` + pub fn get_child_for_proof_and_siblings(&self, n: u8) -> (Option, Vec) { + let mut siblings = vec![]; + assert!(n < 16); + let (child_bitmap, leaf_bitmap) = self.generate_bitmaps(); + + // Nibble height from 3 to 0. + for h in (0..4).rev() { + // Get the number of children of the branch node that each subtree at this height + // covers. + let width = 1 << h; + // Get the index of the first child belonging to the same subtree whose root, let's say + // `r` is at height `h` that the n-th child belongs to. + // Note: `child_half_start` will be always equal to `n` at height 0. + let child_half_start = (0xff << h) & n; + // Get the index of the first child belonging to the subtree whose root is the sibling + // of `r` at height `h`. + let sibling_half_start = child_half_start ^ (1 << h); + // Compute the root hash of the subtree rooted at the sibling of `r`. + siblings.push(self.merkle_hash(sibling_half_start, width, (child_bitmap, leaf_bitmap))); + + let (range_child_bitmap, range_leaf_bitmap) = + BranchNode::range_bitmaps(child_half_start, width, (child_bitmap, leaf_bitmap)); + + if range_child_bitmap == 0 { + // No child in this range. + return (None, siblings); + } else if range_child_bitmap.count_ones() == 1 + && (range_leaf_bitmap.count_ones() == 1 || width == 1) + { + // Return the only 1 leaf child under this subtree or reach the lowest level + // Even this leaf child is not the n-th child, it should be returned instead of + // `None` because it's existence indirectly proves the n-th child doesn't exist. + // Please read proof format for details. + let only_child_index = range_child_bitmap.trailing_zeros() as u8; + return ( + Some( + self.children + .get(&only_child_index) + .unwrap_or_else(|| { + panic!( + "Corrupted branch node: child_bitmap indicates \ + the existense of a non-exist child at index {}", + only_child_index + ) + }) + .0, + ), + siblings, + ); + } + } + unreachable!() + } +} + +impl ExtensionNode { + /// Creates a new extension node. + pub fn new(nibble_path: NibblePath, child: HashValue) -> Self { + Self { nibble_path, child } + } + + /// Gets the only child. + pub fn child(&self) -> HashValue { + self.child + } + + /// Sets the child. + pub fn set_child(&mut self, child_hash: HashValue) { + self.child = child_hash; + } + + /// Gets the `encoded_path`. + pub fn nibble_path(&self) -> &NibblePath { + &self.nibble_path + } + + /// Gets the siblings from this extension node according to the requested nibble iterator. + /// Also return a boolean indicating whether we can stop traversing and return early. + pub fn get_siblings(&self, nibble_iter: &mut NibbleIterator) -> (Vec, bool) { + let mut extension_nibble_iter = self.nibble_path().nibbles(); + let mut siblings = vec![ + *SPARSE_MERKLE_PLACEHOLDER_HASH; + skip_common_prefix(&mut extension_nibble_iter, nibble_iter) * 4 /* 1 nibble == 4 bits */ + ]; + // There are two possible cases after matching prefix: + // 1. Not all the nibbles of the extension node match the nibble path of the queried key. + // This means the queried key meets a default node when being matched with the extension + // node nibble path, so we can terminate the search early and return a non-existence proof + // with the proper number of siblings. + if !extension_nibble_iter.is_finished() { + let mut extension_bit_iter = extension_nibble_iter.bits(); + let mut request_bit_iter = nibble_iter.bits(); + let num_matched_bits = + skip_common_prefix(&mut extension_bit_iter, &mut request_bit_iter); + assert!(num_matched_bits < 4); + // Note: We have to skip 1 bit here to ensure the right result. For example, assume the + // extension node has 2 nibbles (8 bits) and only the first 5 bits are matched. The + // siblings of the queried key should include 5 default hashes followed by `#1`, which + // is the result of iteratively hashing `n` times from bottom up starting with `child` + // where `n` equals the number of bits left after matching minus 1. + // + //```text + // + // #(final hash) + // / \------------------> 1st bit \ + // # placeholder \ + // / \--------------------> 2nd bit \ + // # placeholder } 1st nibble + // / \----------------------> 3rd bit / + // placeholder # / + // / \--------------------> 4th bit / + // # placeholder + // / \----------------------> 5th bit \ + // placeholder \ \ + // / \--------------------> 6th bit \ + // #1 the queried key } 2nd nibble + // / \----------------------> 7th bit / + // # placeholder / + // / \------------------------> 8th bit / + // child placeholder + // ``` + extension_bit_iter.next(); + siblings.append(&mut vec![*SPARSE_MERKLE_PLACEHOLDER_HASH; num_matched_bits]); + siblings.push(extension_bit_iter.rev().fold(self.child, |hash, bit| { + if bit { + SparseMerkleInternalNode::new(*SPARSE_MERKLE_PLACEHOLDER_HASH, hash).hash() + } else { + SparseMerkleInternalNode::new(hash, *SPARSE_MERKLE_PLACEHOLDER_HASH).hash() + } + })); + (siblings, true /* early termination */) + } else { + // 2. All the nibbles of the extension node match the nibble path of the queried key. + // Then just return the siblings and a `false` telling the callsite to continue to + // traverse the tree. + (siblings, false /* early termination */) + } + } +} + +impl LeafNode { + /// Creates a new leaf node. + pub fn new(key: HashValue, value_hash: HashValue) -> Self { + Self { key, value_hash } + } + + /// Gets the `key`. + pub fn key(&self) -> HashValue { + self.key + } + + /// Gets the associated `value_hash`. + pub fn value_hash(&self) -> HashValue { + self.value_hash + } + + /// Sets the associated `value_hash`. + pub fn set_value_hash(&mut self, value_hash: HashValue) { + self.value_hash = value_hash; + } +} + +impl Node { + /// Creates the [`Branch`](Node::Branch) variant. + pub fn new_branch(children: HashMap) -> Self { + Node::Branch(BranchNode::new(children)) + } + + /// Creates the [`Extension`](Node::Extension) variant. + pub fn new_extension(nibble_path: NibblePath, child: HashValue) -> Self { + Node::Extension(ExtensionNode::new(nibble_path, child)) + } + + /// Creates the [`Leaf`](Node::Leaf) variant. + pub fn new_leaf(key: HashValue, value_hash: HashValue) -> Self { + Node::Leaf(LeafNode::new(key, value_hash)) + } + + /// Serializes to bytes for physical storage. + pub fn encode(&self) -> Result> { + let mut out = vec![]; + match self { + Node::Branch(branch_node) => { + out.push(BranchNode::TAG); + out.extend(serialize(&branch_node)?); + } + Node::Leaf(leaf_node) => { + out.push(LeafNode::TAG); + out.extend(serialize(leaf_node)?); + } + Node::Extension(extension_node) => { + out.push(ExtensionNode::TAG); + out.extend(serialize(extension_node)?); + } + } + Ok(out) + } + + /// Hashes are used to lookup the node in the database. + pub fn hash(&self) -> HashValue { + match self { + Node::Branch(branch_node) => branch_node.hash(), + Node::Extension(extension_node) => extension_node.hash(), + Node::Leaf(leaf_node) => leaf_node.hash(), + } + } + + /// Recovers from serialized bytes in physical storage. + pub fn decode(val: &[u8]) -> Result { + if val.is_empty() { + Err(NodeDecodeError::EmptyInput)? + } + let node_tag = val[0]; + match node_tag { + BranchNode::TAG => Ok(Node::Branch(deserialize(&val[1..].to_vec())?)), + ExtensionNode::TAG => Ok(Node::Extension(deserialize(&val[1..].to_vec())?)), + LeafNode::TAG => Ok(Node::Leaf(deserialize(&val[1..].to_vec())?)), + unknown_tag => Err(NodeDecodeError::UnknownTag { unknown_tag })?, + } + } +} + +/// Error thrown when a [`Node`] fails to be deserialized out of a byte sequence stored in physical +/// storage, via [`Node::decode`]. +#[derive(Debug, Fail, Eq, PartialEq)] +pub enum NodeDecodeError { + /// Input is empty. + #[fail(display = "Missing tag due to empty input")] + EmptyInput, + + /// The first byte of the input is not a known tag representing one of the variants. + #[fail(display = "lead tag byte is unknown: {}", unknown_tag)] + UnknownTag { unknown_tag: u8 }, +} diff --git a/storage/sparse_merkle/src/node_type/node_type_test.rs b/storage/sparse_merkle/src/node_type/node_type_test.rs new file mode 100644 index 0000000000000..fe4226d774582 --- /dev/null +++ b/storage/sparse_merkle/src/node_type/node_type_test.rs @@ -0,0 +1,573 @@ +// Copyright (c) The Libra Core Contributors +// SPDX-License-Identifier: Apache-2.0 + +use crate::{ + nibble_path::NibblePath, + node_type::{BranchNode, Children, Node, NodeDecodeError}, +}; +use crypto::{ + hash::{CryptoHash, SPARSE_MERKLE_PLACEHOLDER_HASH}, + HashValue, +}; +use types::proof::{SparseMerkleInternalNode, SparseMerkleLeafNode}; + +fn hash_internal(left: HashValue, right: HashValue) -> HashValue { + SparseMerkleInternalNode::new(left, right).hash() +} + +fn hash_leaf(key: HashValue, value_hash: HashValue) -> HashValue { + SparseMerkleLeafNode::new(key, value_hash).hash() +} + +#[test] +fn test_encode_decode() { + let leaf_node = Node::new_leaf(HashValue::random(), HashValue::random()); + + let mut children = Children::default(); + children.insert(0, (leaf_node.hash(), true)); + + let nodes = vec![ + Node::new_branch(children), + Node::new_leaf(HashValue::random(), HashValue::random()), + Node::new_extension(NibblePath::new(vec![1, 2, 3, 4]), HashValue::random()), + ]; + for n in &nodes { + let v = n.encode().unwrap(); + assert_eq!(*n, Node::decode(&v).unwrap()); + } + // Error cases + if let Err(e) = Node::decode(&[]) { + assert_eq!( + e.downcast::().unwrap(), + NodeDecodeError::EmptyInput + ); + } + if let Err(e) = Node::decode(&[100]) { + assert_eq!( + e.downcast::().unwrap(), + NodeDecodeError::UnknownTag { unknown_tag: 100 } + ); + } +} + +#[test] +fn test_leaf_hash() { + { + let address = HashValue::random(); + let value_hash = HashValue::random(); + let hash = hash_leaf(address, value_hash); + let leaf_node = Node::new_leaf(address, value_hash); + assert_eq!(leaf_node.hash(), hash); + } +} + +#[test] +fn test_extension_hash() { + { + let mut hash = HashValue::zero(); + let extension_node = Node::new_extension(NibblePath::new(vec![0b_0000_1111]), hash); + + for _ in 0..4 { + hash = hash_internal(*SPARSE_MERKLE_PLACEHOLDER_HASH, hash); + } + for _ in 4..8 { + hash = hash_internal(hash, *SPARSE_MERKLE_PLACEHOLDER_HASH); + } + assert_eq!(extension_node.hash(), hash); + } + { + let mut hash = HashValue::random(); + let extension_node = Node::new_extension(NibblePath::new_odd(vec![0b_1110_0000]), hash); + + hash = hash_internal(hash, *SPARSE_MERKLE_PLACEHOLDER_HASH); + for _ in 1..4 { + hash = hash_internal(*SPARSE_MERKLE_PLACEHOLDER_HASH, hash); + } + assert_eq!(extension_node.hash(), hash); + } +} + +#[test] +fn test_branch_hash_and_proof() { + // leaf case 1 + { + let mut branch_node = BranchNode::new(Children::default()); + let hash1 = HashValue::random(); + let hash2 = HashValue::random(); + branch_node.set_child(4, (hash1, true)); + branch_node.set_child(15, (hash2, true)); + // Branch node will have a structure below + // + // root + // / \ + // / \ + // leaf1 leaf2 + // + let root_hash = hash_internal(hash1, hash2); + assert_eq!(branch_node.hash(), root_hash); + + for i in 0..8 { + assert_eq!( + branch_node.get_child_for_proof_and_siblings(i), + (Some(hash1), vec![hash2]) + ); + } + for i in 8..16 { + assert_eq!( + branch_node.get_child_for_proof_and_siblings(i), + (Some(hash2), vec![hash1]) + ); + } + } + + // leaf case 2 + { + let mut branch_node = BranchNode::new(Children::default()); + let hash1 = HashValue::random(); + let hash2 = HashValue::random(); + branch_node.set_child(4, (hash1, true)); + branch_node.set_child(7, (hash2, true)); + // Branch node will have a structure below + // + // root + // / + // / + // x2 + // \ + // \ + // x1 + // / \ + // / \ + // leaf1 leaf2 + + let hash_x1 = hash_internal(hash1, hash2); + let hash_x2 = hash_internal(*SPARSE_MERKLE_PLACEHOLDER_HASH, hash_x1); + + let root_hash = hash_internal(hash_x2, *SPARSE_MERKLE_PLACEHOLDER_HASH); + assert_eq!(branch_node.hash(), root_hash); + + for i in 0..4 { + assert_eq!( + branch_node.get_child_for_proof_and_siblings(i), + (None, vec![*SPARSE_MERKLE_PLACEHOLDER_HASH, hash_x1]) + ); + } + + for i in 4..6 { + assert_eq!( + branch_node.get_child_for_proof_and_siblings(i), + ( + Some(hash1), + vec![ + *SPARSE_MERKLE_PLACEHOLDER_HASH, + *SPARSE_MERKLE_PLACEHOLDER_HASH, + hash2 + ] + ) + ); + } + + for i in 6..8 { + assert_eq!( + branch_node.get_child_for_proof_and_siblings(i), + ( + Some(hash2), + vec![ + *SPARSE_MERKLE_PLACEHOLDER_HASH, + *SPARSE_MERKLE_PLACEHOLDER_HASH, + hash1 + ] + ) + ); + } + + for i in 8..16 { + assert_eq!( + branch_node.get_child_for_proof_and_siblings(i), + (None, vec![hash_x2]) + ); + } + } + + // leaf case 3 + { + let mut branch_node = BranchNode::new(Children::default()); + let hash1 = HashValue::random(); + let hash2 = HashValue::random(); + let hash3 = HashValue::random(); + branch_node.set_child(0, (hash1, true)); + branch_node.set_child(7, (hash2, true)); + branch_node.set_child(8, (hash3, true)); + // Branch node will have a structure below + // + // root + // / \ + // / \ + // x leaf3 + // / \ + // / \ + // leaf1 leaf2 + // + + let hash_x = hash_internal(hash1, hash2); + let root_hash = hash_internal(hash_x, hash3); + + for i in 0..4 { + assert_eq!( + branch_node.get_child_for_proof_and_siblings(i), + (Some(hash1), vec![hash3, hash2]) + ); + } + + for i in 4..8 { + assert_eq!( + branch_node.get_child_for_proof_and_siblings(i), + (Some(hash2), vec![hash3, hash1]) + ); + } + + for i in 8..16 { + assert_eq!( + branch_node.get_child_for_proof_and_siblings(i), + (Some(hash3), vec![hash_x]) + ); + } + + assert_eq!(branch_node.hash(), root_hash); + } + + // non-leaf case 1 + { + let mut branch_node = BranchNode::new(Children::default()); + let hash1 = HashValue::random(); + let hash2 = HashValue::random(); + branch_node.set_child(4, (hash1, false)); + branch_node.set_child(15, (hash2, false)); + // Branch node (B) will have a structure below + // + // root + // / \ + // / \ + // x3 x6 + // \ \ + // \ \ + // x2 x5 + // / \ + // / \ + // x1 x4 + // / \ + // / \ + // non-leaf1 non-leaf2 + // + let hash_x1 = hash_internal(hash1, *SPARSE_MERKLE_PLACEHOLDER_HASH); + let hash_x2 = hash_internal(hash_x1, *SPARSE_MERKLE_PLACEHOLDER_HASH); + let hash_x3 = hash_internal(*SPARSE_MERKLE_PLACEHOLDER_HASH, hash_x2); + let hash_x4 = hash_internal(*SPARSE_MERKLE_PLACEHOLDER_HASH, hash2); + let hash_x5 = hash_internal(*SPARSE_MERKLE_PLACEHOLDER_HASH, hash_x4); + let hash_x6 = hash_internal(*SPARSE_MERKLE_PLACEHOLDER_HASH, hash_x5); + let root_hash = hash_internal(hash_x3, hash_x6); + assert_eq!(branch_node.hash(), root_hash); + + for i in 0..4 { + assert_eq!( + branch_node.get_child_for_proof_and_siblings(i), + (None, vec![hash_x6, hash_x2]) + ); + } + + assert_eq!( + branch_node.get_child_for_proof_and_siblings(4), + ( + Some(hash1), + vec![ + hash_x6, + *SPARSE_MERKLE_PLACEHOLDER_HASH, + *SPARSE_MERKLE_PLACEHOLDER_HASH, + *SPARSE_MERKLE_PLACEHOLDER_HASH + ] + ) + ); + + assert_eq!( + branch_node.get_child_for_proof_and_siblings(5), + ( + None, + vec![ + hash_x6, + *SPARSE_MERKLE_PLACEHOLDER_HASH, + *SPARSE_MERKLE_PLACEHOLDER_HASH, + hash1 + ] + ) + ); + for i in 6..8 { + assert_eq!( + branch_node.get_child_for_proof_and_siblings(i), + ( + None, + vec![hash_x6, *SPARSE_MERKLE_PLACEHOLDER_HASH, hash_x1] + ) + ); + } + + for i in 8..12 { + assert_eq!( + branch_node.get_child_for_proof_and_siblings(i), + (None, vec![hash_x3, hash_x5]) + ); + } + + for i in 12..14 { + assert_eq!( + branch_node.get_child_for_proof_and_siblings(i), + ( + None, + vec![hash_x3, *SPARSE_MERKLE_PLACEHOLDER_HASH, hash_x4] + ) + ); + } + assert_eq!( + branch_node.get_child_for_proof_and_siblings(14), + ( + None, + vec![ + hash_x3, + *SPARSE_MERKLE_PLACEHOLDER_HASH, + *SPARSE_MERKLE_PLACEHOLDER_HASH, + hash2 + ] + ) + ); + assert_eq!( + branch_node.get_child_for_proof_and_siblings(15), + ( + Some(hash2), + vec![ + hash_x3, + *SPARSE_MERKLE_PLACEHOLDER_HASH, + *SPARSE_MERKLE_PLACEHOLDER_HASH, + *SPARSE_MERKLE_PLACEHOLDER_HASH + ] + ) + ); + } + + // non-leaf case 2 + { + let mut branch_node = BranchNode::new(Children::default()); + let hash1 = HashValue::random(); + let hash2 = HashValue::random(); + branch_node.set_child(0, (hash1, false)); + branch_node.set_child(7, (hash2, false)); + // Branch node will have a structure below + // + // root + // / + // / + // x5 + // / \ + // / \ + // x2 x4 + // / \ + // / \ + // x1 x3 + // / \ + // / \ + // non-leaf1 non-leaf2 + + let hash_x1 = hash_internal(hash1, *SPARSE_MERKLE_PLACEHOLDER_HASH); + let hash_x2 = hash_internal(hash_x1, *SPARSE_MERKLE_PLACEHOLDER_HASH); + let hash_x3 = hash_internal(*SPARSE_MERKLE_PLACEHOLDER_HASH, hash2); + let hash_x4 = hash_internal(*SPARSE_MERKLE_PLACEHOLDER_HASH, hash_x3); + let hash_x5 = hash_internal(hash_x2, hash_x4); + let root_hash = hash_internal(hash_x5, *SPARSE_MERKLE_PLACEHOLDER_HASH); + assert_eq!(branch_node.hash(), root_hash); + + assert_eq!( + branch_node.get_child_for_proof_and_siblings(0), + ( + Some(hash1), + vec![ + *SPARSE_MERKLE_PLACEHOLDER_HASH, + hash_x4, + *SPARSE_MERKLE_PLACEHOLDER_HASH, + *SPARSE_MERKLE_PLACEHOLDER_HASH, + ] + ) + ); + + assert_eq!( + branch_node.get_child_for_proof_and_siblings(1), + ( + None, + vec![ + *SPARSE_MERKLE_PLACEHOLDER_HASH, + hash_x4, + *SPARSE_MERKLE_PLACEHOLDER_HASH, + hash1, + ] + ) + ); + + for i in 2..4 { + assert_eq!( + branch_node.get_child_for_proof_and_siblings(i), + ( + None, + vec![*SPARSE_MERKLE_PLACEHOLDER_HASH, hash_x4, hash_x1] + ) + ); + } + + for i in 4..6 { + assert_eq!( + branch_node.get_child_for_proof_and_siblings(i), + ( + None, + vec![*SPARSE_MERKLE_PLACEHOLDER_HASH, hash_x2, hash_x3] + ) + ); + } + + assert_eq!( + branch_node.get_child_for_proof_and_siblings(6), + ( + None, + vec![ + *SPARSE_MERKLE_PLACEHOLDER_HASH, + hash_x2, + *SPARSE_MERKLE_PLACEHOLDER_HASH, + hash2 + ] + ) + ); + + assert_eq!( + branch_node.get_child_for_proof_and_siblings(7), + ( + Some(hash2), + vec![ + *SPARSE_MERKLE_PLACEHOLDER_HASH, + hash_x2, + *SPARSE_MERKLE_PLACEHOLDER_HASH, + *SPARSE_MERKLE_PLACEHOLDER_HASH, + ] + ) + ); + + for i in 8..16 { + assert_eq!( + branch_node.get_child_for_proof_and_siblings(i), + (None, vec![hash_x5]) + ); + } + } + + // mixed case + { + let mut branch_node = BranchNode::new(Children::default()); + let hash1 = HashValue::random(); + let hash2 = HashValue::random(); + let hash3 = HashValue::random(); + branch_node.set_child(0, (hash1, true)); + branch_node.set_child(2, (hash2, false)); + branch_node.set_child(7, (hash3, false)); + // Branch node (B) will have a structure below + // + // B (root hash) + // / + // / + // x5 + // / \ + // / \ + // x2 x4 + // / \ \ + // / \ \ + // leaf1 x1 x3 + // / \ + // / \ + // non-leaf2 non-leaf3 + // + let hash_x1 = hash_internal(hash2, *SPARSE_MERKLE_PLACEHOLDER_HASH); + let hash_x2 = hash_internal(hash1, hash_x1); + let hash_x3 = hash_internal(*SPARSE_MERKLE_PLACEHOLDER_HASH, hash3); + let hash_x4 = hash_internal(*SPARSE_MERKLE_PLACEHOLDER_HASH, hash_x3); + let hash_x5 = hash_internal(hash_x2, hash_x4); + let root_hash = hash_internal(hash_x5, *SPARSE_MERKLE_PLACEHOLDER_HASH); + assert_eq!(branch_node.hash(), root_hash); + + for i in 0..2 { + assert_eq!( + branch_node.get_child_for_proof_and_siblings(i), + ( + Some(hash1), + vec![*SPARSE_MERKLE_PLACEHOLDER_HASH, hash_x4, hash_x1] + ) + ); + } + + assert_eq!( + branch_node.get_child_for_proof_and_siblings(2), + ( + Some(hash2), + vec![ + *SPARSE_MERKLE_PLACEHOLDER_HASH, + hash_x4, + hash1, + *SPARSE_MERKLE_PLACEHOLDER_HASH, + ] + ) + ); + + assert_eq!( + branch_node.get_child_for_proof_and_siblings(3), + ( + None, + vec![*SPARSE_MERKLE_PLACEHOLDER_HASH, hash_x4, hash1, hash2,] + ) + ); + + for i in 4..6 { + assert_eq!( + branch_node.get_child_for_proof_and_siblings(i), + ( + None, + vec![*SPARSE_MERKLE_PLACEHOLDER_HASH, hash_x2, hash_x3] + ) + ); + } + + assert_eq!( + branch_node.get_child_for_proof_and_siblings(6), + ( + None, + vec![ + *SPARSE_MERKLE_PLACEHOLDER_HASH, + hash_x2, + *SPARSE_MERKLE_PLACEHOLDER_HASH, + hash3, + ] + ) + ); + + assert_eq!( + branch_node.get_child_for_proof_and_siblings(7), + ( + Some(hash3), + vec![ + *SPARSE_MERKLE_PLACEHOLDER_HASH, + hash_x2, + *SPARSE_MERKLE_PLACEHOLDER_HASH, + *SPARSE_MERKLE_PLACEHOLDER_HASH, + ] + ) + ); + + for i in 8..16 { + assert_eq!( + branch_node.get_child_for_proof_and_siblings(i), + (None, vec![hash_x5]) + ); + } + } +} diff --git a/storage/sparse_merkle/src/sparse_merkle_test.rs b/storage/sparse_merkle/src/sparse_merkle_test.rs new file mode 100644 index 0000000000000..5977bcd6c3168 --- /dev/null +++ b/storage/sparse_merkle/src/sparse_merkle_test.rs @@ -0,0 +1,615 @@ +// Copyright (c) The Libra Core Contributors +// SPDX-License-Identifier: Apache-2.0 + +use super::*; +use crypto::{hash::SPARSE_MERKLE_PLACEHOLDER_HASH, HashValue}; +use mock_tree_store::MockTreeStore; +use rand::{rngs::StdRng, Rng, SeedableRng}; +use types::proof::verify_sparse_merkle_element; + +fn modify(original_key: &HashValue, n: usize, value: u8) -> HashValue { + let mut key = original_key.to_vec(); + key[n] = value; + HashValue::from_slice(&key).unwrap() +} + +#[test] +fn test_insert_to_empty_tree() { + let db = MockTreeStore::default(); + let tree = SparseMerkleTree::new(&db); + + // Tree is initially empty. Root is a null node. We'll insert a key-value pair which creates a + // leaf node. + let key = HashValue::random(); + let value = AccountStateBlob::from(vec![1u8, 2u8, 3u8, 4u8]); + + let (new_root, batch) = tree + .put_keyed_blob_set( + vec![(key, value.clone())], + *SPARSE_MERKLE_PLACEHOLDER_HASH, /* root hash being based on */ + ) + .unwrap(); + db.write_tree_update_batch(batch).unwrap(); + + assert_eq!(tree.get(key, new_root).unwrap().unwrap(), value); +} + +#[test] +fn test_insert_at_leaf_with_branch_created() { + let db = MockTreeStore::default(); + let tree = SparseMerkleTree::new(&db); + + let key1 = HashValue::new([0x00u8; HashValue::LENGTH]); + let value1 = AccountStateBlob::from(vec![1u8, 2u8]); + + let (root1, batch) = tree + .put_keyed_blob_set( + vec![(key1, value1.clone())], + *SPARSE_MERKLE_PLACEHOLDER_HASH, /* root hash being based on */ + ) + .unwrap(); + db.write_tree_update_batch(batch).unwrap(); + assert_eq!(tree.get(key1, root1).unwrap().unwrap(), value1); + + // Insert at the previous leaf node. Should generate a branch node at root. + // Change the 1st nibble to 15. + let key2 = modify(&key1, 0, 0xf0); + let value2 = AccountStateBlob::from(vec![3u8, 4u8]); + + let (root2, batch) = tree + .put_keyed_blob_set( + vec![(key2, value2.clone())], + root1, /* root hash being based on */ + ) + .unwrap(); + db.write_tree_update_batch(batch).unwrap(); + assert_eq!(tree.get(key1, root1).unwrap().unwrap(), value1); + assert!(tree.get(key2, root1).unwrap().is_none()); + assert_eq!(tree.get(key2, root2).unwrap().unwrap(), value2); + + // get # of nodes + assert_eq!(db.num_nodes(), 3); + assert_eq!(db.num_blobs(), 2); + + let leaf1 = LeafNode::new(key1, value1.hash()); + let leaf2 = LeafNode::new(key2, value2.hash()); + let mut branch = BranchNode::default(); + branch.set_child(0, (leaf1.hash(), true /* is_leaf */)); + branch.set_child(15, (leaf2.hash(), true /* is_leaf */)); + assert_eq!(db.get_node(root1).unwrap(), leaf1.into()); + assert_eq!(db.get_node(leaf2.hash()).unwrap(), leaf2.into()); + assert_eq!(db.get_node(root2).unwrap(), branch.into()); +} + +#[test] +fn test_insert_at_leaf_with_extension_and_branch_created() { + let db = MockTreeStore::default(); + let tree = SparseMerkleTree::new(&db); + + // 1. Insert the first leaf into empty tree + let key1 = HashValue::new([0x00u8; HashValue::LENGTH]); + let value1 = AccountStateBlob::from(vec![1u8, 2u8]); + + let (root1, batch) = tree + .put_keyed_blob_set( + vec![(key1, value1.clone())], + *SPARSE_MERKLE_PLACEHOLDER_HASH, /* root hash being based on */ + ) + .unwrap(); + db.write_tree_update_batch(batch).unwrap(); + assert_eq!(tree.get(key1, root1).unwrap().unwrap(), value1); + + // 2. Insert at the previous leaf node. Should generate a branch node at root. + // Change the 2nd nibble to 15. + let key2 = modify(&key1, 0, 0x01); + let value2 = AccountStateBlob::from(vec![3u8, 4u8]); + + let (root2, batch) = tree + .put_keyed_blob_set( + vec![(key2, value2.clone())], + root1, /* root hash being based on */ + ) + .unwrap(); + db.write_tree_update_batch(batch).unwrap(); + assert_eq!(tree.get(key1, root1).unwrap().unwrap(), value1); + assert!(tree.get(key2, root1).unwrap().is_none()); + assert_eq!(tree.get(key2, root2).unwrap().unwrap(), value2); + + assert_eq!(db.num_nodes(), 4); + assert_eq!(db.num_blobs(), 2); + + let leaf1 = LeafNode::new(key1, value1.hash()); + let leaf2 = LeafNode::new(key2, value2.hash()); + let mut branch = BranchNode::default(); + branch.set_child(0, (leaf1.hash(), true /* is_leaf */)); + branch.set_child(1, (leaf2.hash(), true /* is_leaf */)); + let extension = ExtensionNode::new(NibblePath::new_odd(vec![0x00]), branch.hash()); + assert_eq!(db.get_node(root1).unwrap(), leaf1.into()); + assert_eq!(db.get_node(branch.child(1).unwrap()).unwrap(), leaf2.into()); + assert_eq!(db.get_node(extension.child()).unwrap(), branch.into()); + assert_eq!(db.get_node(root2).unwrap(), extension.clone().into()); + + // 3. Update leaf2 with new value + let value2_update = AccountStateBlob::from(vec![5u8, 6u8]); + let (root3, batch) = tree + .put_keyed_blob_set( + vec![(key2, value2_update.clone())], + root2, /* root hash being based on */ + ) + .unwrap(); + db.write_tree_update_batch(batch).unwrap(); + assert!(tree.get(key2, root1).unwrap().is_none()); + assert_eq!(tree.get(key2, root2).unwrap().unwrap(), value2); + assert_eq!(tree.get(key2, root3).unwrap().unwrap(), value2_update); + + // Get # of nodes. + assert_eq!(db.num_nodes(), 7); + assert_eq!(db.num_blobs(), 3); +} + +fn setup_extension_case(db: &MockTreeStore, n: usize) -> (HashValue, HashValue) { + assert!(n / 2 < HashValue::LENGTH); + let tree = SparseMerkleTree::new(db); + let key1 = HashValue::new([0xffu8; HashValue::LENGTH]); + let value1 = AccountStateBlob::from(vec![0xff, 0xff]); + + // Change the n-th nibble to 1 so it results in an extension node with num_nibbles == n; + // if n == 0, no extension node will be created. + let key2 = modify(&key1, n / 2, if n % 2 == 0 { 0xef } else { 0xfe }); + let value2 = AccountStateBlob::from(vec![0xee, 0xee]); + + let (root, batch) = tree + .put_keyed_blob_set( + vec![(key1, value1.clone()), (key2, value2.clone())], + *SPARSE_MERKLE_PLACEHOLDER_HASH, /* root hash being based on */ + ) + .unwrap(); + db.write_tree_update_batch(batch).unwrap(); + assert_eq!(db.num_nodes(), 4); + assert_eq!(db.num_blobs(), 2); + + let leaf1 = LeafNode::new(key1, value1.hash()); + let leaf2 = LeafNode::new(key2, value2.hash()); + let mut branch = BranchNode::default(); + branch.set_child(15, (leaf1.hash(), true /* is_leaf */)); + branch.set_child(14, (leaf2.hash(), true /* is_leaf */)); + let branch_hash = branch.hash(); + if n == 0 { + assert_eq!(root, branch_hash) + } else { + match db.get_node(root).unwrap() { + Node::Extension(extension) => assert_eq!(extension.child(), branch_hash), + _ => unreachable!(), + } + } + (root, branch_hash) +} + +#[test] +fn test_insert_at_extension_fork_at_begining() { + let db = MockTreeStore::default(); + let (root, extension_child_hash) = setup_extension_case(&db, 6); + let tree = SparseMerkleTree::new(&db); + + let key1 = HashValue::new([0x00; HashValue::LENGTH]); + let value1 = AccountStateBlob::from(vec![1u8, 2u8]); + + let (root1, batch) = tree + .put_keyed_blob_set( + vec![(key1, value1.clone())], + root, /* root hash being based on */ + ) + .unwrap(); + db.write_tree_update_batch(batch).unwrap(); + + let extension_after_fork = ExtensionNode::new( + NibblePath::new_odd(vec![0xff, 0xff, 0xf0]), + extension_child_hash, + ); + let leaf1 = LeafNode::new(key1, value1.hash()); + let mut branch = BranchNode::default(); + branch.set_child(0, (leaf1.hash(), true /* is_leaf */)); + branch.set_child(15, (extension_after_fork.hash(), false /* is_leaf */)); + + assert_eq!(tree.get(key1, root1).unwrap().unwrap(), value1); + assert_eq!(db.get_node(branch.child(0).unwrap()).unwrap(), leaf1.into()); + assert_eq!( + db.get_node(branch.child(15).unwrap()).unwrap(), + extension_after_fork.into() + ); + assert_eq!(db.get_node(root1).unwrap(), branch.into()); + assert_eq!(db.num_nodes(), 7); + assert_eq!(db.num_blobs(), 3); +} + +#[test] +fn test_insert_at_extension_fork_in_the_middle() { + let db = MockTreeStore::default(); + let (root, extension_child_hash) = setup_extension_case(&db, 5); + let tree = SparseMerkleTree::new(&db); + + let key1 = modify(&HashValue::new([0xff; HashValue::LENGTH]), 1, 0x00); + let value1 = AccountStateBlob::from(vec![1u8, 2u8]); + + let (root1, batch) = tree + .put_keyed_blob_set( + vec![(key1, value1.clone())], + root, /* root hash being based on */ + ) + .unwrap(); + db.write_tree_update_batch(batch).unwrap(); + + let extension_after_fork = + ExtensionNode::new(NibblePath::new(vec![0xff]), extension_child_hash); + let leaf1 = LeafNode::new(key1, value1.hash()); + let mut branch = BranchNode::default(); + branch.set_child(0, (leaf1.hash(), true /* is_leaf */)); + branch.set_child(15, (extension_after_fork.hash(), false /* is_leaf */)); + let extension_before_fork = ExtensionNode::new(NibblePath::new(vec![0xff]), branch.hash()); + + assert_eq!(tree.get(key1, root1).unwrap().unwrap(), value1); + assert_eq!(db.get_node(branch.child(0).unwrap()).unwrap(), leaf1.into()); + assert_eq!( + db.get_node(branch.child(15).unwrap()).unwrap(), + extension_after_fork.into() + ); + assert_eq!( + db.get_node(extension_before_fork.child()).unwrap(), + branch.into() + ); + assert_eq!(db.get_node(root1).unwrap(), extension_before_fork.into()); + assert_eq!(db.num_nodes(), 8); + assert_eq!(db.num_blobs(), 3); +} + +#[test] +fn test_insert_at_extension_fork_at_end() { + let db = MockTreeStore::default(); + let (root, extension_child_hash) = setup_extension_case(&db, 4); + let tree = SparseMerkleTree::new(&db); + + let key1 = modify(&HashValue::new([0xff; HashValue::LENGTH]), 1, 0xf0); + let value1 = AccountStateBlob::from(vec![1u8, 2u8]); + + let (root1, batch) = tree + .put_keyed_blob_set( + vec![(key1, value1.clone())], + root, /* root hash being based on */ + ) + .unwrap(); + db.write_tree_update_batch(batch).unwrap(); + + let leaf1 = LeafNode::new(key1, value1.hash()); + let mut branch = BranchNode::default(); + + branch.set_child(0, (leaf1.hash(), true /* is_leaf */)); + branch.set_child(15, (extension_child_hash, false /* is_leaf */)); + let extension_before_fork = + ExtensionNode::new(NibblePath::new_odd(vec![0xff, 0xf0]), branch.hash()); + + assert_eq!(tree.get(key1, root1).unwrap().unwrap(), value1); + assert_eq!(db.get_node(branch.child(0).unwrap()).unwrap(), leaf1.into()); + assert_eq!( + db.get_node(extension_before_fork.child()).unwrap(), + branch.into() + ); + assert_eq!(db.get_node(root1).unwrap(), extension_before_fork.into()); + assert_eq!(db.num_nodes(), 7); + assert_eq!(db.num_blobs(), 3); +} + +#[test] +fn test_insert_at_extension_fork_at_only_nibble() { + let db = MockTreeStore::default(); + let (root, branch_child_hash) = setup_extension_case(&db, 1); + let tree = SparseMerkleTree::new(&db); + + let key1 = HashValue::new([0x00; HashValue::LENGTH]); + let value1 = AccountStateBlob::from(vec![1u8, 2u8]); + + let (root1, batch) = tree + .put_keyed_blob_set( + vec![(key1, value1.clone())], + root, /* root hash being based on */ + ) + .unwrap(); + db.write_tree_update_batch(batch).unwrap(); + + let leaf1 = LeafNode::new(key1, value1.hash()); + let mut branch = BranchNode::default(); + branch.set_child(0, (leaf1.hash(), true /* is_leaf */)); + branch.set_child(15, (branch_child_hash, false /* is_leaf */)); + + assert_eq!(tree.get(key1, root1).unwrap().unwrap(), value1); + assert_eq!(db.get_node(branch.child(0).unwrap()).unwrap(), leaf1.into()); + assert_eq!(db.get_node(root1).unwrap(), branch.into()); + assert_eq!(db.num_nodes(), 6); + assert_eq!(db.num_blobs(), 3); +} + +#[test] +fn test_batch_insertion() { + let db = MockTreeStore::default(); + let tree = SparseMerkleTree::new(&db); + // ```text + // branch(root) + // / \ + // branch 2 + // / | \ + // extension 3 4 + // | + // branch + // / \ + // extension 6 + // | + // branch + // / \ + // 1 5 + // + // Total: 12 nodes, 6 blobs + // ``` + let key1 = HashValue::new([0x00u8; HashValue::LENGTH]); + let value1 = AccountStateBlob::from(vec![1u8]); + + let key2 = modify(&key1, 0, 0xf0); + let value2 = AccountStateBlob::from(vec![2u8]); + let value2_update = AccountStateBlob::from(vec![22u8]); + + let key3 = modify(&key1, 0, 0x03); + let value3 = AccountStateBlob::from(vec![3u8]); + + let key4 = modify(&key1, 0, 0x04); + let value4 = AccountStateBlob::from(vec![4u8]); + + let key5 = modify(&key1, 5, 0x05); + let value5 = AccountStateBlob::from(vec![5u8]); + + let key6 = modify(&key1, 3, 0x06); + let value6 = AccountStateBlob::from(vec![6u8]); + + let (root, batch) = tree + .put_keyed_blob_set( + vec![ + (key1, value1.clone()), + (key2, value2.clone()), + (key3, value3.clone()), + (key4, value4.clone()), + (key5, value5.clone()), + (key6, value6.clone()), + (key2, value2_update.clone()), + ], + *SPARSE_MERKLE_PLACEHOLDER_HASH, /* root hash being based on */ + ) + .unwrap(); + db.write_tree_update_batch(batch).unwrap(); + assert_eq!(tree.get(key1, root).unwrap().unwrap(), value1); + assert_eq!(tree.get(key2, root).unwrap().unwrap(), value2_update); + assert_eq!(tree.get(key3, root).unwrap().unwrap(), value3); + assert_eq!(tree.get(key4, root).unwrap().unwrap(), value4); + assert_eq!(tree.get(key5, root).unwrap().unwrap(), value5); + assert_eq!(tree.get(key6, root).unwrap().unwrap(), value6); + + // get # of nodes + assert_eq!(db.num_nodes(), 12); + assert_eq!(db.num_blobs(), 6); +} + +#[test] +fn test_non_existence() { + let db = MockTreeStore::default(); + let tree = SparseMerkleTree::new(&db); + // ```text + // branch(root) + // / \ + // extension 2 + // | + // branch + // / \ + // 1 3 + // Total: 7 nodes, 3 blobs + // ``` + let key1 = HashValue::new([0x00u8; HashValue::LENGTH]); + let value1 = AccountStateBlob::from(vec![1u8]); + + let key2 = modify(&key1, 0, 0xf0); + let value2 = AccountStateBlob::from(vec![2u8]); + + let key3 = modify(&key1, 1, 0x03); + let value3 = AccountStateBlob::from(vec![3u8]); + + let (root, batch) = tree + .put_keyed_blob_set( + vec![ + (key1, value1.clone()), + (key2, value2.clone()), + (key3, value3.clone()), + ], + *SPARSE_MERKLE_PLACEHOLDER_HASH, /* root hash being based on */ + ) + .unwrap(); + db.write_tree_update_batch(batch).unwrap(); + assert_eq!(tree.get(key1, root).unwrap().unwrap(), value1); + assert_eq!(tree.get(key2, root).unwrap().unwrap(), value2); + assert_eq!(tree.get(key3, root).unwrap().unwrap(), value3); + // get # of nodes + assert_eq!(db.num_nodes(), 6); + assert_eq!(db.num_blobs(), 3); + + // test non-existing nodes. + // 1. Non-existing node at branch node + { + let non_existing_key = modify(&key1, 0, 0x10); + let (value, proof) = tree.get_with_proof(non_existing_key, root).unwrap(); + assert_eq!(value, None); + assert!(verify_sparse_merkle_element(root, non_existing_key, &None, &proof).is_ok()); + } + // 2. Non-existing node at extension node + { + let non_existing_key = modify(&key1, 1, 0x30); + let (value, proof) = tree.get_with_proof(non_existing_key, root).unwrap(); + assert_eq!(value, None); + assert!(verify_sparse_merkle_element(root, non_existing_key, &None, &proof).is_ok()); + } + // 3. Non-existing node at leaf node + { + let non_existing_key = modify(&key1, 10, 0x01); + let (value, proof) = tree.get_with_proof(non_existing_key, root).unwrap(); + assert_eq!(value, None); + assert!(verify_sparse_merkle_element(root, non_existing_key, &None, &proof).is_ok()); + } +} + +#[test] +fn test_put_keyed_blob_sets() { + let mut keys = vec![]; + let mut values = vec![];; + for _i in 0..100 { + keys.push(HashValue::random()); + values.push(AccountStateBlob::from(HashValue::random().to_vec())); + } + + let mut root_hashes_one_by_one = vec![]; + let mut batch_one_by_one = TreeUpdateBatch::default(); + { + let mut iter = keys.clone().into_iter().zip(values.clone().into_iter()); + let mut root = *SPARSE_MERKLE_PLACEHOLDER_HASH; + let db = MockTreeStore::default(); + let tree = SparseMerkleTree::new(&db); + for _ in 0..10 { + let mut keyed_blob_set = vec![]; + for _ in 0..10 { + keyed_blob_set.push(iter.next().unwrap()); + } + let (new_root, batch) = tree + .put_keyed_blob_set(keyed_blob_set, root /* root hash being based on */) + .unwrap(); + root = new_root; + db.write_tree_update_batch(batch.clone()).unwrap(); + root_hashes_one_by_one.push(root); + batch_one_by_one.node_batch.extend(batch.node_batch); + batch_one_by_one.blob_batch.extend(batch.blob_batch); + } + } + { + let mut iter = keys.into_iter().zip(values.into_iter()); + let root = *SPARSE_MERKLE_PLACEHOLDER_HASH; + let db = MockTreeStore::default(); + let tree = SparseMerkleTree::new(&db); + let mut keyed_blob_sets = vec![]; + for _ in 0..10 { + let mut keyed_blob_set = vec![]; + for _ in 0..10 { + keyed_blob_set.push(iter.next().unwrap()); + } + keyed_blob_sets.push(keyed_blob_set); + } + let (root_hashes, batch) = tree + .put_keyed_blob_sets(keyed_blob_sets, root /* root hash being based on */) + .unwrap(); + assert_eq!(root_hashes, root_hashes_one_by_one); + assert_eq!(batch, batch_one_by_one); + } +} + +fn many_keys_get_proof_and_verify_tree_root(seed: &[u8], num_keys: usize) { + assert!(seed.len() < 32); + let mut actual_seed = [0u8; 32]; + actual_seed[..seed.len()].copy_from_slice(&seed); + let mut rng: StdRng = StdRng::from_seed(actual_seed); + + let db = MockTreeStore::default(); + let tree = SparseMerkleTree::new(&db); + + let mut kvs = vec![]; + for _i in 0..num_keys { + let key = HashValue::random_with_rng(&mut rng); + let value = AccountStateBlob::from(HashValue::random_with_rng(&mut rng).to_vec()); + kvs.push((key, value)); + } + + let (root, batch) = tree + .put_keyed_blob_set( + kvs.clone(), + *SPARSE_MERKLE_PLACEHOLDER_HASH, /* root hash being based on */ + ) + .unwrap(); + db.write_tree_update_batch(batch).unwrap(); + + for (k, v) in &kvs { + let (value, proof) = tree.get_with_proof(*k, root).unwrap(); + assert_eq!(value.unwrap(), *v); + assert!(verify_sparse_merkle_element(root, *k, &Some(v.clone()), &proof).is_ok()); + } +} + +#[test] +fn test_1000_keys() { + let seed: &[_] = &[1, 2, 3, 4]; + many_keys_get_proof_and_verify_tree_root(seed, 1000); +} + +fn many_versions_get_proof_and_verify_tree_root(seed: &[u8], num_versions: usize) { + assert!(seed.len() < 32); + let mut actual_seed = [0u8; 32]; + actual_seed[..seed.len()].copy_from_slice(&seed); + let mut rng: StdRng = StdRng::from_seed(actual_seed); + + let db = MockTreeStore::default(); + let tree = SparseMerkleTree::new(&db); + + let mut kvs = vec![]; + let mut roots = vec![]; + let mut prev_root = *SPARSE_MERKLE_PLACEHOLDER_HASH; + + for _i in 0..num_versions { + let key = HashValue::random_with_rng(&mut rng); + let value = AccountStateBlob::from(HashValue::random_with_rng(&mut rng).to_vec()); + let new_value = AccountStateBlob::from(HashValue::random_with_rng(&mut rng).to_vec()); + kvs.push((key, value.clone(), new_value.clone())); + } + + for kvs in kvs.iter().take(num_versions) { + let (new_root, batch) = tree + .put_keyed_blob_set(vec![(kvs.0, kvs.1.clone())], prev_root) + .unwrap(); + roots.push(new_root); + prev_root = new_root; + db.write_tree_update_batch(batch).unwrap(); + } + + // Update value of all keys + for kvs in kvs.iter().take(num_versions) { + let (new_root, batch) = tree + .put_keyed_blob_set(vec![(kvs.0, kvs.2.clone())], prev_root) + .unwrap(); + roots.push(new_root); + prev_root = new_root; + db.write_tree_update_batch(batch).unwrap(); + } + + for (i, (k, v, _)) in kvs.iter().enumerate() { + let random_version = rng.gen_range(i, i + num_versions); + let (value, proof) = tree.get_with_proof(*k, roots[random_version]).unwrap(); + assert_eq!(value.unwrap(), *v); + assert!( + verify_sparse_merkle_element(roots[random_version], *k, &Some(v.clone()), &proof) + .is_ok() + ); + } + + for (i, (k, _, v)) in kvs.iter().enumerate() { + let random_version = rng.gen_range(i + num_versions, 2 * num_versions); + let (value, proof) = tree.get_with_proof(*k, roots[random_version]).unwrap(); + assert_eq!(value.unwrap(), *v); + assert!( + verify_sparse_merkle_element(roots[random_version], *k, &Some(v.clone()), &proof) + .is_ok() + ); + } +} + +#[test] +fn test_1000_versions() { + let seed: &[_] = &[1, 2, 3, 4]; + many_versions_get_proof_and_verify_tree_root(seed, 1000); +} diff --git a/storage/sparse_merkle/src/tree_cache/mod.rs b/storage/sparse_merkle/src/tree_cache/mod.rs new file mode 100644 index 0000000000000..d363450d3eb42 --- /dev/null +++ b/storage/sparse_merkle/src/tree_cache/mod.rs @@ -0,0 +1,244 @@ +// Copyright (c) The Libra Core Contributors +// SPDX-License-Identifier: Apache-2.0 + +//! A transaction can have multiple operations on state. For example, it might update values +//! for a few existing keys. Imagine that we have the following tree. +//! +//! ```text +//! root0 +//! / \ +//! / \ +//! key1 => value11 key2 => value21 +//! ``` +//! +//! The next transaction updates `key1`'s value to `value12` and `key2`'s value to `value22`. +//! Let's assume we update key2 first. Then the tree becomes: +//! +//! ```text +//! (on disk) (in memory) +//! root0 root1' +//! / \ / \ +//! / ___ \ _____________/ \ +//! / _/ \ \ +//! / _/ \ \ +//! / / \ \ +//! key1 => value11 key2 => value21 key2 => value22 +//! (on disk) (on disk) (in memory) +//! ``` +//! +//! Note that +//! 1) we created a new version of the tree with `root1'` and the new `key2` node generated; +//! 2) both `root1'` and the new `key2` node are still held in memory within a batch of nodes +//! that will be written into db atomically. +//! +//! Next, we need to update `key1`'s value. This time we are dealing with the tree starting from +//! the new root. Part of the tree is in memory and the rest of it is in database. We'll update the +//! left child and the new root. We should +//! 1) create a new version for `key1` child. +//! 2) update `root1'` directly instead of making another version. +//! The resulting tree should look like: +//! +//! ```text +//! (on disk) (in memory) +//! root0 root1'' +//! / \ / \ +//! / \ / \ +//! / \ / \ +//! / \ / \ +//! / \ / \ +//! key1 => value11 key2 => value21 key1 => value12 key2 => value22 +//! (on disk) (on disk) (in memory) (in memory) +//! ``` +//! +//! This means that we need to be able to tell whether to create a new version of a node or to +//! update an existing node by deleting it and creating a new node directly. `TreeCache` provides +//! APIs to cache intermediate nodes and blobs in memory and simplify the actual tree +//! implementation. +//! +//! If we are dealing with a single-version tree, any complex tree operation can be seen as a +//! collection of the following operations: +//! - Put a new node/blob. +//! - Delete a node/blob. +//! When we apply these operations on a multi-version tree: +//! 1) Put a new node. +//! 2) When we remove a node, if the node is in the previous on-disk version, we don't need to do +//! anything. Otherwise we delete it from the tree cache. +//! Updating node could be operated as deletion of the node followed by insertion of the updated +//! node. + +#[cfg(test)] +mod tree_cache_test; + +use crate::{node_type::Node, TreeReader, TreeUpdateBatch}; +use crypto::{hash::SPARSE_MERKLE_PLACEHOLDER_HASH, HashValue}; +use failure::prelude::*; +use std::{ + collections::{hash_map::Entry, HashMap}, + convert::Into, +}; +use types::account_state_blob::AccountStateBlob; + +/// `FrozenTreeCache` is used as a field of `TreeCache` storing all the nodes and blobs that are +/// are generated by earlier transactions so they have to be immutable. The motivation of +/// `FrozenTreeCache` is to let `TreeCache` freeze intermediate results from each transaction to +/// help commit more than one transaction in a row atomically. +#[derive(Default)] +struct FrozenTreeCache { + /// Immutable node_cache. + node_cache: HashMap, + + /// Immutable blob_cache. + blob_cache: HashMap, + + /// Frozen root hashes after each earlier transaction. + root_hashes: Vec, +} + +/// `TreeCache` is a in-memory cache for per-transaction updates of sparse Merkle nodes and value +/// blobs. +pub struct TreeCache<'a, R: 'a + TreeReader> { + /// Current root node hash in cache. + root_hash: HashValue, + + /// Intermediate nodes keyed by node hash + node_cache: HashMap, + + /// Intermediate value blobs keyed by blob hash. it is reasonable to assume the blob data + /// associated with distinct keys could be the same so the value of each entry is a tuple, + /// (blob, blob_counter). The first time a blob is put in, `blob_counter = 1`; After that, + /// corresponding `blob_counter` will increment by 1 each time the same blob is put in + /// again. Deletion follows a similar rule, the blob entry will be removed from + /// `blob_cache` only when `blob_counter == 1`; otherwise, only `blob_counter` will + /// decrement by 1. + blob_cache: HashMap, + + /// The immutable part of this cache + frozen_cache: FrozenTreeCache, + + /// The underlying persistent storage. + reader: &'a R, +} + +impl<'a, R> TreeReader for TreeCache<'a, R> +where + R: 'a + TreeReader, +{ + /// Gets a node with given hash. If it doesn't exist in node cache, read from `reader`. + fn get_node(&self, node_hash: HashValue) -> Result { + Ok(if let Some(node) = self.node_cache.get(&node_hash) { + node.clone() + } else if let Some(node) = self.frozen_cache.node_cache.get(&node_hash) { + node.clone() + } else { + self.reader.get_node(node_hash)? + }) + } + + /// Gets the blob with given hash. If it doesn't exist in blob cache, read from `reader`. + fn get_blob(&self, blob_hash: HashValue) -> Result { + Ok(if let Some((blob, _)) = self.blob_cache.get(&blob_hash) { + blob.clone() + } else if let Some(blob) = self.frozen_cache.blob_cache.get(&blob_hash) { + blob.clone() + } else { + self.reader.get_blob(blob_hash)? + }) + } +} + +impl<'a, R> TreeCache<'a, R> +where + R: 'a + TreeReader, +{ + /// Constructs a new `TreeCache` instance. + pub fn new(reader: &'a R, root_hash: HashValue) -> Self { + Self { + node_cache: HashMap::default(), + blob_cache: HashMap::default(), + frozen_cache: FrozenTreeCache::default(), + root_hash, + reader, + } + } + + /// Get root node. + pub fn get_root_node(&self) -> Result> { + Ok(match self.root_hash { + root_hash if root_hash != *SPARSE_MERKLE_PLACEHOLDER_HASH => { + Some(self.get_node(root_hash)?) + } + _ => None, + }) + } + + /// Set root node hash. + pub fn set_root_hash(&mut self, root_hash: HashValue) { + self.root_hash = root_hash; + } + + /// Put the node with given hash as key into node_cache. + pub fn put_node(&mut self, node_hash: HashValue, new_node: Node) -> Result<()> { + if let Entry::Vacant(o) = self.node_cache.entry(node_hash) { + o.insert(new_node); + } else { + bail!("Node with key {:?} already exists in NodeBatch", node_hash); + } + Ok(()) + } + + /// Delete a node with given hash. + pub fn delete_node(&mut self, old_node_hash: HashValue) { + // If node cache doesn't have this node, it means the node is in the previous version of + // the tree on the disk. Then the code below has no effect. + self.node_cache.remove(&old_node_hash); + } + + /// Put a blob with given hash as key into cache. We allow duplicate blobs. + pub fn put_blob(&mut self, blob_hash: HashValue, new_blob: AccountStateBlob) -> Result<()> { + self.blob_cache + .entry(blob_hash) + .and_modify(|(_blob, cnt)| *cnt += 1) + .or_insert((new_blob, 1)); + Ok(()) + } + + /// Delete the blob with given hash. + pub fn delete_blob(&mut self, old_blob_hash: HashValue) { + // If cache doesn't have this blob, it means the blob is in the previous version of + // the tree on the disk. Then the code below has no effect. + if let Entry::Occupied(mut blob) = self.blob_cache.entry(old_blob_hash) { + // We delete the entry from the map. + if blob.get().1 > 1 { + blob.get_mut().1 -= 1; + } else { + blob.remove_entry(); + } + } + } + + /// Freeze all the contents in cache to be immutable and clear both `node_cache` and + /// `blob_cache`. + pub fn freeze(&mut self) { + self.frozen_cache.root_hashes.push(self.root_hash); + self.frozen_cache.node_cache.extend(self.node_cache.drain()); + // throw away counter + self.frozen_cache + .blob_cache + .extend(self.blob_cache.drain().map(|(k, v)| (k, v.0))); + } +} + +impl<'a, R> Into<(Vec, TreeUpdateBatch)> for TreeCache<'a, R> +where + R: 'a + TreeReader, +{ + fn into(self) -> (Vec, TreeUpdateBatch) { + ( + self.frozen_cache.root_hashes, + TreeUpdateBatch { + node_batch: self.frozen_cache.node_cache, + blob_batch: self.frozen_cache.blob_cache, + }, + ) + } +} diff --git a/storage/sparse_merkle/src/tree_cache/tree_cache_test.rs b/storage/sparse_merkle/src/tree_cache/tree_cache_test.rs new file mode 100644 index 0000000000000..c4d80e2e88dad --- /dev/null +++ b/storage/sparse_merkle/src/tree_cache/tree_cache_test.rs @@ -0,0 +1,64 @@ +// Copyright (c) The Libra Core Contributors +// SPDX-License-Identifier: Apache-2.0 + +use super::*; +use crate::{ + mock_tree_store::MockTreeStore, + node_type::{LeafNode, Node}, +}; +use crypto::{hash::SPARSE_MERKLE_PLACEHOLDER_HASH, HashValue}; + +#[test] +fn test_get_node() { + let db = MockTreeStore::default(); + let cache = TreeCache::new( + &db, + *SPARSE_MERKLE_PLACEHOLDER_HASH, /* root_node_hash */ + ); + + let address = HashValue::random(); + let value_hash = HashValue::random(); + let leaf_node = Node::Leaf(LeafNode::new(address, value_hash)); + + db.put_node(leaf_node.hash(), leaf_node.clone()).unwrap(); + assert_eq!(cache.get_node(leaf_node.hash()).unwrap(), leaf_node); +} + +#[test] +fn test_root_node() { + let db = MockTreeStore::default(); + let mut cache = TreeCache::new( + &db, + *SPARSE_MERKLE_PLACEHOLDER_HASH, /* root_node_hash */ + ); + + assert_eq!(cache.get_root_node().unwrap(), None); + + let address = HashValue::random(); + let value_hash = HashValue::random(); + let leaf_node = Node::Leaf(LeafNode::new(address, value_hash)); + + db.put_node(leaf_node.hash(), leaf_node.clone()).unwrap(); + cache.set_root_hash(leaf_node.hash()); + + assert_eq!(cache.get_root_node().unwrap().unwrap(), leaf_node); +} + +#[test] +fn test_duplicate_blob() { + let db = MockTreeStore::default(); + let mut cache = TreeCache::new( + &db, + *SPARSE_MERKLE_PLACEHOLDER_HASH, /* root_node_hash */ + ); + + let blob = AccountStateBlob::from(vec![0u8]); + let blob_hash = HashValue::random(); + cache.put_blob(blob_hash, blob.clone()).unwrap(); + cache.put_blob(blob_hash, blob.clone()).unwrap(); + assert_eq!(cache.get_blob(blob_hash).unwrap(), blob); + cache.delete_blob(blob_hash); + assert_eq!(cache.get_blob(blob_hash).unwrap(), blob); + cache.delete_blob(blob_hash); + assert!(cache.get_blob(blob_hash).is_err()); +} diff --git a/storage/state_view/Cargo.toml b/storage/state_view/Cargo.toml new file mode 100644 index 0000000000000..861217216a0d3 --- /dev/null +++ b/storage/state_view/Cargo.toml @@ -0,0 +1,11 @@ +[package] +name = "state_view" +version = "0.1.0" +authors = ["Libra Association "] +license = "Apache-2.0" +publish = false +edition = "2018" + +[dependencies] +failure = { path = "../../common/failure_ext", package = "failure_ext" } +types = { path = "../../types" } \ No newline at end of file diff --git a/storage/state_view/src/lib.rs b/storage/state_view/src/lib.rs new file mode 100644 index 0000000000000..3b776657fd3a4 --- /dev/null +++ b/storage/state_view/src/lib.rs @@ -0,0 +1,22 @@ +// Copyright (c) The Libra Core Contributors +// SPDX-License-Identifier: Apache-2.0 + +//! This crate defines [`trait StateView`](StateView). + +use failure::prelude::*; +use types::access_path::AccessPath; + +/// `StateView` is a trait that defines a read-only snapshot of the global state. It is passed to +/// the VM for transaction execution, during which the VM is guaranteed to read anything at the +/// given state. +pub trait StateView { + /// Gets the state for a single access path. + fn get(&self, access_path: &AccessPath) -> Result>>; + + /// Gets states for a list of access paths. + fn multi_get(&self, access_paths: &[AccessPath]) -> Result>>>; + + /// VM needs this method to know whether the current state view is for genesis state creation. + /// Currently TransactionPayload::WriteSet is only valid for genesis state creation. + fn is_genesis(&self) -> bool; +} diff --git a/storage/storage_client/Cargo.toml b/storage/storage_client/Cargo.toml new file mode 100644 index 0000000000000..e3ea98fe87ded --- /dev/null +++ b/storage/storage_client/Cargo.toml @@ -0,0 +1,21 @@ +[package] +name = "storage_client" +version = "0.1.0" +authors = ["Libra Association "] +license = "Apache-2.0" +publish = false +edition = "2018" + +[dependencies] +futures = { version = "0.3.0-alpha.13", package = "futures-preview", features = ["compat"] } +futures_01 = { version = "0.1.25", package = "futures" } +grpcio = "0.4.4" + +crypto = { path = "../../crypto/legacy_crypto" } +canonical_serialization = { path = "../../common/canonical_serialization" } +failure = { path = "../../common/failure_ext", package = "failure_ext" } +proto_conv = { path = "../../common/proto_conv" } +scratchpad = { path = "../scratchpad" } +storage_proto = { path = "../storage_proto" } +state_view = { path = "../state_view" } +types = { path = "../../types" } diff --git a/storage/storage_client/src/lib.rs b/storage/storage_client/src/lib.rs new file mode 100644 index 0000000000000..c025bf8b31283 --- /dev/null +++ b/storage/storage_client/src/lib.rs @@ -0,0 +1,346 @@ +// Copyright (c) The Libra Core Contributors +// SPDX-License-Identifier: Apache-2.0 + +//! This crate implements a client library for storage that wraps the protobuf storage client. The +//! main motivation is to hide storage implementation details. For example, if we later want to +//! expand state store to multiple machines and enable sharding, we only need to tweak the client +//! library implementation and protobuf interface, and the interface between the rest of the system +//! and the client library will remain the same, so we won't need to change other components. + +mod state_view; + +use crypto::HashValue; +use failure::prelude::*; +use futures::{compat::Future01CompatExt, executor::block_on, prelude::*}; +use futures_01::future::Future as Future01; +use grpcio::{ChannelBuilder, Environment}; +use proto_conv::{FromProto, IntoProto}; +use std::{pin::Pin, sync::Arc}; +use storage_proto::{ + proto::{storage::GetExecutorStartupInfoRequest, storage_grpc}, + ExecutorStartupInfo, GetAccountStateWithProofByStateRootRequest, + GetAccountStateWithProofByStateRootResponse, GetExecutorStartupInfoResponse, + GetTransactionsRequest, GetTransactionsResponse, SaveTransactionsRequest, +}; +use types::{ + account_address::AccountAddress, + account_state_blob::AccountStateBlob, + get_with_proof::{ + RequestItem, ResponseItem, UpdateToLatestLedgerRequest, UpdateToLatestLedgerResponse, + }, + ledger_info::LedgerInfoWithSignatures, + proof::SparseMerkleProof, + transaction::{TransactionListWithProof, TransactionToCommit, Version}, + validator_change::ValidatorChangeEventWithProof, +}; + +pub use crate::state_view::VerifiedStateView; + +fn convert_grpc_response( + response: grpcio::Result>, +) -> impl Future> { + future::ready(response.map_err(convert_grpc_err)) + .map_ok(Future01CompatExt::compat) + .and_then(|x| x.map_err(convert_grpc_err)) +} + +/// This provides storage read interfaces backed by real storage service. +#[derive(Clone)] +pub struct StorageReadServiceClient { + client: storage_grpc::StorageClient, +} + +impl StorageReadServiceClient { + /// Constructs a `StorageReadServiceClient` with given host and port. + pub fn new(env: Arc, host: &str, port: u16) -> Self { + let channel = ChannelBuilder::new(env).connect(&format!("{}:{}", host, port)); + let client = storage_grpc::StorageClient::new(channel); + StorageReadServiceClient { client } + } +} + +impl StorageRead for StorageReadServiceClient { + fn update_to_latest_ledger( + &self, + client_known_version: Version, + requested_items: Vec, + ) -> Result<( + Vec, + LedgerInfoWithSignatures, + Vec, + )> { + block_on(self.update_to_latest_ledger_async(client_known_version, requested_items)) + } + + fn update_to_latest_ledger_async( + &self, + client_known_version: Version, + requested_items: Vec, + ) -> Pin< + Box< + dyn Future< + Output = Result<( + Vec, + LedgerInfoWithSignatures, + Vec, + )>, + > + Send, + >, + > { + let req = UpdateToLatestLedgerRequest { + client_known_version, + requested_items, + }; + convert_grpc_response(self.client.update_to_latest_ledger_async(&req.into_proto())) + .map(|resp| { + let rust_resp = UpdateToLatestLedgerResponse::from_proto(resp?)?; + Ok(( + rust_resp.response_items, + rust_resp.ledger_info_with_sigs, + rust_resp.validator_change_events, + )) + }) + .boxed() + } + + fn get_transactions( + &self, + start_version: Version, + batch_size: u64, + ledger_version: Version, + fetch_events: bool, + ) -> Result { + block_on(self.get_transactions_async( + start_version, + batch_size, + ledger_version, + fetch_events, + )) + } + + fn get_transactions_async( + &self, + start_version: Version, + batch_size: u64, + ledger_version: Version, + fetch_events: bool, + ) -> Pin> + Send>> { + let req = + GetTransactionsRequest::new(start_version, batch_size, ledger_version, fetch_events); + convert_grpc_response(self.client.get_transactions_async(&req.into_proto())) + .map(|resp| { + let rust_resp = GetTransactionsResponse::from_proto(resp?)?; + Ok(rust_resp.txn_list_with_proof) + }) + .boxed() + } + + fn get_account_state_with_proof_by_state_root( + &self, + address: AccountAddress, + state_root_hash: HashValue, + ) -> Result<(Option, SparseMerkleProof)> { + block_on(self.get_account_state_with_proof_by_state_root_async(address, state_root_hash)) + } + + fn get_account_state_with_proof_by_state_root_async( + &self, + address: AccountAddress, + state_root_hash: HashValue, + ) -> Pin, SparseMerkleProof)>> + Send>> + { + let req = GetAccountStateWithProofByStateRootRequest::new(address, state_root_hash); + convert_grpc_response( + self.client + .get_account_state_with_proof_by_state_root_async(&req.into_proto()), + ) + .map(|resp| { + let resp = GetAccountStateWithProofByStateRootResponse::from_proto(resp?)?; + Ok(resp.into()) + }) + .boxed() + } + + fn get_executor_startup_info(&self) -> Result> { + block_on(self.get_executor_startup_info_async()) + } + + fn get_executor_startup_info_async( + &self, + ) -> Pin>> + Send>> { + let proto_req = GetExecutorStartupInfoRequest::new(); + convert_grpc_response(self.client.get_executor_startup_info_async(&proto_req)) + .map(|resp| { + let resp = GetExecutorStartupInfoResponse::from_proto(resp?)?; + Ok(resp.info) + }) + .boxed() + } +} + +/// This provides storage write interfaces backed by real storage service. +#[derive(Clone)] +pub struct StorageWriteServiceClient { + client: storage_grpc::StorageClient, +} + +impl StorageWriteServiceClient { + /// Constructs a `StorageWriteServiceClient` with given host and port. + pub fn new(env: Arc, host: &str, port: u16) -> Self { + let channel = ChannelBuilder::new(env).connect(&format!("{}:{}", host, port)); + let client = storage_grpc::StorageClient::new(channel); + StorageWriteServiceClient { client } + } +} + +impl StorageWrite for StorageWriteServiceClient { + fn save_transactions( + &self, + txns_to_commit: Vec, + first_version: Version, + ledger_info_with_sigs: Option, + ) -> Result<()> { + block_on(self.save_transactions_async(txns_to_commit, first_version, ledger_info_with_sigs)) + } + + fn save_transactions_async( + &self, + txns_to_commit: Vec, + first_version: Version, + ledger_info_with_sigs: Option, + ) -> Pin> + Send>> { + let req = + SaveTransactionsRequest::new(txns_to_commit, first_version, ledger_info_with_sigs); + convert_grpc_response(self.client.save_transactions_async(&req.into_proto())) + .map_ok(|_| ()) + .boxed() + } +} + +/// This trait defines interfaces to be implemented by a storage read client. +/// +/// There is a 1-1 mapping between each interface provided here and a LibraDB API. A method call on +/// this relays the query to the storage backend behind the scene which calls the corresponding +/// LibraDB API. Both synchronized and asynchronized versions of the APIs are provided. +pub trait StorageRead: Send + Sync { + /// See [`LibraDB::update_to_latest_ledger`]. + /// + /// [`LibraDB::update_to_latest_ledger`]: + /// ../libradb/struct.LibraDB.html#method.update_to_latest_ledger + fn update_to_latest_ledger( + &self, + client_known_version: Version, + request_items: Vec, + ) -> Result<( + Vec, + LedgerInfoWithSignatures, + Vec, + )>; + + /// See [`LibraDB::update_to_latest_ledger`]. + /// + /// [`LibraDB::update_to_latest_ledger`]:../libradb/struct.LibraDB.html#method. + /// update_to_latest_ledger + fn update_to_latest_ledger_async( + &self, + client_known_version: Version, + request_items: Vec, + ) -> Pin< + Box< + dyn Future< + Output = Result<( + Vec, + LedgerInfoWithSignatures, + Vec, + )>, + > + Send, + >, + >; + + /// See [`LibraDB::get_transactions`]. + /// + /// [`LibraDB::get_transactions`]: ../libradb/struct.LibraDB.html#method.get_transactions + fn get_transactions( + &self, + start_version: Version, + batch_size: u64, + ledger_version: Version, + fetch_events: bool, + ) -> Result; + + /// See [`LibraDB::get_transactions`]. + /// + /// [`LibraDB::get_transactions`]: ../libradb/struct.LibraDB.html#method.get_transactions + fn get_transactions_async( + &self, + start_version: Version, + batch_size: u64, + ledger_version: Version, + fetch_events: bool, + ) -> Pin> + Send>>; + + /// See [`LibraDB::get_account_state_with_proof_by_state_root`]. + /// + /// [`LibraDB::get_account_state_with_proof_by_state_root`]: + /// ../libradb/struct.LibraDB.html#method.get_account_state_with_proof_by_state_root + fn get_account_state_with_proof_by_state_root( + &self, + address: AccountAddress, + state_root_hash: HashValue, + ) -> Result<(Option, SparseMerkleProof)>; + + /// See [`LibraDB::get_account_state_with_proof_by_state_root`]. + /// + /// [`LibraDB::get_account_state_with_proof_by_state_root`]: + /// ../libradb/struct.LibraDB.html#method.get_account_state_with_proof_by_state_root + fn get_account_state_with_proof_by_state_root_async( + &self, + address: AccountAddress, + state_root_hash: HashValue, + ) -> Pin, SparseMerkleProof)>> + Send>>; + + /// See [`LibraDB::get_executor_startup_info`]. + /// + /// [`LibraDB::get_executor_startup_info`]: + /// ../libradb/struct.LibraDB.html#method.get_executor_startup_info + fn get_executor_startup_info(&self) -> Result>; + + /// See [`LibraDB::get_executor_startup_info`]. + /// + /// [`LibraDB::get_executor_startup_info`]: + /// ../libradb/struct.LibraDB.html#method.get_executor_startup_info + fn get_executor_startup_info_async( + &self, + ) -> Pin>> + Send>>; +} + +/// This trait defines interfaces to be implemented by a storage write client. +/// +/// There is a 1-1 mappings between each interface provided here and a LibraDB API. A method call on +/// this relays the query to the storage backend behind the scene which calls the corresponding +/// LibraDB API. Both synchronized and asynchronized versions of the APIs are provided. +pub trait StorageWrite: Send + Sync { + /// See [`LibraDB::save_transactions`]. + /// + /// [`LibraDB::save_transactions`]: ../libradb/struct.LibraDB.html#method.save_transactions + fn save_transactions( + &self, + txns_to_commit: Vec, + first_version: Version, + ledger_info_with_sigs: Option, + ) -> Result<()>; + + /// See [`LibraDB::save_transactions`]. + /// + /// [`LibraDB::save_transactions`]: ../libradb/struct.LibraDB.html#method.save_transactions + fn save_transactions_async( + &self, + txns_to_commit: Vec, + first_version: Version, + ledger_info_with_sigs: Option, + ) -> Pin> + Send>>; +} + +fn convert_grpc_err(e: grpcio::Error) -> Error { + format_err!("grpc error: {}", e) +} diff --git a/storage/storage_client/src/state_view.rs b/storage/storage_client/src/state_view.rs new file mode 100644 index 0000000000000..5739f6f70ea23 --- /dev/null +++ b/storage/storage_client/src/state_view.rs @@ -0,0 +1,177 @@ +// Copyright (c) The Libra Core Contributors +// SPDX-License-Identifier: Apache-2.0 + +use crate::StorageRead; +use crypto::{ + hash::{CryptoHash, SPARSE_MERKLE_PLACEHOLDER_HASH}, + HashValue, +}; +use failure::prelude::*; +use scratchpad::{AccountState, SparseMerkleTree}; +use state_view::StateView; +use std::{ + cell::RefCell, + collections::{hash_map::Entry, BTreeMap, HashMap}, + convert::TryInto, + sync::Arc, +}; +use types::{ + access_path::AccessPath, + account_address::AccountAddress, + proof::{definition::SparseMerkleProof, verify_sparse_merkle_element}, +}; + +/// `VerifiedStateView` is like a snapshot of the global state comprised of state view at two +/// levels, persistent storage and memory. +pub struct VerifiedStateView<'a> { + /// A gateway implementing persistent storage interface, which can be a RPC client or direct + /// accessor. + reader: Arc, + + /// The most recent state root hash in persistent storage. + latest_persistent_state_root: HashValue, + + /// The in-momery version of sparse Merkle tree of which the states haven't been committed. + speculative_state: &'a SparseMerkleTree, + + /// The cache of verified account states from `reader` and `speculative_state_view`, + /// represented by a hashmap with an account address as key and a pair of an ordered + /// account state map and an an optional account state proof as value. When the VM queries an + /// `access_path`, this cache will first check whether `reader_cache` is hit. If hit, it + /// will return the corresponding value of that `access_path`; otherwise, the account state + /// will be loaded into the cache from scratchpad or persistent storage in order as a + /// deserialized ordered map and then be returned. If the VM queries this account again, + /// the cached data can be read directly without bothering storage layer. The proofs in + /// cache are needed by ScratchPad after VM execution to construct an in-memory sparse Merkle + /// tree. + /// ```text + /// +----------------------------+ + /// | In-memory SparseMerkleTree <------+ + /// +-------------^--------------+ | + /// | | + /// write sets | + /// | cached account state map + /// +-------+-------+ proof + /// | V M | | + /// +-------^-------+ | + /// | | + /// value of `account_address/path` | + /// | | + /// +---------------------------+---------------------+-------+ + /// | +-------------------------+---------------------+-----+ | + /// | | account_to_btree_cache, account_to_proof_cache | | + /// | +---------------^---------------------------^---------+ | + /// | | | | + /// | account state blob only account state blob | + /// | | proof | + /// | | | | + /// | +---------------+--------------+ +----------+---------+ | + /// | | speculative_state | | reader | | + /// | +------------------------------+ +--------------------+ | + /// +---------------------------------------------------------+ + /// ``` + account_to_btree_cache: RefCell, Vec>>>, + account_to_proof_cache: RefCell>, +} + +impl<'a> VerifiedStateView<'a> { + /// Constructs a [`VerifiedStateView`] with persistent state view represented by + /// `latest_persistent_state_root` plus a storage reader, and the in-memory speculative state + /// on top of it represented by `speculative_state`. + pub fn new( + reader: Arc, + latest_persistent_state_root: HashValue, + speculative_state: &'a SparseMerkleTree, + ) -> Self { + Self { + reader, + latest_persistent_state_root, + speculative_state, + account_to_btree_cache: RefCell::new(HashMap::new()), + account_to_proof_cache: RefCell::new(HashMap::new()), + } + } +} + +impl<'a> + Into<( + HashMap, Vec>>, + HashMap, + )> for VerifiedStateView<'a> +{ + fn into( + self, + ) -> ( + HashMap, Vec>>, + HashMap, + ) { + ( + self.account_to_btree_cache.into_inner(), + self.account_to_proof_cache.into_inner(), + ) + } +} + +impl<'a> StateView for VerifiedStateView<'a> { + fn get(&self, access_path: &AccessPath) -> Result>> { + let address = access_path.address; + let path = &access_path.path; + match self.account_to_btree_cache.borrow_mut().entry(address) { + Entry::Occupied(occupied) => Ok(occupied.get().get(path).cloned()), + Entry::Vacant(vacant) => { + let address_hash = address.hash(); + let account_blob_option = match self.speculative_state.get(address_hash) { + AccountState::ExistsInScratchPad(blob) => Some(blob), + AccountState::DoesNotExist => None, + // No matter it is in db or unknown, we have to query from db since even the + // former case, we don't have the blob data but only its hash. + AccountState::ExistsInDB | AccountState::Unknown => { + let (blob, proof) = + self.reader.get_account_state_with_proof_by_state_root( + address, + self.latest_persistent_state_root, + )?; + verify_sparse_merkle_element( + self.latest_persistent_state_root, + address.hash(), + &blob, + &proof, + ) + .map_err(|err| { + format_err!( + "Proof is invalid for address {:?} with state root hash {:?}: {}", + address, + self.latest_persistent_state_root, + err + ) + })?; + assert!(self + .account_to_proof_cache + .borrow_mut() + .insert(address_hash, proof) + .is_none()); + blob + } + }; + Ok(vacant + .insert( + account_blob_option + .as_ref() + .map(TryInto::try_into) + .transpose()? + .unwrap_or_default(), + ) + .get(path) + .cloned()) + } + } + } + + fn multi_get(&self, _access_paths: &[AccessPath]) -> Result>>> { + unimplemented!(); + } + + fn is_genesis(&self) -> bool { + self.latest_persistent_state_root == *SPARSE_MERKLE_PLACEHOLDER_HASH + } +} diff --git a/storage/storage_proto/Cargo.toml b/storage/storage_proto/Cargo.toml new file mode 100644 index 0000000000000..5a2de0cdf965d --- /dev/null +++ b/storage/storage_proto/Cargo.toml @@ -0,0 +1,22 @@ +[package] +name = "storage_proto" +version = "0.1.0" +authors = ["Libra Association "] +license = "Apache-2.0" +publish = false +edition = "2018" + +[dependencies] +futures = "0.1.25" +grpcio = "0.4.4" +proptest = "0.9.2" +proptest-derive = "0.1.0" +protobuf = "2.6" + +crypto = { path = "../../crypto/legacy_crypto" } +failure = { path = "../../common/failure_ext", package = "failure_ext" } +proto_conv = { path = "../../common/proto_conv", features = ["derive"] } +types = { path = "../../types" } + +[build-dependencies] +build_helpers = { path = "../../common/build_helpers" } diff --git a/storage/storage_proto/build.rs b/storage/storage_proto/build.rs new file mode 100644 index 0000000000000..f107fb59075af --- /dev/null +++ b/storage/storage_proto/build.rs @@ -0,0 +1,18 @@ +// Copyright (c) The Libra Core Contributors +// SPDX-License-Identifier: Apache-2.0 + +//! This compiles all the `.proto` files under `src/` directory. +//! +//! For example, if there is a file `src/a/b/c.proto`, it will generate `src/a/b/c.rs` and +//! `src/a/b/c_grpc.rs`. + +fn main() { + let proto_root = "src/proto"; + let dependent_root = "../../types/src/proto"; + + build_helpers::build_helpers::compile_proto( + proto_root, + vec![dependent_root], + false, /* generate_client_code */ + ); +} diff --git a/storage/storage_proto/src/lib.rs b/storage/storage_proto/src/lib.rs new file mode 100644 index 0000000000000..3b775b9dfadd5 --- /dev/null +++ b/storage/storage_proto/src/lib.rs @@ -0,0 +1,368 @@ +// Copyright (c) The Libra Core Contributors +// SPDX-License-Identifier: Apache-2.0 + +//! This crate provides Protocol Buffers definitions for the services provided by the +//! [`storage_service`](../storage_service/index.html) crate. +//! +//! The protocol is documented in Protocol Buffers sources files in the `.proto` extension and the +//! documentation is not viewable via rustdoc. Refer to the source code to see it. +//! +//! The content provided in this documentation falls to two categories: +//! +//! 1. Those automatically generated by [`grpc-rs`](https://github.com/pingcap/grpc-rs): +//! * In [`proto::storage`] are structs corresponding to our Protocol Buffers messages. +//! * In [`proto::storage_grpc`] live the [GRPC](grpc.io) client struct and the service trait +//! which correspond to our Protocol Buffers services. +//! 1. Structs we wrote manually as helpers to ease the manipulation of the above category of +//! structs. By implementing the [`FromProto`](proto_conv::FromProto) and +//! [`IntoProto`](proto_conv::IntoProto) traits, these structs convert from/to the above category of +//! structs in a single method call and in that process data integrity check can be done. These live +//! right in the root module of this crate (this page). +//! +//! Ihis is provided as a separate crate so that crates that use the storage service via +//! [`storage_client`](../storage_client/index.html) don't need to depending on the entire +//! [`storage_service`](../storage_client/index.html). + +#![allow(clippy::unit_arg)] + +pub mod proto; + +use crypto::HashValue; +use failure::prelude::*; +use proptest_derive::Arbitrary; +use proto_conv::{FromProto, IntoProto}; +use types::{ + account_address::AccountAddress, + account_state_blob::AccountStateBlob, + ledger_info::{LedgerInfo, LedgerInfoWithSignatures}, + proof::definition::SparseMerkleProof, + transaction::{TransactionListWithProof, TransactionToCommit, Version}, +}; + +/// Helper to construct and parse [`proto::storage::GetAccountStateWithProofByStateRootRequest`] +/// +/// It does so by implementing [`IntoProto`](#impl-IntoProto) and [`FromProto`](#impl-FromProto), +/// providing [`into_proto`](IntoProto::into_proto) and [`from_proto`](FromProto::from_proto). +#[derive(PartialEq, Eq, Clone, FromProto, IntoProto)] +#[ProtoType(crate::proto::storage::GetAccountStateWithProofByStateRootRequest)] +pub struct GetAccountStateWithProofByStateRootRequest { + /// The access path to query with. + pub address: AccountAddress, + + /// the state root hash the query is based on. + pub state_root_hash: HashValue, +} + +impl GetAccountStateWithProofByStateRootRequest { + /// Constructor. + pub fn new(address: AccountAddress, state_root_hash: HashValue) -> Self { + Self { + address, + state_root_hash, + } + } +} + +/// Helper to construct and parse [`proto::storage::GetAccountStateWithProofByStateRootResponse`] +/// +/// It does so by implementing [`IntoProto`](#impl-IntoProto) and [`FromProto`](#impl-FromProto), +/// providing [`into_proto`](IntoProto::into_proto) and [`from_proto`](FromProto::from_proto). +#[derive(PartialEq, Eq, Clone)] +pub struct GetAccountStateWithProofByStateRootResponse { + /// The account state blob requested. + pub account_state_blob: Option, + + /// The state root hash the query is based on. + pub sparse_merkle_proof: SparseMerkleProof, +} + +impl GetAccountStateWithProofByStateRootResponse { + /// Constructor. + pub fn new( + account_state_blob: Option, + sparse_merkle_proof: SparseMerkleProof, + ) -> Self { + Self { + account_state_blob, + sparse_merkle_proof, + } + } +} + +impl FromProto for GetAccountStateWithProofByStateRootResponse { + type ProtoType = crate::proto::storage::GetAccountStateWithProofByStateRootResponse; + + fn from_proto(mut object: Self::ProtoType) -> Result { + let account_state_blob = if object.has_account_state_blob() { + Some(AccountStateBlob::from_proto( + object.take_account_state_blob(), + )?) + } else { + None + }; + Ok(Self { + account_state_blob, + sparse_merkle_proof: SparseMerkleProof::from_proto(object.take_sparse_merkle_proof())?, + }) + } +} + +impl IntoProto for GetAccountStateWithProofByStateRootResponse { + type ProtoType = crate::proto::storage::GetAccountStateWithProofByStateRootResponse; + + fn into_proto(self) -> Self::ProtoType { + let mut object = Self::ProtoType::new(); + + if let Some(account_state_blob) = self.account_state_blob { + object.set_account_state_blob(account_state_blob.into_proto()); + } + object.set_sparse_merkle_proof(self.sparse_merkle_proof.into_proto()); + object + } +} + +impl Into<(Option, SparseMerkleProof)> + for GetAccountStateWithProofByStateRootResponse +{ + fn into(self) -> (Option, SparseMerkleProof) { + (self.account_state_blob, self.sparse_merkle_proof) + } +} + +/// Helper to construct and parse [`proto::storage::SaveTransactionsRequest`] +/// +/// It does so by implementing [`IntoProto`](#impl-IntoProto) and [`FromProto`](#impl-FromProto), +/// providing [`into_proto`](IntoProto::into_proto) and [`from_proto`](FromProto::from_proto). +#[derive(Arbitrary, Clone, Debug, Eq, PartialEq)] +pub struct SaveTransactionsRequest { + pub txns_to_commit: Vec, + pub first_version: Version, + pub ledger_info_with_signatures: Option, +} + +impl SaveTransactionsRequest { + /// Constructor. + pub fn new( + txns_to_commit: Vec, + first_version: Version, + ledger_info_with_signatures: Option, + ) -> Self { + SaveTransactionsRequest { + txns_to_commit, + first_version, + ledger_info_with_signatures, + } + } +} + +impl FromProto for SaveTransactionsRequest { + type ProtoType = crate::proto::storage::SaveTransactionsRequest; + + fn from_proto(mut object: Self::ProtoType) -> Result { + let txns_to_commit = object + .take_txns_to_commit() + .into_iter() + .map(TransactionToCommit::from_proto) + .collect::>>()?; + let first_version = object.get_first_version(); + let ledger_info_with_signatures = object + .ledger_info_with_signatures + .take() + .map(LedgerInfoWithSignatures::from_proto) + .transpose()?; + + Ok(Self { + txns_to_commit, + first_version, + ledger_info_with_signatures, + }) + } +} + +impl IntoProto for SaveTransactionsRequest { + type ProtoType = crate::proto::storage::SaveTransactionsRequest; + + fn into_proto(self) -> Self::ProtoType { + let mut proto = Self::ProtoType::new(); + proto.set_txns_to_commit(::protobuf::RepeatedField::from_vec( + self.txns_to_commit + .into_iter() + .map(TransactionToCommit::into_proto) + .collect::>(), + )); + proto.set_first_version(self.first_version); + if let Some(x) = self.ledger_info_with_signatures { + proto.set_ledger_info_with_signatures(x.into_proto()) + } + + proto + } +} + +/// Helper to construct and parse [`proto::storage::GetTransactionsRequest`] +/// +/// It does so by implementing [`IntoProto`](#impl-IntoProto) and [`FromProto`](#impl-FromProto), +/// providing [`into_proto`](IntoProto::into_proto) and [`from_proto`](FromProto::from_proto). +#[derive(Arbitrary, Clone, Debug, Eq, PartialEq)] +pub struct GetTransactionsRequest { + pub start_version: Version, + pub batch_size: u64, + pub ledger_version: Version, + pub fetch_events: bool, +} + +impl GetTransactionsRequest { + /// Constructor. + pub fn new( + start_version: Version, + batch_size: u64, + ledger_version: Version, + fetch_events: bool, + ) -> Self { + GetTransactionsRequest { + start_version, + batch_size, + ledger_version, + fetch_events, + } + } +} + +impl FromProto for GetTransactionsRequest { + type ProtoType = crate::proto::storage::GetTransactionsRequest; + + fn from_proto(object: Self::ProtoType) -> Result { + Ok(GetTransactionsRequest { + start_version: object.get_start_version(), + batch_size: object.get_batch_size(), + ledger_version: object.get_ledger_version(), + fetch_events: object.get_fetch_events(), + }) + } +} + +impl IntoProto for GetTransactionsRequest { + type ProtoType = crate::proto::storage::GetTransactionsRequest; + + fn into_proto(self) -> Self::ProtoType { + let mut out = Self::ProtoType::new(); + out.set_start_version(self.start_version); + out.set_batch_size(self.batch_size); + out.set_ledger_version(self.ledger_version); + out.set_fetch_events(self.fetch_events); + out + } +} + +/// Helper to construct and parse [`proto::storage::GetTransactionsResponse`] +/// +/// It does so by implementing [`IntoProto`](#impl-IntoProto) and [`FromProto`](#impl-FromProto), +/// providing [`into_proto`](IntoProto::into_proto) and [`from_proto`](FromProto::from_proto). +#[derive(Arbitrary, Clone, Debug, Eq, PartialEq, FromProto, IntoProto)] +#[ProtoType(crate::proto::storage::GetTransactionsResponse)] +pub struct GetTransactionsResponse { + pub txn_list_with_proof: TransactionListWithProof, +} + +impl GetTransactionsResponse { + /// Constructor. + pub fn new(txn_list_with_proof: TransactionListWithProof) -> Self { + GetTransactionsResponse { + txn_list_with_proof, + } + } +} + +/// Helper to construct and parse [`proto::storage::ExecutorStartupInfo`] +/// +/// It does so by implementing [`IntoProto`](#impl-IntoProto) and [`FromProto`](#impl-FromProto), +/// providing [`into_proto`](IntoProto::into_proto) and [`from_proto`](FromProto::from_proto). +#[derive(Arbitrary, Clone, Debug, Eq, PartialEq)] +pub struct ExecutorStartupInfo { + pub ledger_info: LedgerInfo, + pub latest_version: Version, + pub account_state_root_hash: HashValue, + pub ledger_frozen_subtree_hashes: Vec, +} + +impl FromProto for ExecutorStartupInfo { + type ProtoType = crate::proto::storage::ExecutorStartupInfo; + + fn from_proto(mut object: Self::ProtoType) -> Result { + let ledger_info = LedgerInfo::from_proto(object.take_ledger_info())?; + let latest_version = object.get_latest_version(); + let account_state_root_hash = HashValue::from_proto(object.take_account_state_root_hash())?; + let ledger_frozen_subtree_hashes = object + .take_ledger_frozen_subtree_hashes() + .into_iter() + .map(HashValue::from_proto) + .collect::>>()?; + + Ok(Self { + ledger_info, + latest_version, + account_state_root_hash, + ledger_frozen_subtree_hashes, + }) + } +} + +impl IntoProto for ExecutorStartupInfo { + type ProtoType = crate::proto::storage::ExecutorStartupInfo; + + fn into_proto(self) -> Self::ProtoType { + let mut proto = Self::ProtoType::new(); + proto.set_ledger_info(self.ledger_info.into_proto()); + proto.set_latest_version(self.latest_version); + proto.set_account_state_root_hash(self.account_state_root_hash.into_proto()); + proto.set_ledger_frozen_subtree_hashes(protobuf::RepeatedField::from_vec( + self.ledger_frozen_subtree_hashes + .into_iter() + .map(HashValue::into_proto) + .collect::>(), + )); + proto + } +} + +/// Helper to construct and parse [`proto::storage::GetExecutorStartupInfoResponse`] +/// +/// It does so by implementing [`IntoProto`](#impl-IntoProto) and [`FromProto`](#impl-FromProto), +/// providing [`into_proto`](IntoProto::into_proto) and [`from_proto`](FromProto::from_proto). +#[derive(Arbitrary, Clone, Debug, Eq, PartialEq)] +pub struct GetExecutorStartupInfoResponse { + pub info: Option, +} + +impl FromProto for GetExecutorStartupInfoResponse { + type ProtoType = crate::proto::storage::GetExecutorStartupInfoResponse; + + fn from_proto(mut object: Self::ProtoType) -> Result { + let info = if object.has_info() { + Some(ExecutorStartupInfo::from_proto(object.take_info())?) + } else { + None + }; + + Ok(Self { info }) + } +} + +impl IntoProto for GetExecutorStartupInfoResponse { + type ProtoType = crate::proto::storage::GetExecutorStartupInfoResponse; + + fn into_proto(self) -> Self::ProtoType { + let mut proto = Self::ProtoType::new(); + if let Some(info) = self.info { + proto.set_info(info.into_proto()) + } + proto + } +} + +pub mod prelude { + pub use super::*; +} + +#[cfg(test)] +mod tests; diff --git a/storage/storage_proto/src/proto/mod.rs b/storage/storage_proto/src/proto/mod.rs new file mode 100644 index 0000000000000..8eef33a729d20 --- /dev/null +++ b/storage/storage_proto/src/proto/mod.rs @@ -0,0 +1,7 @@ +// Copyright (c) The Libra Core Contributors +// SPDX-License-Identifier: Apache-2.0 + +use types::proto::{account_state_blob, get_with_proof, ledger_info, proof, transaction}; + +pub mod storage; +pub mod storage_grpc; diff --git a/storage/storage_proto/src/proto/storage.proto b/storage/storage_proto/src/proto/storage.proto new file mode 100644 index 0000000000000..8e4453df5fcda --- /dev/null +++ b/storage/storage_proto/src/proto/storage.proto @@ -0,0 +1,115 @@ +// Copyright (c) The Libra Core Contributors +// SPDX-License-Identifier: Apache-2.0 + +syntax = "proto3"; + +package storage; + +import "get_with_proof.proto"; +import "ledger_info.proto"; +import "transaction.proto"; +import "account_state_blob.proto"; +import "proof.proto"; + +// ----------------------------------------------------------------------------- +// ---------------- Service definition for storage +// ----------------------------------------------------------------------------- +service Storage { + // Write APIs. + + // Persist transactions. Called by Execution when either syncing nodes or + // committing blocks during normal operation. + rpc SaveTransactions(SaveTransactionsRequest) + returns (SaveTransactionsResponse); + + // Read APIs. + + // Used to get a piece of data and return the proof of it. If the client + // knows and trusts a ledger info at version v, it should pass v in as the + // client_known_version and we will return the latest ledger info together + // with the proof that it derives from v. + rpc UpdateToLatestLedger( + types.UpdateToLatestLedgerRequest) + returns (types.UpdateToLatestLedgerResponse); + + // When we receive a request from a peer validator asking a list of + // transactions for state synchronization, this API can be used to serve the + // request. Note that the peer should specify a ledger version and all proofs + // in the response will be relative to this given ledger version. + rpc GetTransactions(GetTransactionsRequest) returns (GetTransactionsResponse); + + rpc GetAccountStateWithProofByStateRoot( + GetAccountStateWithProofByStateRootRequest) + returns (GetAccountStateWithProofByStateRootResponse); + + // Returns information needed for Executor to start up. + rpc GetExecutorStartupInfo(GetExecutorStartupInfoRequest) + returns (GetExecutorStartupInfoResponse); +} + +message SaveTransactionsRequest { + // Transactions to persist. + repeated types.TransactionToCommit txns_to_commit = 1; + + // The version of the first transaction in `txns_to_commit`. + uint64 first_version = 2; + + // If this is set, Storage will check its state after applying the above + // transactions matches info in this LedgerInfo before committing otherwise + // it denies the request. + types.LedgerInfoWithSignatures ledger_info_with_signatures = 3; +} + +message SaveTransactionsResponse {} + +message GetTransactionsRequest { + // The version to start with. + uint64 start_version = 1; + // The size of the transaction batch. + uint64 batch_size = 2; + // All the proofs returned in the response should be relative to this + // given verison. + uint64 ledger_version = 3; + // Used to return the events associated with each transaction + bool fetch_events = 4; +} + +message GetTransactionsResponse { + types.TransactionListWithProof txn_list_with_proof = 1; +} + +message GetAccountStateWithProofByStateRootRequest { + /// The account address to query with. + bytes address = 1; + + /// The state root hash the query is based on. + bytes state_root_hash = 2; +} + +message GetAccountStateWithProofByStateRootResponse { + /// The optional blob of account state blob. + types.AccountStateBlob account_state_blob = 1; + + /// The state root hash the query is based on. + types.SparseMerkleProof sparse_merkle_proof = 2; +} + +message GetExecutorStartupInfoRequest {} + +message GetExecutorStartupInfoResponse { + // When this is empty, Storage needs to be bootstrapped via the bootstrap API + ExecutorStartupInfo info = 1; +} + +message ExecutorStartupInfo { + // The latest LedgerInfo. Note that at start up storage can have more + // transactions than the latest LedgerInfo indicates due to an incomplete + // start up sync. + types.LedgerInfo ledger_info = 1; + // The latest version. All fields below are based on this version. + uint64 latest_version = 2; + // The latest account state root hash. + bytes account_state_root_hash = 3; + // From left to right, root hashes of all frozen subtrees. + repeated bytes ledger_frozen_subtree_hashes = 4; +} diff --git a/storage/storage_proto/src/tests.rs b/storage/storage_proto/src/tests.rs new file mode 100644 index 0000000000000..c1d393cde7f8a --- /dev/null +++ b/storage/storage_proto/src/tests.rs @@ -0,0 +1,35 @@ +// Copyright (c) The Libra Core Contributors +// SPDX-License-Identifier: Apache-2.0 + +use super::*; +use proptest::prelude::*; +use proto_conv::test_helper::assert_protobuf_encode_decode; + +proptest! { + #![proptest_config(ProptestConfig::with_cases(10))] + + #[test] + fn test_save_transactions_request(req in any::()) { + assert_protobuf_encode_decode(&req); + } + + #[test] + fn test_get_transactions_request(req in any::()) { + assert_protobuf_encode_decode(&req); + } + + #[test] + fn test_get_transactions_response(resp in any::()) { + assert_protobuf_encode_decode(&resp); + } + + #[test] + fn test_executor_startup_info(executor_startup_info in any::()) { + assert_protobuf_encode_decode(&executor_startup_info); + } + + #[test] + fn test_get_executor_startup_info_response(res in any::()) { + assert_protobuf_encode_decode(&res); + } +} diff --git a/storage/storage_service/Cargo.toml b/storage/storage_service/Cargo.toml new file mode 100644 index 0000000000000..358fe18d113cd --- /dev/null +++ b/storage/storage_service/Cargo.toml @@ -0,0 +1,32 @@ +[package] +name = "storage_service" +version = "0.1.0" +authors = ["Libra Association "] +license = "Apache-2.0" +publish = false +edition = "2018" + +[dependencies] +futures = { version = "0.3.0-alpha.13", package = "futures-preview", features = ["compat"] } +grpcio = "0.4.4" +protobuf = "2.6" + +canonical_serialization = { path = "../../common/canonical_serialization" } +config = { path = "../../config" } +crypto = { path = "../../crypto/legacy_crypto" } +debug_interface = { path = "../../common/debug_interface" } +executable_helpers = { path = "../../common/executable_helpers"} +failure = { path = "../../common/failure_ext", package = "failure_ext" } +grpc_helpers = { path = "../../common/grpc_helpers" } +libradb = { path = "../libradb" } +logger = { path = "../../common/logger" } +metrics = { path = "../../common/metrics" } +proto_conv = { path = "../../common/proto_conv", features = ["derive"] } +storage_client = { path = "../storage_client" } +storage_proto = { path = "../storage_proto" } +types = { path = "../../types" } + +[dev-dependencies] +itertools = "0.8.0" +proptest = "0.9.2" +tempfile = "3.0.4" diff --git a/storage/storage_service/src/lib.rs b/storage/storage_service/src/lib.rs new file mode 100644 index 0000000000000..63afc78f71e42 --- /dev/null +++ b/storage/storage_service/src/lib.rs @@ -0,0 +1,279 @@ +// Copyright (c) The Libra Core Contributors +// SPDX-License-Identifier: Apache-2.0 + +//! This crate implements the storage [GRPC](http://grpc.io) service. +//! +//! The user of storage service is supposed to use it via client lib provided in +//! [`storage_client`](../storage_client/index.html) instead of via +//! [`StorageClient`](../storage_proto/proto/storage_grpc/struct.StorageClient.html) directly. + +pub mod mocks; + +use config::config::NodeConfig; +use failure::prelude::*; +use grpc_helpers::{provide_grpc_response, spawn_service_thread_with_drop_closure, ServerHandle}; +use libradb::LibraDB; +use logger::prelude::*; +use metrics::counters::SVC_COUNTERS; +use proto_conv::{FromProto, IntoProto}; +use std::{ + ops::Deref, + path::Path, + sync::{mpsc, Arc, Mutex}, +}; +use storage_proto::proto::{ + storage::{ + GetAccountStateWithProofByStateRootRequest, GetAccountStateWithProofByStateRootResponse, + GetExecutorStartupInfoRequest, GetExecutorStartupInfoResponse, GetTransactionsRequest, + GetTransactionsResponse, SaveTransactionsRequest, SaveTransactionsResponse, + }, + storage_grpc::{create_storage, Storage}, +}; +use types::proto::get_with_proof::{UpdateToLatestLedgerRequest, UpdateToLatestLedgerResponse}; + +/// Starts storage service according to config. +pub fn start_storage_service(config: &NodeConfig) -> ServerHandle { + let (storage_service, shutdown_receiver) = StorageService::new(&config.storage.get_dir()); + spawn_service_thread_with_drop_closure( + create_storage(storage_service), + config.storage.address.clone(), + config.storage.port, + "storage", + move || { + shutdown_receiver + .recv() + .expect("Failed to receive on shutdown channel when storage service was dropped") + }, + ) +} + +/// The implementation of the storage [GRPC](http://grpc.io) service. +/// +/// It serves [`LibraDB`] APIs over the network. See API documentation in [`storage_proto`] and +/// [`LibraDB`]. +#[derive(Clone)] +pub struct StorageService { + db: Arc, +} + +/// When dropping GRPC server we want to wait until LibraDB is dropped first, so the RocksDB +/// instance held by GRPC threads is closed before the main function of GRPC server +/// finishes. Otherwise, if we don't manually guarantee this, some thread(s) may still be +/// alive holding an Arc pointer to LibraDB after main function of GRPC server returns. +/// Having this wrapper with a channel gives us a way to signal the receiving end that all GRPC +/// server threads are joined so RocksDB is closed. +/// +/// See these links for more details. +/// https://github.com/pingcap/grpc-rs/issues/227 +/// https://github.com/facebook/rocksdb/issues/649 +struct LibraDBWrapper { + db: Option, + shutdown_sender: Mutex>, +} + +impl LibraDBWrapper { + pub fn new>(path: &P) -> (Self, mpsc::Receiver<()>) { + let db = LibraDB::new(path); + let (shutdown_sender, shutdown_receiver) = mpsc::channel(); + ( + Self { + db: Some(db), + shutdown_sender: Mutex::new(shutdown_sender), + }, + shutdown_receiver, + ) + } +} + +impl Drop for LibraDBWrapper { + fn drop(&mut self) { + // Drop inner LibraDB instance. + self.db.take(); + // Send the shutdown message after DB is dropped. + self.shutdown_sender + .lock() + .expect("Failed to lock mutex.") + .send(()) + .expect("Failed to send shutdown message."); + } +} + +impl Deref for LibraDBWrapper { + type Target = LibraDB; + + fn deref(&self) -> &Self::Target { + self.db.as_ref().expect("LibraDB is dropped unexptectedly") + } +} + +impl StorageService { + /// This opens a [`LibraDB`] at `path` and returns a [`StorageService`] instance serving it. + /// + /// A receiver side of a channel is also returned through which one can receive a notice after + /// all resources used by the service including the underlying [`LibraDB`] instance are + /// fully dropped. + /// + /// example: + /// ```no_run, + /// # use storage_service::*; + /// # use std::path::Path; + /// let (service, shutdown_receiver) = StorageService::new(&Path::new("path/to/db")); + /// + /// drop(service); + /// shutdown_receiver.recv().expect("recv() should succeed."); + /// + /// // LibraDB instance is guaranteed to be properly dropped at this point. + /// ``` + pub fn new>(path: &P) -> (Self, mpsc::Receiver<()>) { + let (db_wrapper, shutdown_receiver) = LibraDBWrapper::new(path); + ( + Self { + db: Arc::new(db_wrapper), + }, + shutdown_receiver, + ) + } +} + +impl StorageService { + fn update_to_latest_ledger_inner( + &self, + req: UpdateToLatestLedgerRequest, + ) -> Result { + let rust_req = types::get_with_proof::UpdateToLatestLedgerRequest::from_proto(req)?; + + let (response_items, ledger_info_with_sigs, validator_change_events) = self + .db + .update_to_latest_ledger(rust_req.client_known_version, rust_req.requested_items)?; + + let rust_resp = types::get_with_proof::UpdateToLatestLedgerResponse { + response_items, + ledger_info_with_sigs, + validator_change_events, + }; + + Ok(rust_resp.into_proto()) + } + + fn get_transactions_inner( + &self, + req: GetTransactionsRequest, + ) -> Result { + let rust_req = storage_proto::GetTransactionsRequest::from_proto(req)?; + + let txn_list_with_proof = self.db.get_transactions( + rust_req.start_version, + rust_req.batch_size, + rust_req.ledger_version, + rust_req.fetch_events, + )?; + + let rust_resp = storage_proto::GetTransactionsResponse::new(txn_list_with_proof); + + Ok(rust_resp.into_proto()) + } + + fn get_account_state_with_proof_by_state_root_inner( + &self, + req: GetAccountStateWithProofByStateRootRequest, + ) -> Result { + let rust_req = storage_proto::GetAccountStateWithProofByStateRootRequest::from_proto(req)?; + + let (account_state_blob, sparse_merkle_proof) = + self.db.get_account_state_with_proof_by_state_root( + rust_req.address, + rust_req.state_root_hash, + )?; + + let rust_resp = storage_proto::GetAccountStateWithProofByStateRootResponse { + account_state_blob, + sparse_merkle_proof, + }; + + Ok(rust_resp.into_proto()) + } + + fn save_transactions_inner( + &self, + req: SaveTransactionsRequest, + ) -> Result { + let rust_req = storage_proto::SaveTransactionsRequest::from_proto(req)?; + self.db.save_transactions( + &rust_req.txns_to_commit, + rust_req.first_version, + &rust_req.ledger_info_with_signatures, + )?; + Ok(SaveTransactionsResponse::new()) + } + + fn get_executor_startup_info_inner(&self) -> Result { + let info = self.db.get_executor_startup_info()?; + let rust_resp = storage_proto::GetExecutorStartupInfoResponse { info }; + Ok(rust_resp.into_proto()) + } +} + +impl Storage for StorageService { + fn update_to_latest_ledger( + &mut self, + ctx: grpcio::RpcContext<'_>, + req: UpdateToLatestLedgerRequest, + sink: grpcio::UnarySink, + ) { + debug!("[GRPC] Storage::update_to_latest_ledger"); + let _timer = SVC_COUNTERS.req(&ctx); + let resp = self.update_to_latest_ledger_inner(req); + provide_grpc_response(resp, ctx, sink); + } + + fn get_transactions( + &mut self, + ctx: grpcio::RpcContext, + req: GetTransactionsRequest, + sink: grpcio::UnarySink, + ) { + debug!("[GRPC] Storage::get_transactions"); + let _timer = SVC_COUNTERS.req(&ctx); + let resp = self.get_transactions_inner(req); + provide_grpc_response(resp, ctx, sink); + } + + fn get_account_state_with_proof_by_state_root( + &mut self, + ctx: grpcio::RpcContext, + req: GetAccountStateWithProofByStateRootRequest, + sink: grpcio::UnarySink, + ) { + debug!("[GRPC] Storage::get_account_state_with_proof_by_state_root"); + let _timer = SVC_COUNTERS.req(&ctx); + let resp = self.get_account_state_with_proof_by_state_root_inner(req); + provide_grpc_response(resp, ctx, sink); + } + + fn save_transactions( + &mut self, + ctx: grpcio::RpcContext, + req: SaveTransactionsRequest, + sink: grpcio::UnarySink, + ) { + debug!("[GRPC] Storage::save_transactions"); + let _timer = SVC_COUNTERS.req(&ctx); + let resp = self.save_transactions_inner(req); + provide_grpc_response(resp, ctx, sink); + } + + fn get_executor_startup_info( + &mut self, + ctx: grpcio::RpcContext, + _req: GetExecutorStartupInfoRequest, + sink: grpcio::UnarySink, + ) { + debug!("[GRPC] Storage::get_executor_startup_info"); + let _timer = SVC_COUNTERS.req(&ctx); + let resp = self.get_executor_startup_info_inner(); + provide_grpc_response(resp, ctx, sink); + } +} + +#[cfg(test)] +mod storage_service_test; diff --git a/storage/storage_service/src/main.rs b/storage/storage_service/src/main.rs new file mode 100644 index 0000000000000..8bfc77d94c110 --- /dev/null +++ b/storage/storage_service/src/main.rs @@ -0,0 +1,61 @@ +// Copyright (c) The Libra Core Contributors +// SPDX-License-Identifier: Apache-2.0 + +use executable_helpers::helpers::{ + setup_executable, ARG_CONFIG_PATH, ARG_DISABLE_LOGGING, ARG_PEER_ID, +}; +use std::thread; + +use config::config::NodeConfig; +use debug_interface::{node_debug_service::NodeDebugService, proto::node_debug_interface_grpc}; +use failure::prelude::*; +use grpc_helpers::spawn_service_thread; +use logger::prelude::*; + +pub struct StorageNode { + node_config: NodeConfig, +} + +impl Drop for StorageNode { + fn drop(&mut self) { + info!("Drop StorageNode"); + } +} + +impl StorageNode { + pub fn new(node_config: NodeConfig) -> Self { + StorageNode { node_config } + } + + pub fn run(&self) -> Result<()> { + info!("Starting storage node"); + + let _handle = storage_service::start_storage_service(&self.node_config); + + // Start Debug interface + let debug_service = + node_debug_interface_grpc::create_node_debug_interface(NodeDebugService::new()); + let _debug_handle = spawn_service_thread( + debug_service, + self.node_config.storage.address.clone(), + self.node_config.debug_interface.storage_node_debug_port, + "debug_service", + ); + + info!("Started Storage Service"); + loop { + thread::park(); + } + } +} + +fn main() { + let (config, _logger, _args) = setup_executable( + "Libra Storage node".to_string(), + vec![ARG_PEER_ID, ARG_CONFIG_PATH, ARG_DISABLE_LOGGING], + ); + + let storage_node = StorageNode::new(config); + + storage_node.run().expect("Unable to run storage node"); +} diff --git a/storage/storage_service/src/mocks/mock_storage_client.rs b/storage/storage_service/src/mocks/mock_storage_client.rs new file mode 100644 index 0000000000000..b6e044b38ff0b --- /dev/null +++ b/storage/storage_service/src/mocks/mock_storage_client.rs @@ -0,0 +1,265 @@ +// Copyright (c) The Libra Core Contributors +// SPDX-License-Identifier: Apache-2.0 + +//! This module provides mock storage clients for tests. + +use canonical_serialization::SimpleSerializer; +use crypto::{signing::generate_keypair, HashValue}; +use failure::prelude::*; +use futures::prelude::*; +use proto_conv::{FromProto, IntoProto}; +use std::{collections::BTreeMap, pin::Pin}; +use storage_client::StorageRead; +use storage_proto::ExecutorStartupInfo; +use types::{ + account_address::{AccountAddress, ADDRESS_LENGTH}, + account_state_blob::AccountStateBlob, + get_with_proof::{RequestItem, ResponseItem}, + ledger_info::LedgerInfoWithSignatures, + proof::definition::SparseMerkleProof, + proto::{ + account_state_blob::AccountStateWithProof, + get_with_proof::{ + GetAccountStateResponse, GetTransactionsResponse, RequestItem as ProtoRequestItem, + RequestItem_oneof_requested_items, ResponseItem as ProtoResponseItem, + UpdateToLatestLedgerRequest, UpdateToLatestLedgerResponse, + }, + ledger_info::LedgerInfoWithSignatures as ProtoLedgerInfoWithSignatures, + proof::AccumulatorProof, + transaction::TransactionListWithProof, + transaction_info::TransactionInfo, + }, + test_helpers::transaction_test_helpers::get_test_signed_txn, + transaction::Version, + validator_change::ValidatorChangeEventWithProof, +}; + +/// This is a mock of the storage read client used in tests. +/// +/// See the real +/// [`StorageReadServiceClient`](../../../storage_client/struct.StorageReadServiceClient.html). +#[derive(Clone)] +pub struct MockStorageReadClient; + +impl StorageRead for MockStorageReadClient { + fn update_to_latest_ledger( + &self, + client_known_version: Version, + request_items: Vec, + ) -> Result<( + Vec, + LedgerInfoWithSignatures, + Vec, + )> { + let request = types::get_with_proof::UpdateToLatestLedgerRequest::new( + client_known_version, + request_items, + ); + let proto_request = request.into_proto(); + let proto_response = get_mock_update_to_latest_ledger(&proto_request); + let response = + types::get_with_proof::UpdateToLatestLedgerResponse::from_proto(proto_response)?; + Ok(( + response.response_items, + response.ledger_info_with_sigs, + response.validator_change_events, + )) + } + + fn update_to_latest_ledger_async( + &self, + client_known_version: Version, + request_items: Vec, + ) -> Pin< + Box< + dyn Future< + Output = Result<( + Vec, + LedgerInfoWithSignatures, + Vec, + )>, + > + Send, + >, + > { + futures::future::ok( + self.update_to_latest_ledger(client_known_version, request_items) + .unwrap(), + ) + .boxed() + } + + fn get_transactions( + &self, + _start_version: Version, + _batch_size: u64, + _ledger_version: Version, + _fetch_events: bool, + ) -> Result { + unimplemented!() + } + + fn get_transactions_async( + &self, + _start_version: Version, + _batch_size: u64, + _ledger_version: Version, + _fetch_events: bool, + ) -> Pin> + Send>> + { + unimplemented!() + } + + fn get_account_state_with_proof_by_state_root( + &self, + _address: AccountAddress, + _state_root_hash: HashValue, + ) -> Result<(Option, SparseMerkleProof)> { + unimplemented!() + } + + fn get_account_state_with_proof_by_state_root_async( + &self, + _address: AccountAddress, + _state_root_hash: HashValue, + ) -> Pin, SparseMerkleProof)>> + Send>> + { + unimplemented!(); + } + + fn get_executor_startup_info(&self) -> Result> { + unimplemented!() + } + + fn get_executor_startup_info_async( + &self, + ) -> Pin>> + Send>> { + unimplemented!() + } +} + +fn get_mock_update_to_latest_ledger( + req: &UpdateToLatestLedgerRequest, +) -> UpdateToLatestLedgerResponse { + let mut resp = UpdateToLatestLedgerResponse::new(); + for request_item in req.get_requested_items().iter() { + resp.mut_response_items() + .push(get_mock_response_item(request_item).unwrap()); + } + let mut ledger_info = types::proto::ledger_info::LedgerInfo::new(); + ledger_info.set_transaction_accumulator_hash(HashValue::zero().to_vec()); + ledger_info.set_consensus_data_hash(HashValue::zero().to_vec()); + ledger_info.set_consensus_block_id(HashValue::zero().to_vec()); + ledger_info.set_version(7); + let mut ledger_info_with_sigs = ProtoLedgerInfoWithSignatures::new(); + ledger_info_with_sigs.set_ledger_info(ledger_info); + resp.set_ledger_info_with_sigs(ledger_info_with_sigs); + resp +} + +fn get_mock_response_item(request_item: &ProtoRequestItem) -> Result { + let mut response_item = ProtoResponseItem::new(); + if let Some(ref requested_item) = request_item.requested_items { + match requested_item { + RequestItem_oneof_requested_items::get_account_state_request(_request) => { + let mut resp = GetAccountStateResponse::new(); + let mut version_data = BTreeMap::new(); + + let account_resource = types::account_config::AccountResource::new( + 100, + 0, + types::byte_array::ByteArray::new(vec![]), + 0, + 0, + ); + version_data.insert( + types::account_config::account_resource_path(), + SimpleSerializer::serialize(&account_resource)?, + ); + let mut account_state_with_proof = AccountStateWithProof::new(); + let blob = AccountStateBlob::from( + SimpleSerializer::>::serialize(&version_data)? + ).into_proto(); + let proof = { + let ledger_info_to_transaction_info_proof = types::proof::AccumulatorProof::new(vec![]); + let transaction_info = types::transaction::TransactionInfo::new( + HashValue::zero(), + HashValue::zero(), + HashValue::zero(), + 0, + ); + let transaction_info_to_account_proof = types::proof::SparseMerkleProof::new(None, vec![]); + types::proof::AccountStateProof::new( + ledger_info_to_transaction_info_proof, + transaction_info, + transaction_info_to_account_proof, + ).into_proto() + }; + account_state_with_proof.set_blob(blob); + account_state_with_proof.set_proof(proof); + resp.set_account_state_with_proof(account_state_with_proof); + response_item.set_get_account_state_response(resp); + } + RequestItem_oneof_requested_items::get_account_transaction_by_sequence_number_request(_request) => { + unimplemented!(); + } + RequestItem_oneof_requested_items::get_events_by_event_access_path_request(_request) => { + unimplemented!(); + } + RequestItem_oneof_requested_items::get_transactions_request(request) => { + let mut ret = TransactionListWithProof::new(); + let sender = AccountAddress::new([1; ADDRESS_LENGTH]); + if request.limit > 0 { + let (txns, infos) = get_mock_txn_data(sender, 0, request.limit - 1); + if !txns.is_empty() { + ret.set_proof_of_first_transaction(get_accumulator_proof()); + } + if txns.len() >= 2 { + ret.set_proof_of_last_transaction(get_accumulator_proof()); + } + ret.set_transactions(protobuf::RepeatedField::from_vec(txns)); + ret.set_infos(protobuf::RepeatedField::from_vec(infos)); + } + + let mut resp = GetTransactionsResponse::new(); + resp.set_txn_list_with_proof(ret); + + response_item.set_get_transactions_response(resp); + } + } + } + Ok(response_item) +} + +fn get_mock_txn_data( + address: AccountAddress, + start_seq: u64, + end_seq: u64, +) -> ( + Vec, + Vec, +) { + let (priv_key, pub_key) = generate_keypair(); + let mut txns = vec![]; + let mut infos = vec![]; + for i in start_seq..=end_seq { + let signed_txn = get_test_signed_txn(address, i, priv_key.clone(), pub_key, None); + txns.push(signed_txn); + + let info = get_transaction_info().into_proto(); + infos.push(info); + } + (txns, infos) +} + +fn get_accumulator_proof() -> AccumulatorProof { + types::proof::AccumulatorProof::new(vec![]).into_proto() +} + +fn get_transaction_info() -> types::transaction::TransactionInfo { + types::transaction::TransactionInfo::new( + HashValue::zero(), + HashValue::zero(), + HashValue::zero(), + 0, + ) +} diff --git a/storage/storage_service/src/mocks/mod.rs b/storage/storage_service/src/mocks/mod.rs new file mode 100644 index 0000000000000..c98e2be630b7b --- /dev/null +++ b/storage/storage_service/src/mocks/mod.rs @@ -0,0 +1,6 @@ +// Copyright (c) The Libra Core Contributors +// SPDX-License-Identifier: Apache-2.0 + +//! This module provides mocks of the storage components for tests. + +pub mod mock_storage_client; diff --git a/storage/storage_service/src/storage_service_test.rs b/storage/storage_service/src/storage_service_test.rs new file mode 100644 index 0000000000000..50bb089c27929 --- /dev/null +++ b/storage/storage_service/src/storage_service_test.rs @@ -0,0 +1,106 @@ +// Copyright (c) The Libra Core Contributors +// SPDX-License-Identifier: Apache-2.0 + +use super::*; +use config::config::NodeConfigHelpers; +use grpcio::EnvBuilder; +use itertools::zip_eq; +use libradb::{mock_genesis::db_with_mock_genesis, test_helper::arb_blocks_to_commit}; +use proptest::prelude::*; +use std::collections::HashMap; +use storage_client::{ + StorageRead, StorageReadServiceClient, StorageWrite, StorageWriteServiceClient, +}; +use types::get_with_proof::{RequestItem, ResponseItem}; + +fn start_test_storage_with_read_write_client( + need_to_use_genesis: bool, +) -> ( + tempfile::TempDir, + ServerHandle, + StorageReadServiceClient, + StorageWriteServiceClient, +) { + let mut config = NodeConfigHelpers::get_single_node_test_config(/* random_ports = */ true); + let tmp_dir = tempfile::tempdir().unwrap(); + config.storage.dir = tmp_dir.path().to_path_buf(); + + // initialize db with genesis info. + if need_to_use_genesis { + db_with_mock_genesis(&tmp_dir).unwrap(); + } else { + LibraDB::new(&tmp_dir); + } + let storage_server_handle = start_storage_service(&config); + + let read_client = StorageReadServiceClient::new( + Arc::new(EnvBuilder::new().build()), + &config.storage.address, + config.storage.port, + ); + let write_client = StorageWriteServiceClient::new( + Arc::new(EnvBuilder::new().build()), + &config.storage.address, + config.storage.port, + ); + (tmp_dir, storage_server_handle, read_client, write_client) +} + +proptest! { + #![proptest_config(ProptestConfig::with_cases(10))] + + #[test] + fn test_storage_service_basic(blocks in arb_blocks_to_commit().no_shrink()) { + let(_tmp_dir, _server_handler, read_client, write_client) = + start_test_storage_with_read_write_client(/* need_to_use_genesis = */ true); + + let mut version = 0; + for (txns_to_commit, ledger_info_with_sigs) in &blocks { + write_client + .save_transactions(txns_to_commit.clone(), + version + 1, /* first_version */ + Some(ledger_info_with_sigs.clone()), + ).unwrap(); + version += txns_to_commit.len() as u64; + let mut account_states = HashMap::new(); + // Get the ground truth of account states. + txns_to_commit + .iter() + .for_each(|txn_to_commit| + account_states.extend(txn_to_commit + .account_states() + .clone()) + ); + + let account_state_request_items = account_states + .keys() + .map(|address| RequestItem::GetAccountState{ + address: *address, + }).collect::>(); + let ( + response_items, + response_ledger_info_with_sigs, + _validator_change_events + ) = read_client + .update_to_latest_ledger(0, account_state_request_items).unwrap(); + for ((address, blob), response_item) in zip_eq(account_states, response_items) { + match response_item { + ResponseItem::GetAccountState { + account_state_with_proof, + } => { + prop_assert_eq!(&Some(blob), &account_state_with_proof.blob); + prop_assert!(account_state_with_proof.verify( + response_ledger_info_with_sigs.ledger_info(), + version, + address, + ).is_ok()) + } + _ => unreachable!() + } + } + + // Assert ledger info. + prop_assert_eq!(ledger_info_with_sigs, &response_ledger_info_with_sigs); + } + } +} diff --git a/terraform/auth.tf b/terraform/auth.tf new file mode 100644 index 0000000000000..b61b21362268b --- /dev/null +++ b/terraform/auth.tf @@ -0,0 +1,96 @@ +resource "aws_iam_role" "ecsInstanceRole" { + name = "${terraform.workspace}-ecsInstanceRole" + + assume_role_policy = <> /etc/ecs/ecs.config + +curl -o /tmp/node_exporter.rpm https://copr-be.cloud.fedoraproject.org/results/ibotty/prometheus-exporters/epel-7-x86_64/00935314-golang-github-prometheus-node_exporter/golang-github-prometheus-node_exporter-0.18.1-6.el7.x86_64.rpm +yum install -y /tmp/node_exporter.rpm +systemctl start node_exporter + +cat > /etc/cron.d/metric_collector <<"EOF" +* * * * * root docker container ls -q --filter label=com.amazonaws.ecs.container-name | xargs docker inspect --format='{{.State.StartedAt}}' | xargs date +"\%s" -d | xargs echo "ecs_start_time_seconds " > /var/lib/node_exporter/textfile_collector/ecs_stats.prom + +* * * * * root docker container ls -q --filter label=com.amazonaws.ecs.container-name | xargs docker inspect --format='{{$tags := .Config.Labels}}build_info{revision="{{index $tags "org.label-schema.vcs-ref"}}", upstream="{{index $tags "vcs-upstream"}}"} 1' > /var/lib/node_exporter/textfile_collector/build_info.prom +EOF + +yum -y install ngrep tcpdump perf gdb nmap-ncat strace htop sysstat diff --git a/terraform/templates/faucet.json b/terraform/templates/faucet.json new file mode 100644 index 0000000000000..900de276c1b29 --- /dev/null +++ b/terraform/templates/faucet.json @@ -0,0 +1,29 @@ +[ + { + "name": "faucet", + "image": "${faucet_image_repo}${faucet_image_tag_str}", + "cpu": 2048, + "memory": 3883, + "essential": true, + "portMappings": [ + {"containerPort": 8000, "hostPort": 8000} + ], + "environment": [ + {"name": "AC_PORT", "value": "30307"}, + {"name": "AC_HOST", "value": "${ac_hosts}"}, + {"name": "TRUSTED_PEERS", "value": ${trusted_peers}}, + {"name": "LOG_LEVEL", "value": "${log_level}"} + ], + "secrets": [ + {"name": "MINT_KEY", "valueFrom": "${secret}"} + ], + "logConfiguration": { + "logDriver": "awslogs", + "options": { + "awslogs-group": "${log_group}", + "awslogs-region": "${log_region}", + "awslogs-stream-prefix": "${log_prefix}" + } + } + } +] diff --git a/terraform/templates/grafana-dashboards.yml b/terraform/templates/grafana-dashboards.yml new file mode 100644 index 0000000000000..2b675e3fb0355 --- /dev/null +++ b/terraform/templates/grafana-dashboards.yml @@ -0,0 +1,8 @@ +apiVersion: 1 + +providers: +- name: 'default' + folder: 'libra' + type: file + options: + path: /var/lib/grafana/dashboards diff --git a/terraform/templates/grafana-datasources.yml b/terraform/templates/grafana-datasources.yml new file mode 100644 index 0000000000000..e529a1ab5127c --- /dev/null +++ b/terraform/templates/grafana-datasources.yml @@ -0,0 +1,8 @@ +apiVersion: 1 + +datasources: + - name: Prometheus + type: prometheus + isDefault: true + access: proxy + url: http://${ip}:9090 diff --git a/terraform/templates/prometheus.json b/terraform/templates/prometheus.json new file mode 100644 index 0000000000000..fbf14c1e8ab4d --- /dev/null +++ b/terraform/templates/prometheus.json @@ -0,0 +1,59 @@ +[ + { + "name": "prometheus", + "image": "${prometheus_image}", + "cpu": 1280, + "memory": 2088, + "essential": true, + "portMappings": [ + {"containerPort": 9090, "hostPort": 9090} + ], + "mountPoints": [ + {"sourceVolume": "prometheus-data", "containerPath": "/prometheus"}, + {"sourceVolume": "prometheus-config", "containerPath": "/etc/prometheus/prometheus.yml"}, + {"sourceVolume": "prometheus-alerting-rules", "containerPath": "/etc/prometheus/alerting_rules"}, + {"sourceVolume": "prometheus-consoles", "containerPath": "/usr/share/prometheus/consoles"}, + {"sourceVolume": "prometheus-console-libs", "containerPath": "/usr/share/prometheus/console_libs"} + ], + "command": [ + "--config.file=/etc/prometheus/prometheus.yml", + "--storage.tsdb.path=/prometheus", + "--web.console.libraries=/usr/share/prometheus/console_libs", + "--web.console.templates=/usr/share/prometheus/consoles", + "--web.enable-lifecycle" + ] + }, + { + "name": "alertmanager", + "image": "${alertmanager_image}", + "cpu": 384, + "memory": 768, + "essential": true, + "portMappings": [ + {"containerPort": 9093, "hostPort": 9093} + ], + "mountPoints": [ + {"sourceVolume": "alertmanager-data", "containerPath": "/alertmanager"}, + {"sourceVolume": "alertmanager-config", "containerPath": "/etc/alertmanager/alertmanager.yml"} + ], + "command": [ + "--config.file=/etc/alertmanager/alertmanager.yml", + "--storage.path=/alertmanager" + ] + }, + { + "name": "grafana", + "image": "${grafana_image}", + "cpu": 384, + "memory": 1024, + "essential": true, + "portMappings": [ + {"containerPort": 3000, "hostPort": 9091} + ], + "mountPoints": [ + {"sourceVolume": "grafana-data", "containerPath": "/var/lib/grafana"}, + {"sourceVolume": "grafana-provisioning", "containerPath": "/etc/grafana/provisioning"}, + {"sourceVolume": "grafana-dashboards", "containerPath": "/var/lib/grafana/dashboards"} + ] + } +] diff --git a/terraform/templates/prometheus.yml b/terraform/templates/prometheus.yml new file mode 100644 index 0000000000000..6e53651ef4046 --- /dev/null +++ b/terraform/templates/prometheus.yml @@ -0,0 +1,80 @@ +# my global config +global: + scrape_interval: 15s # Set the scrape interval to every 15 seconds. Default is every 1 minute. + evaluation_interval: 15s # Evaluate rules every 15 seconds. The default is every 1 minute. + # scrape_timeout is set to the global default (10s). + +# Alertmanager configuration +alerting: + alertmanagers: + - static_configs: + - targets: + - ${monitoring_private_ip}:9093 + +# Load rules once and periodically evaluate them according to the global 'evaluation_interval'. +rule_files: + - "alerting_rules/blockchain_alerts.yml" + +# A scrape configuration containing exactly one endpoint to scrape: +# Here it's Prometheus itself. +scrape_configs: + # The job name is added as a label `job=` to any timeseries scraped from this config. + - job_name: 'prometheus' + + # metrics_path defaults to '/metrics' + # scheme defaults to 'http'. + + static_configs: + - targets: ['localhost:9090'] + labels: + workspace: '${workspace}' + role: 'prometheus' + + - job_name: 'other_nodes' + static_configs: + %{ for target in split(",", other_nodes) } + - targets: ['${element(split(":", target), 0)}:9100'] + labels: + role: '${element(split(":", target), 1)}' + workspace: '${workspace}' + %{ endfor } + relabel_configs: + - source_labels: ['__address__'] + # NOTICE: not tested with IPv6 + regex: '([0-9\.]+):\d+' + target_label: 'address' + replacement: '$1' + + - job_name: 'validator_nodes' + static_configs: + %{ for target in split(",", validator_nodes) } + - targets: ['${element(split(":", target), 0)}:9100'] + labels: + peer_id: '${substr(element(split(":", target), 1), 0, 8)}' + role: 'validator' + workspace: '${workspace}' + %{ endfor } + + relabel_configs: + - source_labels: ['__address__'] + # NOTICE: not tested with IPv6 + regex: '([0-9\.]+):\d+' + target_label: 'address' + replacement: '$1' + + - job_name: 'validators' + static_configs: + %{ for target in split(",", validator_svcs) } + - targets: ['${element(split(":", target), 0)}:14297'] + labels: + peer_id: '${substr(element(split(":", target), 1), 0, 8)}' + role: 'validator' + workspace: '${workspace}' + %{ endfor } + + relabel_configs: + - source_labels: ['__address__'] + # NOTICE: not tested with IPv6 + regex: '([0-9\.]+):\d+' + target_label: 'address' + replacement: '$1' diff --git a/terraform/templates/prometheus/alerting_rules/blockchain_alerts.yml b/terraform/templates/prometheus/alerting_rules/blockchain_alerts.yml new file mode 100644 index 0000000000000..e88101eb1fbfd --- /dev/null +++ b/terraform/templates/prometheus/alerting_rules/blockchain_alerts.yml @@ -0,0 +1,14 @@ +groups: +- name: "blockchain alerts" + rules: + - alert: High Txn Rate + expr: avg(rate(consensus{op='committed_txns_count'}[1m])) > 50 + for: 1m + labels: + severity: warning + + - alert: Low Consensus-Round Rate + expr: avg(rate(consensus_gauge{op='current_round',job='validators'}[1m])) < 0.2 + for: 1m + labels: + severity: warning diff --git a/terraform/templates/prometheus/console_libs/menu.lib b/terraform/templates/prometheus/console_libs/menu.lib new file mode 100644 index 0000000000000..17a9cf5809866 --- /dev/null +++ b/terraform/templates/prometheus/console_libs/menu.lib @@ -0,0 +1,117 @@ +{{/* vim: set ft=html: */}} + +{{/* Navbar, should be passed . */}} +{{ define "navbar" }} + +{{ end }} + +{{/* LHS menu, should be passed . */}} +{{ define "menu" }} +
+ +
+{{ end }} + +{{/* Helper, pass (args . path name) */}} +{{ define "_menuItem" }} + +{{ end }} diff --git a/terraform/templates/prometheus/console_libs/prom.lib b/terraform/templates/prometheus/console_libs/prom.lib new file mode 100644 index 0000000000000..c85fa9c4a5899 --- /dev/null +++ b/terraform/templates/prometheus/console_libs/prom.lib @@ -0,0 +1,153 @@ +{{/* vim: set ft=html: */}} +{{/* Load Prometheus console library JS/CSS. Should go in */}} +{{ define "prom_console_head" }} + + + + + + + + + + + + + +{{ end }} + +{{/* Top of all pages. */}} +{{ define "head" -}} + + + +{{ template "prom_console_head" }} + + +{{ template "navbar" . }} + +{{ template "menu" . }} +{{ end }} + +{{ define "__prom_query_drilldown_noop" }}{{ . }}{{ end }} +{{ define "humanize" }}{{ humanize . }}{{ end }} +{{ define "humanizeNoSmallPrefix" }}{{ if and (lt . 1.0) (gt . -1.0) }}{{ printf "%.3g" . }}{{ else }}{{ humanize . }}{{ end }}{{ end }} +{{ define "humanize1024" }}{{ humanize1024 . }}{{ end }} +{{ define "humanizeDuration" }}{{ humanizeDuration . }}{{ end }} +{{ define "humanizeTimestamp" }}{{ humanizeTimestamp . }}{{ end }} +{{ define "printf.1f" }}{{ printf "%.1f" . }}{{ end }} +{{ define "printf.3g" }}{{ printf "%.3g" . }}{{ end }} + +{{/* prom_query_drilldown (args expr suffix? renderTemplate?) +Displays the result of the expression, with a link to /graph for it. + +renderTemplate is the name of the template to use to render the value. +*/}} +{{ define "prom_query_drilldown" }} +{{ $expr := .arg0 }}{{ $suffix := (or .arg1 "") }}{{ $renderTemplate := (or .arg2 "__prom_query_drilldown_noop") }} +{{ with query $expr }}{{tmpl $renderTemplate ( . | first | value )}}{{ $suffix }}{{ else }}-{{ end }} +{{ end }} + +{{ define "prom_path" }}/consoles/{{ .Path }}?{{ range $param, $value := .Params }}{{ $param }}={{ $value }}&{{ end }}{{ end }}" + +{{ define "prom_right_table_head" }} +
+
+{{ end }} +{{ define "prom_right_table_tail" }} +
+ +{{ end }} + +{{/* RHS table head, pass job name. Should be used after prom_right_table_head. */}} +{{ define "prom_right_table_job_head" }} + + {{ . }} + {{ template "prom_query_drilldown" (args (printf "sum(up{job='%s'})" .)) }} / {{ template "prom_query_drilldown" (args (printf "count(up{job='%s'})" .)) }} + + + CPU + {{ template "prom_query_drilldown" (args (printf "avg by(job)(irate(process_cpu_seconds_total{job='%s'}[5m]))" .) "s/s" "humanizeNoSmallPrefix") }} + + + Memory + {{ template "prom_query_drilldown" (args (printf "avg by(job)(process_resident_memory_bytes{job='%s'})" .) "B" "humanize1024") }} + +{{ end }} + + +{{ define "prom_content_head" }} +

+
+{{ template "prom_graph_timecontrol" . }} +{{ end }} +{{ define "prom_content_tail" }} +
+
+{{ end }} + +{{ define "prom_graph_timecontrol" }} +
+
+
+ +
+
+ +
+
+
+ + + +
+
+
+ +
+{{ end }} + +{{/* Displays prometheus graph in an html table cell. Pass DOM node name, title, prometheus expression, y axis title. */}} +{{ define "graph_in_table" }} + + {{ .arg1 }} +
+ + +{{ end }} + +{{/* Bottom of all pages. */}} +{{ define "tail" }} + + +{{ end }} diff --git a/terraform/templates/prometheus/console_libs/validator_utils.lib b/terraform/templates/prometheus/console_libs/validator_utils.lib new file mode 100644 index 0000000000000..058a8ee5c1f40 --- /dev/null +++ b/terraform/templates/prometheus/console_libs/validator_utils.lib @@ -0,0 +1,30 @@ +{{ define "validators_table" }} + + + + + + + + + + {{ range query "up{job='validators'}" | sortByLabel "instance" }} + + + + Yes{{ else }} class="alert-danger">No{{ end }} + + + + + + + + + + {{ else }} + + {{ end }} +
ValidatorsUpConnected
Peers
CPU UsedMemory
Available
Revision
{{ reReplaceAll "(.*?://)([^:/]+?)(:\\d+)?/.*" "$2" .Labels.peer_id }}{{ template "prom_query_drilldown" (args (printf "network_gauge{op='connected_peers',job='validators',instance='%s'}" .Labels.instance) "") }}{{ template "prom_query_drilldown" (args (printf "100 * (1 - avg by(address)(irate(node_cpu_seconds_total{job='node_exporter',mode='idle',address='%s'}[5m])))" .Labels.address) "%" "printf.1f") }}{{ template "prom_query_drilldown" (args (printf "node_memory_MemFree_bytes{job='node_exporter',address='%s'} + node_memory_Cached_bytes{job='node_exporter',address='%s'} + node_memory_Buffers_bytes{job='node_exporter',address='%s'}" .Labels.address .Labels.address .Labels.address) "B" "humanize1024") }}{{ with query (printf "build_info{address='%s'}" .Labels.address) }}{{. | first | label "revision" | printf "%.8s"}}{{end}}
No nodes found.
+{{ end }} + diff --git a/terraform/templates/prometheus/consoles/consensus-toplevel.html b/terraform/templates/prometheus/consoles/consensus-toplevel.html new file mode 100644 index 0000000000000..4889e10118a42 --- /dev/null +++ b/terraform/templates/prometheus/consoles/consensus-toplevel.html @@ -0,0 +1,62 @@ +{{ template "head" . }} + +

Consensus Top Level Metrics

+{{ template "prom_right_table_head" }} + + Validators + {{ template "prom_query_drilldown" (args "sum(up{job='validators'})") }} / {{ template "prom_query_drilldown" (args "count(up{job='validators'})") }} + +{{ template "prom_right_table_tail" }} + +{{ template "prom_content_head" . }} + +

Validators

+{{ template "validators_table" }} + +

General Health

+ + + {{ template "graph_in_table" (args "lastCommittedRound" "Committed blocks per sec" "irate(consensus_gauge{op='last_committed_round'}[1m])" "blocks per sec") }} + + {{ template "graph_in_table" (args "lastCommittedVersion" "Committed successful txns per sec" "irate(consensus_gauge{op='last_committed_version'}[1m])" "txns per sec") }} + + {{ template "graph_in_table" (args "pendingBlocks" "Pending blocks" "consensus_gauge{op='num_blocks_in_tree'} - 1" "num pending blocks") }} + + + {{ template "graph_in_table" (args "creationToCommit" "Avg time since block creation to commit (ms)" "irate(consensus_duration_sum{op='creation_to_commit_ms'}[1m])/irate(consensus_duration_count{op='creation_to_commit_ms'}[1m])" "ms") }} + +
+ +

Pacemaker

+ + + {{ template "graph_in_table" (args "qcRoundsRate" "Rounds with QC per sec" "irate(consensus{op='qc_rounds_count'}[1m])" "rounds per sec") }} + + {{ template "graph_in_table" (args "timeoutRoundsCount" "Num of timeout rounds since restart" "consensus{op='timeout_rounds_count'}" "count timeout rounds") }} + + {{ template "graph_in_table" (args "timeoutCount" "Num of timeouts since restart" "consensus{op='timeout_count'}" "count timeouts") }} + + + {{ template "graph_in_table" (args "roundTimeoutVal" "Round timeout val (ms)" "consensus_gauge{op='round_timeout_ms'}" "timeout val (ms)") }} + +
+ +

Synchronization Manager

+ + + {{ template "graph_in_table" (args "syncMgrCount" "State sync since restart" "consensus{op='state_sync_count'}" "count") }} + + {{ template "graph_in_table" (args "txnsReplayedCount" "Txns replayed in state synchronization" "consensus{op='state_sync_txns_replayed'}" "txns count") }} + + {{ template "graph_in_table" (args "syncDuration" "Total time spent on state sync (ms)" "consensus_duration_sum{op='state_sync_duration_ms'}" "total time spent (ms)") }} + + + {{ template "graph_in_table" (args "blockRetrievalCount" "Block retrieval requests since restart" "consensus{op='block_retrieval_count'}" "count") }} + + {{ template "graph_in_table" (args "blockRetrievalDuration" "Total time spent on block retrieval (ms)" "consensus_duration_sum{op='block_retrieval_duration_ms'}" "total time spent (ms)") }} + +
+ +{{ template "prom_content_tail" . }} + +{{ template "tail" }} diff --git a/terraform/templates/prometheus/consoles/index.html b/terraform/templates/prometheus/consoles/index.html new file mode 100644 index 0000000000000..91a19536fc242 --- /dev/null +++ b/terraform/templates/prometheus/consoles/index.html @@ -0,0 +1,32 @@ +{{ template "head" . }} + +{{ template "prom_right_table_head" }} +{{ template "prom_right_table_tail" }} + +{{ template "prom_content_head" . }} +

Overview

+

These are basic consoles for Prometheus on Libra Validators.

+ +

These consoles expect exporters to have the following job labels:

+ + + + + + + + + + + + + + + + + +
ExporterJob label
Validatorvalidators
Node Exportervalidator_nodes
Prometheusprometheus
+ +{{ template "prom_content_tail" . }} + +{{ template "tail" }} diff --git a/terraform/templates/prometheus/consoles/network-toplevel.html b/terraform/templates/prometheus/consoles/network-toplevel.html new file mode 100644 index 0000000000000..4fb1ef066ef07 --- /dev/null +++ b/terraform/templates/prometheus/consoles/network-toplevel.html @@ -0,0 +1,38 @@ +{{ template "head" . }} + +

Network Top Level Metrics

+{{ template "prom_right_table_head" }} + + Validators + {{ template "prom_query_drilldown" (args "sum(up{job='validators'})") }} / {{ template "prom_query_drilldown" (args "count(up{job='validators'})") }} + +{{ template "prom_right_table_tail" }} + +{{ template "prom_content_head" . }} + +

Validators

+{{ template "validators_table" }} + +

Network connectivity

+ + + {{ template "graph_in_table" (args "connectedPeers" "Connected Peers" "network_gauge{op='connected_peers'}" "num peers") }} + + {{ template "graph_in_table" (args "outboundMsgQueue" "Pending Outbound Messages" "network_gauge{op='pending_direct_send_outbound_messages'" "num messages") }} + + {{ template "graph_in_table" (args "outboundRpcQueue" "Pending Outbound RPCs" "network_gauge{op='pending_rpc_requests'" "num RPCs") }} + +
+ +

Intra-node channels

+ + + {{ template "graph_in_table" (args "network2mempool" "Network to Mempool" "network_gauge{op='pending_network_mempool_events}" "num pending events") }} + + {{ template "graph_in_table" (args "network2consensus" "Network to Consensus" "network_gauge{op='pending_network_consensus_events'}" "num pending events") }} + +
+ +{{ template "prom_content_tail" . }} + +{{ template "tail" }} diff --git a/terraform/templates/prometheus/consoles/node-cpu.html b/terraform/templates/prometheus/consoles/node-cpu.html new file mode 100644 index 0000000000000..b79768014d4f6 --- /dev/null +++ b/terraform/templates/prometheus/consoles/node-cpu.html @@ -0,0 +1,60 @@ +{{ template "head" . }} + +{{ template "prom_right_table_head" }} + + CPU(s): {{ template "prom_query_drilldown" (args (printf "scalar(count(count by (cpu)(node_cpu_seconds_total{job='validator_nodes',instance='%s'})))" .Params.instance)) }} + +{{ range printf "sum by (mode)(irate(node_cpu_seconds_total{job='validator_nodes',instance='%s'}[5m])) * 100 / scalar(count(count by (cpu)(node_cpu_seconds_total{job='validator_nodes',instance='%s'})))" .Params.instance .Params.instance | query | sortByLabel "mode" }} + + {{ .Labels.mode | title }} CPU + {{ .Value | printf "%.1f" }}% + +{{ end }} + Misc + + Processes Running + {{ template "prom_query_drilldown" (args (printf "node_procs_running{job='validator_nodes',instance='%s'}" .Params.instance) "" "humanize") }} + + + Processes Blocked + {{ template "prom_query_drilldown" (args (printf "node_procs_blocked{job='validator_nodes',instance='%s'}" .Params.instance) "" "humanize") }} + + + Forks + {{ template "prom_query_drilldown" (args (printf "irate(node_forks_total{job='validator_nodes',instance='%s'}[5m])" .Params.instance) "/s" "humanize") }} + + + Context Switches + {{ template "prom_query_drilldown" (args (printf "irate(node_context_switches_total{job='validator_nodes',instance='%s'}[5m])" .Params.instance) "/s" "humanize") }} + + + Interrupts + {{ template "prom_query_drilldown" (args (printf "irate(node_intr_total{job='validator_nodes',instance='%s'}[5m])" .Params.instance) "/s" "humanize") }} + + + 1m Loadavg + {{ template "prom_query_drilldown" (args (printf "node_load1{job='validator_nodes',instance='%s'}" .Params.instance)) }} + + + +{{ template "prom_right_table_tail" }} + +{{ template "prom_content_head" . }} +

Node CPU - {{ reReplaceAll "(.*?://)([^:/]+?)(:\\d+)?/.*" "$2" .Params.instance }}

+ +

CPU Usage

+
+ +{{ template "prom_content_tail" . }} + +{{ template "tail" }} diff --git a/terraform/templates/prometheus/consoles/node-disk.html b/terraform/templates/prometheus/consoles/node-disk.html new file mode 100644 index 0000000000000..67c83dea79c38 --- /dev/null +++ b/terraform/templates/prometheus/consoles/node-disk.html @@ -0,0 +1,77 @@ +{{ template "head" . }} + +{{ template "prom_content_head" . }} +

Node Disk - {{ reReplaceAll "(.*?://)([^:/]+?)(:\\d+)?/.*" "$2" .Params.instance }}

+ +

Disk I/O Utilization

+
+ +

Filesystem Usage

+
+ + +{{ template "prom_right_table_head" }} + Disks + +{{ range printf "node_disk_io_time_seconds_total{job='validator_nodes',instance='%s'}" .Params.instance | query | sortByLabel "device" }} + {{ .Labels.device }} + + Utilization + {{ template "prom_query_drilldown" (args (printf "irate(node_disk_io_time_seconds_total{job='validator_nodes',instance='%s',device='%s'}[5m]) * 100" .Labels.instance .Labels.device) "%" "printf.1f") }} + + + Throughput + {{ template "prom_query_drilldown" (args (printf "irate(node_disk_read_bytes_total{job='validator_nodes',instance='%s',device='%s'}[5m]) + irate(node_disk_written_bytes_total{job='validator_nodes',instance='%s',device='%s'}[5m])" .Labels.instance .Labels.device .Labels.instance .Labels.device) "B/s" "humanize") }} + + + Avg Read Time + {{ template "prom_query_drilldown" (args (printf "irate(node_disk_read_time_seconds_total{job='validator_nodes',instance='%s',device='%s'}[5m]) / irate(node_disk_reads_completed_total{job='validator_nodes',instance='%s',device='%s'}[5m])" .Labels.instance .Labels.device .Labels.instance .Labels.device) "s" "humanize") }} + + + Avg Write Time + {{ template "prom_query_drilldown" (args (printf "irate(node_disk_write_time_seconds_total{job='validator_nodes',instance='%s',device='%s'}[5m]) / irate(node_disk_writes_completed_total{job='validator_nodes',instance='%s',device='%s'}[5m])" .Labels.instance .Labels.device .Labels.instance .Labels.device) "s" "humanize") }} + +{{ end }} + Filesystem Fullness + +{{ define "roughlyNearZero" }} +{{ if gt .1 . }}~0{{ else }}{{ printf "%.1f" . }}{{ end }} +{{ end }} +{{ range printf "node_filesystem_size_bytes{job='validator_nodes',instance='%s'}" .Params.instance | query | sortByLabel "mountpoint" }} + + {{ .Labels.mountpoint }} + {{ template "prom_query_drilldown" (args (printf "100 - node_filesystem_avail_bytes{job='validator_nodes',instance='%s',mountpoint='%s'} / node_filesystem_size_bytes{job='validator_nodes'} * 100" .Labels.instance .Labels.mountpoint) "%" "roughlyNearZero") }} + +{{ end }} + + +{{ template "prom_right_table_tail" }} + +{{ template "prom_content_tail" . }} + +{{ template "tail" }} diff --git a/terraform/templates/prometheus/consoles/node-overview.html b/terraform/templates/prometheus/consoles/node-overview.html new file mode 100644 index 0000000000000..6afa616542794 --- /dev/null +++ b/terraform/templates/prometheus/consoles/node-overview.html @@ -0,0 +1,122 @@ +{{ template "head" . }} + +{{ template "prom_content_head" . }} +

Node Overview - {{ reReplaceAll "(.*?://)([^:/]+?)(:\\d+)?/.*" "$2" .Params.instance }}

+ +

CPU Usage

+
+ + +

Disk I/O Utilization

+
+ + +

Memory

+
+ + +{{ template "prom_right_table_head" }} + Overview + + User CPU + {{ template "prom_query_drilldown" (args (printf "sum(irate(node_cpu_seconds_total{job='validator_nodes',instance='%s',mode='user'}[5m])) * 100 / count(count by (cpu)(node_cpu_seconds_total{job='validator_nodes',instance='%s'}))" .Params.instance .Params.instance) "%" "printf.1f") }} + + + System CPU + {{ template "prom_query_drilldown" (args (printf "sum(irate(node_cpu_seconds_total{job='validator_nodes',instance='%s',mode='system'}[5m])) * 100 / count(count by (cpu)(node_cpu_seconds_total{job='validator_nodes',instance='%s'}))" .Params.instance .Params.instance) "%" "printf.1f") }} + + + Memory Total + {{ template "prom_query_drilldown" (args (printf "node_memory_MemTotal_bytes{job='validator_nodes',instance='%s'}" .Params.instance) "B" "humanize1024") }} + + + Memory Free + {{ template "prom_query_drilldown" (args (printf "node_memory_MemFree_bytes{job='validator_nodes',instance='%s'}" .Params.instance) "B" "humanize1024") }} + + + Network + +{{ range printf "node_network_receive_bytes_total{job='validator_nodes',instance='%s',device!='lo'}" .Params.instance | query | sortByLabel "device" }} + + {{ .Labels.device }} Received + {{ template "prom_query_drilldown" (args (printf "irate(node_network_receive_bytes_total{job='validator_nodes',instance='%s',device='%s'}[5m])" .Labels.instance .Labels.device) "B/s" "humanize") }} + + + {{ .Labels.device }} Transmitted + {{ template "prom_query_drilldown" (args (printf "irate(node_network_transmit_bytes_total{job='validator_nodes',instance='%s',device='%s'}[5m])" .Labels.instance .Labels.device) "B/s" "humanize") }} + +{{ end }} + + + Disks + +{{ range printf "node_disk_io_time_seconds_total{job='validator_nodes',instance='%s',device!~'^(md\\\\d+$|dm-)'}" .Params.instance | query | sortByLabel "device" }} + + {{ .Labels.device }} Utilization + {{ template "prom_query_drilldown" (args (printf "irate(node_disk_io_time_seconds_total{job='validator_nodes',instance='%s',device='%s'}[5m]) * 100" .Labels.instance .Labels.device) "%" "printf.1f") }} + +{{ end }} +{{ range printf "node_disk_io_time_seconds_total{job='validator_nodes',instance='%s'}" .Params.instance | query | sortByLabel "device" }} + + {{ .Labels.device }} Throughput + {{ template "prom_query_drilldown" (args (printf "irate(node_disk_read_bytes_total{job='validator_nodes',instance='%s',device='%s'}[5m]) + irate(node_disk_written_bytes_total{job='validator_nodes',instance='%s',device='%s'}[5m])" .Labels.instance .Labels.device .Labels.instance .Labels.device) "B/s" "humanize") }} + +{{ end }} + + Filesystem Fullness + +{{ define "roughlyNearZero" }} +{{ if gt .1 . }}~0{{ else }}{{ printf "%.1f" . }}{{ end }} +{{ end }} +{{ range printf "node_filesystem_size_bytes{job='validator_nodes',instance='%s'}" .Params.instance | query | sortByLabel "mountpoint" }} + + {{ .Labels.mountpoint }} + {{ template "prom_query_drilldown" (args (printf "100 - node_filesystem_avail_bytes{job='validator_nodes',instance='%s',mountpoint='%s'} / node_filesystem_size_bytes{job='validator_nodes'} * 100" .Labels.instance .Labels.mountpoint) "%" "roughlyNearZero") }} + +{{ end }} + +{{ template "prom_right_table_tail" }} + +{{ template "prom_content_tail" . }} + +{{ template "tail" }} diff --git a/terraform/templates/prometheus/consoles/node.html b/terraform/templates/prometheus/consoles/node.html new file mode 100644 index 0000000000000..315dbedddeb4b --- /dev/null +++ b/terraform/templates/prometheus/consoles/node.html @@ -0,0 +1,34 @@ +{{ template "head" . }} + +{{ template "prom_right_table_head" }} + + Node + {{ template "prom_query_drilldown" (args "sum(up{job='validator_nodes'})") }} / {{ template "prom_query_drilldown" (args "count(up{job='validator_nodes'})") }} + +{{ template "prom_right_table_tail" }} + +{{ template "prom_content_head" . }} +

Node

+ + + + + + + + +{{ range query "up{job='validator_nodes'}" | sortByLabel "instance" }} + + + Yes{{ else }} class="alert-danger">No{{ end }} + + + +{{ else }} + +{{ end }} + + +{{ template "prom_content_tail" . }} + +{{ template "tail" }} diff --git a/terraform/templates/prometheus/consoles/performance.html b/terraform/templates/prometheus/consoles/performance.html new file mode 100644 index 0000000000000..38f7061daef5f --- /dev/null +++ b/terraform/templates/prometheus/consoles/performance.html @@ -0,0 +1,57 @@ +{{ template "head" . }} + +

Performance deep dive

+{{ template "prom_right_table_head" }} + + + + +{{ template "prom_right_table_tail" }} + +{{ template "prom_content_head" . }} + +

Validators

+{{ template "validators_table" }} + +

High level values

+
NodeUpCPU
Used
Memory
Available
{{ reReplaceAll "(.*?://)([^:/]+?)(:\\d+)?/.*" "$2" .Labels.instance }}{{ template "prom_query_drilldown" (args (printf "100 * (1 - avg by(instance)(irate(node_cpu_seconds_total{job='validator_nodes',mode='idle',instance='%s'}[5m])))" .Labels.instance) "%" "printf.1f") }}{{ template "prom_query_drilldown" (args (printf "node_memory_MemFree_bytes{job='validator_nodes',instance='%s'} + node_memory_Cached_bytes{job='validator_nodes',instance='%s'} + node_memory_Buffers_bytes{job='validator_nodes',instance='%s'}" .Labels.instance .Labels.instance .Labels.instance) "B" "humanize1024") }}
No nodes found.
Validators{{ template "prom_query_drilldown" (args "sum(up{job='validators'})") }} / {{ template "prom_query_drilldown" (args "count(up{job='validators'})") }}
+ + {{ template "graph_in_table" (args "lastCommittedRound" "Avg e2e latency (ms)" "irate(mempool_duration_sum{op='e2e.latency'}[1m])/irate(mempool_duration_count{op='e2e.latency'}[1m])" "blocks per sec") }} + + {{ template "graph_in_table" (args "lastCommittedVersion" "Committed successful txns per sec" "irate(consensus_gauge{op='last_committed_version'}[1m])" "txns per sec") }} + + {{ template "graph_in_table" (args "numTxnsPerBlock" "Avg num txns per block" "irate(consensus_duration_sum{op='num_txns_per_block'}[1m])/irate(consensus_duration_count{op='num_txns_per_block'}[1m])" "txns per block") }} + +
+ +

Storage and execution

+ + + {{ template "graph_in_table" (args "txnExecutionTime" "Avg txn execution time (ms)" "irate(consensus_duration_sum{op='txn_execution_duration_ms'}[1m])/irate(consensus_duration_count{op='txn_execution_duration_ms'}[1m])" "ms") }} + + {{ template "graph_in_table" (args "blockExecutionTime" "Avg block execution time (ms)" "irate(consensus_duration_sum{op='block_execution_duration_ms'}[1m])/irate(consensus_duration_count{op='block_execution_duration_ms'}[1m])" "ms") }} + + {{ template "graph_in_table" (args "blockCommitTime" "Avg commit (store) time (ms)" "irate(consensus_duration_sum{op='block_commit_duration_ms'}[1m])/irate(consensus_duration_count{op='block_commit_duration_ms'}[1m])" "ms") }} + +
+ +

AC and Mempool

+ + + {{ template "graph_in_table" (args "preConsensusTime" "Avg time spent before consensus (ms)" "irate(consensus_duration_sum{op='txn_pre_consensus_ms'}[1m])/irate(consensus_duration_count{op='txn_pre_consensus_ms'}[1m])" "ms") }} + +
+ +

Consensus counters

+ + + {{ template "graph_in_table" (args "creationToCommit" "Avg time since block creation to commit (ms)" "irate(consensus_duration_sum{op='creation_to_commit_ms'}[1m])/irate(consensus_duration_count{op='creation_to_commit_ms'}[1m])" "ms") }} + + {{ template "graph_in_table" (args "creationToQC" "Avg time since block creation to QC (ms)" "irate(consensus_duration_sum{op='creation_to_qc_ms'}[1m])/irate(consensus_duration_count{op='creation_to_qc_ms'}[1m])" "ms") }} + +
+ + +{{ template "prom_content_tail" . }} + +{{ template "tail" }} diff --git a/terraform/templates/prometheus/consoles/prometheus-overview.html b/terraform/templates/prometheus/consoles/prometheus-overview.html new file mode 100644 index 0000000000000..83e44b2624fc6 --- /dev/null +++ b/terraform/templates/prometheus/consoles/prometheus-overview.html @@ -0,0 +1,96 @@ +{{ template "head" . }} + +{{ template "prom_right_table_head" }} + + Overview + + + CPU + {{ template "prom_query_drilldown" (args (printf "irate(process_cpu_seconds_total{job='prometheus',instance='%s'}[5m])" .Params.instance) "s/s" "humanizeNoSmallPrefix") }} + + + Memory + {{ template "prom_query_drilldown" (args (printf "process_resident_memory_bytes{job='prometheus',instance='%s'}" .Params.instance) "B" "humanize1024") }} + + + Version + {{ with query (printf "prometheus_build_info{job='prometheus',instance='%s'}" .Params.instance) }}{{. | first | label "version"}}{{end}} + + + + Storage + + + Ingested Samples + {{ template "prom_query_drilldown" (args (printf "irate(prometheus_tsdb_head_samples_appended_total{job='prometheus',instance='%s'}[5m])" .Params.instance) "/s" "humanizeNoSmallPrefix") }} + + + Head Series + {{ template "prom_query_drilldown" (args (printf "prometheus_tsdb_head_series{job='prometheus',instance='%s'}" .Params.instance) "" "humanize") }} + + + Blocks Loaded + {{ template "prom_query_drilldown" (args (printf "prometheus_tsdb_blocks_loaded{job='prometheus',instance='%s'}" .Params.instance) "" "humanize") }} + + + Rules + + + Evaluation Duration + {{ template "prom_query_drilldown" (args (printf "irate(prometheus_evaluator_duration_seconds_sum{job='prometheus',instance='%s'}[5m]) / irate(prometheus_evaluator_duration_seconds_count{job='prometheus',instance='%s'}[5m])" .Params.instance .Params.instance) "" "humanizeDuration") }} + + + Notification Latency + {{ template "prom_query_drilldown" (args (printf "irate(prometheus_notifications_latency_seconds_sum{job='prometheus',instance='%s'}[5m]) / irate(prometheus_notifications_latency_seconds_count{job='prometheus',instance='%s'}[5m])" .Params.instance .Params.instance) "" "humanizeDuration") }} + + + Notification Queue + {{ template "prom_query_drilldown" (args (printf "prometheus_notifications_queue_length{job='prometheus',instance='%s'}" .Params.instance) "" "humanize") }} + + + HTTP Server + +{{ range printf "http_request_duration_microseconds_count{job='prometheus',instance='%s',handler=~'^(query.*|federate|consoles)$'}" .Params.instance | query | sortByLabel "handler" }} + + {{ .Labels.handler }} + {{ template "prom_query_drilldown" (args (printf "irate(http_request_duration_microseconds_count{job='prometheus',instance='%s',handler='%s'}[5m])" .Labels.instance .Labels.handler) "/s" "humanizeNoSmallPrefix") }} + +{{ end }} + +{{ template "prom_right_table_tail" }} + +{{ template "prom_content_head" . }} +
+

Prometheus Overview - {{ .Params.instance }}

+ +

Ingested Samples

+
+ + +

HTTP Server

+
+ +
+{{ template "prom_content_tail" . }} + +{{ template "tail" }} diff --git a/terraform/templates/prometheus/consoles/prometheus.html b/terraform/templates/prometheus/consoles/prometheus.html new file mode 100644 index 0000000000000..e0d026376d1d8 --- /dev/null +++ b/terraform/templates/prometheus/consoles/prometheus.html @@ -0,0 +1,34 @@ +{{ template "head" . }} + +{{ template "prom_right_table_head" }} + + Prometheus + {{ template "prom_query_drilldown" (args "sum(up{job='prometheus'})") }} / {{ template "prom_query_drilldown" (args "count(up{job='prometheus'})") }} + +{{ template "prom_right_table_tail" }} + +{{ template "prom_content_head" . }} +

Prometheus

+ + + + + + + + +{{ range query "up{job='prometheus'}" | sortByLabel "instance" }} + + + + + + +{{ else }} + +{{ end }} +
PrometheusUpIngested SamplesMemory
{{ .Labels.instance }}Yes{{ else }} class="alert-danger">No{{ end }}{{ template "prom_query_drilldown" (args (printf "irate(prometheus_tsdb_head_samples_appended_total{job='prometheus',instance='%s'}[5m])" .Labels.instance) "/s" "humanizeNoSmallPrefix") }}{{ template "prom_query_drilldown" (args (printf "process_resident_memory_bytes{job='prometheus',instance='%s'}" .Labels.instance) "B" "humanize1024")}}
No devices found.
+ +{{ template "prom_content_tail" . }} + +{{ template "tail" }} diff --git a/terraform/templates/prometheus/consoles/validator-overview.html b/terraform/templates/prometheus/consoles/validator-overview.html new file mode 100644 index 0000000000000..0a00b45c1f3e2 --- /dev/null +++ b/terraform/templates/prometheus/consoles/validator-overview.html @@ -0,0 +1,122 @@ +{{ template "head" . }} + +{{ template "prom_content_head" . }} +

Node Overview - {{ reReplaceAll "(.*?://)([^:/]+?)(:\\d+)?/.*" "$2" .Params.address }}

+ +

CPU Usage

+
+ + +

Disk I/O Utilization

+
+ + +

Memory

+
+ + +{{ template "prom_right_table_head" }} + Overview + + User CPU + {{ template "prom_query_drilldown" (args (printf "sum(irate(node_cpu_seconds_total{job='validator_nodes',instance='%s',mode='user'}[5m])) * 100 / count(count by (cpu)(node_cpu_seconds_total{job='validator_nodes',instance='%s'}))" .Params.address .Params.address) "%" "printf.1f") }} + + + System CPU + {{ template "prom_query_drilldown" (args (printf "sum(irate(node_cpu_seconds_total{job='validator_nodes',instance='%s',mode='system'}[5m])) * 100 / count(count by (cpu)(node_cpu_seconds_total{job='validator_nodes',instance='%s'}))" .Params.address .Params.address) "%" "printf.1f") }} + + + Memory Total + {{ template "prom_query_drilldown" (args (printf "node_memory_MemTotal_bytes{job='validator_nodes',instance='%s'}" .Params.address) "B" "humanize1024") }} + + + Memory Free + {{ template "prom_query_drilldown" (args (printf "node_memory_MemFree_bytes{job='validator_nodes',instance='%s'}" .Params.address) "B" "humanize1024") }} + + + Network + +{{ range printf "node_network_receive_bytes_total{job='validator_nodes',instance='%s',device!='lo'}" .Params.address | query | sortByLabel "device" }} + + {{ .Labels.device }} Received + {{ template "prom_query_drilldown" (args (printf "irate(node_network_receive_bytes_total{job='validator_nodes',instance='%s',device='%s'}[5m])" .Labels.instance .Labels.device) "B/s" "humanize") }} + + + {{ .Labels.device }} Transmitted + {{ template "prom_query_drilldown" (args (printf "irate(node_network_transmit_bytes_total{job='validator_nodes',instance='%s',device='%s'}[5m])" .Labels.instance .Labels.device) "B/s" "humanize") }} + +{{ end }} + + + Disks + +{{ range printf "node_disk_io_time_seconds_total{job='validator_nodes',instance='%s',device!~'^(md\\\\d+$|dm-)'}" .Params.address | query | sortByLabel "device" }} + + {{ .Labels.device }} Utilization + {{ template "prom_query_drilldown" (args (printf "irate(node_disk_io_time_seconds_total{job='validator_nodes',instance='%s',device='%s'}[5m]) * 100" .Labels.instance .Labels.device) "%" "printf.1f") }} + +{{ end }} +{{ range printf "node_disk_io_time_seconds_total{job='validator_nodes',instance='%s'}" .Params.address | query | sortByLabel "device" }} + + {{ .Labels.device }} Throughput + {{ template "prom_query_drilldown" (args (printf "irate(node_disk_read_bytes_total{job='validator_nodes',instance='%s',device='%s'}[5m]) + irate(node_disk_written_bytes_total{job='validator_nodes',instance='%s',device='%s'}[5m])" .Labels.instance .Labels.device .Labels.instance .Labels.device) "B/s" "humanize") }} + +{{ end }} + + Filesystem Fullness + +{{ define "roughlyNearZero" }} +{{ if gt .1 . }}~0{{ else }}{{ printf "%.1f" . }}{{ end }} +{{ end }} +{{ range printf "node_filesystem_size_bytes{job='validator_nodes',instance='%s'}" .Params.address | query | sortByLabel "mountpoint" }} + + {{ .Labels.mountpoint }} + {{ template "prom_query_drilldown" (args (printf "100 - node_filesystem_avail_bytes{job='validator_nodes',instance='%s',mountpoint='%s'} / node_filesystem_size_bytes{job='validator_nodes'} * 100" .Labels.instance .Labels.mountpoint) "%" "roughlyNearZero") }} + +{{ end }} + +{{ template "prom_right_table_tail" }} + +{{ template "prom_content_tail" . }} + +{{ template "tail" }} diff --git a/terraform/templates/prometheus/consoles/validator-toplevel.html b/terraform/templates/prometheus/consoles/validator-toplevel.html new file mode 100644 index 0000000000000..b0bca6084cdbe --- /dev/null +++ b/terraform/templates/prometheus/consoles/validator-toplevel.html @@ -0,0 +1,56 @@ +{{ define "graph" }} +

{{ .arg0 }}

+
+ +
+{{ end }} + +{{ template "head" . }} + +{{ template "prom_content_head" . }} +

Validators Top Level System {{ reReplaceAll "(.*?://)([^:/]+?)(:\\d+)?/.*" "$2" .Params.instance }}

+ +

Submit Txn Rate (AC)

+
+ +
+ +

CPU Usage

+
+ +
+ + {{ template "graph" (args "Txn per Block " "ccTxnBlock" "avg(rate(consensus{op='committed_txns_count'}[1m]))/avg(rate(consensus{op='committed_blocks_count'}[1m]))" "Txn per Block") }} + + {{ template "graph" (args "Finality (MP)" "mpFinality" "avg(rate(mempool_duration_sum{op='e2e.latency'}[1m])/rate(mempool_duration_count{op='e2e.latency'}[1m]))" "ms(avg)") }} + + {{ template "graph" (args "Round" "ccRound" "avg(consensus_gauge{op='current_round'}) by (peer_id)" "Round") }} + +{{ template "prom_content_tail" . }} + +{{ template "tail" }} diff --git a/terraform/templates/prometheus/consoles/validators.html b/terraform/templates/prometheus/consoles/validators.html new file mode 100644 index 0000000000000..6f3eff060807a --- /dev/null +++ b/terraform/templates/prometheus/consoles/validators.html @@ -0,0 +1,89 @@ +{{ template "head" . }} + +{{ template "prom_right_table_head" }} + + Validators + {{ template "prom_query_drilldown" (args "sum(up{job='validators'})") }} / {{ template "prom_query_drilldown" (args "count(up{job='validators'})") }} + +{{ template "prom_right_table_tail" }} + +{{ template "prom_content_head" . }} +

Validators Overview

+ + + + + + + + + + + + +{{ range query "up{job='validators'}" | sortByLabel "instance" }} + + + + Yes{{ else }} class="alert-danger">No{{ end }} + + + + + + + + + + + + +{{ else }} + +{{ end }} + + +
ValidatorsUpConnected
Peers
Round NumberUptimeCPU UsedMemory
Available
Revision
{{ reReplaceAll "(.*?://)([^:/]+?)(:\\d+)?/.*" "$2" .Labels.peer_id }}{{ template "prom_query_drilldown" (args (printf "network_gauge{op='connected_peers',job='validators',instance='%s'}" .Labels.instance) "") }} Total: {{ template "prom_query_drilldown" (args (printf "consensus_gauge{op='current_round',job='validators',instance='%s'}" .Labels.instance) "") }}
+ Rate: {{ template "prom_query_drilldown" (args (printf "rate(consensus_gauge{op='current_round',job='validators',instance='%s'}[1m])" .Labels.instance) "" "printf.1f") }}
Instance: {{ template "prom_query_drilldown" (args (printf "time() - node_boot_time_seconds{job='validator_nodes',address='%s'}" .Labels.address) "" "humanizeDuration") }} +
+ Container: {{ template "prom_query_drilldown" (args (printf "time() - ecs_start_time_seconds{job='validator_nodes',address='%s'}" .Labels.address) "" "humanizeDuration") }} +
{{ template "prom_query_drilldown" (args (printf "100 * (1 - avg by(address)(irate(node_cpu_seconds_total{job='validator_nodes',mode='idle',address='%s'}[5m])))" .Labels.address) "%" "printf.1f") }}{{ template "prom_query_drilldown" (args (printf "node_memory_MemFree_bytes{job='validator_nodes',address='%s'} + node_memory_Cached_bytes{job='validator_nodes',address='%s'} + node_memory_Buffers_bytes{job='validator_nodes',address='%s'}" .Labels.address .Labels.address .Labels.address) "B" "humanize1024") }}{{ with query (printf "build_info{address='%s'}" .Labels.address) }}{{. | first | label "revision" | printf "%.8s"}}{{end}}
No nodes found.
+ + + + + + + + + +
Transactions rate (Consensus) Who's the proposer? (Consensus)
+ + + +
+ + +{{ template "prom_content_tail" . }} + +{{ template "tail" }} diff --git a/terraform/templates/seed_peers.config.toml b/terraform/templates/seed_peers.config.toml new file mode 100644 index 0000000000000..184f94e9b80d4 --- /dev/null +++ b/terraform/templates/seed_peers.config.toml @@ -0,0 +1,4 @@ +[seed_peers] +%{ for validator in split(",", validators) ~} +${element(split(":", validator), 0)} = ["/ip4/${element(split(":", validator), 1)}/tcp/30303"] +%{ endfor ~} diff --git a/terraform/templates/validator.json b/terraform/templates/validator.json new file mode 100644 index 0000000000000..45a1f72e10f36 --- /dev/null +++ b/terraform/templates/validator.json @@ -0,0 +1,39 @@ +[ + { + "name": "validator", + "image": "${image}${image_version}", + "cpu": ${cpu}, + "memory": ${mem}, + "essential": true, + "portMappings": [ + {"containerPort": 30307, "hostPort": 30307}, + {"containerPort": 30303, "hostPort": 30303}, + {"containerPort": 14297, "hostPort": 14297} + ], + "mountPoints": [ + {"sourceVolume": "libra-data", "containerPath": "/opt/libra/data"} + ], + "environment": [ + {"name": "PEER_ID", "value": "${peer_id}"}, + {"name": "SELF_IP", "value": "${self_ip}"}, + {"name": "SEED_PEERS", "value": ${seed_peers}}, + {"name": "TRUSTED_PEERS", "value": ${trusted_peers}}, + {"name": "GENESIS_BLOB", "value": ${genesis_blob}}, + {"name": "RUST_LOG", "value": "${log_level}"} + ], + "ulimits": [ + {"name": "nofile", "softLimit": 131072, "hardLimit": 131072} + ], + "secrets": [ + {"name": "PEER_KEYPAIRS", "valueFrom": "${secret}"} + ], + "logConfiguration": { + "logDriver": "awslogs", + "options": { + "awslogs-group": "${log_group}", + "awslogs-region": "${log_region}", + "awslogs-stream-prefix": "${log_prefix}" + } + } + } +] diff --git a/terraform/terraform.tfvars b/terraform/terraform.tfvars new file mode 100644 index 0000000000000..4b5fc9ae2f24d --- /dev/null +++ b/terraform/terraform.tfvars @@ -0,0 +1 @@ +peer_ids = ["8deeeaed65f0cd7484a9e4e5ac51fbac548f2f71299a05e000156031ca78fb9f", "1e5d5a74b0fd09f601ac0fca2fe7d213704e02e51943d18cf25a546b8416e9e1", "ab0d6a54ce9d7fc79c061f95883a308f9bdfc987262b6a34a360fdd788fcd9cd", "57ff83747054695f2228042c26eb6a243ac73de1b9038aea103999480b076d45"] diff --git a/terraform/validator-sets/build.sh b/terraform/validator-sets/build.sh new file mode 100755 index 0000000000000..01d167168ab21 --- /dev/null +++ b/terraform/validator-sets/build.sh @@ -0,0 +1,20 @@ +#!/bin/sh +set -e + +OUTDIR="${1?[Specify relative output directory]}" +shift + +mkdir -p "$OUTDIR" + +cd ../.. + +if [ ! -e "../setup_scripts/terraform/testnet/validator-sets/$OUTDIR/mint.key" ]; then + cargo run --bin generate_keypair -- -o "../setup_scripts/terraform/testnet/validator-sets/$OUTDIR/mint.key" +fi + +cargo run --bin libra-config -- -b config/data/configs/node.config.toml -m "../setup_scripts/terraform/testnet/validator-sets/$OUTDIR/mint.key" -o "../setup_scripts/terraform/testnet/validator-sets/$OUTDIR" -d "$@" # -r config/data/configs/overrides/testnet.node.config.override.toml + +cd - +cd $OUTDIR +ls *.node.config.toml | head -n1 | xargs -I{} mv {} node.config.toml +rm *.node.config.toml diff --git a/terraform/validator-sets/dev/1e5d5a74b0fd09f601ac0fca2fe7d213704e02e51943d18cf25a546b8416e9e1.node.keys.toml b/terraform/validator-sets/dev/1e5d5a74b0fd09f601ac0fca2fe7d213704e02e51943d18cf25a546b8416e9e1.node.keys.toml new file mode 100644 index 0000000000000..1c62320c11e86 --- /dev/null +++ b/terraform/validator-sets/dev/1e5d5a74b0fd09f601ac0fca2fe7d213704e02e51943d18cf25a546b8416e9e1.node.keys.toml @@ -0,0 +1,6 @@ +network_signing_private_key = "2000000000000000f9e1b4bd35cff88f047043aebedfbd310bb6ca762040ac51c28203fb1873e539" +network_signing_public_key = "200000000000000066535c0f4f9242ea8f9ff7f68b4f50b7c9d8e7961a41cf216124b33a325505ff" +network_identity_private_key = "200000000000000078fea3aa16bf4a9c6471aeaa6fc3e0427d311503a31fe37f0245e30addf14e53" +network_identity_public_key = "2000000000000000882a6058c64664e59415d1692b8ecbe976665b4549631c7ff15e6fbce3e2e50e" +consensus_private_key = "2000000000000000e8df7d746043aa71d45397b001787347bcbf5f23b6fae5a9b83203cfaefc46be" +consensus_public_key = "2000000000000000883bbde2fad70bcfdaf6a8f12a9a4a0722f9e68d2ac4ad4c065d19616f7c42e2" diff --git a/terraform/validator-sets/dev/57ff83747054695f2228042c26eb6a243ac73de1b9038aea103999480b076d45.node.keys.toml b/terraform/validator-sets/dev/57ff83747054695f2228042c26eb6a243ac73de1b9038aea103999480b076d45.node.keys.toml new file mode 100644 index 0000000000000..f25dc0392df3c --- /dev/null +++ b/terraform/validator-sets/dev/57ff83747054695f2228042c26eb6a243ac73de1b9038aea103999480b076d45.node.keys.toml @@ -0,0 +1,6 @@ +network_signing_private_key = "20000000000000007e39e8c9c4efcdd75dfa9660bbbeb584b5593dff4566eac37e857d9b9e21ff06" +network_signing_public_key = "20000000000000006f59175f62857f4b541dcf8226b07300ab3ade5b7a6622cba55c1a22765b04cd" +network_identity_private_key = "2000000000000000385253c287ab3de348e133d87509e1f1678beb1dfa278ade7d766ae58633075e" +network_identity_public_key = "20000000000000007ae91e2907464755c9ddc3ce2f1b403624b9ccc61b4be8c8388ce28c47ac3120" +consensus_private_key = "20000000000000002bb04f975864ebab903123d6ba517ddd4b145efb28f728968e0d7437fc750977" +consensus_public_key = "20000000000000001faf7c0c4dd0b3025e2aa5635395aac947b659dff188a0380e6ead27a2c33a1a" diff --git a/terraform/validator-sets/dev/8deeeaed65f0cd7484a9e4e5ac51fbac548f2f71299a05e000156031ca78fb9f.node.keys.toml b/terraform/validator-sets/dev/8deeeaed65f0cd7484a9e4e5ac51fbac548f2f71299a05e000156031ca78fb9f.node.keys.toml new file mode 100644 index 0000000000000..a29c5982e3231 --- /dev/null +++ b/terraform/validator-sets/dev/8deeeaed65f0cd7484a9e4e5ac51fbac548f2f71299a05e000156031ca78fb9f.node.keys.toml @@ -0,0 +1,6 @@ +network_signing_private_key = "200000000000000082001573a003fd3b7fd72ffb0eaf63aac62f12deb629dca72785a66268ec758b" +network_signing_public_key = "2000000000000000664f6e8f36eacb1770fa879d86c2c1d0fafea145e84fa7d671ab7a011a54d509" +network_identity_private_key = "200000000000000018db36900560898178e0ad009abf1f491330dc1c246e3d6cb264f6900271d55c" +network_identity_public_key = "2000000000000000b1df0ea1b4c1400454bab824e2e3ef6669e4231e2b9332020d9630fe1cfb2808" +consensus_private_key = "2000000000000000fb1c12c1efcb64c5603ca15ac896d1abc1082b17b096c9176547992eaa0eb646" +consensus_public_key = "200000000000000090bba9133465da772eea2823cd0d871dbf0f27580ec8b791ebfa21ce18baae7a" diff --git a/terraform/validator-sets/dev/ab0d6a54ce9d7fc79c061f95883a308f9bdfc987262b6a34a360fdd788fcd9cd.node.keys.toml b/terraform/validator-sets/dev/ab0d6a54ce9d7fc79c061f95883a308f9bdfc987262b6a34a360fdd788fcd9cd.node.keys.toml new file mode 100644 index 0000000000000..c229ecfeb96c8 --- /dev/null +++ b/terraform/validator-sets/dev/ab0d6a54ce9d7fc79c061f95883a308f9bdfc987262b6a34a360fdd788fcd9cd.node.keys.toml @@ -0,0 +1,6 @@ +network_signing_private_key = "200000000000000076f5e28163a6f72f4ab72fcd90c69eedef6d5c73539f14e7ca0bf6a9f229f12d" +network_signing_public_key = "20000000000000006721135d6093ee12624bdac7f8fa1350a01410411020010648b2dc8c80b1c2c1" +network_identity_private_key = "20000000000000001092efcfdad11ee26de44f0e843178d2c989d4f21fe9b53c03c12874e83a7066" +network_identity_public_key = "20000000000000009dc8ab3ea3b059e12ade6ba9c1d9cd17f72ce6cb5d4d9f5e23fb0e6d4519f451" +consensus_private_key = "20000000000000003cdc1d88eda836767090fb77dae5ecc00f5bf5ca8879733fac19e8de5c725636" +consensus_public_key = "2000000000000000bd51b393b5d059055e219b5081fd8de113cddc4e26dbbe10d08027d6c1e03e8e" diff --git a/terraform/validator-sets/dev/genesis.blob b/terraform/validator-sets/dev/genesis.blob new file mode 100644 index 0000000000000..51f762b3a98d0 Binary files /dev/null and b/terraform/validator-sets/dev/genesis.blob differ diff --git a/terraform/validator-sets/dev/mint.key b/terraform/validator-sets/dev/mint.key new file mode 100644 index 0000000000000..b148fc2bdc37d Binary files /dev/null and b/terraform/validator-sets/dev/mint.key differ diff --git a/terraform/validator-sets/dev/node.config.toml b/terraform/validator-sets/dev/node.config.toml new file mode 100644 index 0000000000000..6deb27878284e --- /dev/null +++ b/terraform/validator-sets/dev/node.config.toml @@ -0,0 +1,81 @@ +[base] +peer_id = "" +data_dir_path = "" +trusted_peers_file = "/opt/libra/etc/trusted_peers.config.toml" +peer_keypairs_file = "/opt/libra/etc/peer_keypairs.config.toml" +node_sync_batch_size = 1000 +node_sync_retries = 3 +node_sync_channel_buffer_size = 10 +node_async_log_chan_size = 256 + +[metrics] +dir = "metrics" +collection_interval_ms = 1000 +push_server_addr = "" + +[execution] +address = "localhost" +port = 59622 +testnet_genesis = false +genesis_file_location = "/opt/libra/etc/genesis.blob" + +[admission_control] +address = "0.0.0.0" +admission_control_service_port = 30307 +need_to_check_mempool_before_validation = false + +[debug_interface] +admission_control_node_debug_port = 32987 +storage_node_debug_port = 49125 +secret_service_node_debug_port = 50316 +metrics_server_port = 14297 +address = "0.0.0.0" + +[storage] +address = "localhost" +port = 35647 +dir = "/opt/libra/data" + +[network] +seed_peers_file = "/opt/libra/etc/seed_peers.config.toml" +listen_address = "/ip4/0.0.0.0/tcp/30303" +advertised_address = "/ip4/SELF_IP/tcp/30303" +discovery_interval_ms = 1000 +connectivity_check_interval_ms = 5000 +enable_encryption_and_authentication = true + +[consensus] +max_block_size = 100 +proposer_type = "rotating_proposer" +contiguous_rounds = 2 + +[mempool] +broadcast_transactions = true +shared_mempool_tick_interval_ms = 50 +shared_mempool_batch_size = 100 +shared_mempool_max_concurrent_inbound_syncs = 100 +capacity = 10000000 +capacity_per_user = 100 +sequence_cache_capacity = 1000 +system_transaction_timeout_secs = 86400 +system_transaction_gc_interval_ms = 180000 +mempool_service_port = 59620 +address = "localhost" + +[log_collector] +is_async = true +use_std_output = true + +[vm_config] + [vm_config.publishing_options] + type = "Locked" + whitelist = [ + "88c0c64595f6cec7d0c0bfe29e1be1886c736ec3d26888d049e30909f7a72836", + "d3493756a00b7a9e4d9ca8482e80fd055411ce53882bdcb08fec97d42eef0bde", + "ee31d65b559ad5a300e6a508ff3edb2d23f1589ef68d0ead124d8f0374073d84", + "2bb3828f55bc640a85b17d9c6e120e84f8c068c9fd850e1a1d61d2f91ed295fd" + ] + +[secret_service] +address = "localhost" +secret_service_port = 59618 diff --git a/terraform/validator-sets/dev/seed_peers.config.toml b/terraform/validator-sets/dev/seed_peers.config.toml new file mode 100644 index 0000000000000..44533e8ca25ed --- /dev/null +++ b/terraform/validator-sets/dev/seed_peers.config.toml @@ -0,0 +1,2 @@ +[seed_peers] +8deeeaed65f0cd7484a9e4e5ac51fbac548f2f71299a05e000156031ca78fb9f = ["/ip4/SEED_IP/tcp/30303"] diff --git a/terraform/validator-sets/dev/trusted_peers.config.toml b/terraform/validator-sets/dev/trusted_peers.config.toml new file mode 100644 index 0000000000000..155b87bd11cff --- /dev/null +++ b/terraform/validator-sets/dev/trusted_peers.config.toml @@ -0,0 +1,19 @@ +[peers.ab0d6a54ce9d7fc79c061f95883a308f9bdfc987262b6a34a360fdd788fcd9cd] +network_signing_pubkey = "20000000000000006721135d6093ee12624bdac7f8fa1350a01410411020010648b2dc8c80b1c2c1" +network_identity_pubkey = "20000000000000009dc8ab3ea3b059e12ade6ba9c1d9cd17f72ce6cb5d4d9f5e23fb0e6d4519f451" +consensus_pubkey = "2000000000000000bd51b393b5d059055e219b5081fd8de113cddc4e26dbbe10d08027d6c1e03e8e" + +[peers.1e5d5a74b0fd09f601ac0fca2fe7d213704e02e51943d18cf25a546b8416e9e1] +network_signing_pubkey = "200000000000000066535c0f4f9242ea8f9ff7f68b4f50b7c9d8e7961a41cf216124b33a325505ff" +network_identity_pubkey = "2000000000000000882a6058c64664e59415d1692b8ecbe976665b4549631c7ff15e6fbce3e2e50e" +consensus_pubkey = "2000000000000000883bbde2fad70bcfdaf6a8f12a9a4a0722f9e68d2ac4ad4c065d19616f7c42e2" + +[peers.8deeeaed65f0cd7484a9e4e5ac51fbac548f2f71299a05e000156031ca78fb9f] +network_signing_pubkey = "2000000000000000664f6e8f36eacb1770fa879d86c2c1d0fafea145e84fa7d671ab7a011a54d509" +network_identity_pubkey = "2000000000000000b1df0ea1b4c1400454bab824e2e3ef6669e4231e2b9332020d9630fe1cfb2808" +consensus_pubkey = "200000000000000090bba9133465da772eea2823cd0d871dbf0f27580ec8b791ebfa21ce18baae7a" + +[peers.57ff83747054695f2228042c26eb6a243ac73de1b9038aea103999480b076d45] +network_signing_pubkey = "20000000000000006f59175f62857f4b541dcf8226b07300ab3ade5b7a6622cba55c1a22765b04cd" +network_identity_pubkey = "20000000000000007ae91e2907464755c9ddc3ce2f1b403624b9ccc61b4be8c8388ce28c47ac3120" +consensus_pubkey = "20000000000000001faf7c0c4dd0b3025e2aa5635395aac947b659dff188a0380e6ead27a2c33a1a" diff --git a/terraform/validators.tf b/terraform/validators.tf new file mode 100644 index 0000000000000..85e4e0ec0f5a5 --- /dev/null +++ b/terraform/validators.tf @@ -0,0 +1,206 @@ +data "aws_ami" "ecs" { + most_recent = true + + filter { + name = "name" + values = ["amzn2-ami-ecs-hvm-2.0.*"] + } + + filter { + name = "architecture" + values = ["x86_64"] + } + + owners = ["amazon"] +} + +locals { + cpu_by_instance = { + "t2.large" = 2048 + "t2.medium" = 2048 + "t3.medium" = 2048 + "m5.large" = 2048 + "m5.xlarge" = 4096 + "m5.2xlarge" = 8192 + "m5.4xlarge" = 16384 + "m5.12xlarge" = 49152 + "m5.24xlarge" = 98304 + "c5.large" = 2048 + "c5.xlarge" = 4096 + "c5.2xlarge" = 8192 + "c5.4xlarge" = 16384 + "c5.9xlarge" = 36864 + "c5.18xlarge" = 73728 + } + + mem_by_instance = { + "t2.medium" = 3943 + "t2.large" = 7975 + "t3.medium" = 3884 + "m5.large" = 7680 + "m5.xlarge" = 15576 + "m5.2xlarge" = 31368 + "m5.4xlarge" = 62950 + "m5.12xlarge" = 189283 + "m5.24xlarge" = 378652 + "c5.large" = 3704 + "c5.xlarge" = 7624 + "c5.2xlarge" = 15464 + "c5.4xlarge" = 31142 + "c5.9xlarge" = 70341 + "c5.18xlarge" = 140768 + } +} + +resource "aws_cloudwatch_log_group" "testnet" { + name = terraform.workspace + retention_in_days = 7 +} + +resource "aws_cloudwatch_log_metric_filter" "log_metric_filter" { + name = "critical_log" + pattern = "[code=C*, time, x, file, ...]" + log_group_name = "${aws_cloudwatch_log_group.testnet.name}" + + metric_transformation { + name = "critical_lines" + namespace = "LogMetrics" + value = "1" + } +} + +data "template_file" "user_data" { + template = file("templates/ec2_user_data.sh") + + vars = { + ecs_cluster = aws_ecs_cluster.testnet.name + } +} + +locals { + image_repo = var.image_repo + instance_public_ip = true + user_data = data.template_file.user_data.rendered + image_version = substr(var.image_tag, 0, 6) == "sha256" ? "@${var.image_tag}" : ":${var.image_tag}" +} + +resource "aws_instance" "validator" { + count = length(var.peer_ids) + ami = data.aws_ami.ecs.id + instance_type = var.validator_type + subnet_id = element( + aws_subnet.testnet.*.id, + count.index % length(data.aws_availability_zones.available.names), + ) + vpc_security_group_ids = [aws_security_group.validator.id] + associate_public_ip_address = local.instance_public_ip + key_name = aws_key_pair.libra.key_name + iam_instance_profile = aws_iam_instance_profile.ecsInstanceRole.name + user_data = local.user_data + + root_block_device { + volume_type = "io1" + volume_size = 100 + iops = 5000 # max 50iops/gb + } + + tags = { + Name = "${terraform.workspace}-validator-${substr(var.peer_ids[count.index], 0, 8)}" + Role = "validator" + Workspace = terraform.workspace + PeerId = var.peer_ids[count.index] + } +} + +data "local_file" "keys" { + count = length(var.peer_ids) + filename = "${var.validator_set}/${var.peer_ids[count.index]}.node.keys.toml" +} + +resource "aws_secretsmanager_secret" "validator" { + count = length(var.peer_ids) + name = "${terraform.workspace}-${substr(var.peer_ids[count.index], 0, 8)}" + recovery_window_in_days = 0 +} + +resource "aws_secretsmanager_secret_version" "validator" { + count = length(var.peer_ids) + secret_id = element(aws_secretsmanager_secret.validator.*.id, count.index) + secret_string = element(data.local_file.keys.*.content, count.index) +} + +data "template_file" "seed_peers" { + template = file("templates/seed_peers.config.toml") + + vars = { + validators = join(",", formatlist("%s:%s", var.peer_ids, aws_instance.validator.*.private_ip)) + } +} + +data "template_file" "ecs_task_definition" { + count = length(var.peer_ids) + template = file("templates/validator.json") + + vars = { + image = local.image_repo + image_version = local.image_version + cpu = local.cpu_by_instance[var.validator_type] + mem = local.mem_by_instance[var.validator_type] + self_ip = element(aws_instance.validator.*.private_ip, count.index) + seed_peers = jsonencode(data.template_file.seed_peers.rendered) + trusted_peers = jsonencode(file("${var.validator_set}/trusted_peers.config.toml")) + genesis_blob = jsonencode(filebase64("${var.validator_set}/genesis.blob")) + peer_id = var.peer_ids[count.index] + secret = element(aws_secretsmanager_secret.validator.*.arn, count.index) + log_level = var.validator_log_level + log_group = aws_cloudwatch_log_group.testnet.name + log_region = var.region + log_prefix = "validator-${substr(var.peer_ids[count.index], 0, 8)}" + } +} + +resource "aws_ecs_task_definition" "validator" { + count = length(var.peer_ids) + family = "${terraform.workspace}-validator-${substr(var.peer_ids[count.index], 0, 8)}" + container_definitions = element( + data.template_file.ecs_task_definition.*.rendered, + count.index, + ) + execution_role_arn = aws_iam_role.ecsTaskExecutionRole.arn + network_mode = "host" + + volume { + name = "libra-data" + host_path = "/data/libra" + } + + placement_constraints { + type = "memberOf" + expression = "ec2InstanceId == ${element(aws_instance.validator.*.id, count.index)}" + } + + tags = { + PeerId = "${substr(var.peer_ids[count.index], 0, 8)}" + Role = "validator" + Workspace = terraform.workspace + } +} + +resource "aws_ecs_cluster" "testnet" { + name = terraform.workspace +} + +resource "aws_ecs_service" "validator" { + count = length(var.peer_ids) + name = "${terraform.workspace}-validator-${substr(var.peer_ids[count.index], 0, 8)}" + cluster = aws_ecs_cluster.testnet.id + task_definition = element(aws_ecs_task_definition.validator.*.arn, count.index) + desired_count = 1 + deployment_minimum_healthy_percent = 0 + + tags = { + PeerId = "${substr(var.peer_ids[count.index], 0, 8)}" + Role = "validator" + Workspace = terraform.workspace + } +} diff --git a/terraform/variables.tf b/terraform/variables.tf new file mode 100644 index 0000000000000..75c2bd0ca7fa4 --- /dev/null +++ b/terraform/variables.tf @@ -0,0 +1,93 @@ +variable "region" { + default = "us-west-2" +} + +variable "ssh_pub_key" { + type = string + description = "SSH public key for EC2 instance access" +} + +variable "ssh_priv_key_file" { + type = string + description = "Filename of SSH private key for EC2 instance access" +} + +variable "ssh_sources_ipv4" { + type = list(string) + description = "List of IPv4 CIDR blocks from which to allow SSH access" + default = ["0.0.0.0/0"] +} + +variable "ssh_sources_ipv6" { + type = list(string) + description = "List of IPv6 CIDR blocks from which to allow SSH access" + default = ["::/0"] +} + +variable "api_sources_ipv4" { + type = list(string) + description = "List of IPv4 CIDR blocks from which to allow API access" + default = ["0.0.0.0/0"] +} + +variable "image_repo" { + type = string + description = "Docker image repository to use for validator" + default = "libra/validator" +} + +variable "image_tag" { + type = string + description = "Docker image tag to use for validator" + default = "latest" +} + +variable "peer_ids" { + type = list(string) + description = "List of PeerIds" +} + +variable "validator_type" { + description = "EC2 instance type of validator instances" + default = "m5.large" +} + +variable "faucet_image_repo" { + description = "Docker image repository to use for faucet server" + default = "libra/faucet" +} + +variable "faucet_log_level" { + description = "Log level for faucet to pass to gunicorn" + default = "info" +} + +variable "faucet_image_tag" { + description = "Docker image tag to use for faucet server" + default = "latest" +} + +variable "zone_id" { + description = "Route53 ZoneId to create records in" + default = "" +} + +variable "validator_set" { + description = "Relative path to directory containing validator set configs" + default = "validator-sets/dev" +} + +variable "validator_log_level" { + description = "Log level for validator processes (set with RUST_LOG)" + default = "debug" +} + +variable "append_workspace_dns" { + description = "Append Terraform workspace to DNS names created" + default = true +} + +variable "prometheus_pagerduty_key" { + default = "" + description = "Key for Prometheus-PagerDuty integration" +} diff --git a/testsuite/Cargo.toml b/testsuite/Cargo.toml new file mode 100644 index 0000000000000..ad7dba8cd71d9 --- /dev/null +++ b/testsuite/Cargo.toml @@ -0,0 +1,21 @@ +[package] +name = "testsuite" +version = "0.1.0" +authors = ["Libra Association "] +license = "Apache-2.0" +publish = false +edition = "2018" + +[dev-dependencies] +lazy_static = "1.2.0" +num-traits = "0.2" +rust_decimal = "1.0.1" + +# In order to limit the potential waiting time for binaries to be built while +# running tests all binaries which are being tested under this testsuite +# should have their crates listed as dev-dependencies. +cli = { path = "../client", package="client"} +generate_keypair = { path = "../config/generate_keypair" } +libra_swarm = { path = "../libra_swarm" } +logger = { path = "../common/logger" } +tempfile = "3.0.6" diff --git a/testsuite/libra_fuzzer/Cargo.toml b/testsuite/libra_fuzzer/Cargo.toml new file mode 100644 index 0000000000000..afb131953da5b --- /dev/null +++ b/testsuite/libra_fuzzer/Cargo.toml @@ -0,0 +1,27 @@ +[package] +name = "libra_fuzzer" +version = "0.1.0" +authors = ["Libra Association "] +license = "Apache-2.0" +edition = "2018" + +[dependencies] +failure = { path = "../../common/failure_ext", package = "failure_ext" } +byteorder = "1.3.1" +hex = "0.3.2" +lazy_static = "1.3" +proptest = "0.9.3" +sha-1 = "0.8" +structopt = "0.2.15" +proto_conv = { path = "../../common/proto_conv" } +canonical_serialization = { path = "../../common/canonical_serialization" } + +# List out modules with data structures being fuzzed here. +types = { path = "../../types" } +vm = { path = "../../language/vm" } +vm_runtime = { path = "../../language/vm/vm_runtime" } + +[dev-dependencies] +datatest = "0.3.1" +stats_alloc = "0.1.8" +rusty-fork = "0.2.2" diff --git a/testsuite/libra_fuzzer/README.md b/testsuite/libra_fuzzer/README.md new file mode 100644 index 0000000000000..28217fbca0b9b --- /dev/null +++ b/testsuite/libra_fuzzer/README.md @@ -0,0 +1,51 @@ +## Fuzzing support for Libra + +This crate contains support for fuzzing Libra targets. This support +includes: +* corpus generation with `proptest` +* automatically running failing examples with `cargo test` + +### Prerequisites + +Install [`cargo-fuzz`](https://rust-fuzz.github.io/book/cargo-fuzz.html) if not already available: `cargo install cargo-fuzz`. + +### Fuzzing a target + +To list out known fuzz targets, run `cargo run list`. + +To be effective, fuzzing requires a corpus of existing inputs. This +crate contains support for generating corpuses with `proptest`. Generate +a corpus with `cargo run generate `. + +Once a corpus has been generated, the fuzzer is ready to use: run +`cargo run fuzz `. + +For more options, run `cargo run -- --help`. + +### Adding a new target + +Fuzz targets go in `src/fuzz_targets/`. Adding a new target involves +creating a new type and implementing `FuzzTargetImpl` for it. + +For examples, see the existing implementations in `src/fuzz_targets/`. + +Remember to add your target to `ALL_TARGETS` in `src/fuzz_targets.rs`. +Once that has been done, `cargo run list` should list your new target. + +### Debugging and testing artifacts + +If the fuzzer finds a failing artifact, it will save the artifact to a +file inside the `fuzz` directory and print its path. To add this +artifact to the test suite, copy it to a file inside +`artifacts//`. + +`cargo test` will now test the deserializer against the new artifact. +The test will likely fail at first; use. + +Note that `cargo test` runs each test in a separate process by default +to isolate failures and memory usage; if you're attaching a debugger and +are running a single test, set `NO_FORK=1` to disable forking. + +Once the deserializer has been fixed, check the artifact into the +`artifacts//` directory. The artifact will then act as a +regression test in `cargo test` runs. diff --git a/testsuite/libra_fuzzer/artifacts/compiled_module/crash-5d65cfbc b/testsuite/libra_fuzzer/artifacts/compiled_module/crash-5d65cfbc new file mode 100644 index 0000000000000..66aea501bb19a Binary files /dev/null and b/testsuite/libra_fuzzer/artifacts/compiled_module/crash-5d65cfbc differ diff --git a/testsuite/libra_fuzzer/artifacts/compiled_module/oom-54725f11 b/testsuite/libra_fuzzer/artifacts/compiled_module/oom-54725f11 new file mode 100644 index 0000000000000..3e94cfea3be52 Binary files /dev/null and b/testsuite/libra_fuzzer/artifacts/compiled_module/oom-54725f11 differ diff --git a/testsuite/libra_fuzzer/artifacts/raw_transaction/crash-11f36fef b/testsuite/libra_fuzzer/artifacts/raw_transaction/crash-11f36fef new file mode 100644 index 0000000000000..ea69c781ef22b Binary files /dev/null and b/testsuite/libra_fuzzer/artifacts/raw_transaction/crash-11f36fef differ diff --git a/testsuite/libra_fuzzer/artifacts/raw_transaction/working-1 b/testsuite/libra_fuzzer/artifacts/raw_transaction/working-1 new file mode 100644 index 0000000000000..ef4a5bdd0058e Binary files /dev/null and b/testsuite/libra_fuzzer/artifacts/raw_transaction/working-1 differ diff --git a/testsuite/libra_fuzzer/fuzz/.gitignore b/testsuite/libra_fuzzer/fuzz/.gitignore new file mode 100644 index 0000000000000..572e03bdf321b --- /dev/null +++ b/testsuite/libra_fuzzer/fuzz/.gitignore @@ -0,0 +1,4 @@ + +target +corpus +artifacts diff --git a/testsuite/libra_fuzzer/fuzz/Cargo.toml b/testsuite/libra_fuzzer/fuzz/Cargo.toml new file mode 100644 index 0000000000000..9b56d4f347552 --- /dev/null +++ b/testsuite/libra_fuzzer/fuzz/Cargo.toml @@ -0,0 +1,27 @@ +[package] +name = "libra_fuzzer_fuzz" +version = "0.1.0" +edition = "2018" +authors = ["Automatically generated"] +publish = false + +[package.metadata] +cargo-fuzz = true + +[dependencies] +libfuzzer-sys = { git = "https://github.com/rust-fuzz/libfuzzer-sys.git" } +libra_fuzzer = { path = ".." } +lazy_static = "1.3.0" + +# futures-preview is pinned to 0.3.0-alpha.14 by some other packages. The latest version +# 0.3.0-alpha.16 doesn't compile with the current Rust toolchain, and we don't depend on any +# packages that pin futures exactly, so do it here. +futures-preview = "=0.3.0-alpha.14" + +# Prevent this from interfering with workspaces +[workspace] +members = ["."] + +[[bin]] +name = "fuzz_runner" +path = "fuzz_targets/fuzz_runner.rs" diff --git a/testsuite/libra_fuzzer/fuzz/fuzz_targets/fuzz_runner.rs b/testsuite/libra_fuzzer/fuzz/fuzz_targets/fuzz_runner.rs new file mode 100644 index 0000000000000..f7d0ba8798068 --- /dev/null +++ b/testsuite/libra_fuzzer/fuzz/fuzz_targets/fuzz_runner.rs @@ -0,0 +1,26 @@ +// Copyright (c) The Libra Core Contributors +// SPDX-License-Identifier: Apache-2.0 + +#![no_main] + +use lazy_static::lazy_static; +use libfuzzer_sys::fuzz_target; +use libra_fuzzer::FuzzTarget; +use std::process; + +lazy_static! { + static ref FUZZ_TARGET: FuzzTarget = { + match FuzzTarget::from_env() { + Ok(target) => target, + Err(err) => { + // lazy_static behaves poorly with panics, so abort here. + eprintln!("*** [fuzz_runner] Error while determining fuzz target: {}", err); + process::abort(); + } + } + }; +} + +fuzz_target!(|data: &[u8]| { + FUZZ_TARGET.fuzz(data); +}); diff --git a/testsuite/libra_fuzzer/src/commands.rs b/testsuite/libra_fuzzer/src/commands.rs new file mode 100644 index 0000000000000..48895d8d64129 --- /dev/null +++ b/testsuite/libra_fuzzer/src/commands.rs @@ -0,0 +1,106 @@ +// Copyright (c) The Libra Core Contributors +// SPDX-License-Identifier: Apache-2.0 + +use crate::FuzzTarget; +use failure::prelude::*; +use proptest::test_runner::{Config, TestRunner}; +use sha1::{Digest, Sha1}; +use std::{ + env, + ffi::OsString, + fs, + io::Write, + path::{Path, PathBuf}, + process::Command, +}; + +/// Generate data for this fuzz target into the output directory. +/// +/// The corpus directory should be present at the time this method is called. +pub fn make_corpus( + target: FuzzTarget, + num_items: usize, + corpus_dir: &Path, + debug: bool, +) -> Result<()> { + // TODO: Allow custom proptest configs? + let mut runner = TestRunner::new(Config::default()); + + let mut sha1 = Sha1::new(); + + for _ in 0..num_items { + let result = target.generate(&mut runner); + + // Use the SHA-1 of the result as the file name. + sha1.input(&result); + let hash = sha1.result_reset(); + let name = hex::encode(hash.as_slice()); + let path = corpus_dir.join(name); + let mut f = fs::File::create(&path) + .with_context(|_| format!("Failed to create file: {:?}", path))?; + if debug { + println!("Writing {} bytes to file: {:?}", result.len(), path); + } + + f.write_all(&result) + .with_context(|_| format!("Failed to write to file: {:?}", path))?; + } + Ok(()) +} + +/// Fuzz a target by running `cargo fuzz run`. +pub fn fuzz_target( + target: FuzzTarget, + corpus_dir: PathBuf, + artifact_dir: PathBuf, + mut args: Vec, +) -> Result<()> { + static FUZZ_RUNNER: &str = "fuzz_runner"; + + // Do a bit of arg parsing -- look for a "--" and insert the target and corpus directory + // before that. + let dash_dash_pos = args.iter().position(|x| x == "--"); + let splice_pos = dash_dash_pos.unwrap_or_else(|| args.len()); + args.splice( + splice_pos..splice_pos, + vec![FUZZ_RUNNER.into(), corpus_dir.into()], + ); + + // The artifact dir goes at the end. + if dash_dash_pos.is_none() { + args.push("--".into()); + } + let mut artifact_arg: OsString = "-artifact_prefix=".into(); + artifact_arg.push(&artifact_dir); + // Add a trailing slash as required by libfuzzer to put the artifact in a directory. + artifact_arg.push("/"); + args.push(artifact_arg); + + // Pass the target name in as an environment variable. + // Use the manifest directory as the current one. + let manifest_dir = match env::var_os("CARGO_MANIFEST_DIR") { + Some(dir) => dir, + None => bail!("Fuzzing requires CARGO_MANIFEST_DIR to be set (are you using `cargo run`?)"), + }; + + let status = Command::new("cargo") + .arg("fuzz") + .arg("run") + .args(args) + .current_dir(manifest_dir) + .env(FuzzTarget::ENV_VAR, target.name()) + .status() + .context("cargo fuzz run errored")?; + if !status.success() { + bail!("cargo fuzz run failed with status {}", status); + } + Ok(()) +} + +/// List all known fuzz targets. +pub fn list_targets() { + println!("Available fuzz targets:\n"); + for target in FuzzTarget::all_targets() { + println!(" * {0: <24} {1}", target.name(), target.description()) + } +} diff --git a/testsuite/libra_fuzzer/src/fuzz_targets.rs b/testsuite/libra_fuzzer/src/fuzz_targets.rs new file mode 100644 index 0000000000000..1bb9fdc8a0bc3 --- /dev/null +++ b/testsuite/libra_fuzzer/src/fuzz_targets.rs @@ -0,0 +1,107 @@ +// Copyright (c) The Libra Core Contributors +// SPDX-License-Identifier: Apache-2.0 + +use crate::{FuzzTarget, FuzzTargetImpl}; +use failure::prelude::*; +use lazy_static::lazy_static; +use proptest::{ + strategy::{Strategy, ValueTree}, + test_runner::TestRunner, +}; +use std::{collections::BTreeMap, env, fmt}; + +/// Convenience macro to return the module name. +macro_rules! module_name { + () => { + module_path!() + .rsplit("::") + .next() + .expect("module path must have at least one component") + }; +} + +/// A fuzz target implementation for protobuf-compiled targets. +macro_rules! proto_fuzz_target { + ($target:ident => $ty:ty) => { + #[derive(Clone, Debug, Default)] + pub struct $target; + + impl $crate::FuzzTargetImpl for $target { + fn name(&self) -> &'static str { + module_name!() + } + + fn description(&self) -> &'static str { + concat!(stringify!($ty), " (protobuf)") + } + + fn generate(&self, runner: &mut ::proptest::test_runner::TestRunner) -> Vec { + use proto_conv::IntoProtoBytes; + + let value = + $crate::fuzz_targets::new_value(runner, ::proptest::arbitrary::any::<$ty>()); + value + .into_proto_bytes() + .expect("failed to convert to bytes") + } + + fn fuzz(&self, data: &[u8]) { + use proto_conv::FromProtoBytes; + + // Errors are OK -- the fuzzer cares about panics and OOMs. + let _ = <$ty>::from_proto_bytes(data); + } + } + }; +} + +// List fuzz target modules here. +mod compiled_module; +mod raw_transaction; +mod signed_transaction; +mod vm_value; + +lazy_static! { + static ref ALL_TARGETS: BTreeMap<&'static str, Box> = { + let targets: Vec> = vec![ + // List fuzz targets here in this format. + Box::new(compiled_module::CompiledModuleTarget::default()), + Box::new(raw_transaction::RawTransactionTarget::default()), + Box::new(signed_transaction::SignedTransactionTarget::default()), + Box::new(vm_value::ValueTarget::default()), + ]; + targets.into_iter().map(|target| (target.name(), target)).collect() + }; +} + +impl FuzzTarget { + /// The environment variable used for passing fuzz targets to child processes. + pub(crate) const ENV_VAR: &'static str = "FUZZ_TARGET"; + + /// Get the current fuzz target from the environment. + pub fn from_env() -> Result { + let name = env::var(Self::ENV_VAR)?; + match Self::by_name(&name) { + Some(target) => Ok(target), + None => bail!("Unknown fuzz target '{}'", name), + } + } + + /// Get a fuzz target by name. + pub fn by_name(name: &str) -> Option { + ALL_TARGETS.get(name).map(|target| FuzzTarget(&**target)) + } + + /// A list of all fuzz targets. + pub fn all_targets() -> impl Iterator { + ALL_TARGETS.values().map(|target| FuzzTarget(&**target)) + } +} + +/// Produce a value from this strategy. +fn new_value(runner: &mut TestRunner, strategy: impl Strategy) -> T { + let value_tree = strategy + .new_tree(runner) + .expect("failed to create value tree"); + value_tree.current() +} diff --git a/testsuite/libra_fuzzer/src/fuzz_targets/compiled_module.rs b/testsuite/libra_fuzzer/src/fuzz_targets/compiled_module.rs new file mode 100644 index 0000000000000..04bc00ecb9271 --- /dev/null +++ b/testsuite/libra_fuzzer/src/fuzz_targets/compiled_module.rs @@ -0,0 +1,33 @@ +// Copyright (c) The Libra Core Contributors +// SPDX-License-Identifier: Apache-2.0 + +use crate::{fuzz_targets::new_value, FuzzTargetImpl}; +use proptest::{prelude::*, test_runner::TestRunner}; +use vm::file_format::CompiledModule; + +#[derive(Clone, Debug, Default)] +pub struct CompiledModuleTarget; + +impl FuzzTargetImpl for CompiledModuleTarget { + fn name(&self) -> &'static str { + module_name!() + } + + fn description(&self) -> &'static str { + "VM CompiledModule (custom deserializer)" + } + + fn generate(&self, runner: &mut TestRunner) -> Vec { + let value = new_value(runner, any_with::(16)); + let mut out = vec![]; + value + .serialize(&mut out) + .expect("serialization should work"); + out + } + + fn fuzz(&self, data: &[u8]) { + // Errors are OK -- the fuzzer cares about panics and OOMs. + let _ = CompiledModule::deserialize(data); + } +} diff --git a/testsuite/libra_fuzzer/src/fuzz_targets/raw_transaction.rs b/testsuite/libra_fuzzer/src/fuzz_targets/raw_transaction.rs new file mode 100644 index 0000000000000..a344d641b0a5f --- /dev/null +++ b/testsuite/libra_fuzzer/src/fuzz_targets/raw_transaction.rs @@ -0,0 +1,5 @@ +// Copyright (c) The Libra Core Contributors +// SPDX-License-Identifier: Apache-2.0 + +use types::transaction::RawTransaction; +proto_fuzz_target!(RawTransactionTarget => RawTransaction); diff --git a/testsuite/libra_fuzzer/src/fuzz_targets/signed_transaction.rs b/testsuite/libra_fuzzer/src/fuzz_targets/signed_transaction.rs new file mode 100644 index 0000000000000..35f86f98e5bcf --- /dev/null +++ b/testsuite/libra_fuzzer/src/fuzz_targets/signed_transaction.rs @@ -0,0 +1,5 @@ +// Copyright (c) The Libra Core Contributors +// SPDX-License-Identifier: Apache-2.0 + +use types::transaction::SignedTransaction; +proto_fuzz_target!(SignedTransactionTarget => SignedTransaction); diff --git a/testsuite/libra_fuzzer/src/fuzz_targets/vm_value.rs b/testsuite/libra_fuzzer/src/fuzz_targets/vm_value.rs new file mode 100644 index 0000000000000..1e9440540c48b --- /dev/null +++ b/testsuite/libra_fuzzer/src/fuzz_targets/vm_value.rs @@ -0,0 +1,69 @@ +// Copyright (c) The Libra Core Contributors +// SPDX-License-Identifier: Apache-2.0 + +use crate::{fuzz_targets::new_value, FuzzTargetImpl}; +use byteorder::{BigEndian, ReadBytesExt, WriteBytesExt}; +use canonical_serialization::*; +use failure::prelude::*; +use proptest::test_runner::TestRunner; +use std::io::Cursor; +use vm_runtime::{loaded_data::struct_def::StructDef, value::Value}; + +#[derive(Clone, Debug, Default)] +pub struct ValueTarget; + +impl FuzzTargetImpl for ValueTarget { + fn name(&self) -> &'static str { + module_name!() + } + + fn description(&self) -> &'static str { + "VM values + types (custom deserializer)" + } + + fn generate(&self, runner: &mut TestRunner) -> Vec { + let value = new_value(runner, Value::struct_strategy()); + let struct_def = value.to_struct_def_FOR_TESTING(); + + // Values as currently serialized are not self-describing, so store a serialized form of the + // type along with the value as well. + let mut serializer = SimpleSerializer::new(); + struct_def + .serialize(&mut serializer) + .expect("must serialize"); + let struct_def_blob: Vec = serializer.get_output(); + + let value_blob = value.simple_serialize().expect("must serialize"); + let mut blob = vec![]; + // Prefix the struct def blob with its length. + blob.write_u64::(struct_def_blob.len() as u64) + .expect("writing should work"); + blob.extend_from_slice(&struct_def_blob); + blob.extend_from_slice(&value_blob); + blob + } + + fn fuzz(&self, data: &[u8]) { + let _ = deserialize(data); + } +} + +fn deserialize(data: &[u8]) -> Result<()> { + let mut data = Cursor::new(data); + // Read the length of the struct def blob. + let len = data.read_u64::()? as usize; + let position = data.position() as usize; + let data = &data.into_inner()[position..]; + + if data.len() < len { + bail!("too little data"); + } + let struct_def_data = &data[0..len]; + let value_data = &data[len..]; + + // Deserialize now. + let mut deserializer = SimpleDeserializer::new(struct_def_data); + let struct_def = StructDef::deserialize(&mut deserializer)?; + let _ = Value::simple_deserialize(value_data, struct_def); + Ok(()) +} diff --git a/testsuite/libra_fuzzer/src/lib.rs b/testsuite/libra_fuzzer/src/lib.rs new file mode 100644 index 0000000000000..0f0433b12d3f2 --- /dev/null +++ b/testsuite/libra_fuzzer/src/lib.rs @@ -0,0 +1,43 @@ +// Copyright (c) The Libra Core Contributors +// SPDX-License-Identifier: Apache-2.0 + +use proptest::test_runner::TestRunner; +use std::{fmt, ops::Deref, str::FromStr}; + +pub mod commands; +pub mod fuzz_targets; + +/// Implementation for a particular target of a fuzz operation. +pub trait FuzzTargetImpl: Sync + Send + fmt::Debug { + /// The name of the fuzz target. + fn name(&self) -> &'static str; + + /// A description for this target. + fn description(&self) -> &'static str; + + /// Generate a new example for this target to store in the corpus. + fn generate(&self, runner: &mut TestRunner) -> Vec; + + /// Fuzz the target with this data. The fuzzer tests for panics or OOMs with this method. + fn fuzz(&self, data: &[u8]); +} + +/// A fuzz target. +#[derive(Copy, Clone, Debug)] +pub struct FuzzTarget(&'static (dyn FuzzTargetImpl + 'static)); + +impl Deref for FuzzTarget { + type Target = dyn FuzzTargetImpl + 'static; + + fn deref(&self) -> &Self::Target { + self.0 + } +} + +impl FromStr for FuzzTarget { + type Err = String; + + fn from_str(s: &str) -> Result { + FuzzTarget::by_name(s).ok_or_else(|| format!("Fuzz target '{}' not found (run `list`)", s)) + } +} diff --git a/testsuite/libra_fuzzer/src/main.rs b/testsuite/libra_fuzzer/src/main.rs new file mode 100644 index 0000000000000..5b8637990b151 --- /dev/null +++ b/testsuite/libra_fuzzer/src/main.rs @@ -0,0 +1,149 @@ +// Copyright (c) The Libra Core Contributors +// SPDX-License-Identifier: Apache-2.0 + +//! Helpers for fuzz testing. + +use lazy_static::lazy_static; +use libra_fuzzer::{commands, FuzzTarget}; +use std::{env, ffi::OsString, fs, path::PathBuf}; +use structopt::StructOpt; + +#[derive(Debug, StructOpt)] +#[structopt(name = "fuzzer", about = "Libra fuzzer")] +struct Opt { + /// Print extended debug output + #[structopt(long = "debug")] + debug: bool, + #[structopt(subcommand)] + cmd: Command, +} + +/// The default number of items to generate in a corpus. +const GENERATE_DEFAULT_ITEMS: usize = 128; +lazy_static! { + /// A stringified form of `GENERATE_DEFAULT_ITEMS`. + /// + /// Required because structopt only accepts strings as default values. + static ref GENERATE_DEFAULT_ITEMS_STR: String = GENERATE_DEFAULT_ITEMS.to_string(); +} + +#[derive(Debug, StructOpt)] +enum Command { + /// Generate corpus for a particular fuzz target + #[structopt(name = "generate")] + Generate { + /// Number of items to generate in the corpus + #[structopt( + short = "n", + long = "num-items", + raw(default_value = "&GENERATE_DEFAULT_ITEMS_STR") + )] + num_items: usize, + /// Custom directory for corpus output to be stored in (required if not running under + /// `cargo run`) + #[structopt(long = "corpus-dir", parse(from_os_str))] + corpus_dir: Option, + #[structopt(name = "TARGET")] + /// Name of target to generate (use `list` to list) + target: FuzzTarget, + }, + /// Run fuzzer on specified target (must be run under `cargo run`) + #[structopt(name = "fuzz", usage = "fuzzer fuzz -- [ARGS]")] + Fuzz { + /// Target to fuzz (use `list` to list targets) + #[structopt(name = "TARGET", raw(required = "true"))] + target: FuzzTarget, + /// Custom directory for corpus + #[structopt(long = "corpus-dir", parse(from_os_str))] + corpus_dir: Option, + /// Custom directory for artifacts + #[structopt(long = "artifact-dir", parse(from_os_str))] + artifact_dir: Option, + /// Arguments for `cargo fuzz run` + #[structopt(name = "ARGS", parse(from_os_str), raw(allow_hyphen_values = "true"))] + args: Vec, + }, + /// List fuzz targets + #[structopt(name = "list")] + List, +} + +/// The default directory for corpuses. Also return whether the directory was freshly created. +fn default_corpus_dir(target: FuzzTarget) -> (PathBuf, bool) { + default_dir(target, "corpus") +} + +/// The default directory for artifacts. +fn default_artifact_dir(target: FuzzTarget) -> PathBuf { + default_dir(target, "artifacts").0 +} + +fn default_dir(target: FuzzTarget, intermediate_dir: &str) -> (PathBuf, bool) { + let mut dir = PathBuf::from(env::var_os("CARGO_MANIFEST_DIR").expect( + "--corpus-dir not set and this binary is not running under cargo run. \ + Either use cargo run or pass in the --corpus-dir flag.", + )); + // If a "fuzz" subdirectory doesn't exist, the user might be doing it wrong. + dir.push("fuzz"); + if !dir.is_dir() { + panic!( + "Subdirectory {:?} of cargo manifest directory does not exist \ + (did you run `cargo fuzz init`?)", + dir + ); + } + + // The name of the corpus is derived from the name of the target. + dir.push(intermediate_dir); + dir.push(target.name()); + + println!("Using default {} directory: {:?}", intermediate_dir, dir); + let created = !dir.exists(); + fs::create_dir_all(&dir).expect("Failed to create directory"); + (dir, created) +} + +fn main() { + let opt: Opt = Opt::from_args(); + + match opt.cmd { + Command::Generate { + num_items, + corpus_dir, + target, + } => { + let corpus_dir = corpus_dir.unwrap_or_else(|| default_corpus_dir(target).0); + commands::make_corpus(target, num_items, &corpus_dir, opt.debug) + .expect("Failed to create corpus"); + println!("Wrote {} items to corpus", num_items); + } + Command::Fuzz { + corpus_dir, + artifact_dir, + target, + args, + } => { + let corpus_dir = match corpus_dir { + Some(dir) => { + // Don't generate the corpus here -- custom directory means the user knows + // what they're doing. + dir + } + None => { + let (dir, created) = default_corpus_dir(target); + if created { + println!("New corpus, generating..."); + commands::make_corpus(target, GENERATE_DEFAULT_ITEMS, &dir, opt.debug) + .expect("Failed to create corpus"); + } + dir + } + }; + let artifact_dir = artifact_dir.unwrap_or_else(|| default_artifact_dir(target)); + commands::fuzz_target(target, corpus_dir, artifact_dir, args).unwrap(); + } + Command::List => { + commands::list_targets(); + } + } +} diff --git a/testsuite/libra_fuzzer/tests/artifacts.rs b/testsuite/libra_fuzzer/tests/artifacts.rs new file mode 100644 index 0000000000000..2dba65f995bbe --- /dev/null +++ b/testsuite/libra_fuzzer/tests/artifacts.rs @@ -0,0 +1,90 @@ +// Copyright (c) The Libra Core Contributors +// SPDX-License-Identifier: Apache-2.0 + +//! Test artifacts: examples known to have crashed in the past. + +#![feature(custom_test_frameworks)] +#![test_runner(datatest::runner)] + +use libra_fuzzer::FuzzTarget; +use rusty_fork::{fork, rusty_fork_id}; +use stats_alloc::{Region, StatsAlloc, INSTRUMENTED_SYSTEM}; +use std::{alloc::System, env, fs, path::Path}; + +#[global_allocator] +static GLOBAL: &StatsAlloc = &INSTRUMENTED_SYSTEM; + +/// The memory limit for each deserializer, in bytes. +const MEMORY_LIMIT: usize = 256 * 1024 * 1024; + +#[datatest::files("artifacts", { + artifact_path in r"^.*/.*" +})] +#[test] +fn test_artifact(artifact_path: &Path) { + let test_name = test_name(artifact_path); + + if no_fork() { + test_artifact_impl(artifact_path); + } else { + fork( + &test_name, + rusty_fork_id!(), + |_| {}, + |child, _file| { + let status = child.wait().expect("failed to wait for child"); + assert!( + status.success(), + "child exited unsuccessfully with {}", + status + ); + }, + || test_artifact_impl(artifact_path), + ) + .expect("forking test failed"); + } +} + +fn no_fork() -> bool { + match env::var_os("NO_FORK") { + Some(x) => x == "1", + // Fork by default. + None => false, + } +} + +fn test_artifact_impl(artifact_path: &Path) { + // Extract the target from the path -- it's the second component after "artifacts/". + let target_name = artifact_path + .iter() + .nth(1) + .expect("artifact path must be in format 'artifacts//'"); + let target_name = target_name.to_str().expect("target must be valid Unicode"); + let target = FuzzTarget::by_name(target_name) + .unwrap_or_else(|| panic!("unknown fuzz target: {}", target_name)); + let data = fs::read(artifact_path).expect("failed to read artifact"); + + let reg = Region::new(&GLOBAL); + target.fuzz(&data); + let stats = reg.change(); + + eprintln!("stats: {:?}", stats); + assert!( + stats.bytes_allocated <= MEMORY_LIMIT, + "Deserializer used too much memory: allocated {} bytes (max {} bytes)", + stats.bytes_allocated, + MEMORY_LIMIT + ); +} + +fn test_name(artifact_path: &Path) -> String { + // This matches the test name generated by datatest. + let mut test_name = "test_artifact::".to_string(); + + let path = artifact_path + .strip_prefix("artifacts/") + .expect("artifact path doesn't begin with artifacts/"); + let subname = path.to_str().expect("name must be valid unicode"); + test_name.push_str(subname); + test_name +} diff --git a/testsuite/src/lib.rs b/testsuite/src/lib.rs new file mode 100644 index 0000000000000..9b5c474f42bc7 --- /dev/null +++ b/testsuite/src/lib.rs @@ -0,0 +1,4 @@ +// Copyright (c) The Libra Core Contributors +// SPDX-License-Identifier: Apache-2.0 + +// Empty src/lib.rs to get rusty-tags working. diff --git a/testsuite/tests/libratest/main.rs b/testsuite/tests/libratest/main.rs new file mode 100644 index 0000000000000..3fe41ed20db6c --- /dev/null +++ b/testsuite/tests/libratest/main.rs @@ -0,0 +1,4 @@ +// Copyright (c) The Libra Core Contributors +// SPDX-License-Identifier: Apache-2.0 + +mod smoke_test; diff --git a/testsuite/tests/libratest/smoke_test.rs b/testsuite/tests/libratest/smoke_test.rs new file mode 100644 index 0000000000000..399076c5f4a45 --- /dev/null +++ b/testsuite/tests/libratest/smoke_test.rs @@ -0,0 +1,243 @@ +// Copyright (c) The Libra Core Contributors +// SPDX-License-Identifier: Apache-2.0 +#![allow(unused_mut)] +use cli::client_proxy::ClientProxy; +use libra_swarm::swarm::LibraSwarm; +use num_traits::cast::FromPrimitive; +use rust_decimal::Decimal; + +fn setup_swarm_and_client_proxy( + num_nodes: usize, + client_port_index: usize, +) -> (LibraSwarm, ClientProxy) { + ::logger::init_for_e2e_testing(); + + let (faucet_account_keypair, faucet_key_file_path, _temp_dir) = + generate_keypair::load_faucet_key_or_create_default(None); + + let swarm = LibraSwarm::launch_swarm(num_nodes, false, faucet_account_keypair, true); + let port = *swarm + .get_validators_public_ports() + .get(client_port_index) + .unwrap(); + let tmp_mnemonic_file = tempfile::NamedTempFile::new().unwrap(); + let client_proxy = ClientProxy::new( + "localhost", + port.to_string().as_str(), + &swarm.get_trusted_peers_config_path(), + &faucet_key_file_path, + /* faucet server */ None, + Some( + tmp_mnemonic_file + .into_temp_path() + .canonicalize() + .expect("Unable to get canonical path of mnemonic_file_path") + .to_str() + .unwrap() + .to_string(), + ), + ) + .unwrap(); + (swarm, client_proxy) +} + +fn test_smoke_script(mut client_proxy: ClientProxy) { + client_proxy.create_next_account(&["c"]).unwrap(); + client_proxy + .mint_coins(&["mintb", "0", "10"], true) + .unwrap(); + assert_eq!( + Decimal::from_f64(10.0), + Decimal::from_f64(client_proxy.get_balance(&["b", "0"]).unwrap()) + ); + client_proxy.create_next_account(&["c"]).unwrap(); + client_proxy.mint_coins(&["mintb", "1", "1"], true).unwrap(); + client_proxy + .transfer_coins(&["tb", "0", "1", "3"], true) + .unwrap(); + assert_eq!( + Decimal::from_f64(7.0), + Decimal::from_f64(client_proxy.get_balance(&["b", "0"]).unwrap()) + ); + assert_eq!( + Decimal::from_f64(4.0), + Decimal::from_f64(client_proxy.get_balance(&["b", "1"]).unwrap()) + ); + client_proxy.create_next_account(&["c"]).unwrap(); + client_proxy + .mint_coins(&["mintb", "2", "15"], true) + .unwrap(); + assert_eq!( + Decimal::from_f64(15.0), + Decimal::from_f64(client_proxy.get_balance(&["b", "2"]).unwrap()) + ); +} + +#[test] +fn smoke_test_single_node() { + let (_swarm, mut client_proxy) = setup_swarm_and_client_proxy(1, 0); + test_smoke_script(client_proxy); +} + +#[test] +fn smoke_test_multi_node() { + let (_swarm, mut client_proxy) = setup_swarm_and_client_proxy(4, 0); + test_smoke_script(client_proxy); +} + +#[test] +fn test_concurrent_transfers_single_node() { + let (_swarm, mut client_proxy) = setup_swarm_and_client_proxy(1, 0); + client_proxy.create_next_account(&["c"]).unwrap(); + client_proxy + .mint_coins(&["mintb", "0", "100"], true) + .unwrap(); + client_proxy.create_next_account(&["c"]).unwrap(); + for _ in 0..20 { + client_proxy + .transfer_coins(&["t", "0", "1", "1"], false) + .unwrap(); + } + client_proxy + .transfer_coins(&["tb", "0", "1", "1"], true) + .unwrap(); + assert_eq!( + Decimal::from_f64(79.0), + Decimal::from_f64(client_proxy.get_balance(&["b", "0"]).unwrap()) + ); + assert_eq!( + Decimal::from_f64(21.0), + Decimal::from_f64(client_proxy.get_balance(&["b", "1"]).unwrap()) + ); +} + +#[test] +fn test_basic_fault_tolerance() { + // A configuration with 4 validators should tolerate single node failure. + let (mut swarm, mut client_proxy) = setup_swarm_and_client_proxy(4, 1); + let validator_ports = swarm.get_validators_public_ports(); + // kill the first validator + swarm.kill_node(*validator_ports.get(0).unwrap()); + + // run the script for the smoke test by submitting requests to the second validator + test_smoke_script(client_proxy); +} + +#[test] +fn test_basic_restartability() { + let (mut swarm, mut client_proxy) = setup_swarm_and_client_proxy(4, 0); + client_proxy.create_next_account(&["c"]).unwrap(); + client_proxy.create_next_account(&["c"]).unwrap(); + client_proxy.mint_coins(&["mb", "0", "100"], true).unwrap(); + client_proxy + .transfer_coins(&["tb", "0", "1", "10"], true) + .unwrap(); + assert_eq!( + Decimal::from_f64(90.0), + Decimal::from_f64(client_proxy.get_balance(&["b", "0"]).unwrap()) + ); + assert_eq!( + Decimal::from_f64(10.0), + Decimal::from_f64(client_proxy.get_balance(&["b", "1"]).unwrap()) + ); + let port = swarm.get_validators_public_ports()[0]; + // restart node + swarm.kill_node(port); + assert!(swarm.add_node(port, false)); + assert_eq!( + Decimal::from_f64(90.0), + Decimal::from_f64(client_proxy.get_balance(&["b", "0"]).unwrap()) + ); + assert_eq!( + Decimal::from_f64(10.0), + Decimal::from_f64(client_proxy.get_balance(&["b", "1"]).unwrap()) + ); + client_proxy + .transfer_coins(&["tb", "0", "1", "10"], true) + .unwrap(); + assert_eq!( + Decimal::from_f64(80.0), + Decimal::from_f64(client_proxy.get_balance(&["b", "0"]).unwrap()) + ); + assert_eq!( + Decimal::from_f64(20.0), + Decimal::from_f64(client_proxy.get_balance(&["b", "1"]).unwrap()) + ); +} + +#[test] +fn test_basic_state_synchronization() { + // + // - Start a swarm of 5 nodes (3 nodes forming a QC). + // - Kill one node and continue submitting transactions to the others. + // - Restart the node + // - Wait for all the nodes to catch up + // - Verify that the restarted node has synced up with the submitted transactions. + let (mut swarm, mut client_proxy) = setup_swarm_and_client_proxy(5, 1); + client_proxy.create_next_account(&["c"]).unwrap(); + client_proxy.create_next_account(&["c"]).unwrap(); + client_proxy.mint_coins(&["mb", "0", "100"], true).unwrap(); + client_proxy + .transfer_coins(&["tb", "0", "1", "10"], true) + .unwrap(); + assert_eq!( + Decimal::from_f64(90.0), + Decimal::from_f64(client_proxy.get_balance(&["b", "0"]).unwrap()) + ); + assert_eq!( + Decimal::from_f64(10.0), + Decimal::from_f64(client_proxy.get_balance(&["b", "1"]).unwrap()) + ); + let node_to_restart = *swarm.get_validators_public_ports().get(0).unwrap(); + + swarm.kill_node(node_to_restart); + // All these are executed while one node is down + assert_eq!( + Decimal::from_f64(90.0), + Decimal::from_f64(client_proxy.get_balance(&["b", "0"]).unwrap()) + ); + assert_eq!( + Decimal::from_f64(10.0), + Decimal::from_f64(client_proxy.get_balance(&["b", "1"]).unwrap()) + ); + for _ in 0..5 { + client_proxy + .transfer_coins(&["tb", "0", "1", "1"], true) + .unwrap(); + } + + // Reconnect and synchronize the state + assert!(swarm.add_node(node_to_restart, false)); + + // Wait for all the nodes to catch up + swarm.wait_for_all_nodes_to_catchup(); + + // Connect to the newly recovered node and verify its state + let tmp_mnemonic_file = tempfile::NamedTempFile::new().unwrap(); + let mut client_proxy2 = ClientProxy::new( + "localhost", + node_to_restart.to_string().as_str(), + &swarm.get_trusted_peers_config_path(), + "", + /* faucet server */ None, + Some( + tmp_mnemonic_file + .into_temp_path() + .canonicalize() + .expect("Unable to get canonical path of mnemonic_file_path") + .to_str() + .unwrap() + .to_string(), + ), + ) + .unwrap(); + client_proxy2.set_accounts(client_proxy.copy_all_accounts()); + assert_eq!( + Decimal::from_f64(85.0), + Decimal::from_f64(client_proxy2.get_balance(&["b", "0"]).unwrap()) + ); + assert_eq!( + Decimal::from_f64(15.0), + Decimal::from_f64(client_proxy2.get_balance(&["b", "1"]).unwrap()) + ); +} diff --git a/types/Cargo.toml b/types/Cargo.toml new file mode 100644 index 0000000000000..dfdf36e453181 --- /dev/null +++ b/types/Cargo.toml @@ -0,0 +1,31 @@ +[package] +name = "types" +version = "0.1.0" +authors = ["Libra Association "] +license = "Apache-2.0" +publish = false +edition = "2018" + +[dependencies] +bech32 = "0.6.0" +byteorder = "1.3.1" +bytes = "0.4.12" +hex = "0.3.2" +itertools = "0.8.0" +lazy_static = "1.3.0" +proptest = "0.9" +proptest-derive = "0.1.0" +protobuf = "2.6" +radix_trie = "0.1.3" +rand = "0.6.5" +serde = { version = "1.0.89", features = ["derive"] } +serde_json = "1.0.38" +tiny-keccak = "1.4.2" + +canonical_serialization = { path = "../common/canonical_serialization"} +crypto = { path = "../crypto/legacy_crypto" } +failure = { path = "../common/failure_ext", package = "failure_ext" } +proto_conv = { path = "../common/proto_conv", features = ["derive"] } + +[build-dependencies] +build_helpers = { path = "../common/build_helpers" } diff --git a/types/build.rs b/types/build.rs new file mode 100644 index 0000000000000..0095edf448c71 --- /dev/null +++ b/types/build.rs @@ -0,0 +1,17 @@ +// Copyright (c) The Libra Core Contributors +// SPDX-License-Identifier: Apache-2.0 + +//! This compiles all the `.proto` files under `src/` directory. +//! +//! For example, if there is a file `src/a/b/c.proto`, it will generate `src/a/b/c.rs` and +//! `src/a/b/c_grpc.rs`. + +fn main() { + let proto_root = "src/proto"; + + build_helpers::build_helpers::compile_proto( + proto_root, + vec![], /* dependent roots */ + false, /* generate_client_stub */ + ); +} diff --git a/types/src/access_path.rs b/types/src/access_path.rs new file mode 100644 index 0000000000000..9c569a0fc727b --- /dev/null +++ b/types/src/access_path.rs @@ -0,0 +1,393 @@ +// Copyright (c) The Libra Core Contributors +// SPDX-License-Identifier: Apache-2.0 + +//! Suppose we have the following data structure in a smart contract: +//! +//! struct B { +//! Map mymap; +//! } +//! +//! struct A { +//! B b; +//! int my_int; +//! } +//! +//! struct C { +//! List mylist; +//! } +//! +//! A a; +//! C c; +//! +//! and the data belongs to Alice. Then an access to `a.b.mymap` would be translated to an access +//! to an entry in key-value store whose key is `/a/b/mymap`. In the same way, the access to +//! `c.mylist` would need to query `/c/mylist`. +//! +//! So an account stores its data in a directory structure, for example: +//! /balance: 10 +//! /a/b/mymap: {"Bob" => "abcd", "Carol" => "efgh"} +//! /a/myint: 20 +//! /c/mylist: [3, 5, 7, 9] +//! +//! If someone needs to query the map above and find out what value associated with "Bob" is, +//! `address` will be set to Alice and `path` will be set to "/a/b/mymap/Bob". +//! +//! On the other hand, if you want to query only /a/*, `address` will be set to Alice and +//! `path` will be set to "/a" and use the `get_prefix()` method from statedb + +// This is caused by deriving Arbitrary for AccessPath. +#![allow(clippy::unit_arg)] + +use crate::{ + account_address::AccountAddress, + account_config::{ + account_received_event_path, account_resource_path, account_sent_event_path, + association_address, + }, + language_storage::{CodeKey, ResourceKey, StructTag}, + validator_set::validator_set_path, +}; +use canonical_serialization::{ + CanonicalDeserialize, CanonicalDeserializer, CanonicalSerialize, CanonicalSerializer, +}; +use crypto::hash::{CryptoHash, HashValue}; +use failure::prelude::*; +use hex; +use lazy_static::lazy_static; +use proptest_derive::Arbitrary; +use proto_conv::{FromProto, IntoProto}; +use radix_trie::TrieKey; +use serde::{Deserialize, Serialize}; +use std::{ + fmt::{self, Formatter}, + slice::Iter, + str::{self, FromStr}, +}; + +#[derive(Default, Serialize, Deserialize, Debug, PartialEq, Hash, Eq, Clone, Ord, PartialOrd)] +pub struct Field(String); + +impl Field { + pub fn new(s: &str) -> Field { + Field(s.to_string()) + } + + pub fn name(&self) -> &String { + &self.0 + } +} + +impl From for Field { + fn from(s: String) -> Self { + Field(s) + } +} + +impl fmt::Display for Field { + fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { + write!(f, "{}", self.0) + } +} + +#[derive(Eq, Hash, Serialize, Deserialize, Debug, Clone, PartialEq, Ord, PartialOrd)] +pub enum Access { + Field(Field), + Index(u64), +} + +impl Access { + pub fn new(s: &str) -> Self { + Access::Field(Field::new(s)) + } +} + +impl FromStr for Access { + type Err = ::std::num::ParseIntError; + + fn from_str(s: &str) -> ::std::result::Result { + if let Ok(idx) = s.parse::() { + Ok(Access::Index(idx)) + } else { + Ok(Access::Field(Field::new(s))) + } + } +} + +impl fmt::Display for Access { + fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { + match self { + Access::Field(field) => write!(f, "\"{}\"", field), + Access::Index(i) => write!(f, "{}", i), + } + } +} + +/// Non-empty sequence of field accesses +#[derive(Eq, Hash, Serialize, Deserialize, Debug, Clone, PartialEq, Ord, PartialOrd)] +pub struct Accesses(Vec); + +/// SEPARATOR is used as a delimeter between fields. It should not be a legal part of any identifier +/// in the language +const SEPARATOR: char = '/'; + +impl Accesses { + pub fn empty() -> Self { + Accesses(vec![]) + } + + pub fn new(field: Field) -> Self { + Accesses(vec![Access::Field(field)]) + } + + /// Add a field to the end of the sequence + pub fn add_field_to_back(&mut self, field: Field) { + self.0.push(Access::Field(field)) + } + + /// Add an index to the end of the sequence + pub fn add_index_to_back(&mut self, idx: u64) { + self.0.push(Access::Index(idx)) + } + + pub fn append(&mut self, accesses: &mut Accesses) { + self.0.append(&mut accesses.0) + } + + /// Returns the first field in the sequence and reference to the remaining fields + pub fn split_first(&self) -> (&Access, &[Access]) { + self.0.split_first().unwrap() + } + + /// Return the last access in the sequence + pub fn last(&self) -> &Access { + self.0.last().unwrap() // guaranteed not to fail because sequence is non-empty + } + + pub fn iter(&self) -> Iter<'_, Access> { + self.0.iter() + } + + pub fn len(&self) -> usize { + self.0.len() + } + + pub fn is_empty(&self) -> bool { + self.0.is_empty() + } + + pub fn as_separated_string(&self) -> String { + let mut path = String::new(); + for access in self.0.iter() { + match access { + Access::Field(s) => { + let access_str = s.name().as_ref(); + assert!(access_str != ""); + path.push_str(access_str) + } + Access::Index(i) => path.push_str(i.to_string().as_ref()), + }; + path.push(SEPARATOR); + } + path + } + + pub fn take_nth(&self, new_len: usize) -> Accesses { + assert!(self.0.len() >= new_len); + Accesses(self.0.clone().into_iter().take(new_len).collect()) + } +} + +impl<'a> IntoIterator for &'a Accesses { + type Item = &'a Access; + type IntoIter = Iter<'a, Access>; + fn into_iter(self) -> Self::IntoIter { + self.0.iter() + } +} + +impl From> for Accesses { + fn from(accesses: Vec) -> Accesses { + Accesses(accesses) + } +} + +impl From> for Accesses { + fn from(mut raw_bytes: Vec) -> Accesses { + let access_str = String::from_utf8(raw_bytes.split_off(HashValue::LENGTH + 1)).unwrap(); + let fields_str = access_str.split(SEPARATOR).collect::>(); + let mut accesses = vec![]; + for access_str in fields_str.into_iter() { + if access_str != "" { + accesses.push(Access::from_str(access_str).unwrap()); + } + } + Accesses::from(accesses) + } +} + +impl TrieKey for Accesses { + fn encode_bytes(&self) -> Vec { + self.as_separated_string().into_bytes() + } +} + +lazy_static! { + /// The access path where the Validator Set resource is stored. + pub static ref VALIDATOR_SET_ACCESS_PATH: AccessPath = + AccessPath::new(association_address(), validator_set_path()); +} + +#[derive( + Clone, + Eq, + PartialEq, + Default, + Hash, + Serialize, + Deserialize, + Ord, + PartialOrd, + Arbitrary, + FromProto, + IntoProto, +)] +#[ProtoType(crate::proto::access_path::AccessPath)] +pub struct AccessPath { + pub address: AccountAddress, + pub path: Vec, +} + +impl AccessPath { + const CODE_TAG: u8 = 0; + const RESOURCE_TAG: u8 = 1; + + pub fn new(address: AccountAddress, path: Vec) -> Self { + AccessPath { address, path } + } + + /// Given an address, returns the corresponding access path that stores the Account resource. + pub fn new_for_account(address: AccountAddress) -> Self { + Self::new(address, account_resource_path()) + } + + /// Create an AccessPath for a ContractEvent. + /// That is an AccessPah that uniquely identifies a given event for a published resource. + pub fn new_for_event(address: AccountAddress, root: &[u8], key: &[u8]) -> Self { + let mut path: Vec = Vec::new(); + path.extend_from_slice(root); + path.push(b'/'); + path.extend_from_slice(key); + path.push(b'/'); + Self::new(address, path) + } + + /// Create an AccessPath to the event for the sender account in a deposit operation. + /// The sent counter in LibraAccount.T (LibraAccount.T.sent_events_count) is used to generate + /// the AccessPath. + /// That AccessPath can be used as a key into the event storage to retrieve all sent + /// events for a given account. + pub fn new_for_sent_event(address: AccountAddress) -> Self { + Self::new(address, account_sent_event_path()) + } + + /// Create an AccessPath to the event for the target account (the receiver) + /// in a deposit operation. + /// The received counter in LibraAccount.T (LibraAccount.T.received_events_count) is used to + /// generate the AccessPath. + /// That AccessPath can be used as a key into the event storage to retrieve all received + /// events for a given account. + pub fn new_for_received_event(address: AccountAddress) -> Self { + Self::new(address, account_received_event_path()) + } + + pub fn resource_access_vec(tag: &StructTag, accesses: &Accesses) -> Vec { + let mut key = vec![]; + key.push(Self::RESOURCE_TAG); + + key.append(&mut tag.hash().to_vec()); + + // We don't need accesses in production right now. Accesses are appended here just for + // passing the old tests. + key.append(&mut accesses.as_separated_string().into_bytes()); + key + } + + /// Convert Accesses into a byte offset which would be used by the storage layer to resolve + /// where fields are stored. + pub fn resource_access_path(key: &ResourceKey, accesses: &Accesses) -> AccessPath { + let path = AccessPath::resource_access_vec(&key.type_(), accesses); + AccessPath { + address: key.address().to_owned(), + path, + } + } + + fn code_access_path_vec(key: &CodeKey) -> Vec { + let mut root = vec![]; + root.push(Self::CODE_TAG); + root.append(&mut key.hash().to_vec()); + root + } + + pub fn code_access_path(key: &CodeKey) -> AccessPath { + let path = AccessPath::code_access_path_vec(key); + AccessPath { + address: *key.address(), + path, + } + } +} + +impl fmt::Debug for AccessPath { + fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> fmt::Result { + write!( + f, + "AccessPath {{ address: {:x}, path: {} }}", + self.address, + hex::encode(&self.path) + ) + } +} + +impl fmt::Display for AccessPath { + fn fmt(&self, f: &mut Formatter) -> fmt::Result { + if self.path.len() < 1 + HashValue::LENGTH { + write!(f, "{:?}", self) + } else { + write!(f, "AccessPath {{ address: {:x}, ", self.address)?; + match self.path[0] { + Self::RESOURCE_TAG => write!(f, "type: Resource, ")?, + Self::CODE_TAG => write!(f, "type: Module, ")?, + tag => write!(f, "type: {:?}, ", tag)?, + }; + write!( + f, + "hash: {:?}, ", + hex::encode(&self.path[1..=HashValue::LENGTH]) + )?; + write!( + f, + "suffix: {:?} }} ", + String::from_utf8_lossy(&self.path[1 + HashValue::LENGTH..]) + ) + } + } +} + +impl CanonicalSerialize for AccessPath { + fn serialize(&self, serializer: &mut impl CanonicalSerializer) -> Result<()> { + serializer + .encode_struct(&self.address)? + .encode_variable_length_bytes(&self.path)?; + Ok(()) + } +} + +impl CanonicalDeserialize for AccessPath { + fn deserialize(deserializer: &mut impl CanonicalDeserializer) -> Result { + let address = deserializer.decode_struct::()?; + let path = deserializer.decode_variable_length_bytes()?; + + Ok(Self { address, path }) + } +} diff --git a/types/src/account_address.rs b/types/src/account_address.rs new file mode 100644 index 0000000000000..d8684537d3b08 --- /dev/null +++ b/types/src/account_address.rs @@ -0,0 +1,230 @@ +// Copyright (c) The Libra Core Contributors +// SPDX-License-Identifier: Apache-2.0 + +#![allow(clippy::unit_arg)] + +use bech32::{Bech32, FromBase32, ToBase32}; +use bytes::Bytes; +use canonical_serialization::{ + CanonicalDeserialize, CanonicalDeserializer, CanonicalSerialize, CanonicalSerializer, +}; +use crypto::{ + hash::{AccountAddressHasher, CryptoHash, CryptoHasher}, + HashValue, PublicKey, +}; +use failure::prelude::*; +use hex; +use proptest_derive::Arbitrary; +use proto_conv::{FromProto, IntoProto}; +use rand::{rngs::OsRng, Rng}; +use serde::{Deserialize, Serialize}; +use std::{convert::TryFrom, fmt}; +use tiny_keccak::Keccak; + +pub const ADDRESS_LENGTH: usize = 32; + +const SHORT_STRING_LENGTH: usize = 4; + +const LIBRA_NETWORK_ID_SHORT: &str = "lb"; + +/// A struct that represents an account address. +/// Currently Public Key is used. +#[derive( + Arbitrary, Ord, PartialOrd, Eq, PartialEq, Hash, Default, Clone, Serialize, Deserialize, Copy, +)] +pub struct AccountAddress([u8; ADDRESS_LENGTH]); + +impl AccountAddress { + pub fn new(address: [u8; ADDRESS_LENGTH]) -> Self { + AccountAddress(address) + } + + pub fn random() -> Self { + let mut rng = OsRng::new().expect("can't access OsRng"); + let buf: [u8; 32] = rng.gen(); + AccountAddress::new(buf) + } + + // Helpful in log messages + pub fn short_str(&self) -> String { + hex::encode(&self.0[0..SHORT_STRING_LENGTH]).to_string() + } + + pub fn to_vec(&self) -> Vec { + self.0.to_vec() + } +} + +impl CryptoHash for AccountAddress { + type Hasher = AccountAddressHasher; + + fn hash(&self) -> HashValue { + let mut state = Self::Hasher::default(); + state.write(&self.0); + state.finish() + } +} + +impl AsRef<[u8]> for AccountAddress { + fn as_ref(&self) -> &[u8] { + &self.0 + } +} + +impl fmt::Display for AccountAddress { + fn fmt(&self, f: &mut fmt::Formatter) -> std::fmt::Result { + // Forward to the LowerHex impl with a "0x" prepended (the # flag). + write!(f, "{:#x}", self) + } +} + +impl fmt::Debug for AccountAddress { + fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { + // Forward to the LowerHex impl with a "0x" prepended (the # flag). + write!(f, "{:#x}", self) + } +} + +impl fmt::LowerHex for AccountAddress { + fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { + write!(f, "{}", hex::encode(&self.0)) + } +} + +impl TryFrom<&[u8]> for AccountAddress { + type Error = failure::Error; + + /// Tries to convert the provided byte array into Address. + fn try_from(bytes: &[u8]) -> Result { + ensure!( + bytes.len() == ADDRESS_LENGTH, + "The Address {:?} is of invalid length", + bytes + ); + let mut addr = [0u8; ADDRESS_LENGTH]; + addr.copy_from_slice(bytes); + Ok(AccountAddress(addr)) + } +} + +impl TryFrom<&[u8; 32]> for AccountAddress { + type Error = failure::Error; + + /// Tries to convert the provided byte array into Address. + fn try_from(bytes: &[u8; 32]) -> Result { + AccountAddress::try_from(&bytes[..]) + } +} + +impl TryFrom> for AccountAddress { + type Error = failure::Error; + + /// Tries to convert the provided byte buffer into Address. + fn try_from(bytes: Vec) -> Result { + AccountAddress::try_from(&bytes[..]) + } +} + +impl From for Vec { + fn from(addr: AccountAddress) -> Vec { + addr.0.to_vec() + } +} + +impl From<&AccountAddress> for Vec { + fn from(addr: &AccountAddress) -> Vec { + addr.0.to_vec() + } +} + +impl TryFrom for AccountAddress { + type Error = failure::Error; + + fn try_from(bytes: Bytes) -> Result { + AccountAddress::try_from(bytes.as_ref()) + } +} + +impl From for Bytes { + fn from(addr: AccountAddress) -> Bytes { + addr.0.as_ref().into() + } +} + +impl FromProto for AccountAddress { + type ProtoType = Vec; + + fn from_proto(addr: Self::ProtoType) -> Result { + AccountAddress::try_from(&addr[..]) + } +} + +impl IntoProto for AccountAddress { + type ProtoType = Vec; + + fn into_proto(self) -> Self::ProtoType { + self.0.to_vec() + } +} + +impl From for AccountAddress { + fn from(public_key: PublicKey) -> AccountAddress { + // TODO: using keccak directly instead of crypto::hash because we have to make sure we use + // the same hash function that the Move transaction prologue is using. + // TODO: keccak is just a placeholder, make a principled choose for the hash function + let mut keccak = Keccak::new_sha3_256(); + let mut hash = [0u8; ADDRESS_LENGTH]; + keccak.update(&public_key.to_slice()); + keccak.finalize(&mut hash); + AccountAddress::new(hash) + } +} + +impl From<&AccountAddress> for String { + fn from(addr: &AccountAddress) -> String { + ::hex::encode(addr.as_ref()) + } +} + +impl TryFrom for AccountAddress { + type Error = failure::Error; + + fn try_from(s: String) -> Result { + assert!(!s.is_empty()); + let bytes_out = ::hex::decode(s)?; + AccountAddress::try_from(bytes_out.as_slice()) + } +} + +impl TryFrom for AccountAddress { + type Error = failure::Error; + + fn try_from(encoded_input: Bech32) -> Result { + let base32_hash = encoded_input.data(); + let hash = Vec::from_base32(&base32_hash)?; + AccountAddress::try_from(&hash[..]) + } +} + +impl TryFrom for Bech32 { + type Error = failure::Error; + + fn try_from(addr: AccountAddress) -> Result { + let base32_hash = addr.0.to_base32(); + bech32::Bech32::new(LIBRA_NETWORK_ID_SHORT.into(), base32_hash).map_err(Into::into) + } +} + +impl CanonicalSerialize for AccountAddress { + fn serialize(&self, serializer: &mut impl CanonicalSerializer) -> Result<()> { + serializer.encode_variable_length_bytes(&self.0)?; + Ok(()) + } +} + +impl CanonicalDeserialize for AccountAddress { + fn deserialize(deserializer: &mut impl CanonicalDeserializer) -> Result { + let bytes = deserializer.decode_variable_length_bytes()?; + Self::try_from(bytes) + } +} diff --git a/types/src/account_config.rs b/types/src/account_config.rs new file mode 100644 index 0000000000000..3a96b8ec34db2 --- /dev/null +++ b/types/src/account_config.rs @@ -0,0 +1,239 @@ +// Copyright (c) The Libra Core Contributors +// SPDX-License-Identifier: Apache-2.0 + +use crate::{ + access_path::{AccessPath, Accesses}, + account_address::AccountAddress, + account_state_blob::AccountStateBlob, + byte_array::ByteArray, + language_storage::StructTag, +}; +use canonical_serialization::{ + CanonicalDeserialize, CanonicalDeserializer, CanonicalSerialize, CanonicalSerializer, + SimpleDeserializer, +}; +use failure::prelude::*; +use std::{collections::BTreeMap, convert::TryInto}; + +/// An account object. This is the top-level entry in global storage. We'll never need to create an +/// `Account` struct, but if we did, it would look something like +/// pub struct Account { +/// // Address holding this account +/// address: Address, +/// // Struct types defined by this account +/// code: HashMap, +/// // Resurces owned by this account +/// resoruces: HashMap, +/// } + +// LibraCoin +pub const COIN_MODULE_NAME: &str = "LibraCoin"; +pub const COIN_STRUCT_NAME: &str = "T"; + +// Account +pub const ACCOUNT_MODULE_NAME: &str = "LibraAccount"; +pub const ACCOUNT_STRUCT_NAME: &str = "T"; + +// Hash +pub const HASH_MODULE_NAME: &str = "Hash"; + +pub fn core_code_address() -> AccountAddress { + AccountAddress::default() +} +pub fn association_address() -> AccountAddress { + AccountAddress::default() +} + +pub fn coin_struct_tag() -> StructTag { + StructTag { + module: COIN_MODULE_NAME.to_string(), + name: COIN_STRUCT_NAME.to_string(), + address: core_code_address(), + type_params: vec![], + } +} + +pub fn account_struct_tag() -> StructTag { + StructTag { + module: ACCOUNT_MODULE_NAME.to_string(), + name: ACCOUNT_STRUCT_NAME.to_string(), + address: core_code_address(), + type_params: vec![], + } +} + +/// A Rust representation of an Account resource. +/// This is not how the Account is represented in the VM but it's a convenient representation. +#[derive(Debug, Default)] +pub struct AccountResource { + balance: u64, + sequence_number: u64, + authentication_key: ByteArray, + sent_events_count: u64, + received_events_count: u64, +} + +impl AccountResource { + /// Constructs an Account resource. + pub fn new( + balance: u64, + sequence_number: u64, + authentication_key: ByteArray, + sent_events_count: u64, + received_events_count: u64, + ) -> Self { + AccountResource { + balance, + sequence_number, + authentication_key, + sent_events_count, + received_events_count, + } + } + + /// Given an account map (typically from storage) retrieves the Account resource associated. + pub fn make_from(account_map: &BTreeMap, Vec>) -> Result { + let ap = account_resource_path(); + match account_map.get(&ap) { + Some(bytes) => SimpleDeserializer::deserialize(bytes), + None => bail!("No data for {:?}", ap), + } + } + + /// Return the sequence_number field for the given AccountResource + pub fn sequence_number(&self) -> u64 { + self.sequence_number + } + + /// Return the balance field for the given AccountResource + pub fn balance(&self) -> u64 { + self.balance + } + + /// Return the authentication_key field for the given AccountResource + pub fn authentication_key(&self) -> &ByteArray { + &self.authentication_key + } + + /// Return the sent_events_count field for the given AccountResource + pub fn sent_events_count(&self) -> u64 { + self.sent_events_count + } + + /// Return the received_events_count field for the given AccountResource + pub fn received_events_count(&self) -> u64 { + self.received_events_count + } +} + +impl CanonicalSerialize for AccountResource { + fn serialize(&self, serializer: &mut impl CanonicalSerializer) -> Result<()> { + // TODO(drussi): the order in which these fields are serialized depends on some + // implementation details in the VM. + serializer + .encode_struct(&self.authentication_key)? + .encode_u64(self.balance)? + .encode_u64(self.received_events_count)? + .encode_u64(self.sent_events_count)? + .encode_u64(self.sequence_number)?; + Ok(()) + } +} + +impl CanonicalDeserialize for AccountResource { + fn deserialize(deserializer: &mut impl CanonicalDeserializer) -> Result { + let authentication_key = deserializer.decode_struct()?; + let balance = deserializer.decode_u64()?; + let received_events_count = deserializer.decode_u64()?; + let sent_events_count = deserializer.decode_u64()?; + let sequence_number = deserializer.decode_u64()?; + + Ok(AccountResource { + balance, + sequence_number, + authentication_key, + sent_events_count, + received_events_count, + }) + } +} + +pub fn get_account_resource_or_default( + account_state: &Option, +) -> Result { + match account_state { + Some(blob) => { + let account_btree = blob.try_into()?; + AccountResource::make_from(&account_btree) + } + None => Ok(AccountResource::default()), + } +} + +/// Return the path to the Account resource. It can be used to create an AccessPath for an +/// Account resource. +pub fn account_resource_path() -> Vec { + AccessPath::resource_access_vec( + &StructTag { + address: core_code_address(), + module: ACCOUNT_MODULE_NAME.to_string(), + name: ACCOUNT_STRUCT_NAME.to_string(), + type_params: vec![], + }, + &Accesses::empty(), + ) +} + +/// Return the path to the sent event counter for an Account resource. +/// It can be used to query the event DB for the given event. +pub fn account_sent_event_path() -> Vec { + let mut path = account_resource_path(); + path.push(b'/'); + path.extend_from_slice(b"sent_events_count"); + path.push(b'/'); + path +} + +/// Return the path to the received event counter for an Account resource. +/// It can be used to query the event DB for the given event. +pub fn account_received_event_path() -> Vec { + let mut path = account_resource_path(); + path.push(b'/'); + path.extend_from_slice(b"received_events_count"); + path.push(b'/'); + path +} + +/// Generic struct that represents an Account event. +/// Both SentPaymentEvent and ReceivedPaymentEvent are representable with this struct. +/// They have an AccountAddress for the sender or receiver and the amount transferred. +#[derive(Debug, Default)] +pub struct AccountEvent { + account: AccountAddress, + amount: u64, +} + +impl AccountEvent { + /// Get the account related to the event + pub fn account(&self) -> AccountAddress { + self.account + } + + /// Get the amount sent or received + pub fn amount(&self) -> u64 { + self.amount + } +} + +impl CanonicalDeserialize for AccountEvent { + fn deserialize(deserializer: &mut impl CanonicalDeserializer) -> Result { + // TODO: this is a horrible hack and we need to come up with a proper separation of + // data/code so that we don't need the entire VM to read an Account event. + // Also we cannot depend on the VM here as we would have a circular dependency and + // it's not clear if this API should live in the VM or in types + let amount = deserializer.decode_u64()?; + let account = deserializer.decode_struct()?; + + Ok(AccountEvent { account, amount }) + } +} diff --git a/types/src/account_state_blob/account_state_blob_test.rs b/types/src/account_state_blob/account_state_blob_test.rs new file mode 100644 index 0000000000000..9482869311ce8 --- /dev/null +++ b/types/src/account_state_blob/account_state_blob_test.rs @@ -0,0 +1,29 @@ +// Copyright (c) The Libra Core Contributors +// SPDX-License-Identifier: Apache-2.0 + +use super::*; +use proptest::{collection::vec, prelude::*}; +use proto_conv::test_helper::assert_protobuf_encode_decode; + +fn hash_blob(blob: &[u8]) -> HashValue { + let mut hasher = AccountStateBlobHasher::default(); + hasher.write(blob); + hasher.finish() +} + +proptest! { + #[test] + fn account_state_blob_roundtrip(account_state_blob in any::()) { + assert_protobuf_encode_decode(&account_state_blob); + } + + #[test] + fn account_state_blob_hash(blob in vec(any::(), 1..100)) { + prop_assert_eq!(hash_blob(&blob), AccountStateBlob::from(blob.clone()).hash()); + } + + #[test] + fn account_state_with_proof(account_state_with_proof in any::()) { + assert_protobuf_encode_decode(&account_state_with_proof); + } +} diff --git a/types/src/account_state_blob/mod.rs b/types/src/account_state_blob/mod.rs new file mode 100644 index 0000000000000..ade34904b34b5 --- /dev/null +++ b/types/src/account_state_blob/mod.rs @@ -0,0 +1,171 @@ +// Copyright (c) The Libra Core Contributors +// SPDX-License-Identifier: Apache-2.0 + +#![allow(clippy::unit_arg)] + +use crate::{ + account_address::AccountAddress, + account_config::get_account_resource_or_default, + ledger_info::LedgerInfo, + proof::{verify_account_state, AccountStateProof}, + transaction::Version, +}; +use canonical_serialization::{SimpleDeserializer, SimpleSerializer}; +use crypto::{ + hash::{AccountStateBlobHasher, CryptoHash, CryptoHasher}, + HashValue, +}; +use failure::prelude::*; +use proptest_derive::Arbitrary; +use proto_conv::{FromProto, IntoProto}; +use std::{collections::BTreeMap, convert::TryFrom, fmt}; + +#[derive(Arbitrary, Clone, Eq, PartialEq, FromProto, IntoProto)] +#[ProtoType(crate::proto::account_state_blob::AccountStateBlob)] +pub struct AccountStateBlob { + blob: Vec, +} + +impl fmt::Debug for AccountStateBlob { + fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result { + write!( + f, + "AccountStateBlob {{ \n \ + Raw: 0x{} \n \ + Decoded: {:#?} \n \ + }}", + hex::encode(&self.blob), + get_account_resource_or_default(&Some(self.clone())) + ) + } +} + +impl AsRef<[u8]> for AccountStateBlob { + fn as_ref(&self) -> &[u8] { + &self.blob + } +} + +impl From for Vec { + fn from(account_state_blob: AccountStateBlob) -> Vec { + account_state_blob.blob + } +} + +impl From> for AccountStateBlob { + fn from(blob: Vec) -> AccountStateBlob { + AccountStateBlob { blob } + } +} + +impl TryFrom<&BTreeMap, Vec>> for AccountStateBlob { + type Error = failure::Error; + + fn try_from(map: &BTreeMap, Vec>) -> Result { + Ok(Self { + blob: SimpleSerializer::serialize(map)?, + }) + } +} + +impl TryFrom<&AccountStateBlob> for BTreeMap, Vec> { + type Error = failure::Error; + + fn try_from(account_state_blob: &AccountStateBlob) -> Result { + SimpleDeserializer::deserialize(&account_state_blob.blob) + } +} + +impl CryptoHash for AccountStateBlob { + type Hasher = AccountStateBlobHasher; + + fn hash(&self) -> HashValue { + let mut hasher = Self::Hasher::default(); + hasher.write(&self.blob); + hasher.finish() + } +} + +#[derive(Arbitrary, Clone, Debug, Eq, PartialEq)] +pub struct AccountStateWithProof { + /// The transaction version at which this account state is seen. + pub version: Version, + /// Blob value representing the account state. If this field is not set, it + /// means the account does not exist. + pub blob: Option, + /// The proof the client can use to authenticate the value. + pub proof: AccountStateProof, +} + +impl AccountStateWithProof { + /// Constructor. + pub fn new(version: Version, blob: Option, proof: AccountStateProof) -> Self { + Self { + version, + blob, + proof, + } + } + + /// Verifies the the account state blob with the proof, both carried by `self`. + /// + /// Two things are ensured if no error is raised: + /// 1. This account state exists in the ledger represented by `ledger_info`. + /// 2. It belongs to account of `address` and is seen at the time the transaction at version + /// `state_version` is just committed. To make sure this is the latest state, pass in + /// `ledger_info.version()` as `state_version`. + pub fn verify( + &self, + ledger_info: &LedgerInfo, + version: Version, + address: AccountAddress, + ) -> Result<()> { + ensure!( + self.version == version, + "State version ({}) is not expected ({}).", + self.version, + version, + ); + + verify_account_state( + ledger_info, + version, + address.hash(), + &self.blob, + &self.proof, + ) + } +} + +impl FromProto for AccountStateWithProof { + type ProtoType = crate::proto::account_state_blob::AccountStateWithProof; + + fn from_proto(mut object: Self::ProtoType) -> Result { + Ok(AccountStateWithProof { + version: object.get_version(), + blob: object + .blob + .take() + .map(AccountStateBlob::from_proto) + .transpose()?, + proof: AccountStateProof::from_proto(object.take_proof())?, + }) + } +} + +impl IntoProto for AccountStateWithProof { + type ProtoType = crate::proto::account_state_blob::AccountStateWithProof; + + fn into_proto(self) -> Self::ProtoType { + let mut out = Self::ProtoType::new(); + out.set_version(self.version); + if let Some(blob) = self.blob { + out.set_blob(blob.into_proto()); + } + out.set_proof(self.proof.into_proto()); + out + } +} + +#[cfg(test)] +mod account_state_blob_test; diff --git a/types/src/byte_array.rs b/types/src/byte_array.rs new file mode 100644 index 0000000000000..98a9bcbc3af47 --- /dev/null +++ b/types/src/byte_array.rs @@ -0,0 +1,77 @@ +// Copyright (c) The Libra Core Contributors +// SPDX-License-Identifier: Apache-2.0 + +use canonical_serialization::{ + CanonicalDeserialize, CanonicalDeserializer, CanonicalSerialize, CanonicalSerializer, +}; +use failure::Result; +use hex; +use serde::{Deserialize, Serialize}; + +#[derive(Ord, PartialOrd, Eq, PartialEq, Hash, Default, Clone, Serialize, Deserialize)] +/// A struct that represents a ByteArray in Move. +pub struct ByteArray(Vec); + +impl ByteArray { + pub fn as_bytes(&self) -> &[u8] { + &self.0 + } + + pub fn new(buf: Vec) -> Self { + ByteArray(buf) + } + + pub fn len(&self) -> usize { + self.0.len() + } + + pub fn is_empty(&self) -> bool { + self.len() == 0 + } +} + +impl std::fmt::Debug for ByteArray { + fn fmt(&self, f: &mut std::fmt::Formatter) -> std::fmt::Result { + write!(f, "0x{}", hex::encode(&self.0)) + } +} + +impl std::fmt::Display for ByteArray { + fn fmt(&self, f: &mut std::fmt::Formatter) -> std::fmt::Result { + write!(f, "b\"{}\"", hex::encode(&self.0)) + } +} + +impl std::ops::Index for ByteArray { + type Output = u8; + + #[inline] + fn index(&self, index: usize) -> &Self::Output { + std::ops::Index::index(&*self.0, index) + } +} + +impl CanonicalSerialize for ByteArray { + fn serialize(&self, serializer: &mut impl CanonicalSerializer) -> Result<()> { + serializer.encode_variable_length_bytes(&self.0)?; + Ok(()) + } +} + +impl CanonicalDeserialize for ByteArray { + fn deserialize(deserializer: &mut impl CanonicalDeserializer) -> Result { + let bytes = deserializer.decode_variable_length_bytes()?; + Ok(ByteArray(bytes)) + } +} + +/* TODO: Once we implement char as byte, then we can allow for Range Slicing of ByteArrays +impl std::ops::Index> for ByteArray { + type Output = [u8]; + + #[inline] + fn index(&self, index: std::ops::RangeToInclusive) -> &Self::Output { + std::ops::Index::index(&*self.0, index) + } +} +*/ diff --git a/types/src/contract_event.rs b/types/src/contract_event.rs new file mode 100644 index 0000000000000..61d7aa67c610e --- /dev/null +++ b/types/src/contract_event.rs @@ -0,0 +1,191 @@ +// Copyright (c) The Libra Core Contributors +// SPDX-License-Identifier: Apache-2.0 + +#![allow(clippy::unit_arg)] +use crate::{ + access_path::AccessPath, + account_config::AccountEvent, + ledger_info::LedgerInfo, + proof::{verify_event, EventProof}, + transaction::Version, +}; +use canonical_serialization::{ + CanonicalSerialize, CanonicalSerializer, SimpleDeserializer, SimpleSerializer, +}; +use crypto::{ + hash::{ContractEventHasher, CryptoHash, CryptoHasher}, + HashValue, +}; +use failure::prelude::*; +use proptest_derive::Arbitrary; +use proto_conv::{FromProto, IntoProto}; + +/// Entry produced via a call to the `emit_event` builtin. +#[derive(Clone, Default, Eq, PartialEq, FromProto, IntoProto)] +#[ProtoType(crate::proto::events::Event)] +pub struct ContractEvent { + /// The path that the event was emitted to + access_path: AccessPath, + /// The number of messages that have been emitted to the path previously + sequence_number: u64, + /// The data payload of the event + event_data: Vec, +} + +impl ContractEvent { + pub fn new(access_path: AccessPath, sequence_number: u64, event_data: Vec) -> Self { + ContractEvent { + access_path, + sequence_number, + event_data, + } + } + + pub fn access_path(&self) -> &AccessPath { + &self.access_path + } + + pub fn sequence_number(&self) -> u64 { + self.sequence_number + } + + pub fn event_data(&self) -> &[u8] { + &self.event_data + } +} + +impl std::fmt::Debug for ContractEvent { + fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + write!( + f, + "ContractEvent {{ access_path: {:?}, index: {:?}, event_data: {:?} }}", + self.access_path, + self.sequence_number, + hex::encode(&self.event_data) + ) + } +} + +impl std::fmt::Display for ContractEvent { + fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + if let Ok(payload) = SimpleDeserializer::deserialize::(&self.event_data) { + write!( + f, + "ContractEvent {{ access_path: {}, index: {:?}, event_data: {:?} }}", + self.access_path, self.sequence_number, payload, + ) + } else { + write!(f, "{:?}", self) + } + } +} + +impl CanonicalSerialize for ContractEvent { + fn serialize(&self, serializer: &mut impl CanonicalSerializer) -> Result<()> { + serializer + .encode_struct(&self.access_path)? + .encode_u64(self.sequence_number)? + .encode_variable_length_bytes(&self.event_data)?; + Ok(()) + } +} + +impl CryptoHash for ContractEvent { + type Hasher = ContractEventHasher; + + fn hash(&self) -> HashValue { + let mut state = Self::Hasher::default(); + state.write(&SimpleSerializer::>::serialize(self).expect("Failed to serialize.")); + state.finish() + } +} + +#[derive(Arbitrary, Clone, Debug, Eq, PartialEq, FromProto, IntoProto)] +#[ProtoType(crate::proto::events::EventWithProof)] +pub struct EventWithProof { + pub transaction_version: u64, // Should be `Version`, but FromProto derive won't work that way. + pub event_index: u64, + pub event: ContractEvent, + pub proof: EventProof, +} + +impl std::fmt::Display for EventWithProof { + fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + write!( + f, + "EventWithProof {{ \n\ttransaction_version: {}, \n\tevent_index: {}, \ + \n\tevent: {}, \n\tproof: {:?} \n}}", + self.transaction_version, self.event_index, self.event, self.proof + ) + } +} + +impl EventWithProof { + /// Constructor. + pub fn new( + transaction_version: Version, + event_index: u64, + event: ContractEvent, + proof: EventProof, + ) -> Self { + Self { + transaction_version, + event_index, + event, + proof, + } + } + + /// Verifies the event with the proof, both carried by `self`. + /// + /// Two things are ensured if no error is raised: + /// 1. This event exists in the ledger represented by `ledger_info`. + /// 2. And this event has the same `access_path`, `sequence_number`, `transaction_version`, + /// and `event_index` as indicated in the parameter list. If any of these parameter is unknown + /// to the call site and is supposed to be informed by this struct, get it from the struct + /// itself, such as: `event_with_proof.event.access_path()`, `event_with_proof.event_index()`, + /// etc. + pub fn verify( + &self, + ledger_info: &LedgerInfo, + access_path: &AccessPath, + sequence_number: u64, + transaction_version: Version, + event_index: u64, + ) -> Result<()> { + ensure!( + self.event.access_path() == access_path, + "Access path ({}) not expected ({}).", + self.event.access_path(), + *access_path, + ); + ensure!( + self.event.sequence_number == sequence_number, + "Sequence number ({}) not expected ({}).", + self.event.sequence_number(), + sequence_number, + ); + ensure!( + self.transaction_version == transaction_version, + "Transaction version ({}) not expected ({}).", + self.transaction_version, + transaction_version, + ); + ensure!( + self.event_index == event_index, + "Event index ({}) not expected ({}).", + self.event_index, + event_index, + ); + + verify_event( + ledger_info, + self.event.hash(), + transaction_version, + event_index, + &self.proof, + )?; + + Ok(()) + } +} diff --git a/types/src/get_with_proof.rs b/types/src/get_with_proof.rs new file mode 100644 index 0000000000000..7899ce1073de0 --- /dev/null +++ b/types/src/get_with_proof.rs @@ -0,0 +1,673 @@ +// Copyright (c) The Libra Core Contributors +// SPDX-License-Identifier: Apache-2.0 + +#![allow(clippy::unit_arg)] + +use crate::{ + access_path::AccessPath, + account_address::AccountAddress, + account_config::{ + account_received_event_path, account_sent_event_path, get_account_resource_or_default, + }, + account_state_blob::{AccountStateBlob, AccountStateWithProof}, + contract_event::EventWithProof, + ledger_info::{LedgerInfo, LedgerInfoWithSignatures}, + proto::get_with_proof::{ + GetAccountStateRequest, GetAccountStateResponse, + GetAccountTransactionBySequenceNumberRequest, + GetAccountTransactionBySequenceNumberResponse, GetEventsByEventAccessPathRequest, + GetEventsByEventAccessPathResponse, GetTransactionsRequest, GetTransactionsResponse, + }, + transaction::{SignedTransactionWithProof, TransactionListWithProof, Version}, + validator_change::ValidatorChangeEventWithProof, + validator_verifier::ValidatorVerifier, +}; +use crypto::hash::CryptoHash; +use failure::prelude::*; +use proptest_derive::Arbitrary; +use proto_conv::{FromProto, IntoProto}; +use std::{cmp, mem, sync::Arc}; + +#[derive(Arbitrary, Clone, Debug, Eq, PartialEq, FromProto, IntoProto)] +#[ProtoType(crate::proto::get_with_proof::UpdateToLatestLedgerRequest)] +pub struct UpdateToLatestLedgerRequest { + pub client_known_version: u64, + pub requested_items: Vec, +} + +impl UpdateToLatestLedgerRequest { + pub fn new(client_known_version: u64, requested_items: Vec) -> Self { + UpdateToLatestLedgerRequest { + client_known_version, + requested_items, + } + } +} + +#[derive(Clone, Debug, Eq, PartialEq, FromProto, IntoProto)] +#[ProtoType(crate::proto::get_with_proof::UpdateToLatestLedgerResponse)] +pub struct UpdateToLatestLedgerResponse { + pub response_items: Vec, + pub ledger_info_with_sigs: LedgerInfoWithSignatures, + pub validator_change_events: Vec, +} + +impl UpdateToLatestLedgerResponse { + /// Constructor. + pub fn new( + response_items: Vec, + ledger_info_with_sigs: LedgerInfoWithSignatures, + validator_change_events: Vec, + ) -> Self { + UpdateToLatestLedgerResponse { + response_items, + ledger_info_with_sigs, + validator_change_events, + } + } + + /// Verifies that the response has items corresponding to request items and each of them are + /// supported by proof it carries and is what the request item asks for. + /// + /// After calling this one can trust the info in the response items without further + /// verification. + pub fn verify( + &self, + validator_verifier: Arc, + request: &UpdateToLatestLedgerRequest, + ) -> Result<()> { + verify_update_to_latest_ledger_response( + validator_verifier, + request.client_known_version, + &request.requested_items, + &self.response_items, + &self.ledger_info_with_sigs, + ) + } +} + +/// Verifies content of an [`UpdateToLatestLedgerResponse`] against the proofs it +/// carries and the content of the corresponding [`UpdateToLatestLedgerRequest`] +pub fn verify_update_to_latest_ledger_response( + validator_verifier: Arc, + req_client_known_version: u64, + req_request_items: &[RequestItem], + response_items: &[ResponseItem], + ledger_info_with_sigs: &LedgerInfoWithSignatures, +) -> Result<()> { + let (ledger_info, signatures) = ( + ledger_info_with_sigs.ledger_info(), + ledger_info_with_sigs.signatures(), + ); + + // Verify that the same or a newer ledger info is returned. + ensure!( + ledger_info.version() >= req_client_known_version, + "Got stale ledger_info with version {}, known version: {}.", + ledger_info.version(), + req_client_known_version, + ); + + // Verify ledger info signatures. + if !(ledger_info.version() == 0 && signatures.is_empty()) { + validator_verifier.verify_aggregated_signature(ledger_info.hash(), signatures)?; + } + + // Verify each sub response. + ensure!( + req_request_items.len() == response_items.len(), + "Number of request items ({}) does not match that of response items ({}).", + req_request_items.len(), + response_items.len(), + ); + itertools::zip_eq(req_request_items, response_items) + .map(|(req, res)| verify_response_item(ledger_info, req, res)) + .collect::>>()?; + + Ok(()) +} + +fn verify_response_item( + ledger_info: &LedgerInfo, + req: &RequestItem, + res: &ResponseItem, +) -> Result<()> { + match (req, res) { + // GetAccountState + ( + RequestItem::GetAccountState { address }, + ResponseItem::GetAccountState { + account_state_with_proof, + }, + ) => account_state_with_proof.verify(ledger_info, ledger_info.version(), *address), + // GetAccountTransactionBySequenceNumber + ( + RequestItem::GetAccountTransactionBySequenceNumber { + account, + sequence_number, + fetch_events, + }, + ResponseItem::GetAccountTransactionBySequenceNumber { + signed_transaction_with_proof, + proof_of_current_sequence_number, + }, + ) => verify_get_txn_by_seq_num_resp( + ledger_info, + *account, + *sequence_number, + *fetch_events, + signed_transaction_with_proof.as_ref(), + proof_of_current_sequence_number.as_ref(), + ), + // GetEventsByEventAccessPath + ( + RequestItem::GetEventsByEventAccessPath { + access_path, + start_event_seq_num, + ascending, + limit, + }, + ResponseItem::GetEventsByEventAccessPath { + events_with_proof, + proof_of_latest_event, + }, + ) => verify_get_events_by_access_path_resp( + ledger_info, + access_path, + *start_event_seq_num, + *ascending, + *limit, + events_with_proof, + proof_of_latest_event.as_ref(), + ), + // GetTransactions + ( + RequestItem::GetTransactions { + start_version, + limit, + fetch_events, + }, + ResponseItem::GetTransactions { + txn_list_with_proof, + }, + ) => verify_get_txns_resp( + ledger_info, + *start_version, + *limit, + *fetch_events, + txn_list_with_proof, + ), + // Request-response item types mismatch. + _ => bail!( + "ResquestItem/ResponseItem types mismatch. request: {:?}, response: {:?}", + mem::discriminant(req), + mem::discriminant(res), + ), + } +} + +fn verify_get_txn_by_seq_num_resp( + ledger_info: &LedgerInfo, + req_account: AccountAddress, + req_sequence_number: u64, + req_fetch_events: bool, + signed_transaction_with_proof: Option<&SignedTransactionWithProof>, + proof_of_current_sequence_number: Option<&AccountStateWithProof>, +) -> Result<()> { + match (signed_transaction_with_proof, proof_of_current_sequence_number) { + (Some(signed_transaction_with_proof), None) => { + ensure!( + req_fetch_events == signed_transaction_with_proof.events.is_some(), + "Bad GetAccountTxnBySeqNum response. Events requested: {}, events returned: {}.", + req_fetch_events, + signed_transaction_with_proof.events.is_some(), + ); + signed_transaction_with_proof.verify( + ledger_info, + signed_transaction_with_proof.version, + req_account, + req_sequence_number, + ) + }, + (None, Some(proof_of_current_sequence_number)) => { + let sequence_number_in_ledger = + get_account_resource_or_default(&proof_of_current_sequence_number.blob)? + .sequence_number(); + ensure!( + sequence_number_in_ledger <= req_sequence_number, + "Server returned no transactions while it should. Seq num requested: {}, latest seq num in ledger: {}.", + req_sequence_number, + sequence_number_in_ledger + ); + proof_of_current_sequence_number.verify(ledger_info, ledger_info.version(), req_account) + }, + _ => bail!( + "Bad GetAccountTxnBySeqNum response. txn_proof.is_none():{}, cur_seq_num_proof.is_none():{}", + signed_transaction_with_proof.is_none(), + proof_of_current_sequence_number.is_none(), + ) + } +} + +fn verify_get_events_by_access_path_resp( + ledger_info: &LedgerInfo, + req_access_path: &AccessPath, + req_start_seq_num: u64, + req_ascending: bool, + req_limit: u64, + events_with_proof: &[EventWithProof], + proof_of_latest_event: Option<&AccountStateWithProof>, +) -> Result<()> { + let seq_num_upper_bound = match proof_of_latest_event { + Some(proof) => { + proof.verify(ledger_info, ledger_info.version(), req_access_path.address)?; + get_next_event_seq_num(&proof.blob, &req_access_path)? + } + None => u64::max_value(), + }; + let cursor = + if !req_ascending && req_start_seq_num == u64::max_value() && seq_num_upper_bound > 0 { + seq_num_upper_bound - 1 + } else { + req_start_seq_num + }; + + let expected_seq_nums = if cursor >= seq_num_upper_bound { + // Unreachable, so empty. + Vec::new() + } else if req_ascending { + // Ascending, from start to upper bound or limit. + (cursor..cmp::min(cursor + req_limit, seq_num_upper_bound)).collect() + } else if cursor + 1 < req_limit { + // Descending and hitting 0. + (0..=cursor).rev().collect() + } else { + // Descending and hitting limit. + (cursor + 1 - req_limit..=cursor).rev().collect() + }; + + ensure!( + expected_seq_nums.len() == events_with_proof.len(), + "Expecting {} events, got {}.", + expected_seq_nums.len(), + events_with_proof.len(), + ); + itertools::zip_eq(events_with_proof, expected_seq_nums) + .map(|(e, seq_num)| { + e.verify( + ledger_info, + req_access_path, + seq_num, + e.transaction_version, + e.event_index, + ) + }) + .collect::>>()?; + + Ok(()) +} + +fn get_next_event_seq_num( + account_state_blob: &Option, + access_path: &AccessPath, +) -> Result { + let account_blob = get_account_resource_or_default(account_state_blob)?; + if account_received_event_path() == access_path.path { + Ok(account_blob.received_events_count()) + } else if account_sent_event_path() == access_path.path { + Ok(account_blob.sent_events_count()) + } else { + bail!("Unrecognized access path: {}", access_path); + } +} + +fn verify_get_txns_resp( + ledger_info: &LedgerInfo, + req_start_version: Version, + req_limit: u64, + req_fetch_events: bool, + txn_list_with_proof: &TransactionListWithProof, +) -> Result<()> { + ensure!( + req_fetch_events == txn_list_with_proof.events.is_some(), + "Bad GetTransactions response. Events requested: {}, events returned: {}.", + req_fetch_events, + txn_list_with_proof.events.is_some(), + ); + + if req_limit == 0 || req_start_version > ledger_info.version() { + txn_list_with_proof.verify(ledger_info, None) + } else { + let num_txns = txn_list_with_proof.transaction_and_infos.len(); + ensure!( + cmp::min(req_limit, ledger_info.version() - req_start_version + 1) + == txn_list_with_proof.transaction_and_infos.len() as u64, + "Number of transactions returned not expected. num_txns: {}, start version: {}, latest version: {}", + num_txns, + req_start_version, + ledger_info.version(), + ); + txn_list_with_proof.verify(ledger_info, Some(req_start_version)) + } +} + +#[derive(Arbitrary, Clone, Debug, Eq, PartialEq)] +pub enum RequestItem { + GetAccountTransactionBySequenceNumber { + account: AccountAddress, + sequence_number: u64, + fetch_events: bool, + }, + // this can't be the first variant, tracked here https://github.com/AltSysrq/proptest/issues/141 + GetAccountState { + address: AccountAddress, + }, + GetEventsByEventAccessPath { + access_path: AccessPath, + start_event_seq_num: u64, + ascending: bool, + limit: u64, + }, + GetTransactions { + start_version: Version, + limit: u64, + fetch_events: bool, + }, +} + +impl FromProto for RequestItem { + type ProtoType = crate::proto::get_with_proof::RequestItem; + + fn from_proto(mut object: Self::ProtoType) -> Result { + Ok(if object.has_get_account_state_request() { + let address = + AccountAddress::from_proto(object.take_get_account_state_request().take_address())?; + RequestItem::GetAccountState { address } + } else if object.has_get_account_transaction_by_sequence_number_request() { + let mut req = object.take_get_account_transaction_by_sequence_number_request(); + let account = AccountAddress::from_proto(req.take_account())?; + let sequence_number = req.get_sequence_number(); + let fetch_events = req.get_fetch_events(); + + RequestItem::GetAccountTransactionBySequenceNumber { + account, + sequence_number, + fetch_events, + } + } else if object.has_get_events_by_event_access_path_request() { + let mut req = object.take_get_events_by_event_access_path_request(); + + let access_path = AccessPath::from_proto(req.take_access_path())?; + let start_event_seq_num = req.get_start_event_seq_num(); + let ascending = req.get_ascending(); + let limit = req.get_limit(); + + RequestItem::GetEventsByEventAccessPath { + access_path, + start_event_seq_num, + ascending, + limit, + } + } else if object.has_get_transactions_request() { + let req = object.get_get_transactions_request(); + let start_version = req.get_start_version(); + let limit = req.get_limit(); + let fetch_events = req.get_fetch_events(); + + RequestItem::GetTransactions { + start_version, + limit, + fetch_events, + } + } else { + unreachable!("Unknown RequestItem type.") + }) + } +} + +impl IntoProto for RequestItem { + type ProtoType = crate::proto::get_with_proof::RequestItem; + + fn into_proto(self) -> Self::ProtoType { + let mut out = Self::ProtoType::new(); + match self { + RequestItem::GetAccountState { address } => { + let mut req = GetAccountStateRequest::new(); + req.set_address(address.into_proto()); + out.set_get_account_state_request(req); + } + RequestItem::GetAccountTransactionBySequenceNumber { + account, + sequence_number, + fetch_events, + } => { + let mut req = GetAccountTransactionBySequenceNumberRequest::new(); + req.set_account(account.into_proto()); + req.set_sequence_number(sequence_number); + req.set_fetch_events(fetch_events); + + out.set_get_account_transaction_by_sequence_number_request(req); + } + RequestItem::GetEventsByEventAccessPath { + access_path, + start_event_seq_num, + ascending, + limit, + } => { + let mut req = GetEventsByEventAccessPathRequest::new(); + req.set_access_path(access_path.into_proto()); + req.set_start_event_seq_num(start_event_seq_num); + req.set_ascending(ascending); + req.set_limit(limit); + + out.set_get_events_by_event_access_path_request(req); + } + RequestItem::GetTransactions { + start_version, + limit, + fetch_events, + } => { + let mut req = GetTransactionsRequest::new(); + req.set_start_version(start_version); + req.set_limit(limit); + req.set_fetch_events(fetch_events); + + out.set_get_transactions_request(req); + } + } + out + } +} + +#[allow(clippy::large_enum_variant)] +#[derive(Arbitrary, Clone, Debug, Eq, PartialEq)] +pub enum ResponseItem { + GetAccountTransactionBySequenceNumber { + signed_transaction_with_proof: Option, + proof_of_current_sequence_number: Option, + }, + // this can't be the first variant, tracked here https://github.com/AltSysrq/proptest/issues/141 + GetAccountState { + account_state_with_proof: AccountStateWithProof, + }, + GetEventsByEventAccessPath { + events_with_proof: Vec, + proof_of_latest_event: Option, + }, + GetTransactions { + txn_list_with_proof: TransactionListWithProof, + }, +} + +impl ResponseItem { + pub fn into_get_account_state_response(self) -> Result { + match self { + ResponseItem::GetAccountState { + account_state_with_proof, + } => Ok(account_state_with_proof), + _ => bail!("Not ResponseItem::GetAccountState."), + } + } + + pub fn into_get_account_txn_by_seq_num_response( + self, + ) -> Result<( + Option, + Option, + )> { + match self { + ResponseItem::GetAccountTransactionBySequenceNumber { + signed_transaction_with_proof, + proof_of_current_sequence_number, + } => Ok(( + signed_transaction_with_proof, + proof_of_current_sequence_number, + )), + _ => bail!("Not ResponseItem::GetAccountTransactionBySequenceNumber."), + } + } + + pub fn into_get_events_by_access_path_response( + self, + ) -> Result<(Vec, Option)> { + match self { + ResponseItem::GetEventsByEventAccessPath { + events_with_proof, + proof_of_latest_event, + } => Ok((events_with_proof, proof_of_latest_event)), + _ => bail!("Not ResponseItem::GetEventsByEventAccessPath."), + } + } + + pub fn into_get_transactions_response(self) -> Result { + match self { + ResponseItem::GetTransactions { + txn_list_with_proof, + } => Ok(txn_list_with_proof), + _ => bail!("Not ResponseItem::GetTransactions."), + } + } +} + +impl FromProto for ResponseItem { + type ProtoType = crate::proto::get_with_proof::ResponseItem; + + fn from_proto(mut object: Self::ProtoType) -> Result { + Ok(if object.has_get_account_state_response() { + let account_state_with_proof = AccountStateWithProof::from_proto( + object + .take_get_account_state_response() + .take_account_state_with_proof(), + )?; + + ResponseItem::GetAccountState { + account_state_with_proof, + } + } else if object.has_get_account_transaction_by_sequence_number_response() { + let mut res = object.take_get_account_transaction_by_sequence_number_response(); + let signed_transaction_with_proof = res + .signed_transaction_with_proof + .take() + .map(SignedTransactionWithProof::from_proto) + .transpose()?; + let proof_of_current_sequence_number = res + .proof_of_current_sequence_number + .take() + .map(AccountStateWithProof::from_proto) + .transpose()?; + + ResponseItem::GetAccountTransactionBySequenceNumber { + signed_transaction_with_proof, + proof_of_current_sequence_number, + } + } else if object.has_get_events_by_event_access_path_response() { + let mut res = object.take_get_events_by_event_access_path_response(); + + let events_with_proof = res + .take_events_with_proof() + .into_iter() + .map(EventWithProof::from_proto) + .collect::>>()?; + + let proof_of_latest_event = res + .proof_of_latest_event + .take() + .map(AccountStateWithProof::from_proto) + .transpose()?; + + ResponseItem::GetEventsByEventAccessPath { + events_with_proof, + proof_of_latest_event, + } + } else if object.has_get_transactions_response() { + let mut res = object.take_get_transactions_response(); + let txn_list_with_proof = + TransactionListWithProof::from_proto(res.take_txn_list_with_proof())?; + + ResponseItem::GetTransactions { + txn_list_with_proof, + } + } else { + unreachable!("Unknown ResponseItem type.") + }) + } +} + +impl IntoProto for ResponseItem { + type ProtoType = crate::proto::get_with_proof::ResponseItem; + + fn into_proto(self) -> Self::ProtoType { + let mut out = Self::ProtoType::new(); + match self { + ResponseItem::GetAccountState { + account_state_with_proof, + } => { + let mut res = GetAccountStateResponse::new(); + res.set_account_state_with_proof(account_state_with_proof.into_proto()); + + out.set_get_account_state_response(res); + } + ResponseItem::GetAccountTransactionBySequenceNumber { + signed_transaction_with_proof, + proof_of_current_sequence_number, + } => { + let mut res = GetAccountTransactionBySequenceNumberResponse::new(); + + if let Some(t) = signed_transaction_with_proof { + res.set_signed_transaction_with_proof(t.into_proto()) + } + if let Some(p) = proof_of_current_sequence_number { + res.set_proof_of_current_sequence_number(p.into_proto()) + } + + out.set_get_account_transaction_by_sequence_number_response(res); + } + ResponseItem::GetEventsByEventAccessPath { + events_with_proof, + proof_of_latest_event, + } => { + let mut res = GetEventsByEventAccessPathResponse::new(); + res.set_events_with_proof(::protobuf::RepeatedField::from_vec( + events_with_proof + .into_iter() + .map(EventWithProof::into_proto) + .collect(), + )); + if let Some(p) = proof_of_latest_event { + res.set_proof_of_latest_event(p.into_proto()); + } + + out.set_get_events_by_event_access_path_response(res); + } + ResponseItem::GetTransactions { + txn_list_with_proof, + } => { + let mut res = GetTransactionsResponse::new(); + res.set_txn_list_with_proof(txn_list_with_proof.into_proto()); + + out.set_get_transactions_response(res) + } + } + out + } +} diff --git a/types/src/language_storage.rs b/types/src/language_storage.rs new file mode 100644 index 0000000000000..3d827963c3ba5 --- /dev/null +++ b/types/src/language_storage.rs @@ -0,0 +1,136 @@ +// Copyright (c) The Libra Core Contributors +// SPDX-License-Identifier: Apache-2.0 + +use crate::{access_path::AccessPath, account_address::AccountAddress}; +use canonical_serialization::{ + CanonicalDeserialize, CanonicalDeserializer, CanonicalSerialize, CanonicalSerializer, + SimpleSerializer, +}; +use crypto::hash::{AccessPathHasher, CryptoHash, CryptoHasher, HashValue}; +use failure::Result; +use serde::{Deserialize, Serialize}; +use std::string::String; + +#[derive(Serialize, Deserialize, Debug, PartialEq, Hash, Eq, Clone, PartialOrd, Ord)] +pub struct StructTag { + pub address: AccountAddress, + pub module: String, + pub name: String, + pub type_params: Vec, +} + +/// Represents the intitial key into global storage where we first index by the address, and then +/// the struct tag +#[derive(Serialize, Deserialize, Debug, PartialEq, Hash, Eq, Clone, PartialOrd, Ord)] +pub struct ResourceKey { + address: AccountAddress, + type_: StructTag, +} + +impl ResourceKey { + pub fn address(&self) -> AccountAddress { + self.address + } + + pub fn type_(&self) -> &StructTag { + &self.type_ + } +} + +impl ResourceKey { + pub fn new(address: AccountAddress, type_: StructTag) -> Self { + ResourceKey { address, type_ } + } +} + +/// Represents the intitial key into global storage where we first index by the address, and then +/// the struct tag +#[derive(Serialize, Deserialize, Debug, PartialEq, Hash, Eq, Clone, PartialOrd, Ord)] +pub struct CodeKey { + address: AccountAddress, + name: String, +} + +impl CodeKey { + pub fn new(address: AccountAddress, name: String) -> Self { + CodeKey { address, name } + } + + pub fn name(&self) -> &String { + &self.name + } + + pub fn address(&self) -> &AccountAddress { + &self.address + } +} + +impl<'a> Into for &'a CodeKey { + fn into(self) -> AccessPath { + AccessPath::code_access_path(self) + } +} + +impl CanonicalSerialize for CodeKey { + fn serialize(&self, serializer: &mut impl CanonicalSerializer) -> Result<()> { + serializer + .encode_struct(&self.address)? + .encode_variable_length_bytes(self.name.as_bytes())?; + Ok(()) + } +} + +impl CanonicalDeserialize for CodeKey { + fn deserialize(deserializer: &mut impl CanonicalDeserializer) -> Result { + let address = deserializer.decode_struct::()?; + let name = String::from_utf8(deserializer.decode_variable_length_bytes()?)?; + + Ok(Self { address, name }) + } +} + +impl CryptoHash for CodeKey { + type Hasher = AccessPathHasher; + + fn hash(&self) -> HashValue { + let mut state = Self::Hasher::default(); + state.write(&SimpleSerializer::>::serialize(self).unwrap()); + state.finish() + } +} + +impl CanonicalSerialize for StructTag { + fn serialize(&self, serializer: &mut impl CanonicalSerializer) -> Result<()> { + serializer + .encode_struct(&self.address)? + .encode_variable_length_bytes(self.module.as_bytes())? + .encode_variable_length_bytes(self.name.as_bytes())? + .encode_vec(&self.type_params)?; + Ok(()) + } +} + +impl CanonicalDeserialize for StructTag { + fn deserialize(deserializer: &mut impl CanonicalDeserializer) -> Result { + let address = deserializer.decode_struct::()?; + let module = String::from_utf8(deserializer.decode_variable_length_bytes()?)?; + let name = String::from_utf8(deserializer.decode_variable_length_bytes()?)?; + let type_params = deserializer.decode_vec::()?; + Ok(Self { + address, + name, + module, + type_params, + }) + } +} + +impl CryptoHash for StructTag { + type Hasher = AccessPathHasher; + + fn hash(&self) -> HashValue { + let mut state = Self::Hasher::default(); + state.write(&SimpleSerializer::>::serialize(self).unwrap()); + state.finish() + } +} diff --git a/types/src/ledger_info.rs b/types/src/ledger_info.rs new file mode 100644 index 0000000000000..65585e1b70d6a --- /dev/null +++ b/types/src/ledger_info.rs @@ -0,0 +1,274 @@ +// Copyright (c) The Libra Core Contributors +// SPDX-License-Identifier: Apache-2.0 + +#![allow(clippy::unit_arg)] + +use crate::{ + account_address::AccountAddress, + transaction::Version, + validator_verifier::{ValidatorVerifier, VerifyError}, +}; +use canonical_serialization::{CanonicalSerialize, CanonicalSerializer, SimpleSerializer}; +use crypto::{ + hash::{CryptoHash, CryptoHasher, LedgerInfoHasher}, + HashValue, Signature, +}; +use failure::prelude::*; +use proptest_derive::Arbitrary; +use proto_conv::{FromProto, IntoProto}; +use serde::{Deserialize, Serialize}; +use std::{ + collections::HashMap, + fmt::{Display, Formatter}, +}; + +/// This structure serves a dual purpose. +/// +/// First, if this structure is signed by 2f+1 validators it signifies the state of the ledger at +/// version `version` -- it contains the transaction accumulator at that version which commits to +/// all historical transactions. This structure may be expanded to include other information that +/// is derived from that accumulator (e.g. the current time according to the time contract) to +/// reduce the number of proofs a client must get. +/// +/// Second, the structure contains a `consensus_data_hash` value. This is the hash of an internal +/// data structure that represents a block that is voted on in HotStuff. If 2f+1 signatures are +/// gathered on the same ledger info that represents a Quorum Certificate (QC) on the HotStuff +/// data. +/// +/// Combining these two concepts when the consensus algorithm votes on a block B it votes for a +/// LedgerInfo with the `version` being the latest version that will be committed if B gets 2f+1 +/// votes. It sets `consensus_data_hash` to represent B so that if those 2f+1 votes are gathered a +/// QC is formed on B. +#[derive(Arbitrary, Clone, Debug, Eq, PartialEq, IntoProto, Serialize, Deserialize)] +#[ProtoType(crate::proto::ledger_info::LedgerInfo)] +pub struct LedgerInfo { + /// The version of latest transaction in the ledger. + version: Version, + + /// The root hash of transaction accumulator. + transaction_accumulator_hash: HashValue, + + /// Hash of consensus specific data that is opaque to all parts of the system other than + /// consensus. + consensus_data_hash: HashValue, + + /// Block id of the last committed block corresponding to this LedgerInfo + /// as reported by consensus. + consensus_block_id: HashValue, + + /// Epoch number corresponds to the set of validators that are active for this ledger info. + epoch_num: u64, + + // Timestamp that represents the microseconds since the epoch (unix time) that is + // generated by the proposer of the block. This is strictly increasing with every block. + // If a client reads a timestamp > the one they specified for transaction expiration time, + // they can be certain that their transaction will never be included in a block in the future + // (assuming that their transaction has not yet been included) + timestamp_usecs: u64, +} + +impl Display for LedgerInfo { + fn fmt(&self, f: &mut Formatter) -> std::fmt::Result { + write!( + f, + "LedgerInfo: [committed_block_id: {}, version: {}, epoch_num: {}, timestamp (us): {}]", + self.consensus_block_id(), + self.version(), + self.epoch_num(), + self.timestamp_usecs() + ) + } +} + +impl LedgerInfo { + /// Constructs a `LedgerInfo` object at a specific version using given transaction accumulator + /// root and hot stuff data hash. + pub fn new( + version: Version, + transaction_accumulator_hash: HashValue, + consensus_data_hash: HashValue, + consensus_block_id: HashValue, + epoch_num: u64, + timestamp_usecs: u64, + ) -> Self { + LedgerInfo { + version, + transaction_accumulator_hash, + consensus_data_hash, + consensus_block_id, + epoch_num, + timestamp_usecs, + } + } + + /// Returns the version of this `LedgerInfo`. + pub fn version(&self) -> Version { + self.version + } + + /// Returns the transaction accumulator root of this `LedgerInfo`. + pub fn transaction_accumulator_hash(&self) -> HashValue { + self.transaction_accumulator_hash + } + + /// Returns hash of consensus data in this `LedgerInfo`. + pub fn consensus_data_hash(&self) -> HashValue { + self.consensus_data_hash + } + + pub fn consensus_block_id(&self) -> HashValue { + self.consensus_block_id + } + + pub fn set_consensus_data_hash(&mut self, consensus_data_hash: HashValue) { + self.consensus_data_hash = consensus_data_hash; + } + + pub fn epoch_num(&self) -> u64 { + self.epoch_num + } + + pub fn timestamp_usecs(&self) -> u64 { + self.timestamp_usecs + } + + /// A ledger info is nominal if it's not certifying any real version. + pub fn is_zero(&self) -> bool { + self.version == 0 + } +} + +impl FromProto for LedgerInfo { + type ProtoType = crate::proto::ledger_info::LedgerInfo; + + fn from_proto(proto: Self::ProtoType) -> Result { + Ok(LedgerInfo::new( + proto.get_version(), + HashValue::from_slice(proto.get_transaction_accumulator_hash())?, + HashValue::from_slice(proto.get_consensus_data_hash())?, + HashValue::from_slice(proto.get_consensus_block_id())?, + proto.get_epoch_num(), + proto.get_timestamp_usecs(), + )) + } +} + +impl CanonicalSerialize for LedgerInfo { + fn serialize(&self, serializer: &mut impl CanonicalSerializer) -> Result<()> { + serializer + .encode_u64(self.version)? + .encode_raw_bytes(self.transaction_accumulator_hash.as_ref())? + .encode_raw_bytes(self.consensus_data_hash.as_ref())? + .encode_raw_bytes(self.consensus_block_id.as_ref())? + .encode_u64(self.epoch_num)? + .encode_u64(self.timestamp_usecs)?; + Ok(()) + } +} + +impl CryptoHash for LedgerInfo { + type Hasher = LedgerInfoHasher; + + fn hash(&self) -> HashValue { + let mut state = Self::Hasher::default(); + state.write( + &SimpleSerializer::>::serialize(self).expect("Serialization should work."), + ); + state.finish() + } +} + +// The validator node returns this structure which includes signatures +// from each validator to confirm the state. The client needs to only pass back +// the LedgerInfo element since the validator node doesn't need to know the signatures +// again when the client performs a query, those are only there for the client +// to be able to verify the state +#[derive(Clone, Debug, Eq, PartialEq, Serialize, Deserialize)] +pub struct LedgerInfoWithSignatures { + ledger_info: LedgerInfo, + /// The validator is identified by its account address: in order to verify a signature + /// one needs to retrieve the public key of the validator for the given epoch. + signatures: HashMap, +} + +impl Display for LedgerInfoWithSignatures { + fn fmt(&self, f: &mut Formatter) -> std::fmt::Result { + write!(f, "LedgerInfoWithSignatures: {}", self.ledger_info) + } +} + +impl LedgerInfoWithSignatures { + pub fn new(ledger_info: LedgerInfo, signatures: HashMap) -> Self { + LedgerInfoWithSignatures { + ledger_info, + signatures, + } + } + + pub fn ledger_info(&self) -> &LedgerInfo { + &self.ledger_info + } + + pub fn add_signature(&mut self, validator: AccountAddress, signature: Signature) { + self.signatures.entry(validator).or_insert(signature); + } + + pub fn signatures(&self) -> &HashMap { + &self.signatures + } + + pub fn verify(&self, validator: &ValidatorVerifier) -> ::std::result::Result<(), VerifyError> { + if self.ledger_info.is_zero() { + // We're not trying to verify nominal ledger info that does not carry any information. + return Ok(()); + } + let ledger_hash = self.ledger_info().hash(); + validator.verify_aggregated_signature(ledger_hash, self.signatures()) + } +} + +impl FromProto for LedgerInfoWithSignatures { + type ProtoType = crate::proto::ledger_info::LedgerInfoWithSignatures; + + fn from_proto(mut proto: Self::ProtoType) -> Result { + let ledger_info = LedgerInfo::from_proto(proto.take_ledger_info())?; + + let signatures_proto = proto.take_signatures(); + let num_signatures = signatures_proto.len(); + let signatures = signatures_proto + .into_iter() + .map(|proto| { + let validator_id = AccountAddress::from_proto(proto.get_validator_id().to_vec())?; + let signature = Signature::from_compact(proto.get_signature())?; + Ok((validator_id, signature)) + }) + .collect::>>()?; + ensure!( + signatures.len() == num_signatures, + "Signatures should be from different validators." + ); + + Ok(LedgerInfoWithSignatures { + ledger_info, + signatures, + }) + } +} + +impl IntoProto for LedgerInfoWithSignatures { + type ProtoType = crate::proto::ledger_info::LedgerInfoWithSignatures; + + fn into_proto(self) -> Self::ProtoType { + let mut proto = Self::ProtoType::new(); + proto.set_ledger_info(self.ledger_info.into_proto()); + self.signatures + .into_iter() + .for_each(|(validator_id, signature)| { + let mut validator_signature = crate::proto::ledger_info::ValidatorSignature::new(); + validator_signature.set_validator_id(validator_id.into_proto()); + validator_signature.set_signature(signature.to_compact().to_vec()); + proto.mut_signatures().push(validator_signature) + }); + proto + } +} diff --git a/types/src/lib.rs b/types/src/lib.rs new file mode 100644 index 0000000000000..e5a570ba85419 --- /dev/null +++ b/types/src/lib.rs @@ -0,0 +1,35 @@ +// Copyright (c) The Libra Core Contributors +// SPDX-License-Identifier: Apache-2.0 + +#![feature(test)] + +pub mod access_path; +pub mod account_address; +pub mod account_config; +pub mod account_state_blob; +pub mod byte_array; +pub mod contract_event; +pub mod get_with_proof; +pub mod language_storage; +pub mod ledger_info; +pub mod proof; +pub mod proptest_types; +pub mod proto; +pub mod test_helpers; +pub mod transaction; +pub mod transaction_helpers; +pub mod validator_change; +pub mod validator_public_keys; +pub mod validator_set; +pub mod validator_signer; +pub mod validator_verifier; +pub mod vm_error; +pub mod write_set; + +pub use account_address::AccountAddress as PeerId; + +#[cfg(test)] +extern crate test; + +#[cfg(test)] +mod unit_tests; diff --git a/types/src/proof/definition.rs b/types/src/proof/definition.rs new file mode 100644 index 0000000000000..7e1cb216791c9 --- /dev/null +++ b/types/src/proof/definition.rs @@ -0,0 +1,547 @@ +// Copyright (c) The Libra Core Contributors +// SPDX-License-Identifier: Apache-2.0 + +//! This module has definition of various proofs. + +#[cfg(test)] +#[path = "unit_tests/proof_proto_conversion_test.rs"] +mod proof_proto_conversion_test; + +use self::bitmap::{AccumulatorBitmap, SparseMerkleBitmap}; +use crate::transaction::TransactionInfo; +use crypto::{ + hash::{ACCUMULATOR_PLACEHOLDER_HASH, SPARSE_MERKLE_PLACEHOLDER_HASH}, + HashValue, +}; +use failure::prelude::*; +use proto_conv::{FromProto, IntoProto}; + +/// A proof that can be used authenticate an element in an accumulator given trusted root hash. For +/// example, both `LedgerInfoToTransactionInfoProof` and `TransactionInfoToEventProof` can be +/// constructed on top of this structure. +#[derive(Clone, Debug, Eq, PartialEq)] +pub struct AccumulatorProof { + /// All siblings in this proof, including the default ones. Siblings near the root are at the + /// beginning of the vector. + siblings: Vec, +} + +impl AccumulatorProof { + /// Constructs a new `AccumulatorProof` using a list of siblings. + pub fn new(siblings: Vec) -> Self { + // The sibling list could be empty in case the accumulator is empty or has a single + // element. When it's not empty, the top most sibling will never be default, otherwise the + // accumulator should have collapsed to a smaller one. + if let Some(first_sibling) = siblings.first() { + assert_ne!(*first_sibling, *ACCUMULATOR_PLACEHOLDER_HASH); + } + + AccumulatorProof { siblings } + } + + /// Returns the list of siblings in this proof. + pub fn siblings(&self) -> &[HashValue] { + &self.siblings + } +} + +impl FromProto for AccumulatorProof { + type ProtoType = crate::proto::proof::AccumulatorProof; + + fn from_proto(mut proto_proof: Self::ProtoType) -> Result { + let bitmap = proto_proof.get_bitmap(); + let num_non_default_siblings = bitmap.count_ones() as usize; + ensure!( + num_non_default_siblings == proto_proof.get_non_default_siblings().len(), + "Malformed proof. Bitmap indicated {} non-default siblings. Found {} siblings.", + num_non_default_siblings, + proto_proof.get_non_default_siblings().len() + ); + + let mut proto_siblings = proto_proof.take_non_default_siblings().into_iter(); + // Iterate from the leftmost 1-bit to LSB in the bitmap. If a bit is set, the corresponding + // sibling is non-default and we take the sibling from proto_siblings. Otherwise the + // sibling on this position is default. + let siblings = AccumulatorBitmap::new(bitmap) + .iter() + .map(|x| { + if x { + let hash_bytes = proto_siblings + .next() + .expect("Unexpected number of siblings."); + HashValue::from_slice(&hash_bytes) + } else { + Ok(*ACCUMULATOR_PLACEHOLDER_HASH) + } + }) + .collect::>>()?; + + Ok(AccumulatorProof::new(siblings)) + } +} + +impl IntoProto for AccumulatorProof { + type ProtoType = crate::proto::proof::AccumulatorProof; + + fn into_proto(self) -> Self::ProtoType { + let mut proto_proof = Self::ProtoType::new(); + // Iterate over all siblings. For each non-default sibling, add to protobuf struct and set + // the corresponding bit in the bitmap. + let bitmap: AccumulatorBitmap = self + .siblings + .into_iter() + .map(|sibling| { + if sibling != *ACCUMULATOR_PLACEHOLDER_HASH { + proto_proof + .mut_non_default_siblings() + .push(sibling.to_vec()); + true + } else { + false + } + }) + .collect(); + proto_proof.set_bitmap(bitmap.into()); + proto_proof + } +} + +/// A proof that can be used to authenticate an element in a Sparse Merkle Tree given trusted root +/// hash. For example, `TransactionInfoToAccountProof` can be constructed on top of this structure. +#[derive(Clone, Debug, Eq, PartialEq)] +pub struct SparseMerkleProof { + /// This proof can be used to authenticate whether a given leaf exists in the tree or not. + /// - If this is `Some(HashValue, HashValue)` + /// - If the first `HashValue` equals requested key, this is an inclusion proof and the + /// second `HashValue` equals the hash of the corresponding account blob. + /// - Otherwise this is a non-inclusion proof. The first `HashValue` is the only key + /// that exists in the subtree and the second `HashValue` equals the hash of the + /// corresponding account blob. + /// - If this is `None`, this is also a non-inclusion proof which indicates the subtree is + /// empty. + leaf: Option<(HashValue, HashValue)>, + + /// All siblings in this proof, including the default ones. Siblings near the root are at the + /// beginning of the vector. + siblings: Vec, +} + +impl SparseMerkleProof { + /// Constructs a new `SparseMerkleProof` using leaf and a list of siblings. + pub fn new(leaf: Option<(HashValue, HashValue)>, siblings: Vec) -> Self { + // The sibling list could be empty in case the Sparse Merkle Tree is empty or has a single + // element. When it's not empty, the bottom most sibling will never be default, otherwise a + // leaf and a default sibling should have collapsed to a leaf. + if let Some(last_sibling) = siblings.last() { + assert_ne!(*last_sibling, *SPARSE_MERKLE_PLACEHOLDER_HASH); + } + + SparseMerkleProof { leaf, siblings } + } + + /// Returns the leaf node in this proof. + pub fn leaf(&self) -> Option<(HashValue, HashValue)> { + self.leaf + } + + /// Returns the list of siblings in this proof. + pub fn siblings(&self) -> &[HashValue] { + &self.siblings + } +} + +impl FromProto for SparseMerkleProof { + type ProtoType = crate::proto::proof::SparseMerkleProof; + + /// Validates `proto_proof` and converts it to `Self` if validation passed. + fn from_proto(mut proto_proof: Self::ProtoType) -> Result { + let proto_leaf = proto_proof.take_leaf(); + let leaf = if proto_leaf.is_empty() { + None + } else if proto_leaf.len() == HashValue::LENGTH * 2 { + let key = HashValue::from_slice(&proto_leaf[0..HashValue::LENGTH])?; + let value_hash = HashValue::from_slice(&proto_leaf[HashValue::LENGTH..])?; + Some((key, value_hash)) + } else { + bail!( + "Mailformed proof. Leaf has {} bytes. Expect 0 or {} bytes.", + proto_leaf.len(), + HashValue::LENGTH * 2 + ); + }; + + let bitmap = proto_proof.take_bitmap(); + if let Some(last_byte) = bitmap.last() { + ensure!( + *last_byte != 0, + "Malformed proof. The last byte of the bitmap is zero." + ); + } + let num_non_default_siblings = bitmap.iter().fold(0, |total, x| total + x.count_ones()); + ensure!( + num_non_default_siblings as usize == proto_proof.get_non_default_siblings().len(), + "Malformed proof. Bitmap indicated {} non-default siblings. Found {} siblings.", + num_non_default_siblings, + proto_proof.get_non_default_siblings().len() + ); + + let mut proto_siblings = proto_proof.take_non_default_siblings().into_iter(); + // Iterate from the MSB of the first byte to the rightmost 1-bit in the bitmap. If a bit is + // set, the corresponding sibling is non-default and we take the sibling from + // proto_siblings. Otherwise the sibling on this position is default. + let siblings: Result> = SparseMerkleBitmap::new(bitmap) + .iter() + .map(|x| { + if x { + let hash_bytes = proto_siblings + .next() + .expect("Unexpected number of siblings."); + HashValue::from_slice(&hash_bytes) + } else { + Ok(*SPARSE_MERKLE_PLACEHOLDER_HASH) + } + }) + .collect(); + + Ok(SparseMerkleProof::new(leaf, siblings?)) + } +} + +impl IntoProto for SparseMerkleProof { + type ProtoType = crate::proto::proof::SparseMerkleProof; + + fn into_proto(self) -> Self::ProtoType { + let mut proto_proof = Self::ProtoType::new(); + // If a leaf is present, we write the key and value hash as a single byte array of 64 + // bytes. Otherwise we write an empty byte array. + if let Some((key, value_hash)) = self.leaf { + proto_proof.mut_leaf().extend_from_slice(key.as_ref()); + proto_proof + .mut_leaf() + .extend_from_slice(value_hash.as_ref()); + } + // Iterate over all siblings. For each non-default sibling, add to protobuf struct and set + // the corresponding bit in the bitmap. + let bitmap: SparseMerkleBitmap = self + .siblings + .into_iter() + .map(|sibling| { + if sibling != *SPARSE_MERKLE_PLACEHOLDER_HASH { + proto_proof + .mut_non_default_siblings() + .push(sibling.to_vec()); + true + } else { + false + } + }) + .collect(); + proto_proof.set_bitmap(bitmap.into()); + proto_proof + } +} + +/// The complete proof used to authenticate a `SignedTransaction` object. This structure consists +/// of an `AccumulatorProof` from `LedgerInfo` to `TransactionInfo` the verifier needs to verify +/// the correctness of the `TransactionInfo` object, and the `TransactionInfo` object that is +/// supposed to match the `SignedTransaction`. +#[derive(Clone, Debug, Eq, PartialEq, FromProto, IntoProto)] +#[ProtoType(crate::proto::proof::SignedTransactionProof)] +pub struct SignedTransactionProof { + /// The accumulator proof from ledger info root to leaf that authenticates the hash of the + /// `TransactionInfo` object. + ledger_info_to_transaction_info_proof: AccumulatorProof, + + /// The `TransactionInfo` object at the leaf of the accumulator. + transaction_info: TransactionInfo, +} + +impl SignedTransactionProof { + /// Constructs a new `SignedTransactionProof` object using given + /// `ledger_info_to_transaction_info_proof`. + pub fn new( + ledger_info_to_transaction_info_proof: AccumulatorProof, + transaction_info: TransactionInfo, + ) -> Self { + SignedTransactionProof { + ledger_info_to_transaction_info_proof, + transaction_info, + } + } + + /// Returns the `ledger_info_to_transaction_info_proof` object in this proof. + pub fn ledger_info_to_transaction_info_proof(&self) -> &AccumulatorProof { + &self.ledger_info_to_transaction_info_proof + } + + /// Returns the `transaction_info` object in this proof. + pub fn transaction_info(&self) -> &TransactionInfo { + &self.transaction_info + } +} + +/// The complete proof used to authenticate the state of an account. This structure consists of the +/// `AccumulatorProof` from `LedgerInfo` to `TransactionInfo`, the `TransactionInfo` object and the +/// `SparseMerkleProof` from state root to the account. +#[derive(Clone, Debug, Eq, PartialEq, FromProto, IntoProto)] +#[ProtoType(crate::proto::proof::AccountStateProof)] +pub struct AccountStateProof { + /// The accumulator proof from ledger info root to leaf that authenticates the hash of the + /// `TransactionInfo` object. + ledger_info_to_transaction_info_proof: AccumulatorProof, + + /// The `TransactionInfo` object at the leaf of the accumulator. + transaction_info: TransactionInfo, + + /// The sparse merkle proof from state root to the account state. + transaction_info_to_account_proof: SparseMerkleProof, +} + +impl AccountStateProof { + /// Constructs a new `AccountStateProof` using given `ledger_info_to_transaction_info_proof`, + /// `transaction_info` and `transaction_info_to_account_proof`. + pub fn new( + ledger_info_to_transaction_info_proof: AccumulatorProof, + transaction_info: TransactionInfo, + transaction_info_to_account_proof: SparseMerkleProof, + ) -> Self { + AccountStateProof { + ledger_info_to_transaction_info_proof, + transaction_info, + transaction_info_to_account_proof, + } + } + + /// Returns the `ledger_info_to_transaction_info_proof` object in this proof. + pub fn ledger_info_to_transaction_info_proof(&self) -> &AccumulatorProof { + &self.ledger_info_to_transaction_info_proof + } + + /// Returns the `transaction_info` object in this proof. + pub fn transaction_info(&self) -> &TransactionInfo { + &self.transaction_info + } + + /// Returns the `transaction_info_to_account_proof` object in this proof. + pub fn transaction_info_to_account_proof(&self) -> &SparseMerkleProof { + &self.transaction_info_to_account_proof + } +} + +/// The complete proof used to authenticate a contract event. This structure consists of the +/// `AccumulatorProof` from `LedgerInfo` to `TransactionInfo`, the `TransactionInfo` object and the +/// `AccumulatorProof` from event accumulator root to the event. +#[derive(Clone, Debug, Eq, PartialEq, FromProto, IntoProto)] +#[ProtoType(crate::proto::proof::EventProof)] +pub struct EventProof { + /// The accumulator proof from ledger info root to leaf that authenticates the hash of the + /// `TransactionInfo` object. + ledger_info_to_transaction_info_proof: AccumulatorProof, + + /// The `TransactionInfo` object at the leaf of the accumulator. + transaction_info: TransactionInfo, + + /// The accumulator proof from event root to the actual event. + transaction_info_to_event_proof: AccumulatorProof, +} + +impl EventProof { + /// Constructs a new `EventProof` using given `ledger_info_to_transaction_info_proof`, + /// `transaction_info` and `transaction_info_to_event_proof`. + pub fn new( + ledger_info_to_transaction_info_proof: AccumulatorProof, + transaction_info: TransactionInfo, + transaction_info_to_event_proof: AccumulatorProof, + ) -> Self { + EventProof { + ledger_info_to_transaction_info_proof, + transaction_info, + transaction_info_to_event_proof, + } + } + + /// Returns the `ledger_info_to_transaction_info_proof` object in this proof. + pub fn ledger_info_to_transaction_info_proof(&self) -> &AccumulatorProof { + &self.ledger_info_to_transaction_info_proof + } + + /// Returns the `transaction_info` object in this proof. + pub fn transaction_info(&self) -> &TransactionInfo { + &self.transaction_info + } + + /// Returns the `transaction_info_to_event_proof` object in this proof. + pub fn transaction_info_to_event_proof(&self) -> &AccumulatorProof { + &self.transaction_info_to_event_proof + } +} + +mod bitmap { + /// The bitmap indicating which siblings are default in a compressed accumulator proof. 1 means + /// non-default and 0 means default. The LSB corresponds to the sibling at the bottom of the + /// accumulator. The leftmost 1-bit corresponds to the sibling at the top of the accumulator, + /// since this one is always non-default. + #[derive(Clone, Copy, Debug, Eq, PartialEq)] + pub struct AccumulatorBitmap(u64); + + impl AccumulatorBitmap { + pub fn new(bitmap: u64) -> Self { + AccumulatorBitmap(bitmap) + } + + pub fn iter(self) -> AccumulatorBitmapIterator { + AccumulatorBitmapIterator::new(self.0) + } + } + + impl std::convert::From for u64 { + fn from(bitmap: AccumulatorBitmap) -> u64 { + bitmap.0 + } + } + + /// Given a u64 bitmap, this iterator generates one bit at a time starting from the leftmost + /// 1-bit. + pub struct AccumulatorBitmapIterator { + bitmap: AccumulatorBitmap, + mask: u64, + } + + impl AccumulatorBitmapIterator { + fn new(bitmap: u64) -> Self { + let num_leading_zeros = bitmap.leading_zeros(); + let mask = if num_leading_zeros >= 64 { + 0 + } else { + 1 << (63 - num_leading_zeros) + }; + AccumulatorBitmapIterator { + bitmap: AccumulatorBitmap(bitmap), + mask, + } + } + } + + impl std::iter::Iterator for AccumulatorBitmapIterator { + type Item = bool; + + fn next(&mut self) -> Option { + if self.mask == 0 { + return None; + } + let ret = self.bitmap.0 & self.mask != 0; + self.mask >>= 1; + Some(ret) + } + } + + impl std::iter::FromIterator for AccumulatorBitmap { + fn from_iter(iter: I) -> Self + where + I: std::iter::IntoIterator, + { + let mut bitmap = 0; + for (i, bit) in iter.into_iter().enumerate() { + if i == 0 { + assert!(bit, "The first bit should always be set."); + } else if i > 63 { + panic!("Trying to put more than 64 bits in AccumulatorBitmap."); + } + bitmap <<= 1; + bitmap |= bit as u64; + } + AccumulatorBitmap::new(bitmap) + } + } + + /// The bitmap indicating which siblings are default in a compressed sparse merkle proof. 1 + /// means non-default and 0 means default. The MSB of the first byte corresponds to the + /// sibling at the top of the Sparse Merkle Tree. The rightmost 1-bit of the last byte + /// corresponds to the sibling at the bottom, since this one is always non-default. + #[derive(Clone, Debug, Eq, PartialEq)] + pub struct SparseMerkleBitmap(Vec); + + impl SparseMerkleBitmap { + pub fn new(bitmap: Vec) -> Self { + SparseMerkleBitmap(bitmap) + } + + pub fn iter(&self) -> SparseMerkleBitmapIterator { + SparseMerkleBitmapIterator::new(&self.0) + } + } + + impl std::convert::From for Vec { + fn from(bitmap: SparseMerkleBitmap) -> Vec { + bitmap.0 + } + } + + /// Given a `Vec` bitmap, this iterator generates one bit at a time starting from the MSB + /// of the first byte. All trailing zeros of the last byte are discarded. + pub struct SparseMerkleBitmapIterator<'a> { + bitmap: &'a [u8], + index: usize, + len: usize, + } + + impl<'a> SparseMerkleBitmapIterator<'a> { + fn new(bitmap: &'a [u8]) -> Self { + match bitmap.last() { + Some(last_byte) => { + assert_ne!( + *last_byte, 0, + "The last byte of the bitmap should never be zero." + ); + SparseMerkleBitmapIterator { + bitmap, + index: 0, + len: bitmap.len() * 8 - last_byte.trailing_zeros() as usize, + } + } + None => SparseMerkleBitmapIterator { + bitmap, + index: 0, + len: 0, + }, + } + } + } + + impl<'a> std::iter::Iterator for SparseMerkleBitmapIterator<'a> { + type Item = bool; + + fn next(&mut self) -> Option { + // We are past the last useful bit. + if self.index >= self.len { + return None; + } + + let pos = self.index / 8; + let bit = self.index % 8; + let ret = self.bitmap[pos] >> (7 - bit) & 1 != 0; + self.index += 1; + Some(ret) + } + } + + impl std::iter::FromIterator for SparseMerkleBitmap { + fn from_iter(iter: I) -> Self + where + I: std::iter::IntoIterator, + { + let mut bitmap = vec![]; + for (i, bit) in iter.into_iter().enumerate() { + let pos = i % 8; + if pos == 0 { + bitmap.push(0); + } + let last_byte = bitmap + .last_mut() + .expect("The bitmap vector should not be empty"); + *last_byte |= (bit as u8) << (7 - pos); + } + SparseMerkleBitmap::new(bitmap) + } + } +} diff --git a/types/src/proof/mod.rs b/types/src/proof/mod.rs new file mode 100644 index 0000000000000..d754216ce6617 --- /dev/null +++ b/types/src/proof/mod.rs @@ -0,0 +1,540 @@ +// Copyright (c) The Libra Core Contributors +// SPDX-License-Identifier: Apache-2.0 + +pub mod definition; +pub mod position; +pub mod proptest_proof; +pub mod treebits; + +#[cfg(test)] +#[path = "unit_tests/proof_test.rs"] +mod proof_test; +#[cfg(test)] +mod unit_tests; + +use crate::{ + account_state_blob::AccountStateBlob, + contract_event::ContractEvent, + ledger_info::LedgerInfo, + transaction::{TransactionInfo, TransactionListWithProof, Version}, +}; +use crypto::{ + hash::{ + CryptoHash, CryptoHasher, EventAccumulatorHasher, SparseMerkleInternalHasher, + SparseMerkleLeafHasher, TestOnlyHasher, TransactionAccumulatorHasher, + ACCUMULATOR_PLACEHOLDER_HASH, SPARSE_MERKLE_PLACEHOLDER_HASH, + }, + HashValue, +}; +use failure::prelude::*; +use std::{collections::VecDeque, marker::PhantomData}; + +pub use crate::proof::definition::{ + AccountStateProof, AccumulatorProof, EventProof, SignedTransactionProof, SparseMerkleProof, +}; + +/// Verifies that a `SignedTransaction` with hash value of `signed_transaction_hash` +/// is the version `transaction_version` transaction in the ledger using the provided proof. +/// If event_root_hash is provided, it's also verified against the proof. +pub fn verify_signed_transaction( + ledger_info: &LedgerInfo, + signed_transaction_hash: HashValue, + event_root_hash: Option, + transaction_version: Version, + signed_transaction_proof: &SignedTransactionProof, +) -> Result<()> { + let transaction_info = signed_transaction_proof.transaction_info(); + + ensure!( + signed_transaction_hash == transaction_info.signed_transaction_hash(), + "The hash of signed transaction does not match the transaction info in proof. \ + Transaction hash: {:x}. Transaction hash provided by proof: {:x}.", + signed_transaction_hash, + transaction_info.signed_transaction_hash() + ); + + if let Some(event_root_hash) = event_root_hash { + ensure!( + event_root_hash == transaction_info.event_root_hash(), + "Event root hash ({}) doesn't match that in the transaction info ({}).", + event_root_hash, + transaction_info.event_root_hash(), + ); + } + + verify_transaction_info( + ledger_info, + transaction_version, + transaction_info, + signed_transaction_proof.ledger_info_to_transaction_info_proof(), + )?; + Ok(()) +} + +/// Verifies that the state of an account at version `state_version` is correct using the provided +/// proof. If `account_state_blob` is present, we expect the account to exist, otherwise we +/// expect the account to not exist. +pub fn verify_account_state( + ledger_info: &LedgerInfo, + state_version: Version, + account_address_hash: HashValue, + account_state_blob: &Option, + account_state_proof: &AccountStateProof, +) -> Result<()> { + let transaction_info = account_state_proof.transaction_info(); + + verify_sparse_merkle_element( + transaction_info.state_root_hash(), + account_address_hash, + account_state_blob, + account_state_proof.transaction_info_to_account_proof(), + )?; + + verify_transaction_info( + ledger_info, + state_version, + transaction_info, + account_state_proof.ledger_info_to_transaction_info_proof(), + )?; + Ok(()) +} + +/// Verifies that a given event is correct using provided proof. +pub(crate) fn verify_event( + ledger_info: &LedgerInfo, + event_hash: HashValue, + transaction_version: Version, + event_version_within_transaction: Version, + event_proof: &EventProof, +) -> Result<()> { + let transaction_info = event_proof.transaction_info(); + + verify_event_accumulator_element( + transaction_info.event_root_hash(), + event_hash, + event_version_within_transaction, + event_proof.transaction_info_to_event_proof(), + )?; + + verify_transaction_info( + ledger_info, + transaction_version, + transaction_info, + event_proof.ledger_info_to_transaction_info_proof(), + )?; + + Ok(()) +} + +pub(crate) fn verify_transaction_list( + ledger_info: &LedgerInfo, + transaction_list_with_proof: &TransactionListWithProof, +) -> Result<()> { + let (transaction_and_infos, event_lists, first_transaction_version, first_proof, last_proof) = ( + &transaction_list_with_proof.transaction_and_infos, + transaction_list_with_proof.events.as_ref(), + transaction_list_with_proof.first_transaction_version, + transaction_list_with_proof + .proof_of_first_transaction + .as_ref(), + transaction_list_with_proof + .proof_of_last_transaction + .as_ref(), + ); + + let num_txns = transaction_and_infos.len(); + if let Some(event_lists) = event_lists { + ensure!( + num_txns == event_lists.len(), + "Number of the event lists doesn't match that of the transactions: {} vs {}", + num_txns, + event_lists.len(), + ); + } + + // 1. Emtpy list; + if num_txns == 0 { + ensure!( + first_proof.is_none(), + "List is empty but proof of the first transaction is provided." + ); + ensure!( + last_proof.is_none(), + "List is empty but proof of the last transaction is provided." + ); + ensure!( + first_transaction_version.is_none(), + "List is empty but expecting first transaction to exist.", + ); + return Ok(()); + } + + // 2. Non-empty list. + let first_version = first_transaction_version.ok_or_else(|| { + format_err!("Invalid TransactionListWithProof: First_transaction_version is None.") + })?; + let siblings_of_first_txn = first_proof + .ok_or_else(|| { + format_err!("Invalid TransactionListWithProof: First transaction proof is None") + })? + .siblings(); + let siblings_of_last_txn = match (num_txns, last_proof) { + (1, None) => siblings_of_first_txn, + (_, Some(last_proof)) => last_proof.siblings(), + _ => bail!( + "Invalid TransactionListWithProof: Last transaction proof is_none:{}, num_txns:{}", + last_proof.is_none(), + num_txns + ), + }; + + // Verify event root hashes match what is carried on the transaction infos. + if let Some(event_lists) = event_lists { + itertools::zip_eq(event_lists, transaction_and_infos).map(|(events, (_txn, txn_info))| { + let event_hashes: Vec<_> = events.iter().map(ContractEvent::hash).collect(); + let event_root_hash = get_accumulator_root_hash::(&event_hashes); + ensure!( + event_root_hash == txn_info.event_root_hash(), + "Some event root hash calculated doesn't match that carried on the transaction info.", + ); + Ok(()) + }).collect::>>()?; + } + + // Get the hashes of all nodes at the accumulator leaf level. + let mut hashes = transaction_and_infos + .iter() + .map(|(txn, txn_info)| { + // Verify all transaction_infos and signed_transactions are consistent. + ensure!( + txn.hash() == txn_info.signed_transaction_hash(), + "Some hash of signed transaction does not match the corresponding transaction info in proof" + ); + Ok(txn_info.hash()) + }) + .collect::>>()?; + + let mut first_index = first_version; + + // Verify level by level from the leaf level upwards. + for (first_sibling, last_sibling) in siblings_of_first_txn + .iter() + .zip(siblings_of_last_txn.iter()) + .rev() + { + assert!(!hashes.is_empty()); + let num_nodes = hashes.len(); + + if num_nodes > 1 { + let last_index = first_index + num_nodes as u64 - 1; + if last_index % 2 == 0 { + // if `last_index` is even, it is the left child of its parent so the sibling is not + // in `hashes`, we have to append it to `hashes` generate parent nodes' hashes. + hashes.push_back(*last_sibling); + } else { + // Otherwise, the sibling should be the second to last hash. + // Note: if we check `first_index` first we cannot use num_nodes to index because + // hashes length may change. + ensure!(hashes[num_nodes - 2] == *last_sibling, + "Invalid TransactionListWithProof: Last transaction proof doesn't match provided siblings"); + } + // We haven't reached the first common ancester of all transactions in the list. + if first_index % 2 == 0 { + // if `first_index` is even, it is the left child of its parent so the sibling must + // be the next node. + ensure!(hashes[1] == *first_sibling, + "Invalid TransactionListWithProof: First transaction proof doesn't match provided siblings"); + } else { + // Otherwise, the sibling is not in `hashes`, we have to prepend it to `hashes` to + // generate parent nodes' hashes. + hashes.push_front(*first_sibling); + } + } else { + // We have reached the first common ancestor of all the transactions in the list. + ensure!( + first_sibling == last_sibling, + "Invalid TransactionListWithProof: Either proof is invalid." + ); + if first_index % 2 == 0 { + hashes.push_back(*first_sibling); + } else { + hashes.push_front(*first_sibling); + } + } + let mut hash_iter = hashes.into_iter(); + let mut parent_hashes = VecDeque::new(); + while let Some(left) = hash_iter.next() { + let right = hash_iter.next().expect("Can't be None"); + parent_hashes.push_back( + MerkleTreeInternalNode::::new(left, right).hash(), + ) + } + hashes = parent_hashes; + // The parent node index at its level should be floor(index / 2) + first_index /= 2; + } + assert!(hashes.len() == 1); + let expected_root_hash = ledger_info.transaction_accumulator_hash(); + ensure!( + hashes[0] == expected_root_hash, + "Root hashes do not match. Actual root hash: {:x}. Expected root hash: {:x}.", + hashes[0], + expected_root_hash + ); + Ok(()) +} + +/// Verifies that a given `transaction_info` exists in the ledger using provided proof. +fn verify_transaction_info( + ledger_info: &LedgerInfo, + transaction_version: Version, + transaction_info: &TransactionInfo, + ledger_info_to_transaction_info_proof: &AccumulatorProof, +) -> Result<()> { + ensure!( + transaction_version <= ledger_info.version(), + "Transaction version {} is newer than LedgerInfo version {}.", + transaction_version, + ledger_info.version(), + ); + + let transaction_info_hash = transaction_info.hash(); + verify_transaction_accumulator_element( + ledger_info.transaction_accumulator_hash(), + transaction_info_hash, + transaction_version, + ledger_info_to_transaction_info_proof, + )?; + + Ok(()) +} + +/// Verifies an element whose hash is `element_hash` and version is `element_version` exists in the +/// accumulator whose root hash is `expected_root_hash` using the provided proof. +fn verify_accumulator_element( + expected_root_hash: HashValue, + element_hash: HashValue, + element_index: u64, + accumulator_proof: &AccumulatorProof, +) -> Result<()> { + let siblings = accumulator_proof.siblings(); + ensure!( + siblings.len() <= 63, + "Accumulator proof has more than 63 ({}) siblings.", + siblings.len() + ); + + let actual_root_hash = siblings + .iter() + .rev() + .fold( + (element_hash, element_index), + // `index` denotes the index of the ancestor of the element at the current level. + |(hash, index), sibling_hash| { + ( + if index % 2 == 0 { + // the current node is a left child. + MerkleTreeInternalNode::::new(hash, *sibling_hash).hash() + } else { + // the current node is a right child. + MerkleTreeInternalNode::::new(*sibling_hash, hash).hash() + }, + // The index of the parent at its level. + index / 2, + ) + }, + ) + .0; + ensure!( + actual_root_hash == expected_root_hash, + "Root hashes do not match. Actual root hash: {:x}. Expected root hash: {:x}.", + actual_root_hash, + expected_root_hash + ); + + Ok(()) +} + +pub(crate) fn get_accumulator_root_hash( + element_hashes: &[HashValue], +) -> HashValue { + if element_hashes.is_empty() { + return *ACCUMULATOR_PLACEHOLDER_HASH; + } + + let mut next_level: Vec; + let mut current_level: &[HashValue] = element_hashes; + + while current_level.len() > 1 { + next_level = current_level + .chunks(2) + .map(|t| { + if t.len() == 2 { + MerkleTreeInternalNode::::new(t[0], t[1]).hash() + } else { + MerkleTreeInternalNode::::new(t[0], *ACCUMULATOR_PLACEHOLDER_HASH).hash() + } + }) + .collect(); + + current_level = &next_level; + } + + current_level[0] +} + +type AccumulatorElementVerifier = fn( + expected_root_hash: HashValue, + element_hash: HashValue, + element_version: Version, + accumulator_proof: &AccumulatorProof, +) -> Result<()>; + +#[allow(non_upper_case_globals)] +pub const verify_event_accumulator_element: AccumulatorElementVerifier = + verify_accumulator_element::; + +#[allow(non_upper_case_globals)] +pub const verify_transaction_accumulator_element: AccumulatorElementVerifier = + verify_accumulator_element::; + +#[allow(non_upper_case_globals)] +pub const verify_test_accumulator_element: AccumulatorElementVerifier = + verify_accumulator_element::; + +/// If `element_blob` is present, verifies an element whose key is `element_key` and value +/// is `element_blob` exists in the Sparse Merkle Tree using the provided proof. +/// Otherwise verifies the proof is a valid non-inclusion proof that shows this key doesn't exist +/// in the tree. +pub fn verify_sparse_merkle_element( + expected_root_hash: HashValue, + element_key: HashValue, + element_blob: &Option, + sparse_merkle_proof: &SparseMerkleProof, +) -> Result<()> { + let siblings = sparse_merkle_proof.siblings(); + ensure!( + siblings.len() <= HashValue::LENGTH_IN_BITS, + "Sparse Merkle Tree proof has more than {} ({}) siblings.", + HashValue::LENGTH_IN_BITS, + siblings.len() + ); + + match (element_blob, sparse_merkle_proof.leaf()) { + (Some(blob), Some((proof_key, proof_value_hash))) => { + // This is an inclusion proof, so the key and value hash provided in the proof should + // match element_key and element_value_hash. + ensure!( + element_key == proof_key, + "Keys do not match. Key in proof: {:x}. Expected key: {:x}.", + proof_key, + element_key + ); + let hash = blob.hash(); + ensure!( + hash == proof_value_hash, + "Value hashes do not match. Value hash in proof: {:x}. Expected value hash: {:x}", + proof_value_hash, + hash, + ); + } + (Some(_blob), None) => bail!("Expected inclusion proof. Found non-inclusion proof."), + (None, Some((proof_key, _))) => { + // The proof intends to show that proof_key is the only key in a subtree and + // element_key would have ended up in the same subtree if it existed in the tree. + ensure!( + element_key != proof_key, + "Expected non-inclusion proof, but key exists in proof." + ); + ensure!( + element_key.common_prefix_bits_len(proof_key) >= siblings.len(), + "Key would not have ended up in the subtree where the provided key in proof is \ + the only existing key, if it existed. So this is not a valid non-inclusion proof." + ); + } + (None, None) => (), + } + + let leaf_hash = match sparse_merkle_proof.leaf() { + Some((key, value_hash)) => SparseMerkleLeafNode::new(key, value_hash).hash(), + None => *SPARSE_MERKLE_PLACEHOLDER_HASH, + }; + let actual_root_hash = siblings + .iter() + .rev() + .zip( + element_key + .iter_bits() + .rev() + .skip(HashValue::LENGTH_IN_BITS - siblings.len()), + ) + .fold(leaf_hash, |hash, (sibling_hash, bit)| { + if bit { + SparseMerkleInternalNode::new(*sibling_hash, hash).hash() + } else { + SparseMerkleInternalNode::new(hash, *sibling_hash).hash() + } + }); + ensure!( + actual_root_hash == expected_root_hash, + "Root hashes do not match. Actual root hash: {:x}. Expected root hash: {:x}.", + actual_root_hash, + expected_root_hash + ); + + Ok(()) +} + +pub struct MerkleTreeInternalNode { + left_child: HashValue, + right_child: HashValue, + hasher: PhantomData, +} + +impl MerkleTreeInternalNode { + pub fn new(left_child: HashValue, right_child: HashValue) -> Self { + Self { + left_child, + right_child, + hasher: PhantomData, + } + } +} + +impl CryptoHash for MerkleTreeInternalNode { + type Hasher = H; + + fn hash(&self) -> HashValue { + let mut state = Self::Hasher::default(); + state.write(self.left_child.as_ref()); + state.write(self.right_child.as_ref()); + state.finish() + } +} + +pub type SparseMerkleInternalNode = MerkleTreeInternalNode; +pub type TransactionAccumulatorInternalNode = MerkleTreeInternalNode; +pub type EventAccumulatorInternalNode = MerkleTreeInternalNode; +pub type TestAccumulatorInternalNode = MerkleTreeInternalNode; + +pub struct SparseMerkleLeafNode { + key: HashValue, + value_hash: HashValue, +} + +impl SparseMerkleLeafNode { + pub fn new(key: HashValue, value_hash: HashValue) -> Self { + SparseMerkleLeafNode { key, value_hash } + } +} + +impl CryptoHash for SparseMerkleLeafNode { + type Hasher = SparseMerkleLeafHasher; + + fn hash(&self) -> HashValue { + let mut state = Self::Hasher::default(); + state.write(self.key.as_ref()); + state.write(self.value_hash.as_ref()); + state.finish() + } +} diff --git a/types/src/proof/position.rs b/types/src/proof/position.rs new file mode 100644 index 0000000000000..524551ca22f4c --- /dev/null +++ b/types/src/proof/position.rs @@ -0,0 +1,219 @@ +// Copyright (c) The Libra Core Contributors +// SPDX-License-Identifier: Apache-2.0 + +//! This module provides an abstraction for positioning a node in a binary tree, +//! A `Position` uniquely identifies the location of a node +//! +//! In this implementation, `Position` is represented by the in-order-traversal sequence number +//! of the node. +//! The process of locating a node and jumping between nodes is done through position calculation, +//! which comes from treebits. +//! +//! For example +//! ```text +//! 3 +//! / \ +//! / \ +//! 1 5 <-[Node index, a.k.a, Position] +//! / \ / \ +//! 0 2 4 6 +//! +//! 0 1 2 3 <[Leaf index] +//! ``` +//! Note1: The in-order-traversal counts from 0 +//! Note2: The level of tree counts from leaf level, start from 0 +//! Note3: The leaf index starting from left-most leaf, starts from 0 + +use super::treebits; + +#[derive(Clone, Copy, Debug, Eq, PartialEq, Hash)] +pub struct Position(u64); + +impl Position { + pub fn from_inorder_index(index: u64) -> Self { + Position(index) + } + + pub fn to_inorder_index(self) -> u64 { + self.0 + } + + pub fn get_parent(self) -> Self { + Self::from_inorder_index(treebits::parent(self.0)) + } + + // Note: if self is root, the sibling will overflow + pub fn get_sibling(self) -> Self { + Self::from_inorder_index(treebits::sibling(self.0)) + } + + // Requirement: self can not be leaf. + pub fn get_left_child(self) -> Position { + Self::from_inorder_index(treebits::left_child(self.0)) + } + + // Requirement: self can not be leaf. + pub fn get_right_child(self) -> Position { + Self::from_inorder_index(treebits::right_child(self.0)) + } + + // Note: if self is root, the direction will overflow (and will always be left) + pub fn get_direction_for_self(self) -> treebits::NodeDirection { + treebits::direction_from_parent(self.0) + } + + // The level start from 0 counting from the leaf level + pub fn get_level(self) -> u32 { + treebits::level(self.0) + } + + // Given the position, return the leaf index counting from the left + pub fn to_leaf_index(self) -> u64 { + treebits::pos_counting_from_left(self.0) + } + + // Opposite of get_left_node_count_from_position. + pub fn from_leaf_index(leaf_index: u64) -> Position { + Self::from_inorder_index(treebits::node_from_level_and_pos(0, leaf_index)) + } + + /// Given a position, returns the position next to it on the right on the same level. For + /// example, given input 5 this function should return 9. + /// + /// ```text + /// 3 + /// / \ + /// 1 5 9 + /// / \ / \ / \ + /// 0 2 4 6 8 10 + /// ``` + pub fn get_next_sibling(self) -> Position { + let level = self.get_level(); + let pos = treebits::pos_counting_from_left(self.0); + Position(treebits::node_from_level_and_pos(level, pos + 1)) + } + + // Given a leaf index, calculate the position of a minimum root which contains this leaf + pub fn get_root_position(leaf_index: u64) -> Position { + let leaf = Self::from_leaf_index(leaf_index); + Self::from_inorder_index(treebits::get_root(leaf.0)) + } + + // Given index of right most leaf, calculate if a position is the root + // of a perfect subtree that does not contains placeholder nodes. + pub fn is_freezable(self, leaf_index: u64) -> bool { + let leaf = Self::from_leaf_index(leaf_index); + treebits::is_freezable(self.0, leaf.0) + } + + // Given index of right most leaf, calculate if a position should be a placeholder node at this + // moment + pub fn is_placeholder(self, leaf_index: u64) -> bool { + let leaf = Self::from_leaf_index(leaf_index); + treebits::is_placeholder(self.0, leaf.0) + } + + /// Creates an `AncestorIterator` using this position. + pub fn iter_ancestor(self) -> AncestorIterator { + AncestorIterator { position: self } + } + + /// Creates an `AncestorSiblingIterator` using this position. + pub fn iter_ancestor_sibling(self) -> AncestorSiblingIterator { + AncestorSiblingIterator { position: self } + } +} + +/// `AncestorSiblingIterator` generates current sibling position and moves itself to its parent +/// position for each iteration. +#[derive(Debug)] +pub struct AncestorSiblingIterator { + position: Position, +} + +impl Iterator for AncestorSiblingIterator { + type Item = Position; + + fn next(&mut self) -> Option { + let current_sibling_position = self.position.get_sibling(); + self.position = self.position.get_parent(); + Some(current_sibling_position) + } +} + +/// `AncestorIterator` generates current position and moves itself to its parent position for each +/// iteration. +#[derive(Debug)] +pub struct AncestorIterator { + position: Position, +} + +impl Iterator for AncestorIterator { + type Item = Position; + + fn next(&mut self) -> Option { + let current_position = self.position; + self.position = self.position.get_parent(); + Some(current_position) + } +} + +/// Traverse leaves from left to right in groups that forms full subtrees, yielding root positions +/// of such subtrees. +/// Note that each 1-bit in num_leaves corresponds to a full subtree. +/// For example, in the below tree of 5=0b101 leaves, the two 1-bits corresponds to Fzn2 and L4 +/// accordingly. +/// +/// ```text +/// Non-fzn +/// / \ +/// / \ +/// / \ +/// Fzn2 Non-fzn +/// / \ / \ +/// / \ / \ +/// Fzn1 Fzn3 Non-fzn [Placeholder] +/// / \ / \ / \ +/// L0 L1 L2 L3 L4 [Placeholder] +/// ``` +pub struct FrozenSubTreeIterator { + bitmap: u64, + seen_leaves: u64, +} + +impl FrozenSubTreeIterator { + pub fn new(num_leaves: u64) -> Self { + Self { + bitmap: num_leaves, + seen_leaves: 0, + } + } +} + +impl Iterator for FrozenSubTreeIterator { + type Item = Position; + + fn next(&mut self) -> Option { + if self.bitmap == 0 { + return None; + } + + // Find the remaining biggest full subtree. + // The MSB of the bitmap represents it. For example for a tree of 0x1010=10 leaves, the + // biggest and leftmost full subtree has 0x1000=8 leaves, which can be got by smearing all + // bits after MSB with 1-bits (got 0x1111), right shift once (got 0x0111) and add 1 (got + // 0x1000=8). At the same time, we also observe that the in-order numbering of a full + // subtree root is (num_leaves - 1) greater than that of the leftmost leaf, and also + // (num_leaves - 1) less than that of the rightmost leaf. + let root_offset = treebits::smear_ones_for_u64(self.bitmap) >> 1; + let num_leaves = root_offset + 1; + let leftmost_leaf = Position::from_leaf_index(self.seen_leaves); + let root = Position::from_inorder_index(leftmost_leaf.to_inorder_index() + root_offset); + + // Mark it consumed. + self.bitmap &= !num_leaves; + self.seen_leaves += num_leaves; + + Some(root) + } +} diff --git a/types/src/proof/proptest_proof.rs b/types/src/proof/proptest_proof.rs new file mode 100644 index 0000000000000..c3dec3f6df01a --- /dev/null +++ b/types/src/proof/proptest_proof.rs @@ -0,0 +1,111 @@ +// Copyright (c) The Libra Core Contributors +// SPDX-License-Identifier: Apache-2.0 + +//! All proofs generated in this module are not valid proofs. They are only for the purpose of +//! testing conversion between Rust and Protobuf. + +use crate::{ + proof::{ + AccountStateProof, AccumulatorProof, EventProof, SignedTransactionProof, SparseMerkleProof, + }, + transaction::TransactionInfo, +}; +use crypto::{ + hash::{ACCUMULATOR_PLACEHOLDER_HASH, SPARSE_MERKLE_PLACEHOLDER_HASH}, + HashValue, +}; +use proptest::{collection::vec, prelude::*}; +use rand::{seq::SliceRandom, thread_rng}; + +prop_compose! { + fn arb_accumulator_proof()( + non_default_siblings in vec(any::(), 0..63usize), + total_num_siblings in 0..64usize, + ) -> AccumulatorProof { + let mut siblings = non_default_siblings; + if !siblings.is_empty() { + let total_num_siblings = std::cmp::max(siblings.len(), total_num_siblings); + for _ in siblings.len()..total_num_siblings { + siblings.push(ACCUMULATOR_PLACEHOLDER_HASH.clone()); + } + assert_eq!(siblings.len(), total_num_siblings); + (&mut siblings[1..]).shuffle(&mut thread_rng()); + } + AccumulatorProof::new(siblings) + } +} + +prop_compose! { + fn arb_sparse_merkle_proof()( + leaf in any::>(), + non_default_siblings in vec(any::(), 0..256usize), + total_num_siblings in 0..257usize, + ) -> SparseMerkleProof { + let mut siblings = non_default_siblings; + if !siblings.is_empty() { + let total_num_siblings = std::cmp::max(siblings.len(), total_num_siblings); + for _ in siblings.len()..total_num_siblings { + siblings.insert(0, SPARSE_MERKLE_PLACEHOLDER_HASH.clone()); + } + assert_eq!(siblings.len(), total_num_siblings); + (&mut siblings[0..total_num_siblings-1]).shuffle(&mut thread_rng()); + } + SparseMerkleProof::new(leaf, siblings) + } +} + +prop_compose! { + fn arb_signed_transaction_proof()( + ledger_info_to_transaction_info_proof in any::(), + transaction_info in any::(), + ) -> SignedTransactionProof { + SignedTransactionProof::new(ledger_info_to_transaction_info_proof, transaction_info) + } +} + +prop_compose! { + fn arb_account_state_proof()( + ledger_info_to_transaction_info_proof in any::(), + transaction_info in any::(), + transaction_info_to_account_proof in any::(), + ) -> AccountStateProof { + AccountStateProof::new( + ledger_info_to_transaction_info_proof, + transaction_info, + transaction_info_to_account_proof, + ) + } +} + +prop_compose! { + fn arb_event_proof()( + ledger_info_to_transaction_info_proof in any::(), + transaction_info in any::(), + transaction_info_to_event_proof in any::(), + ) -> EventProof { + EventProof::new( + ledger_info_to_transaction_info_proof, + transaction_info, + transaction_info_to_event_proof, + ) + } +} + +macro_rules! impl_arbitrary_for_proof { + ($proof_type: ident, $arb_func: ident) => { + impl Arbitrary for $proof_type { + type Parameters = (); + type Strategy = BoxedStrategy; + + fn arbitrary_with(_args: Self::Parameters) -> Self::Strategy { + $arb_func().boxed() + } + } + }; +} + +impl_arbitrary_for_proof!(AccumulatorProof, arb_accumulator_proof); +impl_arbitrary_for_proof!(SparseMerkleProof, arb_sparse_merkle_proof); +impl_arbitrary_for_proof!(SignedTransactionProof, arb_signed_transaction_proof); +impl_arbitrary_for_proof!(AccountStateProof, arb_account_state_proof); +impl_arbitrary_for_proof!(EventProof, arb_event_proof); diff --git a/types/src/proof/treebits.rs b/types/src/proof/treebits.rs new file mode 100644 index 0000000000000..ea419bc83214b --- /dev/null +++ b/types/src/proof/treebits.rs @@ -0,0 +1,297 @@ +// Copyright (c) The Libra Core Contributors +// SPDX-License-Identifier: Apache-2.0 + +//! This module provides functions to manipulate a perfect binary tree. Nodes +//! in the tree are represented by the order they would be visited in an +//! in-order traversal. For example -- +//! ```text +//! 3 +//! / \ +//! / \ +//! 1 5 +//! / \ / \ +//! 0 2 4 6 +//! ``` +//! +//! This module can answer questions like "what is the level of 5" +//! (`level(5)=1`), "what is the right child of 3" `right_child(3)=5` +#[derive(Debug, Eq, PartialEq)] +pub enum NodeDirection { + Left, + Right, +} + +/// What level is this node in the tree, 0 if the node is a leaf, +/// 1 if the level is one above a leaf, etc. +pub fn level(node: u64) -> u32 { + (!node).trailing_zeros() +} + +fn is_leaf(node: u64) -> bool { + node & 1 == 0 +} + +/// What position is the node within the level? i.e. how many nodes +/// are to the left of this node at the same level +pub fn pos_counting_from_left(node: u64) -> u64 { + node >> (level(node) + 1) +} + +/// pos count start from 0 on each level +pub fn node_from_level_and_pos(level: u32, pos: u64) -> u64 { + let level_one_bits = (1u64 << level) - 1; + let shifted_pos = pos << (level + 1); + shifted_pos | level_one_bits +} + +/// What is the parent of this node? +pub fn parent(node: u64) -> u64 { + (node | isolate_rightmost_zero_bit(node)) & !(isolate_rightmost_zero_bit(node) << 1) +} + +/// What is the left node of this node? Will overflow if the node is a leaf +pub fn left_child(node: u64) -> u64 { + child(node, NodeDirection::Left) +} + +/// What is the right node of this node? Will overflow if the node is a leaf +pub fn right_child(node: u64) -> u64 { + child(node, NodeDirection::Right) +} + +pub fn child(node: u64, dir: NodeDirection) -> u64 { + assert!(!is_leaf(node)); + + let direction_bit = match dir { + NodeDirection::Left => 0, + NodeDirection::Right => isolate_rightmost_zero_bit(node), + }; + (node | direction_bit) & !(isolate_rightmost_zero_bit(node) >> 1) +} + +/// This method takes in a node position and return NodeDirection based on if it's left or right +/// child Similar to sibling. The observation is that, +/// after strip out the right-most common bits, +/// if next right-most bits is 0, it is left child. Otherwise, right child +pub fn direction_from_parent(node: u64) -> NodeDirection { + match node & (isolate_rightmost_zero_bit(node) << 1) { + 0 => NodeDirection::Left, + _ => NodeDirection::Right, + } +} + +/// This method takes in a node position and return its sibling position +/// +/// The observation is that, after stripping out the right-most common bits, +/// two sibling nodes flip the the next right-most bits with each other. +/// To find out the right-most common bits, first remove all the right-most ones +/// because they are corresponding to level's indicator. Then remove next zero right after. +pub fn sibling(node: u64) -> u64 { + node ^ (isolate_rightmost_zero_bit(node) << 1) +} + +/// This method calculates the index of the smallest root which contains this leaf. +/// Observe that, the root position is composed by a "height" number of ones +/// +/// For example +/// ```text +/// 0010010(node) +/// 0011111(smearing) +/// ------- +/// 0001111(root) +/// ``` +pub fn get_root(node: u64) -> u64 { + smear_ones_for_u64(node) >> 1 +} + +/// Smearing all the bits starting from MSB with ones +pub fn smear_ones_for_u64(v: u64) -> u64 { + let mut n = v; + n |= n >> 1; + n |= n >> 2; + n |= n >> 4; + n |= n >> 8; + n |= n >> 16; + n |= n >> 32; + n +} + +/// Returns the number of children a node `level` nodes high in a perfect +/// binary tree has. +/// +/// Recursively, +/// +/// children_from_level(0) = 0 +/// children_from_level(n) = 2 * (1 + children(n-1)) +/// +/// But expanding the series this can be computed non-recursively +/// sum 2^n, n=1 to x = 2^(x+1) - 2 +pub fn children_from_level(level: u32) -> u64 { + (1u64 << (level + 1)) - 2 +} + +pub fn children_of_node(node: u64) -> u64 { + (isolate_rightmost_zero_bit(node) << 1) - 2 +} + +/// Finds the rightmost 0-bit, turns off all bits, and sets this bit to 1 in +/// the result. For example: +/// +/// ```text +/// 01110111 (x) +/// -------- +/// 10001000 (~x) +/// & 01111000 (x+1) +/// -------- +/// 00001000 +/// ``` +/// http://www.catonmat.net/blog/low-level-bit-hacks-you-absolutely-must-know/ +fn isolate_rightmost_zero_bit(v: u64) -> u64 { + !v & (v + 1) +} + +/// In a post-order tree traversal, how many nodes are traversed before `node` +/// not including nodes that are children of `node`. +pub fn nodes_to_left_of(node: u64) -> u64 { + // If node = 0b0100111, ones_up_to_level = 0b111 + let ones_up_to_level = isolate_rightmost_zero_bit(node) - 1; + // Unset all the 1s due to the level + let unset_level_zeros = node ^ ones_up_to_level; + + // What remains is a 1 bit set every time a node navigated right + // For example, consider node=5=0b101. unset_level_zeros=0b100. + // the 1 bit in unset_level_zeros at position 2 represents the + // fact that 5 is the right child at the level 1. At this level + // there are 2^2 - 1 children on the left hand side. + // + // So what we do is subtract the count of one bits from unset_level_zeros + // to account for the fact that if the node is the right child at level + // n that there are 2^n - 1 children. + unset_level_zeros - u64::from(unset_level_zeros.count_ones()) +} + +/// This method checks if a node is freezable. +/// A freezable is a node with a perfect subtree that does not include any placeholder node. +/// +/// First find its right most child +/// the right most child of any node will be at leaf level, which will be a either placeholder node +/// or leaf node. if right most child is a leaf node, then it is freezable. +/// if right most child is larger than max_leaf_node, it is a placeholder node, and not freezable. +pub fn is_freezable(node: u64, max_leaf_node: u64) -> bool { + let right_most_child = right_most_child(node); + right_most_child <= max_leaf_node +} + +/// This method checks if a node is a placeholder node. +/// A node is a placeholder if both two conditions below are true: +/// 1, the node's in order traversal seq > max_leaf_node's, and +/// 2, the node does not have left child or right child. +pub fn is_placeholder(node: u64, max_leaf_node: u64) -> bool { + if node <= max_leaf_node { + return false; + } + if left_most_child(node) <= max_leaf_node { + return false; + } + true +} + +/// Given a node, find its right most child in its subtree. +/// Right most child is a node, could be itself, at level 0 +pub fn right_most_child(node: u64) -> u64 { + let level = level(node); + node + (1_u64 << level) - 1 +} + +/// Given a node, find its left most child in its subtree +/// Left most child is a node, could be itself, at level 0 +pub fn left_most_child(node: u64) -> u64 { + // Turn off its right most x bits. while x=level of node + let level = level(node); + turn_off_right_most_n_bits(node, level) +} + +/// Turn off n right most bits +/// +/// For example +/// ```text +/// 00010010101 +/// ----------- +/// 00010010100 n=1 +/// 00010010000 n=3 +/// ``` +fn turn_off_right_most_n_bits(v: u64, n: u32) -> u64 { + (v >> n) << n +} + +/// Given `node`, an index in an in-order traversal of a perfect binary tree, +/// what order would the node be visited in in post-order traversal? +/// For example, consider this tree of in-order nodes. +/// +/// ```text +/// 3 +/// / \ +/// / \ +/// 1 5 +/// / \ / \ +/// 0 2 4 6 +/// ``` +/// +/// The post-order ordering of the nodes is: +/// ```text +/// 6 +/// / \ +/// / \ +/// 2 5 +/// / \ / \ +/// 0 1 3 4 +/// ``` +/// +/// post_order_index(1) == 2 +/// post_order_index(4) == 3 +pub fn post_order_index(node: u64) -> u64 { + let children = children_of_node(node); + let left_nodes = nodes_to_left_of(node); + + children + left_nodes +} + +/// Defines an order for the nodes optimized for writing to disk. In this order +/// all of the nodes with the same level(node)/levels_to_collapse and a common +/// parent will share the same Disk Write Order. The Disk Write Order +/// of a node will be greater than that of any of it's children. In this tree: +/// ```text +/// 3 +/// / \ +/// / \ +/// 1 5 +/// / \ / \ +/// 0 2 4 6 +/// ``` +/// With `levels_to_collapse=2`, 0, 1, 2 have a DWO of 0; 4,5,6 a DWO of 1. +/// 3 will have a DWO of 4 to account for the fact that the nodes with DWOs of +/// 2 and 3 will have parents that have a DWO of 4. +pub fn disk_write_order(node: u64, levels_to_collapse: u32) -> u64 { + let new_level = level(node) / levels_to_collapse; + let new_pos = pos_counting_from_left(node) + >> (levels_to_collapse - 1 - (level(node) - new_level * levels_to_collapse)); + let children_this_level = children_from_level_nary(new_level, levels_to_collapse); + + (1 + new_pos) * children_this_level + new_pos + (new_pos >> u64::from(levels_to_collapse)) +} + +/// In a perfect (2^levels_to_collapse)-ary tree, how many children are there in +/// level `level` +pub fn children_from_level_nary(level: u32, levels_to_collapse: u32) -> u64 { + if level == 0 { + 0 + } else { + let two_pow_levels_collapse = 1u64 << levels_to_collapse; + // (2^level)^levels_to_collapse = 2^(level*levels_to_collapse) + let two_pow_levels_collapse_times_level = 1u64 << (level * levels_to_collapse); + + // Sum[a^n, {n, 1, x}] = (a (a^x-1))/(a-1) + (two_pow_levels_collapse * (two_pow_levels_collapse_times_level - 1)) + / (two_pow_levels_collapse - 1) + } +} diff --git a/types/src/proof/unit_tests/mod.rs b/types/src/proof/unit_tests/mod.rs new file mode 100644 index 0000000000000..1b05b0d3150ec --- /dev/null +++ b/types/src/proof/unit_tests/mod.rs @@ -0,0 +1,5 @@ +// Copyright (c) The Libra Core Contributors +// SPDX-License-Identifier: Apache-2.0 + +mod position_test; +mod treebits_test; diff --git a/types/src/proof/unit_tests/position_test.rs b/types/src/proof/unit_tests/position_test.rs new file mode 100644 index 0000000000000..fc645efb87fd0 --- /dev/null +++ b/types/src/proof/unit_tests/position_test.rs @@ -0,0 +1,321 @@ +// Copyright (c) The Libra Core Contributors +// SPDX-License-Identifier: Apache-2.0 + +use crate::proof::{ + position::*, + treebits::{pos_counting_from_left, NodeDirection}, +}; + +/// Position is marked with in-order-traversal sequence. +/// +/// For example +/// ```text +/// 0 +/// ``` +/// +/// Another example +/// ```text +/// 3 +/// / \ +/// / \ +/// 1 5 +/// / \ / \ +/// 0 2 4 6 +/// ``` +#[test] +fn test_position_get_parent() { + let position = Position::from_inorder_index(5); + let target = position.get_parent(); + assert_eq!(target, Position::from_inorder_index(3)); +} + +#[test] +fn test_position_get_sibling_right() { + let position = Position::from_inorder_index(5); + let target = position.get_sibling(); + assert_eq!(target, Position::from_inorder_index(1)); +} + +#[test] +fn test_position_get_sibling_left() { + let position = Position::from_inorder_index(4); + let target = position.get_sibling(); + assert_eq!(target, Position::from_inorder_index(6)); +} + +#[test] +fn test_position_get_left_child() { + let position = Position::from_inorder_index(5); + let target = position.get_left_child(); + assert_eq!(target, Position::from_inorder_index(4)); +} + +#[test] +fn test_position_get_right_child() { + let position = Position::from_inorder_index(5); + let target = position.get_right_child(); + assert_eq!(target, Position::from_inorder_index(6)); +} + +#[test] +#[should_panic] +fn test_position_get_left_child_from_leaf() { + let position = Position::from_inorder_index(0); + let _target = position.get_left_child(); +} +#[test] +#[should_panic] +fn test_position_get_right_child_from_leaf() { + let position = Position::from_inorder_index(0); + let _target = position.get_right_child(); +} + +#[test] +fn test_position_get_level() { + let mut position = Position::from_inorder_index(5); + let level = position.get_level(); + assert_eq!(level, 1); + + position = Position::from_inorder_index(0); + let level = position.get_level(); + assert_eq!(level, 0); +} + +#[test] +fn test_position_get_next_sibling() { + for i in 0..1000 { + let left_position = Position::from_inorder_index(i); + let position = left_position.get_next_sibling(); + assert_eq!(left_position.get_level(), position.get_level()); + assert_eq!( + pos_counting_from_left(left_position.to_inorder_index()) + 1, + pos_counting_from_left(position.to_inorder_index()) + ); + } +} + +#[test] +fn test_position_get_direction() { + assert_eq!( + Position::from_inorder_index(5).get_direction_for_self(), + NodeDirection::Right + ); + assert_eq!( + Position::from_inorder_index(6).get_direction_for_self(), + NodeDirection::Right + ); + assert_eq!( + Position::from_inorder_index(2).get_direction_for_self(), + NodeDirection::Right + ); + assert_eq!( + Position::from_inorder_index(11).get_direction_for_self(), + NodeDirection::Right + ); + assert_eq!( + Position::from_inorder_index(13).get_direction_for_self(), + NodeDirection::Right + ); + assert_eq!( + Position::from_inorder_index(14).get_direction_for_self(), + NodeDirection::Right + ); + assert_eq!( + Position::from_inorder_index(10).get_direction_for_self(), + NodeDirection::Right + ); + assert_eq!( + Position::from_inorder_index(1).get_direction_for_self(), + NodeDirection::Left + ); + assert_eq!( + Position::from_inorder_index(0).get_direction_for_self(), + NodeDirection::Left + ); + assert_eq!( + Position::from_inorder_index(3).get_direction_for_self(), + NodeDirection::Left + ); + assert_eq!( + Position::from_inorder_index(8).get_direction_for_self(), + NodeDirection::Left + ); + assert_eq!( + Position::from_inorder_index(12).get_direction_for_self(), + NodeDirection::Left + ); +} + +#[test] +fn test_position_get_direction_from_root() { + assert_eq!( + Position::from_inorder_index(7).get_direction_for_self(), + NodeDirection::Left + ); +} + +#[test] +fn test_position_get_root_position() { + let target = Position::get_root_position(6); + assert_eq!(target, Position::from_inorder_index(7)); + + let target = Position::get_root_position(0); + assert_eq!(target, Position::from_inorder_index(0)); + + let target = Position::get_root_position(3); + assert_eq!(target, Position::from_inorder_index(3)); +} + +#[test] +fn test_is_freezable() { + let mut position = Position::from_inorder_index(5); + assert_eq!(position.is_freezable(2), false); + assert_eq!(position.is_freezable(3), true); + assert_eq!(position.is_freezable(4), true); + + position = Position::from_inorder_index(0); + assert_eq!(position.is_freezable(0), true); + assert_eq!(position.is_freezable(3), true); + assert_eq!(position.is_freezable(4), true); + + // Testing a root + position = Position::from_inorder_index(7); + assert_eq!(position.is_freezable(6), false); + assert_eq!(position.is_freezable(7), true); + assert_eq!(position.is_freezable(8), true); + + // Testing a leaf + position = Position::from_inorder_index(10); + assert_eq!(position.is_freezable(5), true); +} + +#[test] +fn test_is_freezable_out_of_boundary() { + // Testing out of boundary + let position = Position::from_inorder_index(10); + assert_eq!(position.is_freezable(2), false); +} + +#[test] +fn test_is_placeholder() { + assert_eq!(Position::from_inorder_index(5).is_placeholder(0), true); + assert_eq!(Position::from_inorder_index(5).is_placeholder(1), true); + assert_eq!(Position::from_inorder_index(5).is_placeholder(2), false); + assert_eq!(Position::from_inorder_index(5).is_placeholder(3), false); + assert_eq!(Position::from_inorder_index(13).is_placeholder(5), true); + assert_eq!(Position::from_inorder_index(13).is_placeholder(6), false); +} + +#[test] +fn test_is_placeholder_out_of_boundary() { + // Testing out of boundary + assert_eq!(Position::from_inorder_index(7).is_placeholder(2), false); + assert_eq!(Position::from_inorder_index(11).is_placeholder(2), true); + assert_eq!(Position::from_inorder_index(14).is_placeholder(2), true); +} + +#[test] +pub fn test_sibling_sequence() { + let sibling_sequence1 = Position::from_inorder_index(0) + .iter_ancestor_sibling() + .take(20) + .map(Position::to_inorder_index) + .collect::>(); + assert_eq!( + sibling_sequence1, + vec![ + 2, 5, 11, 23, 47, 95, 191, 383, 767, 1535, 3071, 6143, 12287, 24575, 49151, 98303, + 196_607, 393_215, 786_431, 1_572_863 + ] + ); + + let sibling_sequence2 = Position::from_inorder_index(6) + .iter_ancestor_sibling() + .take(20) + .map(Position::to_inorder_index) + .collect::>(); + assert_eq!( + sibling_sequence2, + vec![ + 4, 1, 11, 23, 47, 95, 191, 383, 767, 1535, 3071, 6143, 12287, 24575, 49151, 98303, + 196_607, 393_215, 786_431, 1_572_863 + ] + ); + + let sibling_sequence3 = Position::from_inorder_index(7) + .iter_ancestor_sibling() + .take(20) + .map(Position::to_inorder_index) + .collect::>(); + assert_eq!( + sibling_sequence3, + vec![ + 23, 47, 95, 191, 383, 767, 1535, 3071, 6143, 12287, 24575, 49151, 98303, 196_607, + 393_215, 786_431, 1_572_863, 3_145_727, 6_291_455, 12_582_911 + ] + ); +} + +#[test] +pub fn test_parent_sequence() { + let parent_sequence1 = Position::from_inorder_index(0) + .iter_ancestor() + .take(20) + .map(Position::to_inorder_index) + .collect::>(); + assert_eq!( + parent_sequence1, + vec![ + 0, 1, 3, 7, 15, 31, 63, 127, 255, 511, 1023, 2047, 4095, 8191, 16383, 32767, 65535, + 131_071, 262_143, 524_287 + ] + ); + + let parent_sequence2 = Position::from_inorder_index(12) + .iter_ancestor() + .take(20) + .map(Position::to_inorder_index) + .collect::>(); + assert_eq!( + parent_sequence2, + vec![ + 12, 13, 11, 7, 15, 31, 63, 127, 255, 511, 1023, 2047, 4095, 8191, 16383, 32767, 65535, + 131_071, 262_143, 524_287 + ] + ); +} + +fn slow_get_frozen_subtree_roots_impl(root: Position, max_leaf_index: u64) -> Vec { + if root.is_freezable(max_leaf_index) { + vec![root] + } else if root.is_placeholder(max_leaf_index) { + Vec::new() + } else { + let mut roots = slow_get_frozen_subtree_roots_impl(root.get_left_child(), max_leaf_index); + roots.extend(slow_get_frozen_subtree_roots_impl( + root.get_right_child(), + max_leaf_index, + )); + roots + } +} + +fn slow_get_frozen_subtree_roots(num_leaves: u64) -> Vec { + if num_leaves == 0 { + Vec::new() + } else { + let max_leaf_index = num_leaves - 1; + let root = Position::get_root_position(max_leaf_index); + slow_get_frozen_subtree_roots_impl(root, max_leaf_index) + } +} + +#[test] +fn test_frozen_subtree_iterator() { + for n in 0..10000 { + assert_eq!( + FrozenSubTreeIterator::new(n).collect::>(), + slow_get_frozen_subtree_roots(n), + ); + } +} diff --git a/types/src/proof/unit_tests/proof_proto_conversion_test.rs b/types/src/proof/unit_tests/proof_proto_conversion_test.rs new file mode 100644 index 0000000000000..aad1a29280559 --- /dev/null +++ b/types/src/proof/unit_tests/proof_proto_conversion_test.rs @@ -0,0 +1,330 @@ +// Copyright (c) The Libra Core Contributors +// SPDX-License-Identifier: Apache-2.0 + +use crate::proof::{ + definition::bitmap::{AccumulatorBitmap, SparseMerkleBitmap}, + AccountStateProof, AccumulatorProof, EventProof, SignedTransactionProof, SparseMerkleProof, +}; +use crypto::{ + hash::{TestOnlyHash, ACCUMULATOR_PLACEHOLDER_HASH, SPARSE_MERKLE_PLACEHOLDER_HASH}, + HashValue, +}; +use proptest::{collection::vec, prelude::*}; +use proto_conv::{test_helper::assert_protobuf_encode_decode, FromProto, IntoProto}; + +fn accumulator_bitmap_iterator_test(bitmap_value: u64, expected_bits: Vec) { + let bitmap = AccumulatorBitmap::new(bitmap_value); + let bits: Vec<_> = bitmap.iter().collect(); + assert_eq!(bits, expected_bits); + let bitmap2: AccumulatorBitmap = bits.into_iter().collect(); + let bitmap_value2: u64 = bitmap2.into(); + assert_eq!(bitmap_value, bitmap_value2); +} + +#[test] +fn test_accumulator_bitmap() { + accumulator_bitmap_iterator_test(0b0, vec![]); + accumulator_bitmap_iterator_test(0b1, vec![true]); + accumulator_bitmap_iterator_test(0b10_1101, vec![true, false, true, true, false, true]); +} + +fn sparse_merkle_bitmap_iterator_test(bitmap_value: Vec, expected_bits: Vec) { + let bitmap = SparseMerkleBitmap::new(bitmap_value.clone()); + let bits: Vec<_> = bitmap.iter().collect(); + assert_eq!(bits, expected_bits); + let bitmap2: SparseMerkleBitmap = bits.into_iter().collect(); + let bitmap_value2: Vec<_> = bitmap2.into(); + assert_eq!(bitmap_value, bitmap_value2); +} + +#[test] +fn test_sparse_merkle_bitmap() { + sparse_merkle_bitmap_iterator_test(vec![], vec![]); + sparse_merkle_bitmap_iterator_test(vec![0b1000_0000], vec![true]); + sparse_merkle_bitmap_iterator_test(vec![0b0100_0000], vec![false, true]); + sparse_merkle_bitmap_iterator_test(vec![0b1001_0000], vec![true, false, false, true]); + sparse_merkle_bitmap_iterator_test( + vec![0b0001_0011], + vec![false, false, false, true, false, false, true, true], + ); + sparse_merkle_bitmap_iterator_test( + vec![0b0001_0011, 0b0010_0000], + vec![ + false, false, false, true, false, false, true, true, false, false, true, + ], + ); + sparse_merkle_bitmap_iterator_test( + vec![0b1001_0011, 0b0010_0011], + vec![ + true, false, false, true, false, false, true, true, false, false, true, false, false, + false, true, true, + ], + ); +} + +fn accumulator_proof_protobuf_conversion_test( + siblings: Vec, + expected_bitmap: u64, + expected_num_non_default_siblings: usize, +) { + let proof = AccumulatorProof::new(siblings); + let compressed_proof = proof.clone().into_proto(); + assert_eq!(compressed_proof.get_bitmap(), expected_bitmap); + assert_eq!( + compressed_proof.get_non_default_siblings().len(), + expected_num_non_default_siblings + ); + let decompressed_proof = AccumulatorProof::from_proto(compressed_proof).unwrap(); + assert_eq!(decompressed_proof, proof); +} + +#[test] +fn test_convert_accumulator_proof_to_protobuf() { + accumulator_proof_protobuf_conversion_test(vec![], 0b0, 0); + accumulator_proof_protobuf_conversion_test(vec![b"0".test_only_hash()], 0b1, 1); + accumulator_proof_protobuf_conversion_test( + vec![ + b"0".test_only_hash(), + b"1".test_only_hash(), + b"2".test_only_hash(), + ], + 0b111, + 3, + ); + accumulator_proof_protobuf_conversion_test( + vec![ + b"0".test_only_hash(), + *ACCUMULATOR_PLACEHOLDER_HASH, + b"2".test_only_hash(), + ], + 0b101, + 2, + ); + accumulator_proof_protobuf_conversion_test( + vec![ + b"0".test_only_hash(), + *ACCUMULATOR_PLACEHOLDER_HASH, + *ACCUMULATOR_PLACEHOLDER_HASH, + ], + 0b100, + 1, + ); +} + +#[test] +fn test_convert_accumulator_proof_wrong_number_of_siblings() { + let sibling0 = b"0".test_only_hash(); + let sibling1 = b"1".test_only_hash(); + + let mut compressed_proof = crate::proto::proof::AccumulatorProof::new(); + compressed_proof.set_bitmap(0b100); + compressed_proof + .mut_non_default_siblings() + .push(sibling0.to_vec()); + compressed_proof + .mut_non_default_siblings() + .push(sibling1.to_vec()); + assert!(AccumulatorProof::from_proto(compressed_proof).is_err()); +} + +#[test] +fn test_convert_accumulator_proof_malformed_hashes() { + let mut sibling0 = b"0".test_only_hash().to_vec(); + sibling0.push(1); + + let mut compressed_proof = crate::proto::proof::AccumulatorProof::new(); + compressed_proof.set_bitmap(0b100); + compressed_proof.mut_non_default_siblings().push(sibling0); + assert!(AccumulatorProof::from_proto(compressed_proof).is_err()); +} + +fn sparse_merkle_proof_protobuf_conversion_test( + leaf: Option<(HashValue, HashValue)>, + siblings: Vec, + expected_bitmap: Vec, + expected_num_non_default_siblings: usize, +) { + let proof = SparseMerkleProof::new(leaf, siblings); + let compressed_proof = proof.clone().into_proto(); + assert_eq!(expected_bitmap, compressed_proof.get_bitmap()); + assert_eq!( + compressed_proof.get_non_default_siblings().len(), + expected_num_non_default_siblings + ); + let decompressed_proof = SparseMerkleProof::from_proto(compressed_proof).unwrap(); + assert_eq!(decompressed_proof, proof); +} + +#[test] +fn test_convert_sparse_merkle_proof_to_protobuf() { + sparse_merkle_proof_protobuf_conversion_test(None, vec![], vec![], 0); + sparse_merkle_proof_protobuf_conversion_test( + None, + vec![b"0".test_only_hash()], + vec![0b1000_0000], + 1, + ); + sparse_merkle_proof_protobuf_conversion_test( + None, + vec![ + b"0".test_only_hash(), + b"1".test_only_hash(), + b"2".test_only_hash(), + ], + vec![0b1110_0000], + 3, + ); + sparse_merkle_proof_protobuf_conversion_test( + None, + vec![*SPARSE_MERKLE_PLACEHOLDER_HASH, b"1".test_only_hash()], + vec![0b0100_0000], + 1, + ); + sparse_merkle_proof_protobuf_conversion_test( + None, + vec![ + b"0".test_only_hash(), + *SPARSE_MERKLE_PLACEHOLDER_HASH, + b"2".test_only_hash(), + ], + vec![0b1010_0000], + 2, + ); + sparse_merkle_proof_protobuf_conversion_test( + None, + vec![ + b"0".test_only_hash(), + *SPARSE_MERKLE_PLACEHOLDER_HASH, + *SPARSE_MERKLE_PLACEHOLDER_HASH, + *SPARSE_MERKLE_PLACEHOLDER_HASH, + *SPARSE_MERKLE_PLACEHOLDER_HASH, + *SPARSE_MERKLE_PLACEHOLDER_HASH, + *SPARSE_MERKLE_PLACEHOLDER_HASH, + b"7".test_only_hash(), + ], + vec![0b1000_0001], + 2, + ); + sparse_merkle_proof_protobuf_conversion_test( + None, + vec![ + b"0".test_only_hash(), + *SPARSE_MERKLE_PLACEHOLDER_HASH, + *SPARSE_MERKLE_PLACEHOLDER_HASH, + *SPARSE_MERKLE_PLACEHOLDER_HASH, + *SPARSE_MERKLE_PLACEHOLDER_HASH, + *SPARSE_MERKLE_PLACEHOLDER_HASH, + *SPARSE_MERKLE_PLACEHOLDER_HASH, + b"7".test_only_hash(), + b"8".test_only_hash(), + ], + vec![0b1000_0001, 0b1000_0000], + 3, + ); + sparse_merkle_proof_protobuf_conversion_test( + Some((HashValue::random(), HashValue::random())), + vec![b"0".test_only_hash()], + vec![0b1000_0000], + 1, + ); +} + +#[test] +fn test_convert_sparse_merkle_proof_wrong_number_of_siblings() { + let sibling0 = b"0".test_only_hash(); + let sibling1 = b"1".test_only_hash(); + + let mut compressed_proof = crate::proto::proof::SparseMerkleProof::new(); + compressed_proof.mut_bitmap().push(0b1000_0000); + compressed_proof + .mut_non_default_siblings() + .push(sibling0.to_vec()); + compressed_proof + .mut_non_default_siblings() + .push(sibling1.to_vec()); + assert!(SparseMerkleProof::from_proto(compressed_proof).is_err()); +} + +#[test] +fn test_convert_sparse_merkle_proof_malformed_hashes() { + let mut sibling0 = b"0".test_only_hash().to_vec(); + sibling0.push(1); + + let mut compressed_proof = crate::proto::proof::SparseMerkleProof::new(); + compressed_proof.mut_bitmap().push(0b1000_0000); + compressed_proof.mut_non_default_siblings().push(sibling0); + assert!(SparseMerkleProof::from_proto(compressed_proof).is_err()); +} + +#[test] +fn test_convert_sparse_merkle_proof_malformed_leaf() { + let sibling0 = b"0".test_only_hash().to_vec(); + + let mut compressed_proof = crate::proto::proof::SparseMerkleProof::new(); + compressed_proof.set_leaf(vec![1, 2, 3]); + compressed_proof.mut_bitmap().push(0b1000_0000); + compressed_proof.mut_non_default_siblings().push(sibling0); + assert!(SparseMerkleProof::from_proto(compressed_proof).is_err()); +} + +proptest! { + #[test] + fn test_accumulator_bitmap_iterator_roundtrip(value in any::()) { + let bitmap = AccumulatorBitmap::new(value); + let iter = bitmap.iter(); + let bitmap2 = iter.collect(); + prop_assert_eq!(bitmap, bitmap2); + } + + #[test] + fn test_accumulator_bitmap_iterator_inverse_roundtrip(mut value in vec(any::(), 0..63)) { + value.insert(0, true); + let bitmap: AccumulatorBitmap = value.iter().cloned().collect(); + let value2: Vec<_> = bitmap.iter().collect(); + prop_assert_eq!(value, value2); + } + + #[test] + fn test_sparse_merkle_bitmap_iterator_roundtrip(mut value in vec(any::(), 0..64)) { + if !value.is_empty() && *value.last().unwrap() == 0 { + *value.last_mut().unwrap() |= 0b100; + } + let bitmap = SparseMerkleBitmap::new(value); + let iter = bitmap.iter(); + let bitmap2 = iter.collect(); + prop_assert_eq!(bitmap, bitmap2); + } + + #[test] + fn test_sparse_merkle_bitmap_iterator_inverse_roundtrip(mut value in vec(any::(), 0..255)) { + value.push(true); + let bitmap: SparseMerkleBitmap = value.iter().cloned().collect(); + let value2: Vec<_> = bitmap.iter().collect(); + prop_assert_eq!(value, value2); + } + + #[test] + fn test_accumulator_protobuf_conversion_roundtrip(proof in any::()) { + assert_protobuf_encode_decode(&proof); + } + + #[test] + fn test_sparse_merkle_protobuf_conversion_roundtrip(proof in any::()) { + assert_protobuf_encode_decode(&proof); + } + + #[test] + fn test_signed_transaction_proof_protobuf_conversion_roundtrip(proof in any::()) { + assert_protobuf_encode_decode(&proof); + } + + #[test] + fn test_account_state_proof_protobuf_conversion_roundtrip(proof in any::()) { + assert_protobuf_encode_decode(&proof); + } + + #[test] + fn test_event_proof_protobuf_conversion_roundtrip(proof in any::()) { + assert_protobuf_encode_decode(&proof); + } +} diff --git a/types/src/proof/unit_tests/proof_test.rs b/types/src/proof/unit_tests/proof_test.rs new file mode 100644 index 0000000000000..7c83ddb15bd1a --- /dev/null +++ b/types/src/proof/unit_tests/proof_test.rs @@ -0,0 +1,598 @@ +// Copyright (c) The Libra Core Contributors +// SPDX-License-Identifier: Apache-2.0 + +use crate::{ + account_address::AccountAddress, + account_state_blob::AccountStateBlob, + ledger_info::LedgerInfo, + proof::{ + verify_account_state, verify_event, verify_signed_transaction, + verify_sparse_merkle_element, verify_test_accumulator_element, AccountStateProof, + AccumulatorProof, EventAccumulatorInternalNode, EventProof, MerkleTreeInternalNode, + SignedTransactionProof, SparseMerkleInternalNode, SparseMerkleLeafNode, SparseMerkleProof, + TestAccumulatorInternalNode, TransactionAccumulatorInternalNode, + }, + transaction::{ + Program, RawTransaction, SignedTransaction, TransactionInfo, TransactionListWithProof, + }, +}; +use crypto::{ + hash::{ + CryptoHash, TestOnlyHash, TransactionAccumulatorHasher, ACCUMULATOR_PLACEHOLDER_HASH, + GENESIS_BLOCK_ID, SPARSE_MERKLE_PLACEHOLDER_HASH, + }, + signing::generate_keypair, + HashValue, +}; +use proptest::{collection::vec, prelude::*}; + +#[test] +fn test_verify_empty_accumulator() { + let element_hash = b"hello".test_only_hash(); + let root_hash = *ACCUMULATOR_PLACEHOLDER_HASH; + let proof = AccumulatorProof::new(vec![]); + assert!(verify_test_accumulator_element(root_hash, element_hash, 0, &proof).is_err()); +} + +#[test] +fn test_verify_single_element_accumulator() { + let element_hash = b"hello".test_only_hash(); + let root_hash = element_hash; + let proof = AccumulatorProof::new(vec![]); + assert!(verify_test_accumulator_element(root_hash, element_hash, 0, &proof).is_ok()); +} + +#[test] +fn test_verify_two_element_accumulator() { + let element0_hash = b"hello".test_only_hash(); + let element1_hash = b"world".test_only_hash(); + let root_hash = TestAccumulatorInternalNode::new(element0_hash, element1_hash).hash(); + + assert!(verify_test_accumulator_element( + root_hash, + element0_hash, + 0, + &AccumulatorProof::new(vec![element1_hash]), + ) + .is_ok()); + assert!(verify_test_accumulator_element( + root_hash, + element1_hash, + 1, + &AccumulatorProof::new(vec![element0_hash]), + ) + .is_ok()); +} + +#[test] +fn test_verify_three_element_accumulator() { + let element0_hash = b"hello".test_only_hash(); + let element1_hash = b"world".test_only_hash(); + let element2_hash = b"!".test_only_hash(); + let internal0_hash = TestAccumulatorInternalNode::new(element0_hash, element1_hash).hash(); + let internal1_hash = + TestAccumulatorInternalNode::new(element2_hash, *ACCUMULATOR_PLACEHOLDER_HASH).hash(); + let root_hash = TestAccumulatorInternalNode::new(internal0_hash, internal1_hash).hash(); + + assert!(verify_test_accumulator_element( + root_hash, + element0_hash, + 0, + &AccumulatorProof::new(vec![internal1_hash, element1_hash]), + ) + .is_ok()); + assert!(verify_test_accumulator_element( + root_hash, + element1_hash, + 1, + &AccumulatorProof::new(vec![internal1_hash, element0_hash]), + ) + .is_ok()); + assert!(verify_test_accumulator_element( + root_hash, + element2_hash, + 2, + &AccumulatorProof::new(vec![internal0_hash, *ACCUMULATOR_PLACEHOLDER_HASH]), + ) + .is_ok()); +} + +#[test] +fn test_accumulator_proof_63_siblings_leftmost() { + let element_hash = b"hello".test_only_hash(); + let mut siblings = vec![]; + for i in 0..63 { + siblings.push(HashValue::new([i; 32])); + } + let root_hash = siblings + .iter() + .rev() + .fold(element_hash, |hash, sibling_hash| { + TestAccumulatorInternalNode::new(hash, *sibling_hash).hash() + }); + let proof = AccumulatorProof::new(siblings); + + assert!(verify_test_accumulator_element(root_hash, element_hash, 0, &proof).is_ok()); +} + +#[test] +fn test_accumulator_proof_63_siblings_rightmost() { + let element_hash = b"hello".test_only_hash(); + let mut siblings = vec![]; + for i in 0..63 { + siblings.push(HashValue::new([i; 32])); + } + let root_hash = siblings + .iter() + .rev() + .fold(element_hash, |hash, sibling_hash| { + TestAccumulatorInternalNode::new(*sibling_hash, hash).hash() + }); + let leaf_index = (std::u64::MAX - 1) / 2; + let proof = AccumulatorProof::new(siblings); + + assert!(verify_test_accumulator_element(root_hash, element_hash, leaf_index, &proof).is_ok()); +} + +#[test] +fn test_accumulator_proof_64_siblings() { + let element_hash = b"hello".test_only_hash(); + let mut siblings = vec![]; + for i in 0..64 { + siblings.push(HashValue::new([i; 32])); + } + let root_hash = siblings + .iter() + .rev() + .fold(element_hash, |hash, sibling_hash| { + TestAccumulatorInternalNode::new(hash, *sibling_hash).hash() + }); + let proof = AccumulatorProof::new(siblings); + + assert!(verify_test_accumulator_element(root_hash, element_hash, 0, &proof).is_err()); +} + +#[test] +fn test_verify_empty_sparse_merkle() { + let key = b"hello".test_only_hash(); + let blob = b"world".to_vec().into(); + let root_hash = *SPARSE_MERKLE_PLACEHOLDER_HASH; + let proof = SparseMerkleProof::new(None, vec![]); + + // Trying to show that this key doesn't exist. + assert!(verify_sparse_merkle_element(root_hash, key, &None, &proof).is_ok()); + // Trying to show that this key exists. + assert!(verify_sparse_merkle_element(root_hash, key, &Some(blob), &proof).is_err()); +} + +#[test] +fn test_verify_single_element_sparse_merkle() { + let key = b"hello".test_only_hash(); + let blob: Option = Some((b"world".to_vec()).into()); + let blob_hash = blob.as_ref().unwrap().hash(); + let non_existing_blob = b"world?".to_vec().into(); + let root_hash = SparseMerkleLeafNode::new(key, blob_hash).hash(); + let proof = SparseMerkleProof::new(Some((key, blob_hash)), vec![]); + + // Trying to show this exact key exists with its value. + assert!(verify_sparse_merkle_element(root_hash, key, &blob, &proof).is_ok()); + // Trying to show this exact key exists with another value. + assert!( + verify_sparse_merkle_element(root_hash, key, &Some(non_existing_blob), &proof).is_err() + ); + // Trying to show this key doesn't exist. + assert!(verify_sparse_merkle_element(root_hash, key, &None, &proof).is_err()); + + let non_existing_key = b"HELLO".test_only_hash(); + + // The proof can be used to show non_existing_key doesn't exist. + assert!(verify_sparse_merkle_element(root_hash, non_existing_key, &None, &proof).is_ok()); + // The proof can't be used to non_existing_key exists. + assert!(verify_sparse_merkle_element(root_hash, non_existing_key, &blob, &proof).is_err()); +} + +#[test] +fn test_verify_three_element_sparse_merkle() { + // root + // / \ + // a default + // / \ + // key1 b + // / \ + // key2 key3 + let key1 = b"hello".test_only_hash(); + let key2 = b"world".test_only_hash(); + let key3 = b"!".test_only_hash(); + assert_eq!(key1[0], 0b0011_0011); + assert_eq!(key2[0], 0b0100_0010); + assert_eq!(key3[0], 0b0110_1001); + + let blob1 = Some(AccountStateBlob::from(b"1".to_vec())); + let blob2 = Some(AccountStateBlob::from(b"2".to_vec())); + let blob3 = Some(AccountStateBlob::from(b"3".to_vec())); + + let leaf1_hash = SparseMerkleLeafNode::new(key1, blob1.as_ref().unwrap().hash()).hash(); + let leaf2_hash = SparseMerkleLeafNode::new(key2, blob2.as_ref().unwrap().hash()).hash(); + let leaf3_hash = SparseMerkleLeafNode::new(key3, blob3.as_ref().unwrap().hash()).hash(); + let internal_b_hash = SparseMerkleInternalNode::new(leaf2_hash, leaf3_hash).hash(); + let internal_a_hash = SparseMerkleInternalNode::new(leaf1_hash, internal_b_hash).hash(); + let root_hash = + SparseMerkleInternalNode::new(internal_a_hash, *SPARSE_MERKLE_PLACEHOLDER_HASH).hash(); + + let non_existing_key1 = b"abc".test_only_hash(); + let non_existing_key2 = b"def".test_only_hash(); + assert_eq!(non_existing_key1[0], 0b0011_1010); + assert_eq!(non_existing_key2[0], 0b1000_1110); + + { + // Construct a proof of key1. + let proof = SparseMerkleProof::new( + Some((key1, blob1.as_ref().unwrap().hash())), + vec![*SPARSE_MERKLE_PLACEHOLDER_HASH, internal_b_hash], + ); + + // The exact key value exists. + assert!(verify_sparse_merkle_element(root_hash, key1, &(blob1), &proof).is_ok()); + // Trying to show that this key has another value. + assert!(verify_sparse_merkle_element(root_hash, key1, &(blob2), &proof).is_err()); + // Trying to show that this key doesn't exist. + assert!(verify_sparse_merkle_element(root_hash, key1, &None, &proof).is_err()); + // This proof can't be used to show anything about key2. + assert!(verify_sparse_merkle_element(root_hash, key2, &None, &proof).is_err()); + assert!(verify_sparse_merkle_element(root_hash, key2, &(blob1), &proof).is_err()); + assert!(verify_sparse_merkle_element(root_hash, key2, &(blob2), &proof).is_err()); + + // This proof can be used to show that non_existing_key1 indeed doesn't exist. + assert!(verify_sparse_merkle_element(root_hash, non_existing_key1, &None, &proof).is_ok()); + // This proof can't be used to show that non_existing_key2 doesn't exist because it lives + // in a different subtree. + assert!(verify_sparse_merkle_element(root_hash, non_existing_key2, &None, &proof).is_err()); + } + + { + // Construct a proof of the default node. + let proof = SparseMerkleProof::new(None, vec![internal_a_hash]); + + // This proof can't be used to show that a key starting with 0 doesn't exist. + assert!(verify_sparse_merkle_element(root_hash, non_existing_key1, &None, &proof).is_err()); + // This proof can be used to show that a key starting with 1 doesn't exist. + assert!(verify_sparse_merkle_element(root_hash, non_existing_key2, &None, &proof).is_ok()); + } +} + +#[test] +fn test_verify_signed_transaction() { + // root + // / \ + // / \ + // a b + // / \ / \ + // txn0 txn1 txn2 default + let txn_info0_hash = b"hello".test_only_hash(); + let txn_info2_hash = b"!".test_only_hash(); + + let txn1_hash = HashValue::random(); + let state_root1_hash = b"a".test_only_hash(); + let event_root1_hash = b"b".test_only_hash(); + let txn_info1 = TransactionInfo::new( + txn1_hash, + state_root1_hash, + event_root1_hash, + /* gas_used = */ 0, + ); + let txn_info1_hash = txn_info1.hash(); + + let internal_a_hash = + TransactionAccumulatorInternalNode::new(txn_info0_hash, txn_info1_hash).hash(); + let internal_b_hash = + TransactionAccumulatorInternalNode::new(txn_info2_hash, *ACCUMULATOR_PLACEHOLDER_HASH) + .hash(); + let root_hash = + TransactionAccumulatorInternalNode::new(internal_a_hash, internal_b_hash).hash(); + let consensus_data_hash = b"c".test_only_hash(); + let ledger_info = LedgerInfo::new( + /* version = */ 2, + root_hash, + consensus_data_hash, + *GENESIS_BLOCK_ID, + 0, + /* timestamp = */ 10000, + ); + + let ledger_info_to_transaction_info_proof = + AccumulatorProof::new(vec![internal_b_hash, txn_info0_hash]); + let proof = SignedTransactionProof::new(ledger_info_to_transaction_info_proof, txn_info1); + + // The proof can be used to verify txn1. + assert!(verify_signed_transaction(&ledger_info, txn1_hash, None, 1, &proof).is_ok()); + // Replacing txn1 with some other txn should cause the verification to fail. + assert!(verify_signed_transaction(&ledger_info, HashValue::random(), None, 1, &proof).is_err()); + // Trying to show that txn1 is at version 2. + assert!(verify_signed_transaction(&ledger_info, txn1_hash, None, 2, &proof).is_err()); +} + +#[test] +fn test_verify_account_state_and_event() { + // root + // / \ + // / \ + // a b + // / \ / \ + // txn0 txn1 txn2 default + // ^ + // | + // transaction_info2 + // / / \ + // / / \ + // signed_txn state_root event_root + // / \ / \ + // c default event0 event1 + // / \ + // key1 d + // / \ + // key2 key3 + let key1 = b"hello".test_only_hash(); + let key2 = b"world".test_only_hash(); + let key3 = b"!".test_only_hash(); + let non_existing_key = b"#".test_only_hash(); + assert_eq!(key1[0], 0b0011_0011); + assert_eq!(key2[0], 0b0100_0010); + assert_eq!(key3[0], 0b0110_1001); + assert_eq!(non_existing_key[0], 0b0100_0001); + + let blob1 = AccountStateBlob::from(b"value1".to_vec()); + let blob2 = AccountStateBlob::from(b"value2".to_vec()); + let blob3 = AccountStateBlob::from(b"value3".to_vec()); + + let leaf1_hash = SparseMerkleLeafNode::new(key1, blob1.hash()).hash(); + let leaf2_hash = SparseMerkleLeafNode::new(key2, blob2.hash()).hash(); + let leaf3_hash = SparseMerkleLeafNode::new(key3, blob3.hash()).hash(); + let internal_d_hash = SparseMerkleInternalNode::new(leaf2_hash, leaf3_hash).hash(); + let internal_c_hash = SparseMerkleInternalNode::new(leaf1_hash, internal_d_hash).hash(); + let state_root_hash = + SparseMerkleInternalNode::new(internal_c_hash, *SPARSE_MERKLE_PLACEHOLDER_HASH).hash(); + + let txn_info0_hash = b"hellohello".test_only_hash(); + let txn_info1_hash = b"worldworld".test_only_hash(); + + let (privkey, pubkey) = generate_keypair(); + let txn2_hash = RawTransaction::new( + AccountAddress::from(pubkey), + /* sequence_number = */ 0, + Program::new(vec![], vec![], vec![]), + /* max_gas_amount = */ 0, + /* gas_unit_price = */ 0, + /* expiration_time = */ std::time::Duration::new(0, 0), + ) + .sign(&privkey, pubkey) + .expect("Signing failed.") + .hash(); + + let event0_hash = b"event0".test_only_hash(); + let event1_hash = b"event1".test_only_hash(); + let event_root_hash = EventAccumulatorInternalNode::new(event0_hash, event1_hash).hash(); + + let txn_info2 = TransactionInfo::new( + txn2_hash, + state_root_hash, + event_root_hash, + /* gas_used = */ 0, + ); + let txn_info2_hash = txn_info2.hash(); + + let internal_a_hash = + TransactionAccumulatorInternalNode::new(txn_info0_hash, txn_info1_hash).hash(); + let internal_b_hash = + TransactionAccumulatorInternalNode::new(txn_info2_hash, *ACCUMULATOR_PLACEHOLDER_HASH) + .hash(); + let root_hash = + TransactionAccumulatorInternalNode::new(internal_a_hash, internal_b_hash).hash(); + + // consensus_data_hash isn't used in proofs, but we need it to construct LedgerInfo. + let consensus_data_hash = b"consensus_data".test_only_hash(); + let ledger_info = LedgerInfo::new( + /* version = */ 2, + root_hash, + consensus_data_hash, + *GENESIS_BLOCK_ID, + 0, + /* timestamp = */ 10000, + ); + + let ledger_info_to_transaction_info_proof = + AccumulatorProof::new(vec![internal_a_hash, *ACCUMULATOR_PLACEHOLDER_HASH]); + let transaction_info_to_account_proof = SparseMerkleProof::new( + Some((key2, blob2.hash())), + vec![*SPARSE_MERKLE_PLACEHOLDER_HASH, leaf1_hash, leaf3_hash], + ); + let account_state_proof = AccountStateProof::new( + ledger_info_to_transaction_info_proof.clone(), + txn_info2.clone(), + transaction_info_to_account_proof, + ); + + // Prove that account at `key2` has value `value2`. + assert!(verify_account_state( + &ledger_info, + /* state_version = */ 2, + key2, + &Some(blob2), + &account_state_proof, + ) + .is_ok()); + // Use the same proof to prove that `non_existing_key` doesn't exist. + assert!(verify_account_state( + &ledger_info, + /* state_version = */ 2, + non_existing_key, + &None, + &account_state_proof, + ) + .is_ok()); + + let bad_blob2 = b"3".to_vec().into(); + assert!(verify_account_state( + &ledger_info, + /* state_version = */ 2, + key2, + &Some(bad_blob2), + &account_state_proof, + ) + .is_err()); + + let transaction_info_to_event_proof = AccumulatorProof::new(vec![event1_hash]); + let event_proof = EventProof::new( + ledger_info_to_transaction_info_proof.clone(), + txn_info2.clone(), + transaction_info_to_event_proof, + ); + + // Prove that the first event within transaction 2 is `event0`. + assert!(verify_event( + &ledger_info, + event0_hash, + /* transaction_version = */ 2, + /* event_version_within_transaction = */ 0, + &event_proof, + ) + .is_ok()); + + let bad_event0_hash = b"event1".test_only_hash(); + assert!(verify_event( + &ledger_info, + bad_event0_hash, + /* transaction_version = */ 2, + /* event_version_within_transaction = */ 0, + &event_proof, + ) + .is_err()); +} + +// Return a variable length of transaction_and_info list with a random range within [0, +// list_length). +fn arb_signed_txn_list_and_range( +) -> impl Strategy, usize, usize)> { + vec( + (any::(), any::()), + 0..100, + ) + .prop_flat_map(|list| { + let len = list.len(); + (Just(list), 0..std::cmp::max(len, 1)) + }) + .prop_flat_map(|(list, start)| { + let len = list.len(); + (Just(list), Just(start), start..std::cmp::max(len, 1)) + }) + .prop_map(|(list, start, end)| { + let final_list = list + .into_iter() + .map(|(txn, txn_info)| { + let txn_hash = txn.hash(); + ( + txn, + TransactionInfo::new( + txn_hash, + txn_info.state_root_hash(), + txn_info.event_root_hash(), + txn_info.gas_used(), + ), + ) + }) + .collect::>(); + (final_list, start, end) + }) +} + +proptest! { + #![proptest_config(ProptestConfig::with_cases(20))] + + #[test] + fn test_transaction_list_with_proof((txn_and_infos, first_version, last_version) in arb_signed_txn_list_and_range()) { + let mut root_hash = *ACCUMULATOR_PLACEHOLDER_HASH; + + let txn_list_with_proof = + if txn_and_infos.is_empty() { + TransactionListWithProof::new(vec![], None, None, None, None) + } else { + let mut hashes = txn_and_infos + .iter() + .map(|(_, txn_info)| + txn_info.hash() + ).collect::>(); + if hashes.len() % 2 == 1 && hashes.len() != 1 { + hashes.push(*ACCUMULATOR_PLACEHOLDER_HASH); + } + let mut tree = vec![hashes]; + while tree.last().unwrap().len() > 1 { + let mut parent_hashes = vec![]; + let mut hash_iter = tree.last().unwrap().iter(); + while let Some(left) = hash_iter.next() { + let right = hash_iter.next().expect("Can't be None"); + parent_hashes.push( + MerkleTreeInternalNode::::new(*left, *right).hash(), + ) + } + hashes = parent_hashes; + if hashes.len() % 2 == 1 && hashes.len() != 1 { + hashes.push(*ACCUMULATOR_PLACEHOLDER_HASH); + } + tree.push(hashes); + } + assert_eq!(tree.last().unwrap().len(), 1); + root_hash = tree.pop().unwrap()[0]; + + // Get proofs. + let mut first_index = first_version; + let mut last_index = last_version; + let mut first_siblings = vec![]; + let mut last_siblings = vec![]; + for nodes in tree { + first_siblings.push( + if first_index % 2 == 0 { + nodes[first_index + 1] + } else { + nodes[first_index - 1] + } + ); + last_siblings.push( + if last_index % 2 == 0 { + nodes[last_index + 1] + } else { + nodes[last_index - 1] + } + ); + first_index /= 2; + last_index /= 2; + } + let first_proof = + Some(AccumulatorProof::new(first_siblings.into_iter().rev().collect::>())); + let last_proof = if first_version == last_version { + None + } else { + Some(AccumulatorProof::new(last_siblings.into_iter().rev().collect::>())) + }; + + TransactionListWithProof::new( + txn_and_infos[first_version..=last_version].to_vec(), + None, + Some(first_version as u64), + first_proof, + last_proof, + ) + }; + + // consensus_data_hash isn't used in proofs, but we need it to construct LedgerInfo. + let consensus_data_hash = b"consensus_data".test_only_hash(); + let ledger_info = LedgerInfo::new( + /* version = */ std::cmp::max(1, txn_and_infos.len()) as u64 - 1, + root_hash, + consensus_data_hash, + *GENESIS_BLOCK_ID, + 0, + /* timestamp = */ 10000, + ); + let first_version = if txn_and_infos.is_empty() { None } else { Some(first_version as u64) }; + prop_assert!(txn_list_with_proof.verify(&ledger_info,first_version).is_ok()); + } +} diff --git a/types/src/proof/unit_tests/treebits_test.rs b/types/src/proof/unit_tests/treebits_test.rs new file mode 100644 index 0000000000000..4a4a9053d4a8f --- /dev/null +++ b/types/src/proof/unit_tests/treebits_test.rs @@ -0,0 +1,257 @@ +// Copyright (c) The Libra Core Contributors +// SPDX-License-Identifier: Apache-2.0 + +use crate::proof::treebits::*; + +fn slow_nodes_to_left_of(node: u64) -> u64 { + let ret_add = if node == right_child(parent(node)) { + children_from_level(level(node)) + 1 + } else { + 0 + }; + let parent_add = if pos_counting_from_left(node) == 0 { + 0 + } else { + nodes_to_left_of(parent(node)) + }; + ret_add + parent_add +} + +fn test_invariant(invariant_fn: fn(u64) -> bool) { + for node in 0..300 { + assert!(invariant_fn(node), "node = {}", node) + } +} + +fn test_invariant_non_leaf(invariant_fn: fn(u64) -> bool) { + for node in 0..300 { + assert!(level(node) == 0 || invariant_fn(node), "node = {}", node) + } +} + +#[test] +fn test_basic_invariants() { + test_invariant_non_leaf(|node| node == parent(right_child(node))); + test_invariant_non_leaf(|node| node == parent(left_child(node))); + + test_invariant(|node| level(node) == level(parent(node)) - 1); + test_invariant(|node| { + node_from_level_and_pos(level(node), pos_counting_from_left(node)) == node + }); + + test_invariant_non_leaf(|node| { + pos_counting_from_left(right_child(node)) == pos_counting_from_left(left_child(node)) + 1 + }); + + test_invariant_non_leaf(|node| left_child(node) < node); + test_invariant_non_leaf(|node| node < right_child(node)); + test_invariant_non_leaf(|node| post_order_index(left_child(node)) < post_order_index(node)); + test_invariant_non_leaf(|node| post_order_index(right_child(node)) < post_order_index(node)); + + test_invariant_non_leaf(|node| { + post_order_index(right_child(node)) + 1 == post_order_index(node) + }); + + test_invariant_non_leaf(|node| right_child(node) == sibling(left_child(node))); + test_invariant_non_leaf(|node| sibling(right_child(node)) == left_child(node)); + + test_invariant_non_leaf(|node| right_child(node) == child(node, NodeDirection::Right)); + test_invariant_non_leaf(|node| left_child(node) == child(node, NodeDirection::Left)); + test_invariant(|node| node == child(parent(node), direction_from_parent(node))); +} + +#[test] +fn test_children_from_level_nary() { + assert_eq!(children_from_level_nary(0, 2), 0); + assert_eq!(children_from_level_nary(1, 2), 4); + assert_eq!(children_from_level_nary(2, 2), 20); + assert_eq!(children_from_level_nary(3, 2), 84); +} + +#[test] +fn test_disk_write_order() { + assert_eq!(disk_write_order(0, 2), 0); + assert_eq!(disk_write_order(1, 2), 0); + assert_eq!(disk_write_order(2, 2), 0); + assert_eq!(disk_write_order(4, 2), 1); + assert_eq!(disk_write_order(3, 2), 4); + assert_eq!(disk_write_order(35, 2), 14); + + assert_eq!(disk_write_order(57, 2), 17); +} + +#[test] +#[allow(clippy::cognitive_complexity)] +fn test_treebits() { + for x in 0..300 { + assert_eq!(slow_nodes_to_left_of(x), nodes_to_left_of(x)); + assert_eq!( + node_from_level_and_pos(level(x), pos_counting_from_left(x)), + x + ); + } + + for x in &[1u64 << 33, 1u64 << 63] { + assert_eq!(slow_nodes_to_left_of(*x), nodes_to_left_of(*x)); + assert_eq!( + node_from_level_and_pos(level(*x), pos_counting_from_left(*x)), + *x + ); + } + + assert_eq!(children_from_level(0), 0); + assert_eq!(children_from_level(1), 2); + assert_eq!(children_from_level(2), 6); + assert_eq!(children_from_level(3), 14); + assert_eq!(children_from_level(4), 30); + assert_eq!(children_from_level(5), 62); + assert_eq!(children_from_level(6), 126); + assert_eq!(children_from_level(7), 254); + assert_eq!(children_from_level(8), 510); + assert_eq!(children_from_level(9), 1022); + // Test for level > 32 to discover overflow bugs + assert_eq!(children_from_level(50), 2_251_799_813_685_246); + assert_eq!(level(0), 0); + assert_eq!(pos_counting_from_left(0), 0); + assert_eq!(post_order_index(0), 0); + assert_eq!(parent(0), 1); + + assert_eq!(level(1), 1); + assert_eq!(pos_counting_from_left(1), 0); + assert_eq!(post_order_index(1), 2); + assert_eq!(parent(1), 3); + assert_eq!(left_child(1), 0); + assert_eq!(right_child(1), 2); + + assert_eq!(level(2), 0); + assert_eq!(pos_counting_from_left(2), 1); + assert_eq!(post_order_index(2), 1); + assert_eq!(parent(2), 1); + + assert_eq!(level(3), 2); + assert_eq!(pos_counting_from_left(3), 0); + assert_eq!(post_order_index(3), 6); + assert_eq!(parent(3), 7); + assert_eq!(left_child(3), 1); + assert_eq!(right_child(3), 5); + + assert_eq!(level(4), 0); + assert_eq!(pos_counting_from_left(4), 2); + assert_eq!(post_order_index(4), 3); + assert_eq!(parent(4), 5); + + assert_eq!(level(5), 1); + assert_eq!(pos_counting_from_left(5), 1); + assert_eq!(post_order_index(5), 5); + assert_eq!(parent(5), 3); + assert_eq!(left_child(5), 4); + assert_eq!(right_child(5), 6); + + assert_eq!(level(6), 0); + assert_eq!(pos_counting_from_left(6), 3); + assert_eq!(post_order_index(6), 4); + assert_eq!(parent(6), 5); + + assert_eq!(level(7), 3); + assert_eq!(pos_counting_from_left(7), 0); + assert_eq!(post_order_index(7), 14); + assert_eq!(parent(7), 15); + assert_eq!(left_child(7), 3); + assert_eq!(right_child(7), 11); + + assert_eq!(level(8), 0); + assert_eq!(pos_counting_from_left(8), 4); + assert_eq!(post_order_index(8), 7); + assert_eq!(parent(8), 9); + + assert_eq!(level(9), 1); + assert_eq!(pos_counting_from_left(9), 2); + assert_eq!(post_order_index(9), 9); + assert_eq!(parent(9), 11); + assert_eq!(left_child(9), 8); + assert_eq!(right_child(9), 10); + + assert_eq!(level(10), 0); + assert_eq!(pos_counting_from_left(10), 5); + assert_eq!(post_order_index(10), 8); + assert_eq!(parent(10), 9); + + assert_eq!(level(11), 2); + assert_eq!(pos_counting_from_left(11), 1); + assert_eq!(post_order_index(11), 13); + assert_eq!(parent(11), 7); + assert_eq!(left_child(11), 9); + assert_eq!(right_child(11), 13); + + assert_eq!(level(12), 0); + assert_eq!(pos_counting_from_left(12), 6); + assert_eq!(post_order_index(12), 10); + assert_eq!(parent(12), 13); + + assert_eq!(level(13), 1); + assert_eq!(pos_counting_from_left(13), 3); + assert_eq!(post_order_index(13), 12); + assert_eq!(parent(13), 11); + assert_eq!(left_child(13), 12); + assert_eq!(right_child(13), 14); + + assert_eq!(level(14), 0); + assert_eq!(pos_counting_from_left(14), 7); + assert_eq!(post_order_index(14), 11); + assert_eq!(parent(14), 13); + + assert_eq!(level(15), 4); + assert_eq!(pos_counting_from_left(15), 0); + assert_eq!(post_order_index(15), 30); + assert_eq!(parent(15), 31); + assert_eq!(left_child(15), 7); + assert_eq!(right_child(15), 23); + + assert_eq!(level(16), 0); + assert_eq!(pos_counting_from_left(16), 8); + assert_eq!(post_order_index(16), 15); + assert_eq!(parent(16), 17); + + assert_eq!(level(17), 1); + assert_eq!(pos_counting_from_left(17), 4); + assert_eq!(post_order_index(17), 17); + assert_eq!(parent(17), 19); + assert_eq!(left_child(17), 16); + assert_eq!(right_child(17), 18); + + assert_eq!(level(18), 0); + assert_eq!(pos_counting_from_left(18), 9); + assert_eq!(post_order_index(18), 16); + assert_eq!(parent(18), 17); + + assert_eq!(level(19), 2); + assert_eq!(pos_counting_from_left(19), 2); + assert_eq!(post_order_index(19), 21); + assert_eq!(parent(19), 23); + assert_eq!(left_child(19), 17); + assert_eq!(right_child(19), 21); +} + +#[test] +fn test_right_most_child() { + assert_eq!(right_most_child(0), 0); + assert_eq!(right_most_child(1), 2); + assert_eq!(right_most_child(5), 6); + assert_eq!(right_most_child(7), 14); + assert_eq!(right_most_child(3), 6); + assert_eq!(right_most_child(11), 14); + assert_eq!(right_most_child(12), 12); + assert_eq!(right_most_child(14), 14); +} + +#[test] +fn test_left_most_child() { + assert_eq!(left_most_child(0), 0); + assert_eq!(left_most_child(1), 0); + assert_eq!(left_most_child(5), 4); + assert_eq!(left_most_child(7), 0); + assert_eq!(left_most_child(3), 0); + assert_eq!(left_most_child(11), 8); + assert_eq!(left_most_child(12), 12); + assert_eq!(left_most_child(14), 14); +} diff --git a/types/src/proptest_types.rs b/types/src/proptest_types.rs new file mode 100644 index 0000000000000..4de442538af03 --- /dev/null +++ b/types/src/proptest_types.rs @@ -0,0 +1,523 @@ +// Copyright (c) The Libra Core Contributors +// SPDX-License-Identifier: Apache-2.0 + +use crate::{ + access_path::AccessPath, + account_address::AccountAddress, + account_state_blob::AccountStateBlob, + byte_array::ByteArray, + contract_event::ContractEvent, + get_with_proof::{ResponseItem, UpdateToLatestLedgerResponse}, + ledger_info::{LedgerInfo, LedgerInfoWithSignatures}, + proof::AccumulatorProof, + transaction::{ + Program, RawTransaction, SignedTransaction, TransactionArgument, TransactionInfo, + TransactionListWithProof, TransactionPayload, TransactionStatus, TransactionToCommit, + Version, + }, + validator_change::ValidatorChangeEventWithProof, + vm_error::VMStatus, + write_set::{WriteOp, WriteSet, WriteSetMut}, +}; +use crypto::{ + hash::CryptoHash, + signing::{sign_message, PrivateKey as OldPrivateKey, PublicKey as OldPublicKey}, + utils::{keypair_strategy as gen_keypair_strategy, keypair_strategy}, + HashValue, Signature, +}; +use proptest::{ + collection::{hash_map, hash_set, vec, SizeRange}, + option, + prelude::*, + strategy::Union, +}; +use std::{collections::HashMap, time::Duration}; + +prop_compose! { + #[inline] + pub fn arb_byte_array()(byte_array in vec(any::(), 1..=10)) -> ByteArray { + ByteArray::new(byte_array) + } +} + +impl Arbitrary for ByteArray { + type Parameters = (); + type Strategy = BoxedStrategy; + + #[inline] + fn arbitrary_with(_args: ()) -> Self::Strategy { + arb_byte_array().boxed() + } +} + +impl WriteOp { + pub fn value_strategy() -> impl Strategy { + vec(any::(), 0..64).prop_map(WriteOp::Value) + } + + pub fn deletion_strategy() -> impl Strategy { + Just(WriteOp::Deletion) + } +} + +impl Arbitrary for WriteOp { + type Parameters = (); + type Strategy = BoxedStrategy; + + fn arbitrary_with(_args: ()) -> Self::Strategy { + prop_oneof![Self::deletion_strategy(), Self::value_strategy()].boxed() + } +} + +impl WriteSet { + fn genesis_strategy() -> impl Strategy { + vec((any::(), WriteOp::value_strategy()), 0..64).prop_map(|write_set| { + let write_set_mut = WriteSetMut::new(write_set); + write_set_mut + .freeze() + .expect("generated write sets should always be valid") + }) + } +} + +impl Arbitrary for WriteSet { + type Parameters = (); + type Strategy = BoxedStrategy; + + fn arbitrary_with(_args: ()) -> Self::Strategy { + // XXX there's no checking for repeated access paths here, nor in write_set. Is that + // important? Not sure. + vec((any::(), any::()), 0..64) + .prop_map(|write_set| { + let write_set_mut = WriteSetMut::new(write_set); + write_set_mut + .freeze() + .expect("generated write sets should always be valid") + }) + .boxed() + } +} + +impl RawTransaction { + fn strategy_impl( + address_strategy: impl Strategy, + payload_strategy: impl Strategy, + ) -> impl Strategy { + // XXX what other constraints do these need to obey? + ( + address_strategy, + any::(), + payload_strategy, + any::(), + any::(), + any::(), + ) + .prop_map( + |( + sender, + sequence_number, + payload, + max_gas_amount, + gas_unit_price, + expiration_time_secs, + )| { + match payload { + TransactionPayload::Program(program) => RawTransaction::new( + sender, + sequence_number, + program, + max_gas_amount, + gas_unit_price, + Duration::from_secs(expiration_time_secs), + ), + TransactionPayload::WriteSet(write_set) => { + // It's a bit unfortunate that max_gas_amount etc is generated but + // not used, but it isn't a huge deal. + RawTransaction::new_write_set(sender, sequence_number, write_set) + } + } + }, + ) + } +} + +impl Arbitrary for RawTransaction { + type Parameters = (); + type Strategy = BoxedStrategy; + + fn arbitrary_with(_args: ()) -> Self::Strategy { + Self::strategy_impl(any::(), any::()).boxed() + } +} + +impl SignedTransaction { + // This isn't an Arbitrary impl because this doesn't generate *any* possible SignedTransaction, + // just one kind of them. + pub fn program_strategy( + keypair_strategy: impl Strategy, + ) -> impl Strategy { + Self::strategy_impl(keypair_strategy, TransactionPayload::program_strategy()) + } + + pub fn write_set_stratedy( + keypair_strategy: impl Strategy, + ) -> impl Strategy { + Self::strategy_impl(keypair_strategy, TransactionPayload::write_set_strategy()) + } + + pub fn genesis_stratedy( + keypair_strategy: impl Strategy, + ) -> impl Strategy { + Self::strategy_impl(keypair_strategy, TransactionPayload::genesis_strategy()) + } + + fn strategy_impl( + keypair_strategy: impl Strategy, + payload_strategy: impl Strategy, + ) -> impl Strategy { + (keypair_strategy, payload_strategy) + .prop_flat_map(|(keypair, payload)| { + let address = AccountAddress::from(keypair.1); + ( + Just(keypair), + RawTransaction::strategy_impl(Just(address), Just(payload)), + ) + }) + .prop_map(|((private_key, public_key), raw_txn)| { + raw_txn + .sign(&private_key, public_key) + .expect("signing should always work") + }) + } +} + +impl Arbitrary for SignedTransaction { + type Parameters = (); + type Strategy = BoxedStrategy; + + fn arbitrary_with(_args: ()) -> Self::Strategy { + Self::strategy_impl(gen_keypair_strategy(), any::()).boxed() + } +} + +impl TransactionPayload { + pub fn program_strategy() -> impl Strategy { + any::().prop_map(TransactionPayload::Program) + } + + pub fn write_set_strategy() -> impl Strategy { + any::().prop_map(TransactionPayload::WriteSet) + } + + /// Similar to `write_set_strategy` except generates a valid write set for the genesis block. + pub fn genesis_strategy() -> impl Strategy { + WriteSet::genesis_strategy().prop_map(TransactionPayload::WriteSet) + } +} + +prop_compose! { + fn arb_transaction_status()(vm_status in any::()) -> TransactionStatus { + vm_status.into() + } +} + +impl Arbitrary for TransactionStatus { + type Parameters = (); + type Strategy = BoxedStrategy; + + fn arbitrary_with(_args: Self::Parameters) -> Self::Strategy { + arb_transaction_status().boxed() + } +} + +impl Arbitrary for TransactionPayload { + type Parameters = (); + type Strategy = BoxedStrategy; + + fn arbitrary_with(_args: ()) -> Self::Strategy { + // Most transactions in practice will be programs, but other parts of the system should + // at least not choke on write set strategies so introduce them with decent probability. + // The figures below are probability weights. + prop_oneof![ + 9 => Self::program_strategy(), + 1 => Self::write_set_strategy(), + ] + .boxed() + } +} + +impl Arbitrary for Program { + type Parameters = (); + type Strategy = BoxedStrategy; + + fn arbitrary_with(_args: ()) -> Self::Strategy { + // XXX This should eventually be an actually valid program, maybe? + // How should we generate random modules? + // The vector sizes are picked out of thin air. + ( + vec(any::(), 0..100), + vec(any::>(), 0..100), + vec(any::(), 0..10), + ) + .prop_map(|(code, modules, args)| Program::new(code, modules, args)) + .boxed() + } +} + +impl Arbitrary for TransactionArgument { + type Parameters = (); + type Strategy = BoxedStrategy; + + fn arbitrary_with(_args: ()) -> Self::Strategy { + prop_oneof![ + any::().prop_map(TransactionArgument::U64), + any::().prop_map(TransactionArgument::Address), + any::().prop_map(TransactionArgument::ByteArray), + ".*".prop_map(TransactionArgument::String), + ] + .boxed() + } +} + +prop_compose! { + fn arb_validator_signature_for_hash(hash: HashValue)( + hash in Just(hash), + (private_key, public_key) in keypair_strategy(), + ) -> (AccountAddress, Signature) { + let signature = sign_message(hash, &private_key).unwrap(); + (AccountAddress::from(public_key), signature) + } +} + +impl Arbitrary for LedgerInfoWithSignatures { + type Parameters = SizeRange; + type Strategy = BoxedStrategy; + + fn arbitrary_with(num_validators_range: Self::Parameters) -> Self::Strategy { + (any::(), Just(num_validators_range)) + .prop_flat_map(|(ledger_info, num_validators_range)| { + let hash = ledger_info.hash(); + ( + Just(ledger_info), + prop::collection::vec( + arb_validator_signature_for_hash(hash), + num_validators_range, + ), + ) + }) + .prop_map(|(ledger_info, signatures)| { + LedgerInfoWithSignatures::new(ledger_info, signatures.into_iter().collect()) + }) + .boxed() + } +} + +prop_compose! { + fn arb_update_to_latest_ledger_response()( + response_items in vec(any::(), 0..10), + ledger_info_with_sigs in any::(), + validator_change_events in vec(any::(), 0..10), + ) -> UpdateToLatestLedgerResponse { + UpdateToLatestLedgerResponse::new( + response_items, ledger_info_with_sigs, validator_change_events) + } +} + +impl Arbitrary for UpdateToLatestLedgerResponse { + type Parameters = (); + type Strategy = BoxedStrategy; + + fn arbitrary_with(_args: Self::Parameters) -> Self::Strategy { + arb_update_to_latest_ledger_response().boxed() + } +} + +#[allow(clippy::implicit_hasher)] +pub fn renumber_events( + events: &[ContractEvent], + next_seq_num_by_access_path: &mut HashMap, +) -> Vec { + events + .iter() + .map(|e| { + let next_seq_num = next_seq_num_by_access_path + .entry(e.access_path().clone()) + .or_insert(0); + *next_seq_num += 1; + ContractEvent::new( + e.access_path().clone(), + *next_seq_num - 1, + e.event_data().to_vec(), + ) + }) + .collect::>() +} + +pub fn arb_txn_to_commit_batch( + num_accounts: usize, + num_event_paths: usize, + num_transactions: usize, +) -> impl Strategy> { + ( + vec(gen_keypair_strategy(), num_accounts), + hash_set(any::>(), num_event_paths), + Just(num_transactions), + ) + .prop_flat_map(|(keypairs, event_paths, num_transactions)| { + let keypair_strategy = Union::new(keypairs.into_iter().map(Just)).boxed(); + let event_path_strategy = Union::new(event_paths.into_iter().map(Just)); + vec( + TransactionToCommit::strategy_impl(keypair_strategy, event_path_strategy), + num_transactions, + ) + }) + .prop_map(|txns_to_commit| { + // re- number events to make it logical + let mut next_seq_num_by_access_path = HashMap::new(); + txns_to_commit + .into_iter() + .map(|t| { + let events = renumber_events(t.events(), &mut next_seq_num_by_access_path); + TransactionToCommit::new( + t.signed_txn().clone(), + t.account_states().clone(), + events, + t.gas_used(), + ) + }) + .collect::>() + }) +} + +impl ContractEvent { + pub fn strategy_impl( + access_path_strategy: impl Strategy, + ) -> impl Strategy { + (access_path_strategy, any::(), vec(any::(), 1..10)).prop_map( + |(access_path, seq_num, event_data)| { + ContractEvent::new(access_path, seq_num, event_data) + }, + ) + } +} + +impl Arbitrary for ContractEvent { + type Parameters = (); + type Strategy = BoxedStrategy; + + fn arbitrary_with(_args: Self::Parameters) -> Self::Strategy { + ContractEvent::strategy_impl(any::()).boxed() + } +} + +impl TransactionToCommit { + fn strategy_impl( + keypair_strategy: BoxedStrategy<(OldPrivateKey, OldPublicKey)>, + event_path_strategy: impl Strategy>, + ) -> impl Strategy { + // signed_txn + let signed_txn_strategy = + SignedTransaction::strategy_impl(keypair_strategy.clone(), any::()); + + // acccount_states + let address_strategy = keypair_strategy + .clone() + .prop_map(|(_, public_key)| AccountAddress::from(public_key)); + let account_states_strategy = + hash_map(address_strategy.clone(), any::(), 1..10); + + // events + let access_path_strategy = (address_strategy, event_path_strategy) + .prop_map(|(address, path)| AccessPath::new(address, path)); + let events_strategy = vec(ContractEvent::strategy_impl(access_path_strategy), 0..10); + + // gas_used + let gas_used_strategy = any::(); + + // Combine the above into result. + ( + signed_txn_strategy, + account_states_strategy, + events_strategy, + gas_used_strategy, + ) + .prop_map(|(signed_txn, account_states, events, gas_used)| { + Self::new(signed_txn, account_states, events, gas_used) + }) + } +} + +impl Arbitrary for TransactionToCommit { + type Parameters = (); + type Strategy = BoxedStrategy; + + fn arbitrary_with(_args: Self::Parameters) -> Self::Strategy { + TransactionToCommit::strategy_impl(gen_keypair_strategy().boxed(), any::>()).boxed() + } +} + +fn arb_transaction_list_with_proof() -> impl Strategy { + vec( + ( + any::(), + any::(), + vec(any::(), 0..10), + ), + 0..10, + ) + .prop_flat_map(|transaction_and_infos_and_events| { + let transaction_and_infos: Vec<_> = transaction_and_infos_and_events + .clone() + .into_iter() + .map(|(transaction, info, _event)| (transaction, info)) + .collect(); + let events: Vec<_> = transaction_and_infos_and_events + .into_iter() + .map(|(_transaction, _info, event)| event) + .collect(); + + ( + Just(transaction_and_infos), + option::of(Just(events)), + any::(), + any::(), + any::(), + ) + }) + .prop_map( + |( + transaction_and_infos, + events, + first_txn_version, + proof_of_first_txn, + proof_of_last_txn, + )| { + match transaction_and_infos.len() { + 0 => TransactionListWithProof::new_empty(), + 1 => TransactionListWithProof::new( + transaction_and_infos, + events, + Some(first_txn_version), + Some(proof_of_first_txn), + None, + ), + _ => TransactionListWithProof::new( + transaction_and_infos, + events, + Some(first_txn_version), + Some(proof_of_first_txn), + Some(proof_of_last_txn), + ), + } + }, + ) +} + +impl Arbitrary for TransactionListWithProof { + type Parameters = (); + type Strategy = BoxedStrategy; + + fn arbitrary_with(_args: Self::Parameters) -> Self::Strategy { + arb_transaction_list_with_proof().boxed() + } +} diff --git a/types/src/proto/access_path.proto b/types/src/proto/access_path.proto new file mode 100644 index 0000000000000..d9cd6725bd0e5 --- /dev/null +++ b/types/src/proto/access_path.proto @@ -0,0 +1,11 @@ +// Copyright (c) The Libra Core Contributors +// SPDX-License-Identifier: Apache-2.0 + +syntax = "proto3"; + +package types; + +message AccessPath { + bytes address = 1; + bytes path = 2; +} diff --git a/types/src/proto/account_state_blob.proto b/types/src/proto/account_state_blob.proto new file mode 100644 index 0000000000000..d791c649778dd --- /dev/null +++ b/types/src/proto/account_state_blob.proto @@ -0,0 +1,16 @@ +// Copyright (c) The Libra Core Contributors +// SPDX-License-Identifier: Apache-2.0 + +syntax = "proto3"; + +package types; + +import "proof.proto"; + +message AccountStateBlob { bytes blob = 1; } + +message AccountStateWithProof { + uint64 version = 1; + AccountStateBlob blob = 2; + AccountStateProof proof = 3; +} diff --git a/types/src/proto/events.proto b/types/src/proto/events.proto new file mode 100644 index 0000000000000..a878d3cf6264a --- /dev/null +++ b/types/src/proto/events.proto @@ -0,0 +1,38 @@ +// Copyright (c) The Libra Core Contributors +// SPDX-License-Identifier: Apache-2.0 + +// This file contains proto definitions related to events. Events are emitted +// by smart contract execution. These could include events such as received +// transactions, sent transactions, etc. + +syntax = "proto3"; + +package types; + +import "access_path.proto"; +import "proof.proto"; + +// An event emitted from a smart contract +message Event { + AccessPath access_path = 1; + uint64 sequence_number = 2; + bytes event_data = 3; +} + +// An event along with the proof for the event +message EventWithProof { + uint64 transaction_version = 1; + uint64 event_index = 2; + Event event = 3; + EventProof proof = 4; +} + +// A list of events. +message EventsList { + repeated Event events = 1; +} + +// A list of EventList's, each representing all events for a transaction. +message EventsForVersions { + repeated EventsList events_for_version = 1; +} diff --git a/types/src/proto/get_with_proof.proto b/types/src/proto/get_with_proof.proto new file mode 100644 index 0000000000000..a86f9a04d87c9 --- /dev/null +++ b/types/src/proto/get_with_proof.proto @@ -0,0 +1,310 @@ +// Copyright (c) The Libra Core Contributors +// SPDX-License-Identifier: Apache-2.0 + +// This file contains proto definitions for performing queries and getting back +// results with proofs. This is the interface for a client to query data from +// the system. Every query result must include proof so that a client can be +// certain that the data returned is valid. A client must verify this proof to +// ensure that a node isn't lying to them. + +// How to verify the response as a client: +// (Note that every response comes in the form of GetWithProofResponse which +// wraps the inner response elements that correspond to the specific request +// types. Below we will assume a single request/response type. The +// verification can be extended as needed for multiple types. Also note that we +// will use the following notation: resp = GetWithProofResponse and req = +// GetWithProofRequest). Also note that the following will be considered +// equivalent for brevity: req.requested_items.get_account_state_request == +// req.get_account_state_request And, resp.values.get_account_state_response == +// resp.get_account_state_response +// +// GetAccountStateResponse: +// - let state_req = req.requested_items.get_account_state_request; +// - let state_resp = resp.values.get_account_state_response; +// - Verify that: +// - state_req.access_path == state_resp.access_path +// - This ensures that the server is responding with the correct access +// path +// - let state_data_hash = Hash(state_resp.value); +// - let state_proof = resp.values.proof.state_proof_value.sparse_merkle_proof; +// - Validate state_proof using state_data_hash as the leaf +// - When verifying the state tree, use: +// state_root_hash = resp.values.transaction_info.state_root_hash +// - Validate accumulator using resp.values.transaction_info as the leaf +// - When verifying the accumulator, use: +// root_hash = +// resp.ledger_info_with_sigs.ledger_info.ledger_info.txn_root_hash; +// - Validate that the transaction root hash submitted in +// req.known_value.node_value.txn_root_hash +// exists in the proof for accumulator and that the proof is valid with +// this hash +// - Validate ledger info +// - let ledger_info_hash = +// Hash(resp.ledger_info_with_sigs.ledger_info.ledger_info); +// - Verify signatures from resp.ledger_info_with_sigs.signatures are +// signing +// ledger_info_hash and that there are >2/3 nodes signing this +// correctly +// - Validate that the timestamp is relatively recent in +// resp.ledger_info_with_sigs.ledger_info.timestamp +// +// +// GetAccountTransactionBySequenceNumberResponse: +// - Note that other than type completed_transaction, there will be no proof +// returned +// since the transaction has not yet been committed. To ensure that a +// validator is telling the truth about it not being committed yet, a +// client should query for their account state and verify that their +// current sequence number is less than what they are searching for with +// GetAccountTransactionBySequenceNumberResponse +// - let txn = +// resp.get_account_transaction_by_sequence_number_response.transaction.committed_transaction; +// - let txn_hash = Hash(txn); +// - Verify that resp.proof.transaction_info.signed_transaction_hash == txn_hash +// - Validate accumulator using resp.proof.transaction_info as the leaf +// - When verifying the accumulator, use: +// root_hash = +// resp.ledger_info_with_sigs.ledger_info.ledger_info.txn_root_hash; +// - Validate that the transaction root hash submitted in +// req.known_value.node_value.txn_root_hash +// exists in the proof for accumulator and that the proof is valid with +// this hash +// - Validate ledger info +// - let ledger_info_hash = +// Hash(resp.ledger_info_with_sigs.ledger_info.ledger_info); +// - Verify signatures from resp.ledger_info_with_sigs.signatures are +// signing +// ledger_info_hash and that there are >2/3 nodes signing this +// correctly +// - Validate that the timestamp is relatively recent in +// resp.ledger_info_with_sigs.ledger_info.timestamp +// +// +// GetTransactionsResponse: +// - for txn in resp.get_transactions_response.transactions: +// - let txn = txn.committed_transaction; +// - let txn_hash = Hash(txn); +// - Verify that txn.proof.transaction_info.signed_transaction_hash == +// txn_hash +// - Validate accumulator using txn.proof.transaction_info as the leaf +// - When verifying the accumulator, use: +// root_hash = +// resp.ledger_info_with_sigs.ledger_info.ledger_info.txn_root_hash; +// - Verify that transactions are sequential and none are missing +// - Validate ledger info +// - let ledger_info_hash = +// Hash(resp.ledger_info_with_sigs.ledger_info.ledger_info); +// - Verify signatures from resp.ledger_info_with_sigs.signatures are +// signing +// ledger_info_hash and that there are >2/3 nodes signing this +// correctly +// - Validate that the timestamp is relatively recent in +// resp.ledger_info_with_sigs.ledger_info.timestamp +// - If the number of transactions returned is less than limit for an ascending +// query +// or if the requested offset > current version for a descending query, +// the client should verify that the timestamp in ledger info is relatively +// recent to determine if it is likely that all transactions available were +// returned +syntax = "proto3"; + +package types; + +import "access_path.proto"; +import "account_state_blob.proto"; +import "events.proto"; +import "ledger_info.proto"; +import "transaction.proto"; +import "validator_change.proto"; + +// ----------------------------------------------------------------------------- +// ---------------- Update to latest ledger request +// ----------------------------------------------------------------------------- + +// This API is used to update the client to the latest ledger version and +// optionally also request 1..n other pieces of data. This allows for batch +// queries. All queries return proofs that a client should check to validate +// the data. +// +// Note that if a client only wishes to update to the latest LedgerInfo and +// receive the proof that this latest ledger extends the client_known_version +// ledger the client had, they can simply set the requested_items to an empty +// list. +message UpdateToLatestLedgerRequest { + // This is the version the client already trusts. Usually the client should + // set this to the version it obtained the last time it synced with the + // chain. If this is the first time ever the client sends a request, it must + // use the waypoint hard-coded in its software. + uint64 client_known_version = 1; + + // The items for which we are requesting data in this API call. + repeated RequestItem requested_items = 2; +} + +message RequestItem { + oneof requested_items { + GetAccountStateRequest get_account_state_request = 1; + GetAccountTransactionBySequenceNumberRequest + get_account_transaction_by_sequence_number_request = 2; + GetEventsByEventAccessPathRequest get_events_by_event_access_path_request = + 3; + GetTransactionsRequest get_transactions_request = 4; + } +} + +// ----------------------------------------------------------------------------- +// ---------------- Update to latest ledger response +// ----------------------------------------------------------------------------- + +// Response from getting latest ledger +message UpdateToLatestLedgerResponse { + // Responses to the queries posed by the requests. The proofs generated will + // be relative to the version of the latest ledger provided below. + repeated ResponseItem response_items = 1; + + // The latest ledger info this node has. It will come with at least 2f+1 + // validator signatures as well as a proof that shows the latest ledger + // extends the old ledger the client had. + LedgerInfoWithSignatures ledger_info_with_sigs = 2; + + // Validator change events from what the client last knew. This is used to + // inform the client of validator changes from the client's last known version + // until the current version + repeated ValidatorChangeEventWithProof validator_change_events = 3; +} + +// Individual response items to the queries posed by the requests +message ResponseItem { + oneof response_items { + GetAccountStateResponse get_account_state_response = 3; + GetAccountTransactionBySequenceNumberResponse + get_account_transaction_by_sequence_number_response = 4; + GetEventsByEventAccessPathResponse get_events_by_event_access_path_response = 5; + GetTransactionsResponse get_transactions_response = 6; + } +} + +// ----------------------------------------------------------------------------- +// ---------------- Get account state (balance, sequence number, etc.) +// ----------------------------------------------------------------------------- + +// Gets latest state for an account. +message GetAccountStateRequest { + // Account for which we are fetching the state. + bytes address = 1; +} + +// State information returned by a get account state query. +message GetAccountStateResponse { + // Blob value representing the account state together with proof the client + // can utilize to verify it. + AccountStateWithProof account_state_with_proof = 1; +} + +// ----------------------------------------------------------------------------- +// ---------------- Get single transaction by account + sequence number +// ----------------------------------------------------------------------------- +// Get transactions that altered an account - this includes both sent and +// received. A user of this should check that the data returned matches what +// they expect. As an example, a potential attack vector would be something +// like the following: Alice is buying an apple from Bob. Alice's phone signs a +// transaction X with sequence number N that pays coins to Bob. Alice transmits +// this signature to Bob's payment terminal which then submits the transaction +// and checks its status to see if Alice can be given the apple. However, as Bob +// is doing this Alice constructs a second transaction X' also with sequence +// number N. Alice gets that transaction inserted in the blockchain. If Bob +// isn't thoughtful about how he uses this API he may assume that if he asks for +// the N'th transaction on Alice's account that when the API returns that this +// means the transaction has gone through. The point here is that one should be +// careful in reading too much into "transaction X is on the chain" and focus on +// the logs, which tell you what the transaction did. +// +// If a client submitted a transaction, they should also verify that the hash of +// the returned transaction matches what they submitted. As an example, if a +// client has two wallets that share the same account, they may both submit a +// transaction at the same sequence number and only one will be committed. A +// client should never assume that if they receive the response that this +// transaction was included that it means that this is definitely the +// transaction that was submitted. They should check that the hash matches what +// they sent +message GetAccountTransactionBySequenceNumberRequest { + // Account for which to query transactions + bytes account = 1; + + uint64 sequence_number = 2; + + // Set to true to fetch events for the transaction at this version + bool fetch_events = 3; +} + +// Transaction information for transactions requested by +// GetAccountTransactionsRequest +message GetAccountTransactionBySequenceNumberResponse { + // When the transaction requested is committed, return the committed + // transaction with proof. + SignedTransactionWithProof signed_transaction_with_proof = 2; + // When the transaction requested is not committed, we give a proof that + // shows the current sequence number is smaller than what would have been if + // the transaction was committed. + AccountStateWithProof proof_of_current_sequence_number = 3; +} + +// ----------------------------------------------------------------------------- +// ---------------- Get events by event access path +// ----------------------------------------------------------------------------- + +// Get events that exist on an event access path. In the current world, +// a user may specify events that were received, events that were sent, or any +// event that modifies their account +message GetEventsByEventAccessPathRequest { + AccessPath access_path = 1; + + // The sequence number of the event to start with for this query. Use a + // sequence number of MAX_INT to represent the latest. + uint64 start_event_seq_num = 2; + + // If ascending is true this query will return up to `limit` events that were + // emitted after `start_event_seq_num`. Otherwise it will return up to `limit` + // events before the offset. Both cases are inclusive. + bool ascending = 3; + + // Limit number of results + uint64 limit = 4; +} + +message GetEventsByEventAccessPathResponse { + // Returns an event and proof of each of the events in the request. The first + // element of proofs will be the closest to `start_event_seq_num`. + repeated EventWithProof events_with_proof = 1; + + // If the number of events returned is less than `limit` for an ascending + // query or if start_event_seq_num > the latest seq_num for a descending + // query, returns the state of the account containing the given access path + // in the latest state. This allows the client to verify that there are in + // fact no extra events. + // + // The LedgerInfoWithSignatures which is on the main + // UpdateToLatestLedgerResponse can be used to validate this. + AccountStateWithProof proof_of_latest_event = 2; +} + +// ----------------------------------------------------------------------------- +// ---------------- Get transactions +// ----------------------------------------------------------------------------- + +// Get up to limit transactions starting from start_version. +message GetTransactionsRequest { + // The version of the transaction to start with for this query. Use a version + // of MAX_INT to represent the latest. + uint64 start_version = 1; + + // Limit number of results + uint64 limit = 2; + + // Set to true to fetch events for the transaction at each version + bool fetch_events = 3; +} + +message GetTransactionsResponse { + TransactionListWithProof txn_list_with_proof = 1; +} diff --git a/types/src/proto/ledger_info.proto b/types/src/proto/ledger_info.proto new file mode 100644 index 0000000000000..eddc9affa19a6 --- /dev/null +++ b/types/src/proto/ledger_info.proto @@ -0,0 +1,82 @@ +// Copyright (c) The Libra Core Contributors +// SPDX-License-Identifier: Apache-2.0 + +syntax = "proto3"; + +package types; + +/// Even though we don't always need all hashes, we pass them in and return them +/// always so that we keep them in sync on the client and don't make the client +/// worry about which one(s) to pass in which cases +/// +/// This structure serves a dual purpose. +/// +/// First, if this structure is signed by 2f+1 validators it signifies the state +/// of the ledger at version `version` -- it contains the transaction +/// accumulator at that version which commits to all historical transactions. +/// This structure may be expanded to include other information that is derived +/// from that accumulator (e.g. the current time according to the time contract) +/// to reduce the number of proofs a client must get. +/// +/// Second, the structure contains a `consensus_data_hash` value. This is the +/// hash of an internal data structure that represents a block that is voted on +/// by consensus. +/// +/// Combining these two concepts when the consensus algorithm votes on a block B +/// it votes for a LedgerInfo with the `version` being the latest version that +/// will be committed if B gets 2f+1 votes. It sets `consensus_data_hash` to +/// represent B so that if those 2f+1 votes are gathered, the block is valid to +/// commit +message LedgerInfo { + // Current latest version of the system + uint64 version = 1; + + // Root hash of transaction accumulator at this version + bytes transaction_accumulator_hash = 2; + + // Hash of consensus-specific data that is opaque to all parts of the system + // other than consensus. This is needed to verify signatures because + // consensus signing includes this hash + bytes consensus_data_hash = 3; + + // The block id of the last committed block corresponding to this ledger info. + // This field is not particularly interesting to the clients, but can be used + // by the validators for synchronization. + bytes consensus_block_id = 4; + + // Epoch number corresponds to the set of validators that are active for this + // ledger info. The main motivation for keeping the epoch number in the + // LedgerInfo is to ensure that the client has enough information to verify + // that the signatures for this info are coming from the validators that + // indeed form a quorum. Without epoch number a potential attack could reuse + // the signatures from the validators in one epoch in order to sign the wrong + // info belonging to another epoch, in which these validators do not form a + // quorum. The very first epoch number is 0. + uint64 epoch_num = 5; + + // Timestamp that represents the microseconds since the epoch (unix time) that is + // generated by the proposer of the block. This is strictly increasing with every block. + // If a client reads a timestamp > the one they specified for transaction expiration time, + // they can be certain that their transaction will never be included in a block in the future + // (assuming that their transaction has not yet been included) + uint64 timestamp_usecs = 6; +} + +/// The validator node returns this structure which includes signatures +/// from each validator to confirm the state. The client needs to only pass +/// back the LedgerInfo element since the validator node doesn't need to know +/// the signatures again when the client performs a query, those are only there +/// for the client to be able to verify the state +message LedgerInfoWithSignatures { + // Signatures of the root node from each validator + repeated ValidatorSignature signatures = 1; + + LedgerInfo ledger_info = 2; +} + +message ValidatorSignature { + // The account address of the validator, which can be used for retrieving its + // public key during the given epoch. + bytes validator_id = 1; + bytes signature = 2; +} diff --git a/types/src/proto/mod.rs b/types/src/proto/mod.rs new file mode 100644 index 0000000000000..7bcc0d55b6dc9 --- /dev/null +++ b/types/src/proto/mod.rs @@ -0,0 +1,15 @@ +// Copyright (c) The Libra Core Contributors +// SPDX-License-Identifier: Apache-2.0 + +pub mod access_path; +pub mod account_state_blob; +pub mod events; +pub mod get_with_proof; +pub mod ledger_info; +pub mod proof; +pub mod transaction; +pub mod transaction_info; +pub mod validator_change; +pub mod validator_public_keys; +pub mod validator_set; +pub mod vm_errors; diff --git a/types/src/proto/proof.proto b/types/src/proto/proof.proto new file mode 100644 index 0000000000000..1037771f11ee8 --- /dev/null +++ b/types/src/proto/proof.proto @@ -0,0 +1,72 @@ +// Copyright (c) The Libra Core Contributors +// SPDX-License-Identifier: Apache-2.0 + +syntax = "proto3"; + +package types; + +import "transaction_info.proto"; + +message AccumulatorProof { + // The bitmap indicating which siblings are default. 1 means non-default and + // 0 means default. The LSB corresponds to the sibling at the bottom of the + // accumulator. The leftmost 1-bit corresponds to the sibling at the level + // just below root level in the accumulator, since this one is always + // non-default. + uint64 bitmap = 1; + + // The non-default siblings. The ones near the root are at the beginning of + // the list. + repeated bytes non_default_siblings = 2; +} + +message SparseMerkleProof { + // This proof can be used to authenticate whether a given leaf exists in the + // tree or not. In Rust: + // - If this is `Some(HashValue, HashValue)` + // - If the first `HashValue` equals requested key, this is an inclusion + // proof and the second `HashValue` equals the hash of the + // corresponding account blob. + // - Otherwise this is a non-inclusion proof. The first `HashValue` is + // the only key that exists in the subtree and the second `HashValue` + // equals the hash of the corresponding account blob. + // - If this is `None`, this is also a non-inclusion proof which indicates + // the subtree is empty. + // + // In protobuf, this leaf field should either be + // - empty, which corresponds to None in the Rust structure. + // - exactly 64 bytes, which corresponds to Some<(HashValue, HashValue)> + // in the Rust structure. + bytes leaf = 1; + + // The bitmap indicating which siblings are default. 1 means non-default and + // 0 means default. The MSB of the first byte corresponds to the sibling at + // the top of the Sparse Merkle Tree. The rightmost 1-bit of the last byte + // corresponds to the sibling at the bottom, since this one is always + // non-default. + bytes bitmap = 2; + + // The non-default siblings. The ones near the root are at the beginning of + // the list. + repeated bytes non_default_siblings = 3; +} + +// The complete proof used to authenticate a signed transaction. +message SignedTransactionProof { + AccumulatorProof ledger_info_to_transaction_info_proof = 1; + TransactionInfo transaction_info = 2; +} + +// The complete proof used to authenticate an account state. +message AccountStateProof { + AccumulatorProof ledger_info_to_transaction_info_proof = 1; + TransactionInfo transaction_info = 2; + SparseMerkleProof transaction_info_to_account_proof = 3; +} + +// The complete proof used to authenticate an event. +message EventProof { + AccumulatorProof ledger_info_to_transaction_info_proof = 1; + TransactionInfo transaction_info = 2; + AccumulatorProof transaction_info_to_event_proof = 3; +} diff --git a/types/src/proto/transaction.proto b/types/src/proto/transaction.proto new file mode 100644 index 0000000000000..ac3c5f149d4c4 --- /dev/null +++ b/types/src/proto/transaction.proto @@ -0,0 +1,170 @@ +// Copyright (c) The Libra Core Contributors +// SPDX-License-Identifier: Apache-2.0 + +syntax = "proto3"; + +package types; + +import "access_path.proto"; +import "events.proto"; +import "proof.proto"; +import "transaction_info.proto"; +import "google/protobuf/wrappers.proto"; + +// A generic structure that describes a transaction that a client submits +message RawTransaction { + // Sender's account address + bytes sender_account = 1; + // Sequence number of this transaction corresponding to sender's account. + uint64 sequence_number = 2; + oneof payload { + // The transaction script to execute. + Program program = 3; + // A write set, used for genesis blocks and other magic transactions. + // This bypasses the rules for regular transactions so will typically be + // rejected. Only under special circumstances will it be accepted. + WriteSet write_set = 4; + } + // Maximal total gas specified by wallet to spend for this transaction. + uint64 max_gas_amount = 5; + // The price to be paid for each unit of gas. + uint64 gas_unit_price = 6; + // Expiration time for this transaction. If storage is queried and + // the time returned is greater than or equal to this time and this + // transaction has not been included, you can be certain that it will + // never be included. + // If set to 0, there will be no expiration time + uint64 expiration_time = 7; +} + +// The code for the transaction to execute +message Program { + bytes code = 1; + repeated TransactionArgument arguments = 2; + repeated bytes modules = 3; +} + +// An argument to the transaction if the transaction takes arguments +message TransactionArgument { + enum ArgType { + U64 = 0; + ADDRESS = 1; + STRING = 2; + BYTEARRAY = 3; + } + ArgType type = 1; + bytes data = 2; +} + +// A generic structure that represents signed RawTransaction +message SignedTransaction { + // The serialized Protobuf bytes for RawTransaction, for which the signature + // was signed. Protobuf doesn't guarantee the serialized bytes is canonical + // across different language implementations, but for our use cases for + // transaction it is not necessary because the client is the only one to + // produce this bytes, which is then persisted in storage. + bytes raw_txn_bytes = 1; + // public key that corresponds to RawTransaction::sender_account + bytes sender_public_key = 2; + // signature for the hash + bytes sender_signature = 3; +} + +message SignedTransactionWithProof { + // The version of the returned signed transaction. + uint64 version = 1; + + // The transaction itself. + SignedTransaction signed_transaction = 2; + + // The proof authenticating the signed transaction. + SignedTransactionProof proof = 3; + + // The events yielded by executing the transaction, if requested. + EventsList events = 4; +} + +// A generic structure that represents a block of transactions originated from a +// particular validator instance. +message SignedTransactionsBlock { + // Set of Signed Transactions + repeated SignedTransaction transactions = 1; + // Public key of the validator that created this block + bytes validator_public_key = 2; + // Signature of the validator that created this block + bytes validator_signature = 3; +} + +// Set of WriteOps to save to storage. +message WriteSet { + // Set of WriteOp for storage update. + repeated WriteOp write_set = 1; +} + +// Write Operation on underlying storage. +message WriteOp { + // AccessPath of the write set. + AccessPath access_path = 1; + // The value of the write op. Empty if `type` is Delete. + bytes value = 2; + // WriteOp type. + WriteOpType type = 3; +} + +// Type of write operation +enum WriteOpType { + // The WriteOp is to create/update the field from storage. + Write = 0; + // The WriteOp is to delete the field from storage. + Delete = 1; +} + +// Account state as a whole. +// After execution, updates to accounts are passed in this form to storage for +// persistence. +message AccountState { + // Account address + bytes address = 1; + // Account state blob + bytes blob = 2; +} + +// Transaction struct to commit to storage +message TransactionToCommit { + // The signed transaction which was executed + SignedTransaction signed_txn = 1; + // State db updates + repeated AccountState account_states = 2; + // Events yielded by the transaction. + repeated Event events = 3; + // The amount of gas used. + uint64 gas_used = 4; +} + +// A list of consecutive transactions with proof. This is mainly used for state +// synchronization when a validator would request a list of transactions from a +// peer, verify the proof, execute the transactions and persist them. Note that +// the transactions are supposed to belong to the same epoch E, otherwise +// verification will fail. +message TransactionListWithProof { + // The list of transactions. + repeated SignedTransaction transactions = 1; + + // The list of corresponding TransactionInfo objects. + repeated TransactionInfo infos = 2; + + // The list of corresponding Event objects (only present if fetch_events was set to true in req) + EventsForVersions events_for_versions = 3; + + // If the list is not empty, the version of the first transaction. + google.protobuf.UInt64Value first_transaction_version = 4; + + // The proofs of the first and last transaction in this chunk. When this is + // used for state synchronization, the validator who requests the transactions + // will provide a version in the request and the proofs will be relative to + // the given version. When this is returned in GetTransactionsResponse, the + // proofs will be relative to the ledger info returned in + // UpdateToLatestLedgerResponse. + AccumulatorProof proof_of_first_transaction = 5; + AccumulatorProof proof_of_last_transaction = 6; +} diff --git a/types/src/proto/transaction_info.proto b/types/src/proto/transaction_info.proto new file mode 100644 index 0000000000000..d17073455c3f4 --- /dev/null +++ b/types/src/proto/transaction_info.proto @@ -0,0 +1,26 @@ +// Copyright (c) The Libra Core Contributors +// SPDX-License-Identifier: Apache-2.0 + +syntax = "proto3"; + +package types; + +// `TransactionInfo` is the object we store in the transaction accumulator. It +// consists of the transaction as well as the execution result of this +// transaction. This are later returned to the client so that a client can +// validate the tree +message TransactionInfo { + // Hash of the signed transaction that is stored + bytes signed_transaction_hash = 1; + + // The root hash of Sparse Merkle Tree describing the world state at the end + // of this transaction + bytes state_root_hash = 2; + + // The root hash of Merkle Accumulator storing all events emitted during this + // transaction. + bytes event_root_hash = 3; + + // The amount of gas used by this transaction. + uint64 gas_used = 4; +} diff --git a/types/src/proto/validator_change.proto b/types/src/proto/validator_change.proto new file mode 100644 index 0000000000000..731db3b33037b --- /dev/null +++ b/types/src/proto/validator_change.proto @@ -0,0 +1,27 @@ +// Copyright (c) The Libra Core Contributors +// SPDX-License-Identifier: Apache-2.0 + +syntax = "proto3"; + +package types; + +import "events.proto"; +import "ledger_info.proto"; + +// This is used to prove validator changes. When a validator is changing, it +// triggers an event on /validator_change_account/events/sent. To tell the +// client about validator changes, we query +// /validator_change_account/events/sent to get all versions that contain +// validator changes after the version that we are trying to update from. For +// each of these versions, the old validator set would have signed the ledger +// info at that version. The client needs this as well as the event results + +// proof. The client can then verify that these events were under the current +// tree and that the changes were signed by the old validators (and that the +// events correctly show which validators are the new validators). +// +// This message represents a single validator change event and the proof that +// corresponds to it +message ValidatorChangeEventWithProof { + LedgerInfoWithSignatures ledger_info_with_sigs = 1; + EventWithProof event_with_proof = 2; +} diff --git a/types/src/proto/validator_public_keys.proto b/types/src/proto/validator_public_keys.proto new file mode 100644 index 0000000000000..efcfcadb02dad --- /dev/null +++ b/types/src/proto/validator_public_keys.proto @@ -0,0 +1,18 @@ +// Copyright (c) The Libra Core Contributors +// SPDX-License-Identifier: Apache-2.0 + +syntax = "proto3"; + +package types; + +// Protobuf definition for the Rust struct ValidatorPublicKeys +message ValidatorPublicKeys { + // Validator account address + bytes account_address = 1; + // Consensus public key + bytes consensus_public_key = 2; + // Network signing publick key + bytes network_signing_public_key = 3; + /// Network identity publick key + bytes network_identity_public_key = 4; +} diff --git a/types/src/proto/validator_set.proto b/types/src/proto/validator_set.proto new file mode 100644 index 0000000000000..2bec2a670a064 --- /dev/null +++ b/types/src/proto/validator_set.proto @@ -0,0 +1,13 @@ +// Copyright (c) The Libra Core Contributors +// SPDX-License-Identifier: Apache-2.0 + +syntax = "proto3"; + +package types; + +import "validator_public_keys.proto"; + +// Protobuf definition for the Rust struct ValidatorSet. +message ValidatorSet { + repeated ValidatorPublicKeys validator_public_keys = 1; +} diff --git a/types/src/proto/vm_errors.proto b/types/src/proto/vm_errors.proto new file mode 100644 index 0000000000000..b29d32e00420f --- /dev/null +++ b/types/src/proto/vm_errors.proto @@ -0,0 +1,276 @@ +// Copyright (c) The Libra Core Contributors +// SPDX-License-Identifier: Apache-2.0 + +syntax = "proto3"; + +package types; + +// The statuses and errors produced by the VM can be categorized into a +// couple different types: +// 1. Validation Statuses: all the errors that can (/should) be +// the result of executing the prologue -- these are primarily used by +// the vm validator and AC. +// 2. Verification Errors: errors that are the result of performing +// bytecode verification (happens at the time of publishing). +// 3. VM Invariant Errors: errors that arise from an internal invariant of +// the VM being violated. These signify a problem with either the VM or +// bytecode verifier. +// 4. Binary Errors: errors that can occur during the process of +// deserialization of a transaction. +// 5. Runtime Statuses: errors that can arise from the execution of a +// transaction (assuming the prologue executes without error). These are +// errors that can occur during execution due to things such as division +// by zero, running out of gas, etc. These do not signify an issue with +// the VM. + +// NB: we make a distinction between a status and an error here: A +// status contains errors, along with possible affirmation of a successful +// execution or valid prologue. + +// The status of a transaction as determined by the prologue. +enum VMValidationStatusCode { + // We don't want the default value to be valid + UnknownValidationStatus = 0; + // The transaction has a bad signature + InvalidSignature = 1; + // Bad account authentication key + InvalidAuthKey = 2; + // Sequence number is too old + SequenceNumberTooOld = 3; + // Sequence number is too new + SequenceNumberTooNew = 4; + // Insufficient balance to pay minimum transaction fee + InsufficientBalanceForTransactionFee = 5; + // The transaction has expired + TransactionExpired = 6; + // The sending account does not exist + SendingAccountDoesNotExist = 7; + // This write set transaction was rejected because it did not meet the + // requirements for one. + RejectedWriteSet = 8; + // This write set transaction cannot be applied to the current state. + InvalidWriteSet = 9; + // Length of program field in raw transaction exceeded max length + ExceededMaxTransactionSize = 10; + // This script is not on our whitelist of script. + UnknownScript = 11; + // Transaction is trying to publish a new module. + UnknownModule = 12; + // Max gas units submitted with transaction exceeds max gas units bound + // in VM + MaxGasUnitsExceedsMaxGasUnitsBound = 13; + // Max gas units submitted with transaction not enough to cover the + // intrinsic cost of the transaction. + MaxGasUnitsBelowMinTransactionGasUnits = 14; + // Gas unit price submitted with transaction is below minimum gas price + // set in the VM. + GasUnitPriceBelowMinBound = 15; + // Gas unit price submitted with the transaction is above the maximum + // gas price set in the VM. + GasUnitPriceAboveMaxBound = 16; +} + +message VMValidationStatus { + VMValidationStatusCode code = 1; + string message = 2; +} + +message VMVerificationStatusList { + repeated VMVerificationStatus status_list = 1; +} + +message VMVerificationStatus { + enum StatusKind { + SCRIPT = 0; + MODULE = 1; + } + StatusKind status_kind = 1; + // For StatusKind::SCRIPT this is ignored. + uint32 module_idx = 2; + VMVerificationErrorKind error_kind = 3; + string message = 4; +} + +// When a code module/script is published it is verified. These are the +// possible errors that can arise from the verification process. +enum VMVerificationErrorKind { + // Likewise default to a unknown verification error + UnknownVerificationError = 0; + IndexOutOfBounds = 1; + RangeOutOfBounds = 2; + InvalidSignatureToken = 3; + InvalidFieldDefReference = 4; + RecursiveStructDefinition = 5; + InvalidResourceField = 6; + InvalidFallThrough = 7; + JoinFailure = 8; + NegativeStackSizeWithinBlock = 9; + UnbalancedStack = 10; + InvalidMainFunctionSignature = 11; + DuplicateElement = 12; + InvalidModuleHandle = 13; + UnimplementedHandle = 14; + InconsistentFields = 15; + UnusedFields = 16; + LookupFailed = 17; + VisibilityMismatch = 18; + TypeResolutionFailure = 19; + TypeMismatch = 20; + MissingDependency = 21; + PopReferenceError = 22; + PopResourceError = 23; + ReleaseRefTypeMismatchError = 24; + BrTypeMismatchError = 25; + AssertTypeMismatchError = 26; + StLocTypeMismatchError = 27; + StLocUnsafeToDestroyError = 28; + RetUnsafeToDestroyError = 29; + RetTypeMismatchError = 30; + FreezeRefTypeMismatchError = 31; + FreezeRefExistsMutableBorrowError = 32; + BorrowFieldTypeMismatchError = 33; + BorrowFieldBadFieldError = 34; + BorrowFieldExistsMutableBorrowError = 35; + CopyLocUnavailableError = 36; + CopyLocResourceError = 37; + CopyLocExistsBorrowError = 38; + MoveLocUnavailableError = 39; + MoveLocExistsBorrowError = 40; + BorrowLocReferenceError = 41; + BorrowLocUnavailableError = 42; + BorrowLocExistsBorrowError = 43; + CallTypeMismatchError = 44; + CallBorrowedMutableReferenceError = 45; + PackTypeMismatchError = 46; + UnpackTypeMismatchError = 47; + ReadRefTypeMismatchError = 48; + ReadRefResourceError = 49; + ReadRefExistsMutableBorrowError = 50; + WriteRefTypeMismatchError = 51; + WriteRefResourceError = 52; + WriteRefExistsBorrowError = 53; + WriteRefNoMutableReferenceError = 54; + IntegerOpTypeMismatchError = 55; + BooleanOpTypeMismatchError = 56; + EqualityOpTypeMismatchError = 57; + ExistsResourceTypeMismatchError = 58; + BorrowGlobalTypeMismatchError = 59; + BorrowGlobalNoResourceError = 60; + MoveFromTypeMismatchError = 61; + MoveFromNoResourceError = 62; + MoveToSenderTypeMismatchError = 63; + MoveToSenderNoResourceError = 64; + CreateAccountTypeMismatchError = 65; + // The self address of a module the transaction is publishing is not the sender address + ModuleAddressDoesNotMatchSender = 66; + // The module does not have any module handles. Each module or script must have at least one module handle. + NoModuleHandles = 67; +} + +// These are errors that the VM might raise if a violation of internal +// invariants takes place. +enum VMInvariantViolationError { + UnknownInvariantViolationError = 0; + OutOfBoundsIndex = 1; + OutOfBoundsRange = 2; + EmptyValueStack = 3; + EmptyCallStack = 4; + PCOverflow = 5; + LinkerError = 6; + LocalReferenceError = 7; + StorageError = 8; +} + +// Errors that can arise from binary decoding (deserialization) +enum BinaryError { + UnknownBinaryError = 0; + Malformed = 1; + BadMagic = 2; + UnknownVersion = 3; + UnknownTableType = 4; + UnknownSignatureType = 5; + UnknownSerializedType = 6; + UnknownOpcode = 7; + BadHeaderTable = 8; + UnexpectedSignatureType = 9; + DuplicateTable = 10; +} + +//************************* +// Runtime errors/status +//************************* + +enum RuntimeStatus { + UnknownRuntimeStatus = 0; + Executed = 1; + OutOfGas = 2; + // We tried to access a resource that does not exist under the account. + ResourceDoesNotExist = 3; + // We tried to create a resource under an account where that resource + // already exists. + ResourceAlreadyExists = 4; + // We accessed an account that is evicted. + EvictedAccountAccess = 5; + // We tried to create an account at an address where an account already + // exists. + AccountAddressAlreadyExists = 6; + TypeError = 7; + MissingData = 8; + DataFormatError = 9; + InvalidData = 10; + RemoteDataError = 11; + CannotWriteExistingResource = 12; + ValueSerializationError = 13; + ValueDeserializationError = 14; + // The sender is trying to publish a module named `M`, but the sender's account already contains + // a module with this name. + DuplicateModuleName = 15; +} + +// user-defined assertion error code number +message AssertionFailure { + uint64 assertion_error_code = 1; +} + +message ArithmeticError { + enum ArithmeticErrorType { + UnknownArithmeticError = 0; + Underflow = 1; + Overflow = 2; + DivisionByZero = 3; + // Fill with more later + } + ArithmeticErrorType error_code = 1; +} + +message DynamicReferenceError { + enum DynamicReferenceErrorType { + UnknownDynamicReferenceError = 0; + MoveOfBorrowedResource = 1; + GlobalRefAlreadyReleased = 2; + MissingReleaseRef = 3; + GlobalAlreadyBorrowed = 4; + // Fill with with more later + } + DynamicReferenceErrorType error_code = 1; +} + +message ExecutionStatus { + oneof execution_status { + RuntimeStatus runtime_status = 1; + AssertionFailure assertion_failure = 2; + ArithmeticError arithmetic_error = 3; + DynamicReferenceError reference_error = 4; + } +} + +// The status of the VM +message VMStatus { + oneof error_type { + VMValidationStatus validation = 1; + VMVerificationStatusList verification = 2; + VMInvariantViolationError invariant_violation = 3; + BinaryError deserialization = 4; + ExecutionStatus execution = 5; + } +} diff --git a/types/src/test_helpers/fixtures/scripts/make_placeholder_script.sh b/types/src/test_helpers/fixtures/scripts/make_placeholder_script.sh new file mode 100755 index 0000000000000..f8127cd511f57 --- /dev/null +++ b/types/src/test_helpers/fixtures/scripts/make_placeholder_script.sh @@ -0,0 +1,10 @@ +#!/bin/bash + +# Copyright (c) The Libra Core Contributors +# SPDX-License-Identifier: Apache-2.0 + +# Run this script if the serialization format changes. +cwd=$(dirname "$0") +INPUT="$cwd/../../../../../language/stdlib/transaction_scripts/placeholder_script.mvir" +OUTPUT="$cwd/placeholder_script.mvbin" +cargo run --package compiler -- --script --no-stdlib "$INPUT" --output "$OUTPUT" diff --git a/types/src/test_helpers/fixtures/scripts/placeholder_script.mvbin b/types/src/test_helpers/fixtures/scripts/placeholder_script.mvbin new file mode 100644 index 0000000000000..1b701f8efc8c4 Binary files /dev/null and b/types/src/test_helpers/fixtures/scripts/placeholder_script.mvbin differ diff --git a/types/src/test_helpers/mod.rs b/types/src/test_helpers/mod.rs new file mode 100644 index 0000000000000..da06d46ba51c3 --- /dev/null +++ b/types/src/test_helpers/mod.rs @@ -0,0 +1,4 @@ +// Copyright (c) The Libra Core Contributors +// SPDX-License-Identifier: Apache-2.0 + +pub mod transaction_test_helpers; diff --git a/types/src/test_helpers/transaction_test_helpers.rs b/types/src/test_helpers/transaction_test_helpers.rs new file mode 100644 index 0000000000000..71e6d56341a4a --- /dev/null +++ b/types/src/test_helpers/transaction_test_helpers.rs @@ -0,0 +1,180 @@ +// Copyright (c) The Libra Core Contributors +// SPDX-License-Identifier: Apache-2.0 + +use crate::{ + account_address::AccountAddress, + proto::transaction::{ + RawTransaction as ProtoRawTransaction, SignedTransaction as ProtoSignedTransaction, + SignedTransactionsBlock, + }, + transaction::{Program, RawTransaction, RawTransactionBytes, SignedTransaction}, + transaction_helpers::get_signed_transactions_digest, + write_set::WriteSet, +}; +use crypto::{hash::CryptoHash, signing::sign_message, PrivateKey, PublicKey}; +use proto_conv::{FromProto, IntoProto}; +use protobuf::Message; +use std::time::{SystemTime, UNIX_EPOCH}; + +static PLACEHOLDER_SCRIPT: &[u8] = include_bytes!("fixtures/scripts/placeholder_script.mvbin"); + +const MAX_GAS_AMOUNT: u64 = 10_000; +const MAX_GAS_PRICE: u64 = 1; + +// Test helper for transaction creation +pub fn get_test_signed_transaction( + sender: AccountAddress, + sequence_number: u64, + private_key: PrivateKey, + public_key: PublicKey, + program: Option, + expiration_time: u64, + gas_unit_price: u64, + max_gas_amount: Option, +) -> ProtoSignedTransaction { + let mut raw_txn = ProtoRawTransaction::new(); + raw_txn.set_sender_account(sender.as_ref().to_vec()); + raw_txn.set_sequence_number(sequence_number); + raw_txn.set_program(program.unwrap_or_else(placeholder_script).into_proto()); + raw_txn.set_expiration_time(expiration_time); + raw_txn.set_max_gas_amount(max_gas_amount.unwrap_or(MAX_GAS_AMOUNT)); + raw_txn.set_gas_unit_price(gas_unit_price); + + let bytes = raw_txn.write_to_bytes().unwrap(); + let hash = RawTransactionBytes(&bytes).hash(); + let signature = sign_message(hash, &private_key).unwrap(); + + let mut signed_txn = ProtoSignedTransaction::new(); + signed_txn.set_raw_txn_bytes(bytes); + signed_txn.set_sender_public_key(public_key.to_slice().to_vec()); + signed_txn.set_sender_signature(signature.to_compact().to_vec()); + signed_txn +} + +// from_proto does checking on the transaction's signature -- which we want to be able to turn off +// if we want to make sure that the VM is testing this. +pub fn get_unverified_test_signed_transaction( + sender: AccountAddress, + sequence_number: u64, + private_key: PrivateKey, + public_key: PublicKey, + program: Option, + expiration_time: u64, + gas_unit_price: u64, + max_gas_amount: Option, +) -> SignedTransaction { + let mut raw_txn = ProtoRawTransaction::new(); + raw_txn.set_sender_account(sender.as_ref().to_vec()); + raw_txn.set_sequence_number(sequence_number); + raw_txn.set_program(program.unwrap_or_else(placeholder_script).into_proto()); + raw_txn.set_expiration_time(expiration_time); + raw_txn.set_max_gas_amount(max_gas_amount.unwrap_or(MAX_GAS_AMOUNT)); + raw_txn.set_gas_unit_price(gas_unit_price); + + let bytes = raw_txn.write_to_bytes().unwrap(); + let hash = RawTransactionBytes(&bytes).hash(); + let signature = sign_message(hash, &private_key).unwrap(); + + SignedTransaction::new_for_test( + RawTransaction::from_proto(raw_txn).unwrap(), + public_key, + signature, + ) +} + +// Test helper for transaction creation. Short version for get_test_signed_transaction +// Omits some fields +pub fn get_test_signed_txn( + sender: AccountAddress, + sequence_number: u64, + private_key: PrivateKey, + public_key: PublicKey, + program: Option, +) -> ProtoSignedTransaction { + let expiration_time = SystemTime::now() + .duration_since(UNIX_EPOCH) + .unwrap() + .as_secs() + + 10; // 10 seconds from now. + get_test_signed_transaction( + sender, + sequence_number, + private_key, + public_key, + program, + expiration_time, + MAX_GAS_PRICE, + None, + ) +} + +pub fn get_unverified_test_signed_txn( + sender: AccountAddress, + sequence_number: u64, + private_key: PrivateKey, + public_key: PublicKey, + program: Option, +) -> SignedTransaction { + let expiration_time = SystemTime::now() + .duration_since(UNIX_EPOCH) + .unwrap() + .as_secs() + + 10; // 10 seconds from now. + get_unverified_test_signed_transaction( + sender, + sequence_number, + private_key, + public_key, + program, + expiration_time, + MAX_GAS_PRICE, + None, + ) +} + +pub fn placeholder_script() -> Program { + Program::new(PLACEHOLDER_SCRIPT.to_vec(), vec![], vec![]) +} + +pub fn get_write_set_txn( + sender: AccountAddress, + sequence_number: u64, + private_key: PrivateKey, + public_key: PublicKey, + write_set: Option, +) -> SignedTransaction { + let write_set = write_set.unwrap_or_default(); + RawTransaction::new_write_set(sender, sequence_number, write_set) + .sign(&private_key, public_key) + .unwrap() +} + +// Test helper for transaction block creation +pub fn create_signed_transactions_block( + sender: AccountAddress, + starting_sequence_number: u64, + num_transactions_in_block: u64, + priv_key: &PrivateKey, + pub_key: &PublicKey, + validator_priv_key: &PrivateKey, + validator_pub_key: &PublicKey, +) -> SignedTransactionsBlock { + let mut signed_txns_block = SignedTransactionsBlock::new(); + for i in starting_sequence_number..(starting_sequence_number + num_transactions_in_block) { + // Add some transactions to the block + signed_txns_block.transactions.push(get_test_signed_txn( + sender, + i, /* seq_number */ + priv_key.clone(), + *pub_key, + None, + )); + } + + let message = get_signed_transactions_digest(&signed_txns_block.transactions); + let signature = sign_message(message, &validator_priv_key).unwrap(); + signed_txns_block.set_validator_signature(signature.to_compact().to_vec()); + signed_txns_block.set_validator_public_key(validator_pub_key.to_slice().to_vec()); + + signed_txns_block +} diff --git a/types/src/transaction.rs b/types/src/transaction.rs new file mode 100644 index 0000000000000..e145042d2e8bf --- /dev/null +++ b/types/src/transaction.rs @@ -0,0 +1,1005 @@ +// Copyright (c) The Libra Core Contributors +// SPDX-License-Identifier: Apache-2.0 + +#![allow(clippy::unit_arg)] + +use crate::{ + account_address::AccountAddress, + account_state_blob::AccountStateBlob, + contract_event::ContractEvent, + ledger_info::LedgerInfo, + proof::{ + get_accumulator_root_hash, verify_signed_transaction, verify_transaction_list, + AccumulatorProof, SignedTransactionProof, + }, + proto::events::{EventsForVersions, EventsList}, + vm_error::VMStatus, + write_set::WriteSet, +}; +use canonical_serialization::{ + CanonicalDeserialize, CanonicalDeserializer, CanonicalSerialize, CanonicalSerializer, + SimpleSerializer, +}; +use crypto::{ + hash::{ + CryptoHash, CryptoHasher, EventAccumulatorHasher, RawTransactionHasher, + SignedTransactionHasher, TransactionInfoHasher, + }, + signing, HashValue, PrivateKey, PublicKey, Signature, +}; +use failure::prelude::*; +use proptest_derive::Arbitrary; +use proto_conv::{FromProto, IntoProto, IntoProtoBytes}; +use serde::{Deserialize, Serialize}; +use std::{collections::HashMap, convert::TryFrom, fmt, time::Duration}; + +mod program; + +pub use program::{Program, TransactionArgument, SCRIPT_HASH_LENGTH}; +use protobuf::well_known_types::UInt64Value; + +pub type Version = u64; // Height - also used for MVCC in StateDB + +pub const MAX_TRANSACTION_SIZE_IN_BYTES: usize = 4096; + +/// RawTransaction is the portion of a transaction that a client signs +#[derive(Clone, Debug, Hash, Eq, PartialEq, Serialize, Deserialize)] +pub struct RawTransaction { + /// Sender's address. + sender: AccountAddress, + // Sequence number of this transaction corresponding to sender's account. + sequence_number: u64, + // The transaction script to execute. + payload: TransactionPayload, + + // Maximal total gas specified by wallet to spend for this transaction. + max_gas_amount: u64, + // Maximal price can be paid per gas. + gas_unit_price: u64, + // Expiration time for this transaction. If storage is queried and + // the time returned is greater than or equal to this time and this + // transaction has not been included, you can be certain that it will + // never be included. + // A transaction that doesn't expire is represented by a very large value like + // u64::max_value(). + expiration_time: Duration, +} + +impl RawTransaction { + /// Create a new `RawTransaction` with a program. + /// + /// Almost all transactions are program transactions. See `new_write_set` for write-set + /// transactions. + pub fn new( + sender: AccountAddress, + sequence_number: u64, + program: Program, + max_gas_amount: u64, + gas_unit_price: u64, + expiration_time: Duration, + ) -> Self { + RawTransaction { + sender, + sequence_number, + payload: TransactionPayload::Program(program), + max_gas_amount, + gas_unit_price, + expiration_time, + } + } + + pub fn new_write_set( + sender: AccountAddress, + sequence_number: u64, + write_set: WriteSet, + ) -> Self { + RawTransaction { + sender, + sequence_number, + payload: TransactionPayload::WriteSet(write_set), + // Since write-set transactions bypass the VM, these fields aren't relevant. + max_gas_amount: 0, + gas_unit_price: 0, + // Write-set transactions are special and important and shouldn't expire. + expiration_time: Duration::new(u64::max_value(), 0), + } + } + + /// Signs the given `RawTransaction`. Note that this consumes the `RawTransaction` and turns it + /// into a `SignedTransaction`. + pub fn sign( + self, + private_key: &PrivateKey, + public_key: PublicKey, + ) -> Result { + let raw_txn_bytes = self.clone().into_proto_bytes()?; + let hash = RawTransactionBytes(&raw_txn_bytes).hash(); + let signature = signing::sign_message(hash, private_key)?; + Ok(SignedTransaction { + raw_txn: self, + public_key, + signature, + raw_txn_bytes, + }) + } + + pub fn into_payload(self) -> TransactionPayload { + self.payload + } + + pub fn format_for_client(&self, get_transaction_name: impl Fn(&[u8]) -> String) -> String { + let empty_vec = vec![]; + let (code, args) = match &self.payload { + TransactionPayload::Program(program) => { + (get_transaction_name(program.code()), program.args()) + } + TransactionPayload::WriteSet(_) => ("genesis".to_string(), &empty_vec[..]), + }; + let mut f_args: String = "".to_string(); + for arg in args { + f_args = format!("{}\n\t\t\t{:#?},", f_args, arg); + } + format!( + "RawTransaction {{ \n\ + \tsender: {}, \n\ + \tsequence_number: {}, \n\ + \tpayload: {{, \n\ + \t\ttransaction: {}, \n\ + \t\targs: [ {} \n\ + \t\t]\n\ + \t}}, \n\ + \tmax_gas_amount: {}, \n\ + \tgas_unit_price: {}, \n\ + \texpiration_time: {:#?}, \n\ + }}", + self.sender, + self.sequence_number, + code, + f_args, + self.max_gas_amount, + self.gas_unit_price, + self.expiration_time, + ) + } +} + +pub struct RawTransactionBytes<'a>(pub &'a [u8]); + +impl<'a> CryptoHash for RawTransactionBytes<'a> { + type Hasher = RawTransactionHasher; + + fn hash(&self) -> HashValue { + let mut state = Self::Hasher::default(); + state.write(self.0); + state.finish() + } +} + +impl FromProto for RawTransaction { + type ProtoType = crate::proto::transaction::RawTransaction; + + fn from_proto(mut txn: Self::ProtoType) -> Result { + Ok(RawTransaction { + sender: AccountAddress::try_from(txn.get_sender_account())?, + sequence_number: txn.sequence_number, + payload: if txn.has_program() { + TransactionPayload::Program(Program::from_proto(txn.take_program())?) + } else if txn.has_write_set() { + TransactionPayload::WriteSet(WriteSet::from_proto(txn.take_write_set())?) + } else { + bail!("RawTransaction payload missing"); + }, + max_gas_amount: txn.max_gas_amount, + gas_unit_price: txn.gas_unit_price, + expiration_time: Duration::from_secs(txn.expiration_time), + }) + } +} + +impl IntoProto for RawTransaction { + type ProtoType = crate::proto::transaction::RawTransaction; + + fn into_proto(self) -> Self::ProtoType { + let mut transaction = Self::ProtoType::new(); + transaction.set_sender_account(self.sender.as_ref().to_vec()); + transaction.set_sequence_number(self.sequence_number); + match self.payload { + TransactionPayload::Program(program) => transaction.set_program(program.into_proto()), + TransactionPayload::WriteSet(write_set) => { + transaction.set_write_set(write_set.into_proto()) + } + } + transaction.set_gas_unit_price(self.gas_unit_price); + transaction.set_max_gas_amount(self.max_gas_amount); + transaction.set_expiration_time(self.expiration_time.as_secs()); + transaction + } +} + +#[derive(Clone, Debug, Hash, Eq, PartialEq, Serialize, Deserialize)] +pub enum TransactionPayload { + /// A regular programmatic transaction that is executed by the VM. + Program(Program), + WriteSet(WriteSet), +} + +/// SignedTransaction is what a client submits to a validator node +#[derive(Clone, Eq, PartialEq, Hash, Serialize, Deserialize)] +pub struct SignedTransaction { + /// The raw transaction + raw_txn: RawTransaction, + + /// Sender's public key. When checking the signature, we first need to check whether this key + /// is indeed the pre-image of the pubkey hash stored under sender's account. + public_key: PublicKey, + + /// Signature of the transaction that correspond to the public key + signature: Signature, + + // The original raw bytes from the protobuf are also stored here so that we use + // these bytes when generating the canonical serialization of the SignedTransaction struct + // rather than re-serializing them again to avoid risk of non-determinism in the process + + // the raw transaction bytes generated from the wallet + raw_txn_bytes: Vec, +} + +impl fmt::Debug for SignedTransaction { + fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result { + write!( + f, + "SignedTransaction {{ \n \ + {{ raw_txn: {:#?}, \n \ + public_key: {:#?}, \n \ + signature: {:#?}, \n \ + }} \n \ + }}", + self.raw_txn, self.public_key, self.signature, + ) + } +} + +impl SignedTransaction { + pub fn new_for_test( + raw_txn: RawTransaction, + public_key: PublicKey, + signature: Signature, + ) -> SignedTransaction { + SignedTransaction { + raw_txn: raw_txn.clone(), + public_key, + signature, + // In real world raw_txn should be derived from raw_txn_bytes, not the opposite. + raw_txn_bytes: raw_txn.into_proto_bytes().expect("Should convert."), + } + } + + pub fn public_key(&self) -> PublicKey { + self.public_key + } + + pub fn signature(&self) -> Signature { + self.signature + } + + pub fn sender(&self) -> AccountAddress { + self.raw_txn.sender + } + + pub fn into_raw_transaction(self) -> RawTransaction { + self.raw_txn + } + + pub fn sequence_number(&self) -> u64 { + self.raw_txn.sequence_number + } + + pub fn payload(&self) -> &TransactionPayload { + &self.raw_txn.payload + } + + pub fn max_gas_amount(&self) -> u64 { + self.raw_txn.max_gas_amount + } + + pub fn gas_unit_price(&self) -> u64 { + self.raw_txn.gas_unit_price + } + + pub fn expiration_time(&self) -> Duration { + self.raw_txn.expiration_time + } + + pub fn raw_txn_bytes_len(&self) -> usize { + self.raw_txn_bytes.len() + } + + /// Verifies the signature of given transaction. Returns `Ok()` if the signature is valid. + pub fn verify_signature(&self) -> Result<()> { + let hash = RawTransactionBytes(&self.raw_txn_bytes).hash(); + signing::verify_message(hash, &self.signature, &self.public_key)?; + Ok(()) + } + + pub fn format_for_client(&self, get_transaction_name: impl Fn(&[u8]) -> String) -> String { + format!( + "SignedTransaction {{ \n \ + raw_txn: {}, \n \ + public_key: {:#?}, \n \ + signature: {:#?}, \n \ + }}", + self.raw_txn.format_for_client(get_transaction_name), + self.public_key, + self.signature, + ) + } +} + +impl CryptoHash for SignedTransaction { + type Hasher = SignedTransactionHasher; + + fn hash(&self) -> HashValue { + let mut state = Self::Hasher::default(); + state.write(&SimpleSerializer::>::serialize(self).expect("serialization failed")); + state.finish() + } +} + +impl FromProto for SignedTransaction { + type ProtoType = crate::proto::transaction::SignedTransaction; + + fn from_proto(txn: Self::ProtoType) -> Result { + let proto_raw_transaction = protobuf::parse_from_bytes::< + crate::proto::transaction::RawTransaction, + >(txn.raw_txn_bytes.as_ref())?; + + // First check if extra data is being sent in the proto. Note that this is a temporary + // measure to prevent extraneous data from being packaged. Longer-term, we will likely + // need to allow this for compatibility reasons. Note that we only need to do this + // for raw bytes under the signed transaction. We do this because we actually store this + // field in the DB. + // TODO: Remove prevention of unknown fields + ensure!( + proto_raw_transaction.unknown_fields.fields.is_none(), + "Unknown fields not allowed in testnet proto for raw transaction" + ); + + let t = SignedTransaction { + raw_txn: RawTransaction::from_proto(proto_raw_transaction)?, + public_key: PublicKey::from_slice(txn.get_sender_public_key())?, + signature: Signature::from_compact(txn.get_sender_signature())?, + raw_txn_bytes: txn.raw_txn_bytes, + }; + + // Please do not remove this check. It may appear redundant, as it is also performed by VM, + // but its goal is to ensure that: + // - transactions parsed from a GRPC request are validated before being processed by other + // portions of code; + // - Moxie Marlinspike's Cryptographic Doom Principle is mitigated; + // - resources are committed only for valid data. + match t.verify_signature() { + Ok(_) => Ok(t), + Err(e) => Err(e), + } + } +} + +impl IntoProto for SignedTransaction { + type ProtoType = crate::proto::transaction::SignedTransaction; + + fn into_proto(self) -> Self::ProtoType { + let mut transaction = Self::ProtoType::new(); + transaction.set_raw_txn_bytes(self.raw_txn_bytes); + transaction.set_sender_public_key(self.public_key.to_slice().to_vec()); + transaction.set_sender_signature(self.signature.to_compact().to_vec()); + transaction + } +} + +#[derive(Arbitrary, Clone, Debug, Eq, PartialEq)] +pub struct SignedTransactionWithProof { + pub version: Version, + pub signed_transaction: SignedTransaction, + pub events: Option>, + pub proof: SignedTransactionProof, +} + +impl SignedTransactionWithProof { + /// Verifies the signed transaction with the proof, both carried by `self`. + /// + /// Two things are ensured if no error is raised: + /// 1. This signed transaction exists in the ledger represented by `ledger_info`. + /// 2. And this signed transaction has the same `version`, `sender`, and `sequence_number` as + /// indicated by the parameter list. If any of these parameter is unknown to the call site that + /// is supposed to be informed via this struct, get it from the struct itself, such as: + /// `signed_txn_with_proof.version`, `signed_txn_with_proof.signed_transaction.sender()`, etc. + pub fn verify( + &self, + ledger_info: &LedgerInfo, + version: Version, + sender: AccountAddress, + sequence_number: u64, + ) -> Result<()> { + ensure!( + self.version == version, + "Version ({}) is not expected ({}).", + self.version, + version, + ); + ensure!( + self.signed_transaction.sender() == sender, + "Sender ({}) not expected ({}).", + self.signed_transaction.sender(), + sender, + ); + ensure!( + self.signed_transaction.sequence_number() == sequence_number, + "Sequence number ({}) not expected ({}).", + self.signed_transaction.sequence_number(), + sequence_number, + ); + + let events_root_hash = self.events.as_ref().map(|events| { + let event_hashes: Vec<_> = events.iter().map(ContractEvent::hash).collect(); + get_accumulator_root_hash::(&event_hashes) + }); + verify_signed_transaction( + ledger_info, + self.signed_transaction.hash(), + events_root_hash, + version, + &self.proof, + ) + } +} + +impl FromProto for SignedTransactionWithProof { + type ProtoType = crate::proto::transaction::SignedTransactionWithProof; + + fn from_proto(mut object: Self::ProtoType) -> Result { + let version = object.get_version(); + let signed_transaction = SignedTransaction::from_proto(object.take_signed_transaction())?; + let proof = SignedTransactionProof::from_proto(object.take_proof())?; + let events = object + .events + .take() + .map(|mut list| { + list.take_events() + .into_iter() + .map(ContractEvent::from_proto) + .collect::>>() + }) + .transpose()?; + + Ok(Self { + version, + signed_transaction, + proof, + events, + }) + } +} + +impl IntoProto for SignedTransactionWithProof { + type ProtoType = crate::proto::transaction::SignedTransactionWithProof; + + fn into_proto(self) -> Self::ProtoType { + let mut proto = Self::ProtoType::new(); + proto.set_version(self.version); + proto.set_signed_transaction(self.signed_transaction.into_proto()); + proto.set_proof(self.proof.into_proto()); + if let Some(events) = self.events { + let mut events_list = EventsList::new(); + events_list.set_events(protobuf::RepeatedField::from_vec( + events.into_iter().map(ContractEvent::into_proto).collect(), + )); + proto.set_events(events_list); + } + + proto + } +} + +impl CanonicalSerialize for SignedTransaction { + fn serialize(&self, serializer: &mut impl CanonicalSerializer) -> Result<()> { + serializer + .encode_variable_length_bytes(&self.raw_txn_bytes)? + .encode_variable_length_bytes(&self.public_key.to_slice())? + .encode_variable_length_bytes(&self.signature.to_compact())?; + Ok(()) + } +} + +impl CanonicalDeserialize for SignedTransaction { + fn deserialize(deserializer: &mut impl CanonicalDeserializer) -> Result + where + Self: Sized, + { + let raw_txn_bytes = deserializer.decode_variable_length_bytes()?; + let public_key_bytes = deserializer.decode_variable_length_bytes()?; + let signature_bytes = deserializer.decode_variable_length_bytes()?; + let proto_raw_transaction = protobuf::parse_from_bytes::< + crate::proto::transaction::RawTransaction, + >(raw_txn_bytes.as_ref())?; + + Ok(SignedTransaction { + raw_txn: RawTransaction::from_proto(proto_raw_transaction)?, + public_key: PublicKey::from_slice(&public_key_bytes)?, + signature: Signature::from_compact(&signature_bytes)?, + raw_txn_bytes, + }) + } +} + +/// The status of executing a transaction. The VM decides whether or not we should `Keep` the +/// transaction output or `Discard` it based upon the execution of the transaction. We wrap these +/// decisions around a `VMStatus` that provides more detail on the final execution state of the VM. +#[derive(Clone, Debug, Eq, PartialEq)] +pub enum TransactionStatus { + /// Discard the transaction output + Discard(VMStatus), + + /// Keep the transaction output + Keep(VMStatus), +} + +impl From for TransactionStatus { + fn from(vm_status: VMStatus) -> Self { + let should_discard = match vm_status { + // Any error that is a validation status (i.e. an error arising from the prologue) + // causes the transaction to not be included. + VMStatus::Validation(_) => true, + // If the VM encountered an invalid internal state, we should discard the transaction. + VMStatus::InvariantViolation(_) => true, + // A transaction that publishes code that cannot be verified is currently not charged. + // Therefore the transaction can be excluded. + // + // The original plan was to charge for verification, but the code didn't implement it + // properly. The decision of whether to charge or not will be made based on data (if + // verification checks are too expensive then yes, otherwise no). + VMStatus::Verification(_) => true, + // Even if we are unable to decode the transaction, there should be a charge made to + // that user's account for the gas fees related to decoding, running the prologue etc. + VMStatus::Deserialization(_) => false, + // Any error encountered during the execution of the transaction will charge gas. + VMStatus::Execution(_) => false, + }; + + if should_discard { + TransactionStatus::Discard(vm_status) + } else { + TransactionStatus::Keep(vm_status) + } + } +} + +/// The output of executing a transaction. +#[derive(Clone, Debug, Eq, PartialEq)] +pub struct TransactionOutput { + /// The list of writes this transaction intends to do. + write_set: WriteSet, + + /// The list of events emitted during this transaction. + events: Vec, + + /// The amount of gas used during execution. + gas_used: u64, + + /// The execution status. + status: TransactionStatus, +} + +impl TransactionOutput { + pub fn new( + write_set: WriteSet, + events: Vec, + gas_used: u64, + status: TransactionStatus, + ) -> Self { + TransactionOutput { + write_set, + events, + gas_used, + status, + } + } + + pub fn write_set(&self) -> &WriteSet { + &self.write_set + } + + pub fn events(&self) -> &[ContractEvent] { + &self.events + } + + pub fn gas_used(&self) -> u64 { + self.gas_used + } + + pub fn status(&self) -> &TransactionStatus { + &self.status + } +} + +/// `TransactionInfo` is the object we store in the transaction accumulator. It consists of the +/// transaction as well as the execution result of this transaction. +#[derive(Arbitrary, Clone, Debug, Eq, PartialEq, Serialize, Deserialize, FromProto, IntoProto)] +#[ProtoType(crate::proto::transaction_info::TransactionInfo)] +pub struct TransactionInfo { + /// The hash of this transaction. + signed_transaction_hash: HashValue, + + /// The root hash of Sparse Merkle Tree describing the world state at the end of this + /// transaction. + state_root_hash: HashValue, + + /// The root hash of Merkle Accumulator storing all events emitted during this transaction. + event_root_hash: HashValue, + + /// The amount of gas used. + gas_used: u64, +} + +impl TransactionInfo { + /// Constructs a new `TransactionInfo` object using signed transaction hash, state root hash + /// and event root hash. + pub fn new( + signed_transaction_hash: HashValue, + state_root_hash: HashValue, + event_root_hash: HashValue, + gas_used: u64, + ) -> TransactionInfo { + TransactionInfo { + signed_transaction_hash, + state_root_hash, + event_root_hash, + gas_used, + } + } + + /// Returns the hash of this transaction. + pub fn signed_transaction_hash(&self) -> HashValue { + self.signed_transaction_hash + } + + /// Returns root hash of Sparse Merkle Tree describing the world state at the end of this + /// transaction. + pub fn state_root_hash(&self) -> HashValue { + self.state_root_hash + } + + /// Returns the root hash of Merkle Accumulator storing all events emitted during this + /// transaction. + pub fn event_root_hash(&self) -> HashValue { + self.event_root_hash + } + + /// Returns the amount of gas used by this transaction. + pub fn gas_used(&self) -> u64 { + self.gas_used + } +} + +impl CanonicalSerialize for TransactionInfo { + fn serialize(&self, serializer: &mut impl CanonicalSerializer) -> Result<()> { + serializer + .encode_raw_bytes(self.signed_transaction_hash.as_ref())? + .encode_raw_bytes(self.state_root_hash.as_ref())? + .encode_raw_bytes(self.event_root_hash.as_ref())? + .encode_u64(self.gas_used)?; + Ok(()) + } +} + +impl CryptoHash for TransactionInfo { + type Hasher = TransactionInfoHasher; + + fn hash(&self) -> HashValue { + let mut state = Self::Hasher::default(); + state.write( + &SimpleSerializer::>::serialize(self).expect("Serialization should work."), + ); + state.finish() + } +} + +#[derive(Clone, Debug, Eq, PartialEq)] +pub struct TransactionToCommit { + signed_txn: SignedTransaction, + account_states: HashMap, + events: Vec, + gas_used: u64, +} + +impl TransactionToCommit { + pub fn new( + signed_txn: SignedTransaction, + account_states: HashMap, + events: Vec, + gas_used: u64, + ) -> Self { + TransactionToCommit { + signed_txn, + account_states, + events, + gas_used, + } + } + + pub fn signed_txn(&self) -> &SignedTransaction { + &self.signed_txn + } + + pub fn account_states(&self) -> &HashMap { + &self.account_states + } + + pub fn events(&self) -> &[ContractEvent] { + &self.events + } + + pub fn gas_used(&self) -> u64 { + self.gas_used + } +} + +impl FromProto for TransactionToCommit { + type ProtoType = crate::proto::transaction::TransactionToCommit; + + fn from_proto(mut object: Self::ProtoType) -> Result { + let signed_txn = SignedTransaction::from_proto(object.take_signed_txn())?; + let account_states_proto = object.take_account_states(); + let num_account_states = account_states_proto.len(); + let account_states = account_states_proto + .into_iter() + .map(|mut x| { + Ok(( + AccountAddress::from_proto(x.take_address())?, + AccountStateBlob::from(x.take_blob()), + )) + }) + .collect::>>()?; + ensure!( + account_states.len() == num_account_states, + "account_states should have no duplication." + ); + let events = object + .take_events() + .into_iter() + .map(ContractEvent::from_proto) + .collect::>>()?; + let gas_used = object.get_gas_used(); + + Ok(TransactionToCommit { + signed_txn, + account_states, + events, + gas_used, + }) + } +} + +impl IntoProto for TransactionToCommit { + type ProtoType = crate::proto::transaction::TransactionToCommit; + + fn into_proto(self) -> Self::ProtoType { + let mut proto = Self::ProtoType::new(); + proto.set_signed_txn(self.signed_txn.into_proto()); + proto.set_account_states(protobuf::RepeatedField::from_vec( + self.account_states + .into_iter() + .map(|(address, blob)| { + let mut account_state = crate::proto::transaction::AccountState::new(); + account_state.set_address(address.as_ref().to_vec()); + account_state.set_blob(blob.into()); + account_state + }) + .collect::>(), + )); + proto.set_events(protobuf::RepeatedField::from_vec( + self.events + .into_iter() + .map(ContractEvent::into_proto) + .collect::>(), + )); + proto.set_gas_used(self.gas_used); + proto + } +} + +/// The list may have three states: +/// 1. The list is empty. Both proofs must be `None`. +/// 2. The list has only 1 transaction/transaction_info. Then `proof_of_first_transaction` +/// must exist and `proof_of_last_transaction` must be `None`. +/// 3. The list has 2+ transactions/transaction_infos. The both proofs must exist. +#[derive(Clone, Debug, Eq, PartialEq)] +pub struct TransactionListWithProof { + pub transaction_and_infos: Vec<(SignedTransaction, TransactionInfo)>, + pub events: Option>>, + pub first_transaction_version: Option, + pub proof_of_first_transaction: Option, + pub proof_of_last_transaction: Option, +} + +impl TransactionListWithProof { + /// Constructor. + pub fn new( + transaction_and_infos: Vec<(SignedTransaction, TransactionInfo)>, + events: Option>>, + first_transaction_version: Option, + proof_of_first_transaction: Option, + proof_of_last_transaction: Option, + ) -> Self { + Self { + transaction_and_infos, + events, + first_transaction_version, + proof_of_first_transaction, + proof_of_last_transaction, + } + } + + /// Creates an empty transaction list. + pub fn new_empty() -> Self { + Self::new(Vec::new(), None, None, None, None) + } + + /// Verifies the transaction list with the proofs, both carried on `self`. + /// + /// Two things are ensured if no error is raised: + /// 1. All the transactions exist on the ledger represented by `ledger_info`. + /// 2. And the transactions in the list has consecutive versions starting from + /// `first_transaction_version`. When `first_transaction_version` is None, ensures the list is + /// empty. + pub fn verify( + &self, + ledger_info: &LedgerInfo, + first_transaction_version: Option, + ) -> Result<()> { + ensure!( + self.first_transaction_version == first_transaction_version, + "First transaction version ({}) not expected ({}).", + Self::display_option_version(self.first_transaction_version), + Self::display_option_version(first_transaction_version), + ); + + verify_transaction_list(ledger_info, self) + } + + fn display_option_version(version: Option) -> String { + match version { + Some(v) => format!("{}", v), + None => String::from("absent"), + } + } +} + +impl FromProto for TransactionListWithProof { + type ProtoType = crate::proto::transaction::TransactionListWithProof; + + fn from_proto(mut object: Self::ProtoType) -> Result { + let num_txns = object.get_transactions().len(); + let num_infos = object.get_infos().len(); + ensure!( + num_txns == num_infos, + "Number of transactions ({}) does not match the number of transaction infos ({}).", + num_txns, + num_infos + ); + let (has_first, has_last, has_first_version) = ( + object.has_proof_of_first_transaction(), + object.has_proof_of_last_transaction(), + object.has_first_transaction_version(), + ); + match num_txns { + 0 => ensure!( + !has_first && !has_last && !has_first_version, + "Some proof exists with 0 transactions" + ), + 1 => ensure!( + has_first && !has_last && has_first_version, + "Proof of last transaction exists with 1 transaction" + ), + _ => ensure!( + has_first && has_last && has_first_version, + "Both proofs of first and last transactions must exist with 2+ transactions" + ), + } + + let events = object + .events_for_versions + .take() // Option + .map(|mut events_for_versions| { + // EventsForVersion + events_for_versions + .take_events_for_version() + .into_iter() + .map(|mut events_for_version| { + events_for_version + .take_events() + .into_iter() + .map(ContractEvent::from_proto) + .collect::>>() + }) + .collect::>>() + }) + .transpose()?; + + let transaction_and_infos = itertools::zip_eq( + object.take_transactions().into_iter(), + object.take_infos().into_iter(), + ) + .map(|(txn, info)| { + Ok(( + SignedTransaction::from_proto(txn)?, + TransactionInfo::from_proto(info)?, + )) + }) + .collect::>>()?; + + Ok(TransactionListWithProof { + transaction_and_infos, + events, + proof_of_first_transaction: object + .proof_of_first_transaction + .take() + .map(AccumulatorProof::from_proto) + .transpose()?, + proof_of_last_transaction: object + .proof_of_last_transaction + .take() + .map(AccumulatorProof::from_proto) + .transpose()?, + first_transaction_version: object + .first_transaction_version + .take() + .map(|v| v.get_value()), + }) + } +} + +impl IntoProto for TransactionListWithProof { + type ProtoType = crate::proto::transaction::TransactionListWithProof; + + fn into_proto(self) -> Self::ProtoType { + let (transactions, infos) = self + .transaction_and_infos + .into_iter() + .map(|(txn, info)| (txn.into_proto(), info.into_proto())) + .unzip(); + + let mut out = Self::ProtoType::new(); + out.set_transactions(protobuf::RepeatedField::from_vec(transactions)); + out.set_infos(protobuf::RepeatedField::from_vec(infos)); + + if let Some(all_events) = self.events { + let mut events_for_versions = EventsForVersions::new(); + for events_for_version in all_events { + let mut events_this_version = EventsList::new(); + events_this_version.set_events(protobuf::RepeatedField::from_vec( + events_for_version + .into_iter() + .map(ContractEvent::into_proto) + .collect(), + )); + events_for_versions + .events_for_version + .push(events_this_version); + } + out.set_events_for_versions(events_for_versions); + } + + if let Some(first_transaction_version) = self.first_transaction_version { + let mut ver = UInt64Value::new(); + ver.set_value(first_transaction_version); + out.set_first_transaction_version(ver); + } + if let Some(proof_of_first_transaction) = self.proof_of_first_transaction { + out.set_proof_of_first_transaction(proof_of_first_transaction.into_proto()); + } + if let Some(proof_of_last_transaction) = self.proof_of_last_transaction { + out.set_proof_of_last_transaction(proof_of_last_transaction.into_proto()); + } + out + } +} diff --git a/types/src/transaction/program.rs b/types/src/transaction/program.rs new file mode 100644 index 0000000000000..976b5688686ee --- /dev/null +++ b/types/src/transaction/program.rs @@ -0,0 +1,164 @@ +// Copyright (c) The Libra Core Contributors +// SPDX-License-Identifier: Apache-2.0 + +use crate::{ + account_address::AccountAddress, + byte_array::ByteArray, + proto::transaction::{TransactionArgument as ProtoArgument, TransactionArgument_ArgType}, +}; +use byteorder::{LittleEndian, WriteBytesExt}; +use failure::prelude::*; +use proto_conv::{FromProto, IntoProto}; +use serde::{Deserialize, Serialize}; +use std::{convert::TryFrom, fmt}; + +pub const SCRIPT_HASH_LENGTH: usize = 32; + +#[derive(Clone, Hash, Eq, PartialEq, Serialize, Deserialize)] +pub struct Program { + code: Vec, + args: Vec, + modules: Vec>, +} + +impl Program { + pub fn new(code: Vec, modules: Vec>, args: Vec) -> Program { + Program { + code, + modules, + args, + } + } + + pub fn code(&self) -> &[u8] { + &self.code + } + + pub fn args(&self) -> &[TransactionArgument] { + &self.args + } + + pub fn modules(&self) -> &[Vec] { + &self.modules + } + + pub fn into_inner(self) -> (Vec, Vec, Vec>) { + (self.code, self.args, self.modules) + } +} + +impl fmt::Debug for Program { + fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { + // XXX note that "code" will eventually be encoded bytecode and will no longer be a + // UTF8-ish string -- at that point the from_utf8_lossy will stop making sense. + f.debug_struct("Program") + .field("code", &String::from_utf8_lossy(&self.code)) + .field("args", &self.args) + .finish() + } +} + +impl FromProto for Program { + type ProtoType = crate::proto::transaction::Program; + + fn from_proto(proto_program: Self::ProtoType) -> Result { + let mut args = vec![]; + for arg in proto_program.get_arguments() { + let argument = match arg.get_field_type() { + TransactionArgument_ArgType::U64 => { + let mut bytes = [0u8; 8]; + let data = arg.get_data(); + ensure!( + bytes.len() == data.len(), + "data has incorrect length: expected {} bytes, found {} bytes", + bytes.len(), + data.len() + ); + bytes.copy_from_slice(arg.get_data()); + let amount = u64::from_le_bytes(bytes); + TransactionArgument::U64(amount) + } + TransactionArgument_ArgType::ADDRESS => { + TransactionArgument::Address(AccountAddress::try_from(arg.get_data())?) + } + TransactionArgument_ArgType::STRING => { + TransactionArgument::String(String::from_utf8(arg.get_data().to_vec())?) + } + TransactionArgument_ArgType::BYTEARRAY => { + TransactionArgument::ByteArray(ByteArray::new(arg.get_data().to_vec())) + } + }; + args.push(argument); + } + let mut modules = vec![]; + for m in proto_program.get_modules() { + modules.push(m.to_vec()); + } + Ok(Program::new( + proto_program.get_code().to_vec(), + modules, + args, + )) + } +} + +impl IntoProto for Program { + type ProtoType = crate::proto::transaction::Program; + + fn into_proto(self) -> Self::ProtoType { + let mut proto_program = Self::ProtoType::new(); + proto_program.set_code(self.code); + for arg in self.args { + let mut argument = ProtoArgument::new(); + + match arg { + TransactionArgument::U64(amount) => { + argument.set_field_type(TransactionArgument_ArgType::U64); + let mut amount_vec = vec![]; + amount_vec + .write_u64::(amount) + .expect("Writing to a vec is guaranteed to work"); + argument.set_data(amount_vec); + } + TransactionArgument::Address(address) => { + argument.set_field_type(TransactionArgument_ArgType::ADDRESS); + argument.set_data(address.as_ref().to_vec()); + } + TransactionArgument::String(string) => { + argument.set_field_type(TransactionArgument_ArgType::STRING); + argument.set_data(string.into_bytes()); + } + TransactionArgument::ByteArray(byte_array) => { + argument.set_field_type(TransactionArgument_ArgType::BYTEARRAY); + argument.set_data(byte_array.as_bytes().to_vec()) + } + } + proto_program.mut_arguments().push(argument); + } + for m in self.modules { + proto_program.mut_modules().push(m); + } + proto_program + } +} + +#[derive(Clone, Hash, Eq, PartialEq, Serialize, Deserialize)] +pub enum TransactionArgument { + U64(u64), + Address(AccountAddress), + ByteArray(ByteArray), + String(String), +} + +impl fmt::Debug for TransactionArgument { + fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { + match self { + TransactionArgument::U64(value) => write!(f, "{{U64: {}}}", value), + TransactionArgument::Address(address) => write!(f, "{{ADDRESS: {:?}}}", address), + TransactionArgument::String(string) => write!(f, "{{STRING: {}}}", string), + TransactionArgument::ByteArray(byte_array) => { + write!(f, "{{ByteArray: 0x{}}}", byte_array) + } + } + } +} diff --git a/types/src/transaction_helpers.rs b/types/src/transaction_helpers.rs new file mode 100644 index 0000000000000..596e0a7a6a3fc --- /dev/null +++ b/types/src/transaction_helpers.rs @@ -0,0 +1,15 @@ +// Copyright (c) The Libra Core Contributors +// SPDX-License-Identifier: Apache-2.0 + +use crate::proto::transaction::SignedTransaction; +use crypto::{hash::TestOnlyHash, HashValue}; + +/// Used to get the digest of a set of signed transactions. This is used by a validator +/// to sign a block and to verify the signatures of other validators on a block +pub fn get_signed_transactions_digest(signed_txns: &[SignedTransaction]) -> HashValue { + let mut signatures = vec![]; + for transaction in signed_txns { + signatures.extend_from_slice(&transaction.sender_signature); + } + signatures.test_only_hash() +} diff --git a/types/src/unit_tests/access_path_test.rs b/types/src/unit_tests/access_path_test.rs new file mode 100644 index 0000000000000..23560a8779750 --- /dev/null +++ b/types/src/unit_tests/access_path_test.rs @@ -0,0 +1,62 @@ +// Copyright (c) The Libra Core Contributors +// SPDX-License-Identifier: Apache-2.0 + +use crate::{ + access_path::AccessPath, + account_address::{AccountAddress, ADDRESS_LENGTH}, +}; +use proptest::prelude::*; +use proto_conv::{test_helper::assert_protobuf_encode_decode, FromProto, IntoProto}; + +#[test] +fn access_path_ord() { + let ap1 = AccessPath { + address: AccountAddress::new([1u8; ADDRESS_LENGTH]), + path: b"/foo/b".to_vec(), + }; + let ap2 = AccessPath { + address: AccountAddress::new([1u8; ADDRESS_LENGTH]), + path: b"/foo/c".to_vec(), + }; + let ap3 = AccessPath { + address: AccountAddress::new([1u8; ADDRESS_LENGTH]), + path: b"/foo/c".to_vec(), + }; + let ap4 = AccessPath { + address: AccountAddress::new([2u8; ADDRESS_LENGTH]), + path: b"/foo/a".to_vec(), + }; + assert!(ap1 < ap2); + assert_eq!(ap2, ap3); + assert!(ap3 < ap4); +} + +#[test] +fn test_access_path_protobuf_conversion() { + let address = AccountAddress::new([1u8; ADDRESS_LENGTH]); + let path = b"/foo/bar".to_vec(); + let ap = AccessPath { + address, + path: path.clone(), + }; + let proto_ap = ap.clone().into_proto(); + assert_eq!(Vec::from(&address), proto_ap.get_address()); + assert_eq!(path, proto_ap.get_path()); + assert_eq!(AccessPath::from_proto(proto_ap).unwrap(), ap); +} + +#[test] +fn test_access_path_protobuf_conversion_error() { + let mut proto_ap = crate::proto::access_path::AccessPath::new(); + // Not a valid address. + proto_ap.set_address(vec![0x12, 0x34]); + proto_ap.set_path(b"/foo/bar".to_vec()); + assert!(AccessPath::from_proto(proto_ap).is_err()); +} + +proptest! { + #[test] + fn test_access_path_to_protobuf_roundtrip(access_path in any::()) { + assert_protobuf_encode_decode(&access_path); + } +} diff --git a/types/src/unit_tests/address_test.rs b/types/src/unit_tests/address_test.rs new file mode 100644 index 0000000000000..9dd4eeeb4c3a9 --- /dev/null +++ b/types/src/unit_tests/address_test.rs @@ -0,0 +1,141 @@ +// Copyright (c) The Libra Core Contributors +// SPDX-License-Identifier: Apache-2.0 + +use crate::account_address::{AccountAddress, ADDRESS_LENGTH}; +use bech32::Bech32; +use crypto::{hash::CryptoHash, HashValue}; +use hex::FromHex; +use proptest::prelude::*; +use proto_conv::{FromProto, IntoProto}; +use rand::{thread_rng, Rng}; +use std::convert::{AsRef, TryFrom}; +use test::Bencher; + +#[test] +fn test_address_bytes() { + let hex = Vec::from_hex("ca843279e3427144cead5e4d5999a3d0ccf92b8e124793820837625638742903") + .expect("You must provide a valid Hex format"); + + assert_eq!( + hex.len(), + ADDRESS_LENGTH as usize, + "Address {:?} is not {}-bytes long. Addresses must be {} bytes", + hex, + ADDRESS_LENGTH, + ADDRESS_LENGTH, + ); + let address = AccountAddress::try_from(&hex[..]).unwrap_or_else(|_| { + panic!( + "The address {:?} is of invalid length. Addresses must be 32-bytes long", + &hex + ) + }); + + assert_eq!(address.as_ref().to_vec(), hex); +} + +#[test] +fn test_address() { + let hex = Vec::from_hex("ca843279e3427144cead5e4d5999a3d0ccf92b8e124793820837625638742903") + .expect("You must provide a valid Hex format"); + + assert_eq!( + hex.len(), + ADDRESS_LENGTH as usize, + "Address {:?} is not {}-bytes long. Addresses must be {} bytes", + hex, + ADDRESS_LENGTH, + ADDRESS_LENGTH, + ); + + let address: AccountAddress = AccountAddress::try_from(&hex[..]).unwrap_or_else(|_| { + panic!( + "The address {:?} is of invalid length. Addresses must be 32-bytes long", + &hex + ) + }); + + let hash_vec = + &Vec::from_hex("2e10c936c9c69d9b4d99030e13b41c88bd09bb2b29bec7f48699f76eac383956") + .expect("You must provide a valid Hex format"); + + let mut hash = [0u8; 32]; + let bytes = &hash_vec[..32]; + hash.copy_from_slice(&bytes); + + assert_eq!(address.hash(), HashValue::new(hash)); + assert_eq!(address.as_ref().to_vec(), hex); +} + +#[test] +fn test_ref() { + let address = AccountAddress::new([1u8; 32]); + let _: &[u8] = address.as_ref(); +} + +#[test] +fn test_bech32() { + let address = AccountAddress::try_from( + &Vec::from_hex("269bdde7f42c25476707821eb44d5ce3c6c9e50a774f43ddebc5494a42870aa6") + .expect("You must provide a valid Hex format")[..], + ) + .expect("Address is not a valid hex format"); + let bech32 = Bech32::try_from(address).unwrap(); + assert_eq!( + bech32.to_string(), + "lb1y6damel59sj5wec8sg0tgn2uu0rvneg2wa858h0tc4y55s58p2nqjyd2lr".to_string() + ); + let bech32_address = AccountAddress::try_from(bech32) + .expect("The provided input string is not a valid bech32 format"); + assert_eq!( + address.as_ref().to_vec(), + bech32_address.as_ref().to_vec(), + "The two addresses do not match", + ); +} + +#[bench] +fn test_n_bech32(bh: &mut Bencher) { + bh.iter(|| { + let mut rng = thread_rng(); + let random_bytes: [u8; ADDRESS_LENGTH] = rng.gen(); + let address = AccountAddress::new(random_bytes); + let bech32 = Bech32::try_from(address).unwrap(); + let address_from_bech32 = AccountAddress::try_from(bech32) + .expect("The provided input string is not valid bech32 format"); + assert_eq!( + address.as_ref().to_vec(), + address_from_bech32.as_ref().to_vec() + ); + }); +} + +#[test] +fn test_address_from_proto_invalid_length() { + let bytes = vec![1; 123]; + assert!(AccountAddress::from_proto(bytes).is_err()); +} + +proptest! { + #[test] + fn test_address_string_roundtrip(addr in any::()) { + let s = String::from(&addr); + let addr2 = AccountAddress::try_from(s).expect("roundtrip to string should work"); + prop_assert_eq!(addr, addr2); + } + + #[test] + fn test_address_bech32_roundtrip(addr in any::()) { + let b = Bech32::try_from(addr).unwrap(); + let addr2 = AccountAddress::try_from(b).expect("Address::from_bech32 should work"); + prop_assert_eq!(addr, addr2); + } + + #[test] + fn test_address_protobuf_roundtrip(addr in any::()) { + let bytes = addr.into_proto(); + prop_assert_eq!(bytes.clone(), addr.as_ref()); + let addr2 = AccountAddress::from_proto(bytes).unwrap(); + prop_assert_eq!(addr, addr2); + } +} diff --git a/types/src/unit_tests/contract_event_proto_conversion_test.rs b/types/src/unit_tests/contract_event_proto_conversion_test.rs new file mode 100644 index 0000000000000..acbf372749861 --- /dev/null +++ b/types/src/unit_tests/contract_event_proto_conversion_test.rs @@ -0,0 +1,18 @@ +// Copyright (c) The Libra Core Contributors +// SPDX-License-Identifier: Apache-2.0 + +use crate::contract_event::{ContractEvent, EventWithProof}; +use proptest::prelude::*; +use proto_conv::test_helper::assert_protobuf_encode_decode; + +proptest! { + #[test] + fn test_event(event in any::()) { + assert_protobuf_encode_decode(&event); + } + + #[test] + fn test_event_with_proof(event_with_proof in any::()) { + assert_protobuf_encode_decode(&event_with_proof); + } +} diff --git a/types/src/unit_tests/get_with_proof_proto_conversion_test.rs b/types/src/unit_tests/get_with_proof_proto_conversion_test.rs new file mode 100644 index 0000000000000..3b1d82c8e536f --- /dev/null +++ b/types/src/unit_tests/get_with_proof_proto_conversion_test.rs @@ -0,0 +1,38 @@ +// Copyright (c) The Libra Core Contributors +// SPDX-License-Identifier: Apache-2.0 + +use crate::get_with_proof::{ + RequestItem, ResponseItem, UpdateToLatestLedgerRequest, UpdateToLatestLedgerResponse, +}; +use proptest::prelude::*; +use proto_conv::test_helper::assert_protobuf_encode_decode; + +proptest! { + #[test] + fn test_update_to_latest_ledger_request( + request in any::() + ) { + assert_protobuf_encode_decode(&request); + } + + #[test] + fn test_request_item_conversion(item in any::()) { + assert_protobuf_encode_decode(&item); + } +} + +proptest! { + #![proptest_config(ProptestConfig::with_cases(20))] + + #[test] + fn test_response_item(item in any::()) { + assert_protobuf_encode_decode(&item); + } + + #[test] + fn test_update_to_latest_ledger_response( + response in any::() + ) { + assert_protobuf_encode_decode(&response); + } +} diff --git a/types/src/unit_tests/ledger_info_proto_conversion_test.rs b/types/src/unit_tests/ledger_info_proto_conversion_test.rs new file mode 100644 index 0000000000000..4ff73f80e8a7c --- /dev/null +++ b/types/src/unit_tests/ledger_info_proto_conversion_test.rs @@ -0,0 +1,32 @@ +// Copyright (c) The Libra Core Contributors +// SPDX-License-Identifier: Apache-2.0 + +use crate::ledger_info::{LedgerInfo, LedgerInfoWithSignatures}; +use proptest::prelude::*; +use proto_conv::test_helper::assert_protobuf_encode_decode; + +proptest! { + #[test] + fn test_ledger_info(ledger_info in any::()) { + assert_protobuf_encode_decode(&ledger_info); + } + + #[test] + fn test_ledger_info_with_signatures( + ledger_info_with_signatures in any_with::((0..11).into()) + ) { + assert_protobuf_encode_decode(&ledger_info_with_signatures); + } +} + +proptest! { + // generating many key pairs are computationally heavy, limiting number of cases + #![proptest_config(ProptestConfig::with_cases(10))] + #[test] + fn test_ledger_info_with_many_signatures( + // 100 is the number we have in mind in real world, setting 200 to have a good chance of hitting it + ledger_info_with_signatures in any_with::((0..200).into()) + ) { + assert_protobuf_encode_decode(&ledger_info_with_signatures); + } +} diff --git a/types/src/unit_tests/mod.rs b/types/src/unit_tests/mod.rs new file mode 100644 index 0000000000000..951d1f472c068 --- /dev/null +++ b/types/src/unit_tests/mod.rs @@ -0,0 +1,14 @@ +// Copyright (c) The Libra Core Contributors +// SPDX-License-Identifier: Apache-2.0 + +mod access_path_test; +mod address_test; +mod contract_event_proto_conversion_test; +mod get_with_proof_proto_conversion_test; +mod ledger_info_proto_conversion_test; +mod transaction_proto_conversion_test; +mod transaction_test; +mod validator_change_proto_conversion_test; +mod validator_set_test; +mod vm_error_proto_conversion_test; +mod write_set_test; diff --git a/types/src/unit_tests/transaction_proto_conversion_test.rs b/types/src/unit_tests/transaction_proto_conversion_test.rs new file mode 100644 index 0000000000000..779e4f6a7d04a --- /dev/null +++ b/types/src/unit_tests/transaction_proto_conversion_test.rs @@ -0,0 +1,47 @@ +// Copyright (c) The Libra Core Contributors +// SPDX-License-Identifier: Apache-2.0 + +use crate::transaction::*; +use proptest::prelude::*; +use proto_conv::test_helper::assert_protobuf_encode_decode; + +proptest! { + #[test] + fn test_signed_txn(signed_txn in any::()) { + assert_protobuf_encode_decode(&signed_txn); + } + + #[test] + fn test_signed_txn_with_proof(signed_txn_with_proof in any::()) { + assert_protobuf_encode_decode(&signed_txn_with_proof); + } + + #[test] + fn test_raw_txn(raw_txn in any::()) { + assert_protobuf_encode_decode(&raw_txn); + } + + #[test] + fn test_program(program in any::()) { + assert_protobuf_encode_decode(&program); + } + + #[test] + fn test_transaction_info(txn_info in any::()) { + assert_protobuf_encode_decode(&txn_info); + } + + #[test] + fn test_transaction_to_commit(txn_to_commit in any::()) { + assert_protobuf_encode_decode(&txn_to_commit); + } +} + +proptest! { + #![proptest_config(ProptestConfig::with_cases(10))] + + #[test] + fn test_transaction_list_with_proof(txn_list in any::()) { + assert_protobuf_encode_decode(&txn_list); + } +} diff --git a/types/src/unit_tests/transaction_test.rs b/types/src/unit_tests/transaction_test.rs new file mode 100644 index 0000000000000..aa42bc45c8210 --- /dev/null +++ b/types/src/unit_tests/transaction_test.rs @@ -0,0 +1,42 @@ +// Copyright (c) The Libra Core Contributors +// SPDX-License-Identifier: Apache-2.0 + +use crate::{ + account_address::AccountAddress, + transaction::{Program, RawTransaction, SignedTransaction}, +}; +use crypto::{ + signing::{generate_keypair, Signature}, + utils::keypair_strategy, +}; +use proptest::prelude::*; +use proto_conv::{FromProto, IntoProto}; + +#[test] +fn test_signed_transaction_from_proto_invalid_signature() { + let keypair = generate_keypair(); + assert!(SignedTransaction::from_proto( + SignedTransaction::new_for_test( + RawTransaction::new( + AccountAddress::random(), + 0, + Program::new(vec![], vec![], vec![]), + 0, + 0, + std::time::Duration::new(0, 0), + ), + keypair.1, + Signature::from_compact(&[0; 64]).unwrap(), + ) + .into_proto(), + ) + .is_err()); +} + +proptest! { + #[test] + fn test_sig(raw_txn in any::(), (sk1, pk1) in keypair_strategy()) { + let signed_txn = raw_txn.sign(&sk1, pk1).unwrap(); + assert!(signed_txn.verify_signature().is_ok()); + } +} diff --git a/types/src/unit_tests/validator_change_proto_conversion_test.rs b/types/src/unit_tests/validator_change_proto_conversion_test.rs new file mode 100644 index 0000000000000..9a70d2fe34ef9 --- /dev/null +++ b/types/src/unit_tests/validator_change_proto_conversion_test.rs @@ -0,0 +1,17 @@ +// Copyright (c) The Libra Core Contributors +// SPDX-License-Identifier: Apache-2.0 + +use crate::validator_change::ValidatorChangeEventWithProof; +use proptest::prelude::*; +use proto_conv::test_helper::assert_protobuf_encode_decode; + +proptest! { + #![proptest_config(ProptestConfig::with_cases(20))] + + #[test] + fn test_validator_change_event_with_proof_conversion( + change in any::() + ) { + assert_protobuf_encode_decode(&change); + } +} diff --git a/types/src/unit_tests/validator_set_test.rs b/types/src/unit_tests/validator_set_test.rs new file mode 100644 index 0000000000000..4f0aa3d59374b --- /dev/null +++ b/types/src/unit_tests/validator_set_test.rs @@ -0,0 +1,21 @@ +// Copyright (c) The Libra Core Contributors +// SPDX-License-Identifier: Apache-2.0 + +use crate::validator_set::ValidatorSet; +use canonical_serialization::test_helper::assert_canonical_encode_decode; +use proptest::prelude::*; +use proto_conv::test_helper::assert_protobuf_encode_decode; + +proptest! { + #![proptest_config(ProptestConfig::with_cases(20))] + + #[test] + fn test_validator_set_protobuf_conversion(set in any::()) { + assert_protobuf_encode_decode(&set); + } + + #[test] + fn test_validator_set_canonical_serialization(set in any::()) { + assert_canonical_encode_decode(&set); + } +} diff --git a/types/src/unit_tests/vm_error_proto_conversion_test.rs b/types/src/unit_tests/vm_error_proto_conversion_test.rs new file mode 100644 index 0000000000000..1c41761f292df --- /dev/null +++ b/types/src/unit_tests/vm_error_proto_conversion_test.rs @@ -0,0 +1,57 @@ +// Copyright (c) The Libra Core Contributors +// SPDX-License-Identifier: Apache-2.0 + +use crate::vm_error::{ + ArithmeticErrorType, BinaryError, DynamicReferenceErrorType, ExecutionStatus, + VMInvariantViolationError, VMStatus, VMValidationStatus, VMVerificationError, + VMVerificationStatus, +}; +use proptest::prelude::*; +use proto_conv::test_helper::assert_protobuf_encode_decode; + +proptest! { + #[test] + fn vm_validation_status_roundtrip(validation_status in any::()) { + assert_protobuf_encode_decode(&validation_status); + } + + #[test] + fn vm_verification_error_roundtrip(verification_error in any::()) { + assert_protobuf_encode_decode(&verification_error); + } + + #[test] + fn vm_verification_status_roundtrip(verification_status in any::()) { + assert_protobuf_encode_decode(&verification_status); + } + + #[test] + fn vm_invariant_violation_roundtrip(invariant_violation in any::()) { + assert_protobuf_encode_decode(&invariant_violation); + } + + #[test] + fn binary_error_roundtrip(binary_error in any::()) { + assert_protobuf_encode_decode(&binary_error); + } + + #[test] + fn dynamic_reference_error_roundtrip(dynamic_reference in any::()) { + assert_protobuf_encode_decode(&dynamic_reference); + } + + #[test] + fn arithmetic_error_roundtrip(arithmetic_error in any::()) { + assert_protobuf_encode_decode(&arithmetic_error); + } + + #[test] + fn execution_status_roundtrip(execution_status in any::()) { + assert_protobuf_encode_decode(&execution_status); + } + + #[test] + fn test_vm_status(vm_status in any::()) { + assert_protobuf_encode_decode(&vm_status); + } +} diff --git a/types/src/unit_tests/write_set_test.rs b/types/src/unit_tests/write_set_test.rs new file mode 100644 index 0000000000000..f2ae20bc3a36c --- /dev/null +++ b/types/src/unit_tests/write_set_test.rs @@ -0,0 +1,13 @@ +// Copyright (c) The Libra Core Contributors +// SPDX-License-Identifier: Apache-2.0 + +use crate::write_set::WriteSet; +use proptest::prelude::*; +use proto_conv::test_helper::assert_protobuf_encode_decode; + +proptest! { + #[test] + fn write_set_roundtrip(write_set in any::()) { + assert_protobuf_encode_decode(&write_set); + } +} diff --git a/types/src/validator_change.rs b/types/src/validator_change.rs new file mode 100644 index 0000000000000..e70fe030bc693 --- /dev/null +++ b/types/src/validator_change.rs @@ -0,0 +1,15 @@ +// Copyright (c) The Libra Core Contributors +// SPDX-License-Identifier: Apache-2.0 + +#![allow(clippy::unit_arg)] + +use crate::{contract_event::EventWithProof, ledger_info::LedgerInfoWithSignatures}; +use proptest_derive::Arbitrary; +use proto_conv::{FromProto, IntoProto}; + +#[derive(Arbitrary, Clone, Debug, Eq, PartialEq, FromProto, IntoProto)] +#[ProtoType(crate::proto::validator_change::ValidatorChangeEventWithProof)] +pub struct ValidatorChangeEventWithProof { + ledger_info_with_sigs: LedgerInfoWithSignatures, + event_with_proof: EventWithProof, +} diff --git a/types/src/validator_public_keys.rs b/types/src/validator_public_keys.rs new file mode 100644 index 0000000000000..1a5444f766171 --- /dev/null +++ b/types/src/validator_public_keys.rs @@ -0,0 +1,136 @@ +// Copyright (c) The Libra Core Contributors +// SPDX-License-Identifier: Apache-2.0 + +#![allow(clippy::unit_arg)] + +use crate::{ + account_address::AccountAddress, + proto::validator_public_keys::ValidatorPublicKeys as ProtoValidatorPublicKeys, +}; +use canonical_serialization::{ + CanonicalDeserialize, CanonicalDeserializer, CanonicalSerialize, CanonicalSerializer, +}; +use crypto::{x25519::X25519PublicKey, PublicKey}; +use failure::Result; +use proptest_derive::Arbitrary; +use proto_conv::{FromProto, IntoProto}; + +/// After executing a special transaction that sets the validators that should be used for the +/// next epoch, consensus and networking get the new list of validators. Consensus will have a +/// public key to validate signed messages and networking will have a TBD public key for +/// creating secure channels of communication between validators. The validators and their +/// public keys may or may not change between epochs. +#[derive(Arbitrary, Clone, Debug, Eq, PartialEq)] +pub struct ValidatorPublicKeys { + // Hash value of the current public key of the account address + account_address: AccountAddress, + // This key can validate messages sent from this validator + consensus_public_key: PublicKey, + // This key can validate signed messages at the network layer + network_signing_public_key: PublicKey, + // This key establishes the corresponding PrivateKey holder's eligibility to join the p2p + // network + network_identity_public_key: X25519PublicKey, +} + +impl ValidatorPublicKeys { + pub fn new( + account_address: AccountAddress, + consensus_public_key: PublicKey, + network_signing_public_key: PublicKey, + network_identity_public_key: X25519PublicKey, + ) -> Self { + ValidatorPublicKeys { + account_address, + consensus_public_key, + network_signing_public_key, + network_identity_public_key, + } + } + + /// Returns the id of this validator (hash of the current public key of the + /// validator associated account address) + pub fn account_address(&self) -> &AccountAddress { + &self.account_address + } + + /// Returns the key for validating signed messages from this validator + pub fn consensus_public_key(&self) -> &PublicKey { + &self.consensus_public_key + } + + /// Returns the key for validating signed messages at the network layers + pub fn network_signing_public_key(&self) -> &PublicKey { + &self.network_signing_public_key + } + + /// Returns the key that establishes a validator's identity in the p2p network + pub fn network_identity_public_key(&self) -> &X25519PublicKey { + &self.network_identity_public_key + } +} + +impl FromProto for ValidatorPublicKeys { + type ProtoType = ProtoValidatorPublicKeys; + + fn from_proto(object: Self::ProtoType) -> Result { + let account_address = AccountAddress::from_proto(object.get_account_address().to_vec())?; + let consensus_public_key = PublicKey::from_slice(object.get_consensus_public_key())?; + let network_signing_public_key = + PublicKey::from_slice(object.get_network_signing_public_key())?; + let network_identity_public_key = + X25519PublicKey::from_slice(object.get_network_identity_public_key())?; + Ok(Self::new( + account_address, + consensus_public_key, + network_signing_public_key, + network_identity_public_key, + )) + } +} + +impl IntoProto for ValidatorPublicKeys { + type ProtoType = ProtoValidatorPublicKeys; + + fn into_proto(self) -> Self::ProtoType { + let mut proto = Self::ProtoType::new(); + proto.set_account_address(AccountAddress::into_proto(self.account_address)); + proto.set_consensus_public_key(PublicKey::to_slice(&self.consensus_public_key).to_vec()); + proto.set_network_signing_public_key( + PublicKey::to_slice(&self.network_signing_public_key).to_vec(), + ); + proto.set_network_identity_public_key( + X25519PublicKey::to_slice(&self.network_identity_public_key).to_vec(), + ); + proto + } +} + +impl CanonicalSerialize for ValidatorPublicKeys { + fn serialize(&self, serializer: &mut impl CanonicalSerializer) -> Result<()> { + serializer + .encode_struct(&self.account_address)? + .encode_variable_length_bytes(&self.consensus_public_key.to_slice())? + .encode_variable_length_bytes(&self.network_identity_public_key.to_slice())? + .encode_variable_length_bytes(&self.network_signing_public_key.to_slice())?; + Ok(()) + } +} + +impl CanonicalDeserialize for ValidatorPublicKeys { + fn deserialize(deserializer: &mut impl CanonicalDeserializer) -> Result { + let account_address = deserializer.decode_struct::()?; + let concensus_public_key = + PublicKey::from_slice(&deserializer.decode_variable_length_bytes()?)?; + let network_identity_public_key = + X25519PublicKey::from_slice(&deserializer.decode_variable_length_bytes()?)?; + let network_signing_public_key = + PublicKey::from_slice(&deserializer.decode_variable_length_bytes()?)?; + Ok(ValidatorPublicKeys::new( + account_address, + concensus_public_key, + network_signing_public_key, + network_identity_public_key, + )) + } +} diff --git a/types/src/validator_set.rs b/types/src/validator_set.rs new file mode 100644 index 0000000000000..d58446ab8b849 --- /dev/null +++ b/types/src/validator_set.rs @@ -0,0 +1,108 @@ +// Copyright (c) The Libra Core Contributors +// SPDX-License-Identifier: Apache-2.0 + +use crate::{ + access_path::{AccessPath, Accesses}, + account_config::core_code_address, + language_storage::StructTag, + validator_public_keys::ValidatorPublicKeys, +}; +use canonical_serialization::{ + CanonicalDeserialize, CanonicalDeserializer, CanonicalSerialize, CanonicalSerializer, + SimpleDeserializer, +}; +use failure::prelude::*; +use proptest_derive::Arbitrary; +use proto_conv::{FromProto, IntoProto}; +use std::collections::btree_map::BTreeMap; + +pub const VALIDATOR_SET_MODULE_NAME: &str = "ValidatorSet"; +pub const VALIDATOR_SET_STRUCT_NAME: &str = "T"; + +pub fn validator_set_tag() -> StructTag { + StructTag { + name: VALIDATOR_SET_STRUCT_NAME.to_string(), + address: core_code_address(), + module: VALIDATOR_SET_MODULE_NAME.to_string(), + type_params: vec![], + } +} + +pub(crate) fn validator_set_path() -> Vec { + AccessPath::resource_access_vec(&validator_set_tag(), &Accesses::empty()) +} + +#[derive(Arbitrary, Clone, Debug, Eq, PartialEq)] +pub struct ValidatorSet(Vec); + +impl ValidatorSet { + /// Constructs a ValidatorSet resource. + pub fn new(payload: Vec) -> Self { + ValidatorSet(payload) + } + + /// Given an account map (typically from storage) retrieves the validator resource associated. + pub fn make_from(account_map: &BTreeMap, Vec>) -> Result { + let ap = validator_set_path(); + match account_map.get(&ap) { + Some(bytes) => SimpleDeserializer::deserialize(bytes), + None => bail!("No data for {:?}", ap), + } + } + + pub fn payload(&self) -> &[ValidatorPublicKeys] { + &self.0 + } +} + +impl CanonicalSerialize for ValidatorSet { + fn serialize(&self, mut serializer: &mut impl CanonicalSerializer) -> Result<()> { + // TODO: We do not use encode_vec and decode_vec because the VM serializes these + // differently. This will be fixed once collections are supported in the language. + serializer = serializer.encode_u64(self.0.len() as u64)?; + for validator_public_keys in &self.0 { + serializer = serializer.encode_struct(validator_public_keys)?; + } + Ok(()) + } +} + +impl CanonicalDeserialize for ValidatorSet { + fn deserialize(deserializer: &mut impl CanonicalDeserializer) -> Result { + let size = deserializer.decode_u64()?; + let mut payload = vec![]; + for _i in 0..size { + payload.push(deserializer.decode_struct::()?); + } + Ok(ValidatorSet::new(payload)) + } +} + +impl FromProto for ValidatorSet { + type ProtoType = crate::proto::validator_set::ValidatorSet; + + fn from_proto(mut object: Self::ProtoType) -> Result { + Ok(ValidatorSet::new( + object + .take_validator_public_keys() + .into_iter() + .map(ValidatorPublicKeys::from_proto) + .collect::>>()?, + )) + } +} + +impl IntoProto for ValidatorSet { + type ProtoType = crate::proto::validator_set::ValidatorSet; + + fn into_proto(self) -> Self::ProtoType { + let mut out = Self::ProtoType::new(); + out.set_validator_public_keys(protobuf::RepeatedField::from_vec( + self.0 + .into_iter() + .map(ValidatorPublicKeys::into_proto) + .collect(), + )); + out + } +} diff --git a/types/src/validator_signer.rs b/types/src/validator_signer.rs new file mode 100644 index 0000000000000..f8ca8545d3f9c --- /dev/null +++ b/types/src/validator_signer.rs @@ -0,0 +1,154 @@ +// Copyright (c) The Libra Core Contributors +// SPDX-License-Identifier: Apache-2.0 + +use crate::account_address::{AccountAddress, ADDRESS_LENGTH}; +use crypto::{signing, HashValue, PrivateKey, PublicKey, Signature}; +use failure::Error; +use proptest::{prelude::*, sample, strategy::LazyJust}; +use std::convert::TryFrom; + +/// ValidatorSigner associates an author with public and private keys with helpers for signing and +/// validating. This struct can be used for all signing operations including block and network +/// signing, respectively. +#[derive(Debug, Clone)] +pub struct ValidatorSigner { + author: AccountAddress, + public_key: PublicKey, + private_key: PrivateKey, +} + +impl ValidatorSigner { + pub fn new( + account_address: AccountAddress, + public_key: PublicKey, + private_key: PrivateKey, + ) -> Self { + ValidatorSigner { + author: account_address, + public_key, + private_key, + } + } + + /// Generate the genesis block signer information. + pub fn genesis() -> Self { + let (private_key, public_key) = signing::generate_genesis_keypair(); + Self::new(AccountAddress::from(public_key), public_key, private_key) + } + + /// Generate a random set of public and private keys and author information. + pub fn random() -> Self { + let (private_key, public_key) = signing::generate_keypair(); + ValidatorSigner { + author: AccountAddress::from(public_key), + public_key, + private_key, + } + } + + /// For test only - makes signer with nicely looking account address that has specified integer + /// as fist byte, and rest are zeroes + pub fn from_int(num: u8) -> Self { + let mut address = [0; ADDRESS_LENGTH]; + address[0] = num; + let (private_key, public_key) = signing::generate_keypair(); + ValidatorSigner { + author: AccountAddress::try_from(&address[..]).unwrap(), + public_key, + private_key, + } + } + + /// Constructs a signature for `message` using `private_key`. + pub fn sign_message(&self, message: HashValue) -> Result { + signing::sign_message(message, &self.private_key) + } + + /// Checks that `signature` is valid for `message` using `public_key`. + pub fn verify_message(&self, message: HashValue, signature: &Signature) -> Result<(), Error> { + signing::verify_message(message, signature, &self.public_key) + } + + /// Returns the author associated with this signer. + pub fn author(&self) -> AccountAddress { + self.author + } + + /// Returns the public key associated with this signer. + pub fn public_key(&self) -> PublicKey { + self.public_key + } +} + +#[allow(clippy::redundant_closure)] +pub fn arb_keypair() -> impl Strategy { + prop_oneof![ + // The no_shrink here reflects that particular keypair choices out + // of random options are irrelevant. + LazyJust::new(|| signing::generate_keypair()).no_shrink(), + LazyJust::new(|| signing::generate_genesis_keypair()), + ] +} + +prop_compose! { + pub fn signer_strategy(key_pair_strategy: impl Strategy)( + keypair in key_pair_strategy) -> ValidatorSigner { + let (private_key, public_key) = keypair; + let account_address = AccountAddress::from(public_key); + ValidatorSigner::new(account_address, public_key, private_key) + } +} + +#[allow(clippy::redundant_closure)] +pub fn rand_signer() -> impl Strategy { + // random signers warrant no shrinkage. + signer_strategy(arb_keypair()).no_shrink() +} + +#[allow(clippy::redundant_closure)] +pub fn arb_signer() -> impl Strategy { + prop_oneof![rand_signer(), LazyJust::new(|| ValidatorSigner::genesis()),] +} + +fn select_keypair( + key_pairs: Vec<(PrivateKey, PublicKey)>, +) -> impl Strategy { + // no_shrink() => shrinking is not relevant as signers are equivalent. + sample::select(key_pairs).no_shrink() +} + +pub fn mostly_in_keypair_pool( + key_pairs: Vec<(PrivateKey, PublicKey)>, +) -> impl Strategy { + prop::strategy::Union::new_weighted(vec![ + (9, signer_strategy(select_keypair(key_pairs)).boxed()), + (1, arb_signer().boxed()), + ]) +} + +#[cfg(test)] +mod tests { + use crate::{ + account_address::AccountAddress, + validator_signer::{arb_keypair, arb_signer, ValidatorSigner}, + }; + use crypto::HashValue; + use proptest::prelude::*; + + proptest! { + #[test] + fn test_new_signer(keypair in arb_keypair()){ + let (private_key, public_key) = keypair; + let signer = ValidatorSigner::new(AccountAddress::from(public_key), public_key, private_key); + prop_assert_eq!(public_key, signer.public_key()); + } + + #[test] + fn test_signer(signer in arb_signer(), message in HashValue::arbitrary()) { + let signature = signer.sign_message(message).unwrap(); + prop_assert!(signer + .verify_message(message, &signature) + .is_ok()); + } + } +} diff --git a/types/src/validator_verifier.rs b/types/src/validator_verifier.rs new file mode 100644 index 0000000000000..fe60f28a42e59 --- /dev/null +++ b/types/src/validator_verifier.rs @@ -0,0 +1,281 @@ +// Copyright (c) The Libra Core Contributors +// SPDX-License-Identifier: Apache-2.0 + +use crate::account_address::AccountAddress; +use crypto::{signing, HashValue, PublicKey, Signature}; +use failure::Fail; +use std::collections::HashMap; + +/// Errors possible during signature verification. +#[derive(Debug, Fail, PartialEq)] +pub enum VerifyError { + #[fail(display = "Author is unknown")] + /// The author for this signature is unknown by this validator. + UnknownAuthor, + #[fail( + display = "The number of signatures ({}) is smaller than quorum size ({})", + num_of_signatures, quorum_size + )] + TooFewSignatures { + num_of_signatures: usize, + quorum_size: usize, + }, + #[fail( + display = "The number of signatures ({}) is greater than total number of authors ({})", + num_of_signatures, num_of_authors + )] + TooManySignatures { + num_of_signatures: usize, + num_of_authors: usize, + }, + #[fail(display = "Signature is invalid")] + /// The signature does not match the hash. + InvalidSignature, +} + +/// Supports validation of signatures for known authors. This struct can be used for all signature +/// verification operations including block and network signature verification, respectively. +#[derive(Clone)] +pub struct ValidatorVerifier { + author_to_public_keys: HashMap, + quorum_size: usize, +} + +impl ValidatorVerifier { + /// Initialize with a map of author to public key. + pub fn new( + author_to_public_keys: HashMap, + quorum_size: usize, + ) -> Self { + ValidatorVerifier { + author_to_public_keys, + quorum_size, + } + } + + /// Helper method to initialize with a single author and public key. + pub fn new_single(author: AccountAddress, public_key: PublicKey) -> Self { + let mut author_to_public_keys = HashMap::new(); + author_to_public_keys.insert(author, public_key); + ValidatorVerifier { + author_to_public_keys, + quorum_size: 1, + } + } + + /// Helper method to initialize with an empty validator set. + pub fn new_empty() -> Self { + ValidatorVerifier { + author_to_public_keys: HashMap::new(), + quorum_size: 0, + } + } + + /// Verify the correctness of a signature of a hash by a known author. + pub fn verify_signature( + &self, + author: AccountAddress, + hash: HashValue, + signature: &Signature, + ) -> Result<(), VerifyError> { + let public_key = self.author_to_public_keys.get(&author); + match public_key { + None => Err(VerifyError::UnknownAuthor), + Some(public_key) => { + if signing::verify_message(hash, signature, public_key).is_err() { + Err(VerifyError::InvalidSignature) + } else { + Ok(()) + } + } + } + } + + /// This function will successfully return when at least quorum_size signatures of known authors + /// are successfully verified. Also, an aggregated signature is considered invalid if any of the + /// attached signatures is invalid or it does not correspond to a known author. The latter is to + /// prevent malicious users from adding arbitrary content to the signature payload that would go + /// unnoticed. + pub fn verify_aggregated_signature( + &self, + hash: HashValue, + aggregated_signature: &HashMap, + ) -> Result<(), VerifyError> { + let num_of_signatures = aggregated_signature.len(); + if num_of_signatures < self.quorum_size { + return Err(VerifyError::TooFewSignatures { + num_of_signatures, + quorum_size: self.quorum_size, + }); + } + if num_of_signatures > self.len() { + return Err(VerifyError::TooManySignatures { + num_of_signatures, + num_of_authors: self.len(), + }); + } + for (author, signature) in aggregated_signature { + if let Err(err) = self.verify_signature(*author, hash, signature) { + return Err(err); + } + } + Ok(()) + } + + pub fn get_public_key(&self, author: AccountAddress) -> Option { + self.author_to_public_keys.get(&author).cloned() + } + + /// Returns a ordered list of account addresses from smallest to largest. + pub fn get_ordered_account_addresses(&self) -> Vec { + let mut account_addresses: Vec = self + .author_to_public_keys + .keys() + .into_iter() + .cloned() + .collect(); + account_addresses.sort(); + account_addresses + } + + /// Returns the number of authors to be validated. + pub fn len(&self) -> usize { + self.author_to_public_keys.len() + } + + /// Is there at least one author? + pub fn is_empty(&self) -> bool { + self.len() == 0 + } + + /// Returns quorum_size. + pub fn quorum_size(&self) -> usize { + self.quorum_size + } +} + +#[cfg(test)] +mod tests { + use crate::{ + account_address::AccountAddress, + validator_signer::ValidatorSigner, + validator_verifier::{ValidatorVerifier, VerifyError}, + }; + use crypto::{HashValue, PublicKey, Signature}; + use std::collections::HashMap; + + #[test] + fn test_validator() { + let validator_signer = ValidatorSigner::random(); + let random_hash = HashValue::random(); + let signature = validator_signer.sign_message(random_hash).unwrap(); + let validator = + ValidatorVerifier::new_single(validator_signer.author(), validator_signer.public_key()); + assert_eq!( + validator.verify_signature(validator_signer.author(), random_hash, &signature), + Ok(()) + ); + let unknown_validator_signer = ValidatorSigner::random(); + let unknown_signature = unknown_validator_signer.sign_message(random_hash).unwrap(); + assert_eq!( + validator.verify_signature( + unknown_validator_signer.author(), + random_hash, + &unknown_signature + ), + Err(VerifyError::UnknownAuthor) + ); + assert_eq!( + validator.verify_signature(validator_signer.author(), random_hash, &unknown_signature), + Err(VerifyError::InvalidSignature) + ); + } + + #[test] + fn test_quorum_validators() { + // Generate 7 random signers. + let validator_signers: Vec = + (0..7).map(|_| ValidatorSigner::random()).collect(); + let random_hash = HashValue::random(); + + // Create a map from authors to public keys. + let mut author_to_public_key_map: HashMap = HashMap::new(); + for validator in validator_signers.iter() { + author_to_public_key_map.insert(validator.author(), validator.public_key()); + } + + // Create a map from author to signatures. + let mut author_to_signature_map: HashMap = HashMap::new(); + for validator in validator_signers.iter() { + author_to_signature_map.insert( + validator.author(), + validator.sign_message(random_hash).unwrap(), + ); + } + + // Let's assume our verifier needs to satisfy at least 5 signatures from the original 7. + let validator_verifier = ValidatorVerifier::new(author_to_public_key_map, 5); + + // Check against signatures == N; this will pass. + assert_eq!( + validator_verifier.verify_aggregated_signature(random_hash, &author_to_signature_map), + Ok(()) + ); + + // Add an extra unknown signer, signatures > N; this will fail. + let unknown_validator_signer = ValidatorSigner::random(); + let unknown_signature = unknown_validator_signer.sign_message(random_hash).unwrap(); + author_to_signature_map.insert(unknown_validator_signer.author(), unknown_signature); + assert_eq!( + validator_verifier.verify_aggregated_signature(random_hash, &author_to_signature_map), + Err(VerifyError::TooManySignatures { + num_of_signatures: 8, + num_of_authors: 7 + }) + ); + + // Add 5 valid signers only (quorum threshold is met); this will pass. + author_to_signature_map.clear(); + for validator in validator_signers.iter().take(5) { + author_to_signature_map.insert( + validator.author(), + validator.sign_message(random_hash).unwrap(), + ); + } + assert_eq!( + validator_verifier.verify_aggregated_signature(random_hash, &author_to_signature_map), + Ok(()) + ); + + // Add an unknown signer, but quorum is satisfied and signatures <= N; this will fail as we + // don't tolerate invalid signatures. + author_to_signature_map.insert(unknown_validator_signer.author(), unknown_signature); + assert_eq!( + validator_verifier.verify_aggregated_signature(random_hash, &author_to_signature_map), + Err(VerifyError::UnknownAuthor) + ); + + // Add 4 valid signers only (quorum threshold is NOT met); this will fail. + author_to_signature_map.clear(); + for validator in validator_signers.iter().take(4) { + author_to_signature_map.insert( + validator.author(), + validator.sign_message(random_hash).unwrap(), + ); + } + assert_eq!( + validator_verifier.verify_aggregated_signature(random_hash, &author_to_signature_map), + Err(VerifyError::TooFewSignatures { + num_of_signatures: 4, + quorum_size: 5 + }) + ); + + // Add an unknown signer, we have 5 signers, but one of them is invalid; this will fail. + author_to_signature_map.insert(unknown_validator_signer.author(), unknown_signature); + assert_eq!( + validator_verifier.verify_aggregated_signature(random_hash, &author_to_signature_map), + Err(VerifyError::UnknownAuthor) + ); + } +} diff --git a/types/src/vm_error.rs b/types/src/vm_error.rs new file mode 100644 index 0000000000000..064252db028d9 --- /dev/null +++ b/types/src/vm_error.rs @@ -0,0 +1,1117 @@ +// Copyright (c) The Libra Core Contributors +// SPDX-License-Identifier: Apache-2.0 + +#![allow(clippy::unit_arg)] + +use failure::prelude::*; +use proptest::{collection::vec, prelude::*}; +use proptest_derive::Arbitrary; +use proto_conv::{FromProto, IntoProto}; + +// We want conversions here so that we don't need to be dealing with the unknown default values +// that we want in the protobuf. +#[derive(Arbitrary, Clone, PartialEq, Eq, Debug, Hash)] +#[proptest(no_params)] +pub enum VMValidationStatus { + InvalidSignature, + InvalidAuthKey, + SequenceNumberTooOld, + SequenceNumberTooNew, + InsufficientBalanceForTransactionFee, + TransactionExpired, + SendingAccountDoesNotExist(String), + RejectedWriteSet, + InvalidWriteSet, + ExceededMaxTransactionSize(String), + UnknownScript, + UnknownModule, + MaxGasUnitsExceedsMaxGasUnitsBound(String), + MaxGasUnitsBelowMinTransactionGasUnits(String), + GasUnitPriceBelowMinBound(String), + GasUnitPriceAboveMaxBound(String), +} + +// TODO: Add string parameters to all the other types as well +#[derive(Arbitrary, Clone, PartialEq, Eq, Debug, Hash)] +#[proptest(no_params)] +pub enum VMVerificationError { + IndexOutOfBounds(String), + RangeOutOfBounds(String), + NoModuleHandles(String), + ModuleAddressDoesNotMatchSender(String), + InvalidSignatureToken(String), + InvalidFieldDefReference(String), + RecursiveStructDefinition(String), + InvalidResourceField(String), + InvalidFallThrough(String), + JoinFailure(String), + NegativeStackSizeWithinBlock(String), + UnbalancedStack(String), + InvalidMainFunctionSignature(String), + DuplicateElement(String), + InvalidModuleHandle(String), + UnimplementedHandle(String), + InconsistentFields(String), + UnusedFields(String), + LookupFailed(String), + VisibilityMismatch(String), + TypeResolutionFailure(String), + TypeMismatch(String), + MissingDependency(String), + PopReferenceError(String), + PopResourceError(String), + ReleaseRefTypeMismatchError(String), + BrTypeMismatchError(String), + AssertTypeMismatchError(String), + StLocTypeMismatchError(String), + StLocUnsafeToDestroyError(String), + RetUnsafeToDestroyError(String), + RetTypeMismatchError(String), + FreezeRefTypeMismatchError(String), + FreezeRefExistsMutableBorrowError(String), + BorrowFieldTypeMismatchError(String), + BorrowFieldBadFieldError(String), + BorrowFieldExistsMutableBorrowError(String), + CopyLocUnavailableError(String), + CopyLocResourceError(String), + CopyLocExistsBorrowError(String), + MoveLocUnavailableError(String), + MoveLocExistsBorrowError(String), + BorrowLocReferenceError(String), + BorrowLocUnavailableError(String), + BorrowLocExistsBorrowError(String), + CallTypeMismatchError(String), + CallBorrowedMutableReferenceError(String), + PackTypeMismatchError(String), + UnpackTypeMismatchError(String), + ReadRefTypeMismatchError(String), + ReadRefResourceError(String), + ReadRefExistsMutableBorrowError(String), + WriteRefTypeMismatchError(String), + WriteRefResourceError(String), + WriteRefExistsBorrowError(String), + WriteRefNoMutableReferenceError(String), + IntegerOpTypeMismatchError(String), + BooleanOpTypeMismatchError(String), + EqualityOpTypeMismatchError(String), + ExistsResourceTypeMismatchError(String), + BorrowGlobalTypeMismatchError(String), + BorrowGlobalNoResourceError(String), + MoveFromTypeMismatchError(String), + MoveFromNoResourceError(String), + MoveToSenderTypeMismatchError(String), + MoveToSenderNoResourceError(String), + CreateAccountTypeMismatchError(String), +} + +#[derive(Arbitrary, Clone, PartialEq, Eq, Debug, Hash)] +#[proptest(no_params)] +pub enum VMVerificationStatus { + /// Verification error in a transaction script. + Script(VMVerificationError), + /// Verification error in a module -- the first element is the index of the module with the + /// error. + Module(u16, VMVerificationError), +} + +#[derive(Arbitrary, Clone, Copy, PartialEq, Eq, Debug, Hash)] +pub enum VMInvariantViolationError { + OutOfBoundsIndex, + OutOfBoundsRange, + EmptyValueStack, + EmptyCallStack, + PCOverflow, + LinkerError, + LocalReferenceError, + StorageError, +} + +#[derive(Arbitrary, Clone, Copy, PartialEq, Eq, Debug, Hash)] +pub enum BinaryError { + Malformed, + BadMagic, + UnknownVersion, + UnknownTableType, + UnknownSignatureType, + UnknownSerializedType, + UnknownOpcode, + BadHeaderTable, + UnexpectedSignatureType, + DuplicateTable, +} + +#[derive(Arbitrary, Clone, Copy, PartialEq, Eq, Debug, Hash)] +pub enum DynamicReferenceErrorType { + MoveOfBorrowedResource, + GlobalRefAlreadyReleased, + MissingReleaseRef, + GlobalAlreadyBorrowed, +} + +#[derive(Arbitrary, Clone, Copy, PartialEq, Eq, Debug, Hash)] +pub enum ArithmeticErrorType { + Underflow, + Overflow, + DivisionByZero, +} + +#[derive(Arbitrary, Clone, Copy, PartialEq, Eq, Debug, Hash)] +pub enum ExecutionStatus { + Executed, + OutOfGas, + ResourceDoesNotExist, + ResourceAlreadyExists, + EvictedAccountAccess, + AccountAddressAlreadyExists, + TypeError, + MissingData, + DataFormatError, + InvalidData, + RemoteDataError, + CannotWriteExistingResource, + ValueSerializationError, + ValueDeserializationError, + AssertionFailure(u64), + ArithmeticError(ArithmeticErrorType), + DynamicReferenceError(DynamicReferenceErrorType), + DuplicateModuleName, +} + +#[derive(Arbitrary, Clone, PartialEq, Eq, Debug, Hash)] +#[proptest(no_params)] +pub enum VMStatus { + Validation(VMValidationStatus), + InvariantViolation(VMInvariantViolationError), + Deserialization(BinaryError), + Execution(ExecutionStatus), + // As of version 0.9.3, proptest's union (enum) strategies are quadratic time in the number of + // variants: https://github.com/AltSysrq/proptest/issues/143 + // + // In particular, if a variant is picked, so are variants for each previous variant (which + // follows enum definition order for proptest-derive). + // + // VerificationStatus is by far the most expensive enum variant to generate since it has a + // vector of statuses. If it were listed out earlier, it would be generated in a lot more + // cases than necessary. Move VerificationStatus to the end so that the cost of generating + // the value tree is only paid when VerificationStatus is generated. + // + // Also reduce the size of the vector because VMVerificationStatus is slow to generate + // for the exact same reason. + #[proptest( + strategy = "vec(any::(), 0..16).prop_map(VMStatus::Verification)" + )] + Verification(Vec), +} + +#[derive(Debug, Fail, Eq, PartialEq)] +pub enum DecodingError { + #[fail(display = "Module index {} greater than max possible value 65535", _0)] + ModuleIndexTooBig(u32), + #[fail(display = "Unknown Validation Status Encountered")] + UnknownValidationStatusEncountered, + #[fail(display = "Unknown Verification Error Encountered")] + UnknownVerificationErrorEncountered, + #[fail(display = "Unknown Invariant Violation Error Encountered")] + UnknownInvariantViolationErrorEncountered, + #[fail(display = "Unknown Transaction Binary Decoding Error Encountered")] + UnknownBinaryErrorEncountered, + #[fail(display = "Unknown Reference Error Type Encountered")] + UnknownDynamicReferenceErrorTypeEncountered, + #[fail(display = "Unknown Arithmetic Error Type Encountered")] + UnknownArithmeticErrorTypeEncountered, + #[fail(display = "Unknown Runtime Status Encountered")] + UnknownRuntimeStatusEncountered, + #[fail(display = "Unknown/Invalid VM Status Encountered")] + InvalidVMStatusEncountered, +} + +//*********************************** +// Decoding/Encoding to Protobuffers +//*********************************** +impl IntoProto for VMValidationStatus { + type ProtoType = crate::proto::vm_errors::VMValidationStatus; + + fn into_proto(self) -> Self::ProtoType { + use crate::proto::vm_errors::VMValidationStatusCode as ProtoCode; + let mut validation_status = Self::ProtoType::new(); + validation_status.set_message("none".to_string()); + match self { + VMValidationStatus::InvalidSignature => { + validation_status.set_code(ProtoCode::InvalidSignature) + } + VMValidationStatus::InvalidAuthKey => { + validation_status.set_code(ProtoCode::InvalidAuthKey) + } + VMValidationStatus::SequenceNumberTooOld => { + validation_status.set_code(ProtoCode::SequenceNumberTooOld) + } + VMValidationStatus::SequenceNumberTooNew => { + validation_status.set_code(ProtoCode::SequenceNumberTooNew) + } + VMValidationStatus::InsufficientBalanceForTransactionFee => { + validation_status.set_code(ProtoCode::InsufficientBalanceForTransactionFee) + } + VMValidationStatus::TransactionExpired => { + validation_status.set_code(ProtoCode::TransactionExpired) + } + VMValidationStatus::SendingAccountDoesNotExist(msg) => { + validation_status.set_message(msg); + validation_status.set_code(ProtoCode::SendingAccountDoesNotExist) + } + VMValidationStatus::RejectedWriteSet => { + validation_status.set_code(ProtoCode::RejectedWriteSet) + } + VMValidationStatus::InvalidWriteSet => { + validation_status.set_code(ProtoCode::InvalidWriteSet) + } + VMValidationStatus::ExceededMaxTransactionSize(msg) => { + validation_status.set_message(msg); + validation_status.set_code(ProtoCode::ExceededMaxTransactionSize) + } + VMValidationStatus::UnknownScript => { + validation_status.set_code(ProtoCode::UnknownScript) + } + VMValidationStatus::UnknownModule => { + validation_status.set_code(ProtoCode::UnknownModule) + } + VMValidationStatus::MaxGasUnitsExceedsMaxGasUnitsBound(msg) => { + validation_status.set_message(msg); + validation_status.set_code(ProtoCode::MaxGasUnitsExceedsMaxGasUnitsBound) + } + VMValidationStatus::MaxGasUnitsBelowMinTransactionGasUnits(msg) => { + validation_status.set_message(msg); + validation_status.set_code(ProtoCode::MaxGasUnitsBelowMinTransactionGasUnits) + } + VMValidationStatus::GasUnitPriceBelowMinBound(msg) => { + validation_status.set_message(msg); + validation_status.set_code(ProtoCode::GasUnitPriceBelowMinBound) + } + VMValidationStatus::GasUnitPriceAboveMaxBound(msg) => { + validation_status.set_message(msg); + validation_status.set_code(ProtoCode::GasUnitPriceAboveMaxBound) + } + } + validation_status + } +} + +impl FromProto for VMValidationStatus { + type ProtoType = crate::proto::vm_errors::VMValidationStatus; + + fn from_proto(mut proto_validation_status: Self::ProtoType) -> Result { + use crate::proto::vm_errors::VMValidationStatusCode as ProtoStatus; + match proto_validation_status.get_code() { + ProtoStatus::InvalidSignature => Ok(VMValidationStatus::InvalidSignature), + ProtoStatus::InvalidAuthKey => Ok(VMValidationStatus::InvalidAuthKey), + ProtoStatus::SequenceNumberTooOld => Ok(VMValidationStatus::SequenceNumberTooOld), + ProtoStatus::SequenceNumberTooNew => Ok(VMValidationStatus::SequenceNumberTooNew), + ProtoStatus::InsufficientBalanceForTransactionFee => { + Ok(VMValidationStatus::InsufficientBalanceForTransactionFee) + } + ProtoStatus::TransactionExpired => Ok(VMValidationStatus::TransactionExpired), + ProtoStatus::SendingAccountDoesNotExist => { + let msg = proto_validation_status.take_message(); + Ok(VMValidationStatus::SendingAccountDoesNotExist(msg)) + } + ProtoStatus::RejectedWriteSet => Ok(VMValidationStatus::RejectedWriteSet), + ProtoStatus::InvalidWriteSet => Ok(VMValidationStatus::InvalidWriteSet), + ProtoStatus::ExceededMaxTransactionSize => { + let msg = proto_validation_status.take_message(); + Ok(VMValidationStatus::ExceededMaxTransactionSize(msg)) + } + ProtoStatus::UnknownScript => Ok(VMValidationStatus::UnknownScript), + ProtoStatus::UnknownModule => Ok(VMValidationStatus::UnknownModule), + ProtoStatus::MaxGasUnitsExceedsMaxGasUnitsBound => { + let msg = proto_validation_status.take_message(); + Ok(VMValidationStatus::MaxGasUnitsExceedsMaxGasUnitsBound(msg)) + } + ProtoStatus::MaxGasUnitsBelowMinTransactionGasUnits => { + let msg = proto_validation_status.take_message(); + Ok(VMValidationStatus::MaxGasUnitsBelowMinTransactionGasUnits( + msg, + )) + } + ProtoStatus::GasUnitPriceBelowMinBound => { + let msg = proto_validation_status.take_message(); + Ok(VMValidationStatus::GasUnitPriceBelowMinBound(msg)) + } + ProtoStatus::GasUnitPriceAboveMaxBound => { + let msg = proto_validation_status.take_message(); + Ok(VMValidationStatus::GasUnitPriceAboveMaxBound(msg)) + } + ProtoStatus::UnknownValidationStatus => { + bail_err!(DecodingError::UnknownValidationStatusEncountered) + } + } + } +} + +impl IntoProto for VMVerificationError { + type ProtoType = (crate::proto::vm_errors::VMVerificationErrorKind, String); + + fn into_proto(self) -> Self::ProtoType { + use crate::proto::vm_errors::VMVerificationErrorKind as ProtoKind; + match self { + VMVerificationError::IndexOutOfBounds(message) => { + (ProtoKind::IndexOutOfBounds, message) + } + VMVerificationError::RangeOutOfBounds(message) => { + (ProtoKind::RangeOutOfBounds, message) + } + VMVerificationError::NoModuleHandles(message) => (ProtoKind::NoModuleHandles, message), + VMVerificationError::ModuleAddressDoesNotMatchSender(message) => { + (ProtoKind::ModuleAddressDoesNotMatchSender, message) + } + VMVerificationError::InvalidSignatureToken(message) => { + (ProtoKind::InvalidSignatureToken, message) + } + VMVerificationError::InvalidFieldDefReference(message) => { + (ProtoKind::InvalidFieldDefReference, message) + } + VMVerificationError::RecursiveStructDefinition(message) => { + (ProtoKind::RecursiveStructDefinition, message) + } + VMVerificationError::InvalidResourceField(message) => { + (ProtoKind::InvalidResourceField, message) + } + VMVerificationError::InvalidFallThrough(message) => { + (ProtoKind::InvalidFallThrough, message) + } + VMVerificationError::JoinFailure(message) => (ProtoKind::JoinFailure, message), + VMVerificationError::NegativeStackSizeWithinBlock(message) => { + (ProtoKind::NegativeStackSizeWithinBlock, message) + } + VMVerificationError::UnbalancedStack(message) => (ProtoKind::UnbalancedStack, message), + VMVerificationError::InvalidMainFunctionSignature(message) => { + (ProtoKind::InvalidMainFunctionSignature, message) + } + VMVerificationError::DuplicateElement(message) => { + (ProtoKind::DuplicateElement, message) + } + VMVerificationError::InvalidModuleHandle(message) => { + (ProtoKind::InvalidModuleHandle, message) + } + VMVerificationError::UnimplementedHandle(message) => { + (ProtoKind::UnimplementedHandle, message) + } + VMVerificationError::InconsistentFields(message) => { + (ProtoKind::InconsistentFields, message) + } + VMVerificationError::UnusedFields(message) => (ProtoKind::UnusedFields, message), + VMVerificationError::LookupFailed(message) => (ProtoKind::LookupFailed, message), + VMVerificationError::VisibilityMismatch(message) => { + (ProtoKind::VisibilityMismatch, message) + } + VMVerificationError::TypeResolutionFailure(message) => { + (ProtoKind::TypeResolutionFailure, message) + } + VMVerificationError::TypeMismatch(message) => (ProtoKind::TypeMismatch, message), + VMVerificationError::MissingDependency(message) => { + (ProtoKind::MissingDependency, message) + } + VMVerificationError::PopReferenceError(message) => { + (ProtoKind::PopReferenceError, message) + } + VMVerificationError::PopResourceError(message) => { + (ProtoKind::PopResourceError, message) + } + VMVerificationError::ReleaseRefTypeMismatchError(message) => { + (ProtoKind::ReleaseRefTypeMismatchError, message) + } + VMVerificationError::BrTypeMismatchError(message) => { + (ProtoKind::BrTypeMismatchError, message) + } + VMVerificationError::AssertTypeMismatchError(message) => { + (ProtoKind::AssertTypeMismatchError, message) + } + VMVerificationError::StLocTypeMismatchError(message) => { + (ProtoKind::StLocTypeMismatchError, message) + } + VMVerificationError::StLocUnsafeToDestroyError(message) => { + (ProtoKind::StLocUnsafeToDestroyError, message) + } + VMVerificationError::RetUnsafeToDestroyError(message) => { + (ProtoKind::RetUnsafeToDestroyError, message) + } + VMVerificationError::RetTypeMismatchError(message) => { + (ProtoKind::RetTypeMismatchError, message) + } + VMVerificationError::FreezeRefTypeMismatchError(message) => { + (ProtoKind::FreezeRefTypeMismatchError, message) + } + VMVerificationError::FreezeRefExistsMutableBorrowError(message) => { + (ProtoKind::FreezeRefExistsMutableBorrowError, message) + } + VMVerificationError::BorrowFieldTypeMismatchError(message) => { + (ProtoKind::BorrowFieldTypeMismatchError, message) + } + VMVerificationError::BorrowFieldBadFieldError(message) => { + (ProtoKind::BorrowFieldBadFieldError, message) + } + VMVerificationError::BorrowFieldExistsMutableBorrowError(message) => { + (ProtoKind::BorrowFieldExistsMutableBorrowError, message) + } + VMVerificationError::CopyLocUnavailableError(message) => { + (ProtoKind::CopyLocUnavailableError, message) + } + VMVerificationError::CopyLocResourceError(message) => { + (ProtoKind::CopyLocResourceError, message) + } + VMVerificationError::CopyLocExistsBorrowError(message) => { + (ProtoKind::CopyLocExistsBorrowError, message) + } + VMVerificationError::MoveLocUnavailableError(message) => { + (ProtoKind::MoveLocUnavailableError, message) + } + VMVerificationError::MoveLocExistsBorrowError(message) => { + (ProtoKind::MoveLocExistsBorrowError, message) + } + VMVerificationError::BorrowLocReferenceError(message) => { + (ProtoKind::BorrowLocReferenceError, message) + } + VMVerificationError::BorrowLocUnavailableError(message) => { + (ProtoKind::BorrowLocUnavailableError, message) + } + VMVerificationError::BorrowLocExistsBorrowError(message) => { + (ProtoKind::BorrowLocExistsBorrowError, message) + } + VMVerificationError::CallTypeMismatchError(message) => { + (ProtoKind::CallTypeMismatchError, message) + } + VMVerificationError::CallBorrowedMutableReferenceError(message) => { + (ProtoKind::CallBorrowedMutableReferenceError, message) + } + VMVerificationError::PackTypeMismatchError(message) => { + (ProtoKind::PackTypeMismatchError, message) + } + VMVerificationError::UnpackTypeMismatchError(message) => { + (ProtoKind::UnpackTypeMismatchError, message) + } + VMVerificationError::ReadRefTypeMismatchError(message) => { + (ProtoKind::ReadRefTypeMismatchError, message) + } + VMVerificationError::ReadRefResourceError(message) => { + (ProtoKind::ReadRefResourceError, message) + } + VMVerificationError::ReadRefExistsMutableBorrowError(message) => { + (ProtoKind::ReadRefExistsMutableBorrowError, message) + } + VMVerificationError::WriteRefTypeMismatchError(message) => { + (ProtoKind::WriteRefTypeMismatchError, message) + } + VMVerificationError::WriteRefResourceError(message) => { + (ProtoKind::WriteRefResourceError, message) + } + VMVerificationError::WriteRefExistsBorrowError(message) => { + (ProtoKind::WriteRefExistsBorrowError, message) + } + VMVerificationError::WriteRefNoMutableReferenceError(message) => { + (ProtoKind::WriteRefNoMutableReferenceError, message) + } + VMVerificationError::IntegerOpTypeMismatchError(message) => { + (ProtoKind::IntegerOpTypeMismatchError, message) + } + VMVerificationError::BooleanOpTypeMismatchError(message) => { + (ProtoKind::BooleanOpTypeMismatchError, message) + } + VMVerificationError::EqualityOpTypeMismatchError(message) => { + (ProtoKind::EqualityOpTypeMismatchError, message) + } + VMVerificationError::ExistsResourceTypeMismatchError(message) => { + (ProtoKind::ExistsResourceTypeMismatchError, message) + } + VMVerificationError::BorrowGlobalTypeMismatchError(message) => { + (ProtoKind::BorrowGlobalTypeMismatchError, message) + } + VMVerificationError::BorrowGlobalNoResourceError(message) => { + (ProtoKind::BorrowGlobalNoResourceError, message) + } + VMVerificationError::MoveFromTypeMismatchError(message) => { + (ProtoKind::MoveFromTypeMismatchError, message) + } + VMVerificationError::MoveFromNoResourceError(message) => { + (ProtoKind::MoveFromNoResourceError, message) + } + VMVerificationError::MoveToSenderTypeMismatchError(message) => { + (ProtoKind::MoveToSenderTypeMismatchError, message) + } + VMVerificationError::MoveToSenderNoResourceError(message) => { + (ProtoKind::MoveToSenderNoResourceError, message) + } + VMVerificationError::CreateAccountTypeMismatchError(message) => { + (ProtoKind::CreateAccountTypeMismatchError, message) + } + } + } +} + +impl FromProto for VMVerificationError { + type ProtoType = (crate::proto::vm_errors::VMVerificationErrorKind, String); + + fn from_proto(proto_verification_error: Self::ProtoType) -> Result { + use crate::proto::vm_errors::VMVerificationErrorKind as ProtoKind; + + let (kind, message) = proto_verification_error; + match kind { + ProtoKind::IndexOutOfBounds => Ok(VMVerificationError::IndexOutOfBounds(message)), + ProtoKind::RangeOutOfBounds => Ok(VMVerificationError::RangeOutOfBounds(message)), + ProtoKind::NoModuleHandles => Ok(VMVerificationError::NoModuleHandles(message)), + ProtoKind::ModuleAddressDoesNotMatchSender => Ok( + VMVerificationError::ModuleAddressDoesNotMatchSender(message), + ), + ProtoKind::InvalidSignatureToken => { + Ok(VMVerificationError::InvalidSignatureToken(message)) + } + ProtoKind::InvalidFieldDefReference => { + Ok(VMVerificationError::InvalidFieldDefReference(message)) + } + ProtoKind::RecursiveStructDefinition => { + Ok(VMVerificationError::RecursiveStructDefinition(message)) + } + ProtoKind::InvalidResourceField => { + Ok(VMVerificationError::InvalidResourceField(message)) + } + ProtoKind::InvalidFallThrough => Ok(VMVerificationError::InvalidFallThrough(message)), + ProtoKind::JoinFailure => Ok(VMVerificationError::JoinFailure(message)), + ProtoKind::NegativeStackSizeWithinBlock => { + Ok(VMVerificationError::NegativeStackSizeWithinBlock(message)) + } + ProtoKind::UnbalancedStack => Ok(VMVerificationError::UnbalancedStack(message)), + ProtoKind::InvalidMainFunctionSignature => { + Ok(VMVerificationError::InvalidMainFunctionSignature(message)) + } + ProtoKind::DuplicateElement => Ok(VMVerificationError::DuplicateElement(message)), + ProtoKind::InvalidModuleHandle => Ok(VMVerificationError::InvalidModuleHandle(message)), + ProtoKind::UnimplementedHandle => Ok(VMVerificationError::UnimplementedHandle(message)), + ProtoKind::InconsistentFields => Ok(VMVerificationError::InconsistentFields(message)), + ProtoKind::UnusedFields => Ok(VMVerificationError::UnusedFields(message)), + ProtoKind::LookupFailed => Ok(VMVerificationError::LookupFailed(message)), + ProtoKind::VisibilityMismatch => Ok(VMVerificationError::VisibilityMismatch(message)), + ProtoKind::TypeResolutionFailure => { + Ok(VMVerificationError::TypeResolutionFailure(message)) + } + ProtoKind::TypeMismatch => Ok(VMVerificationError::TypeMismatch(message)), + ProtoKind::MissingDependency => Ok(VMVerificationError::MissingDependency(message)), + ProtoKind::PopReferenceError => Ok(VMVerificationError::PopReferenceError(message)), + ProtoKind::PopResourceError => Ok(VMVerificationError::PopResourceError(message)), + ProtoKind::ReleaseRefTypeMismatchError => { + Ok(VMVerificationError::ReleaseRefTypeMismatchError(message)) + } + ProtoKind::BrTypeMismatchError => Ok(VMVerificationError::BrTypeMismatchError(message)), + ProtoKind::AssertTypeMismatchError => { + Ok(VMVerificationError::AssertTypeMismatchError(message)) + } + ProtoKind::StLocTypeMismatchError => { + Ok(VMVerificationError::StLocTypeMismatchError(message)) + } + ProtoKind::StLocUnsafeToDestroyError => { + Ok(VMVerificationError::StLocUnsafeToDestroyError(message)) + } + ProtoKind::RetUnsafeToDestroyError => { + Ok(VMVerificationError::RetUnsafeToDestroyError(message)) + } + ProtoKind::RetTypeMismatchError => { + Ok(VMVerificationError::RetTypeMismatchError(message)) + } + ProtoKind::FreezeRefTypeMismatchError => { + Ok(VMVerificationError::FreezeRefTypeMismatchError(message)) + } + ProtoKind::FreezeRefExistsMutableBorrowError => Ok( + VMVerificationError::FreezeRefExistsMutableBorrowError(message), + ), + ProtoKind::BorrowFieldTypeMismatchError => { + Ok(VMVerificationError::BorrowFieldTypeMismatchError(message)) + } + ProtoKind::BorrowFieldBadFieldError => { + Ok(VMVerificationError::BorrowFieldBadFieldError(message)) + } + ProtoKind::BorrowFieldExistsMutableBorrowError => Ok( + VMVerificationError::BorrowFieldExistsMutableBorrowError(message), + ), + ProtoKind::CopyLocUnavailableError => { + Ok(VMVerificationError::CopyLocUnavailableError(message)) + } + ProtoKind::CopyLocResourceError => { + Ok(VMVerificationError::CopyLocResourceError(message)) + } + ProtoKind::CopyLocExistsBorrowError => { + Ok(VMVerificationError::CopyLocExistsBorrowError(message)) + } + ProtoKind::MoveLocUnavailableError => { + Ok(VMVerificationError::MoveLocUnavailableError(message)) + } + ProtoKind::MoveLocExistsBorrowError => { + Ok(VMVerificationError::MoveLocExistsBorrowError(message)) + } + ProtoKind::BorrowLocReferenceError => { + Ok(VMVerificationError::BorrowLocReferenceError(message)) + } + ProtoKind::BorrowLocUnavailableError => { + Ok(VMVerificationError::BorrowLocUnavailableError(message)) + } + ProtoKind::BorrowLocExistsBorrowError => { + Ok(VMVerificationError::BorrowLocExistsBorrowError(message)) + } + ProtoKind::CallTypeMismatchError => { + Ok(VMVerificationError::CallTypeMismatchError(message)) + } + ProtoKind::CallBorrowedMutableReferenceError => Ok( + VMVerificationError::CallBorrowedMutableReferenceError(message), + ), + ProtoKind::PackTypeMismatchError => { + Ok(VMVerificationError::PackTypeMismatchError(message)) + } + ProtoKind::UnpackTypeMismatchError => { + Ok(VMVerificationError::UnpackTypeMismatchError(message)) + } + ProtoKind::ReadRefTypeMismatchError => { + Ok(VMVerificationError::ReadRefTypeMismatchError(message)) + } + ProtoKind::ReadRefResourceError => { + Ok(VMVerificationError::ReadRefResourceError(message)) + } + ProtoKind::ReadRefExistsMutableBorrowError => Ok( + VMVerificationError::ReadRefExistsMutableBorrowError(message), + ), + ProtoKind::WriteRefTypeMismatchError => { + Ok(VMVerificationError::WriteRefTypeMismatchError(message)) + } + ProtoKind::WriteRefResourceError => { + Ok(VMVerificationError::WriteRefResourceError(message)) + } + ProtoKind::WriteRefExistsBorrowError => { + Ok(VMVerificationError::WriteRefExistsBorrowError(message)) + } + ProtoKind::WriteRefNoMutableReferenceError => Ok( + VMVerificationError::WriteRefNoMutableReferenceError(message), + ), + ProtoKind::IntegerOpTypeMismatchError => { + Ok(VMVerificationError::IntegerOpTypeMismatchError(message)) + } + ProtoKind::BooleanOpTypeMismatchError => { + Ok(VMVerificationError::BooleanOpTypeMismatchError(message)) + } + ProtoKind::EqualityOpTypeMismatchError => { + Ok(VMVerificationError::EqualityOpTypeMismatchError(message)) + } + ProtoKind::ExistsResourceTypeMismatchError => Ok( + VMVerificationError::ExistsResourceTypeMismatchError(message), + ), + ProtoKind::BorrowGlobalTypeMismatchError => { + Ok(VMVerificationError::BorrowGlobalTypeMismatchError(message)) + } + ProtoKind::BorrowGlobalNoResourceError => { + Ok(VMVerificationError::BorrowGlobalNoResourceError(message)) + } + ProtoKind::MoveFromTypeMismatchError => { + Ok(VMVerificationError::MoveFromTypeMismatchError(message)) + } + ProtoKind::MoveFromNoResourceError => { + Ok(VMVerificationError::MoveFromNoResourceError(message)) + } + ProtoKind::MoveToSenderTypeMismatchError => { + Ok(VMVerificationError::MoveToSenderTypeMismatchError(message)) + } + ProtoKind::MoveToSenderNoResourceError => { + Ok(VMVerificationError::MoveToSenderNoResourceError(message)) + } + ProtoKind::CreateAccountTypeMismatchError => { + Ok(VMVerificationError::CreateAccountTypeMismatchError(message)) + } + ProtoKind::UnknownVerificationError => { + bail_err!(DecodingError::UnknownVerificationErrorEncountered) + } + } + } +} + +impl IntoProto for VMVerificationStatus { + type ProtoType = crate::proto::vm_errors::VMVerificationStatus; + + fn into_proto(self) -> Self::ProtoType { + use crate::proto::vm_errors::VMVerificationStatus_StatusKind as ProtoStatusKind; + + let mut proto_status = Self::ProtoType::new(); + + let (kind, message) = match self { + VMVerificationStatus::Script(error) => { + proto_status.set_status_kind(ProtoStatusKind::SCRIPT); + error.into_proto() + } + VMVerificationStatus::Module(module_idx, error) => { + proto_status.set_status_kind(ProtoStatusKind::MODULE); + proto_status.set_module_idx(u32::from(module_idx)); + error.into_proto() + } + }; + proto_status.set_error_kind(kind); + proto_status.set_message(message); + proto_status + } +} + +impl FromProto for VMVerificationStatus { + type ProtoType = crate::proto::vm_errors::VMVerificationStatus; + + fn from_proto(mut proto_status: Self::ProtoType) -> Result { + use crate::proto::vm_errors::VMVerificationStatus_StatusKind as ProtoStatusKind; + + let err = VMVerificationError::from_proto(( + proto_status.get_error_kind(), + proto_status.take_message(), + ))?; + + match proto_status.get_status_kind() { + ProtoStatusKind::SCRIPT => Ok(VMVerificationStatus::Script(err)), + ProtoStatusKind::MODULE => { + let module_idx = proto_status.get_module_idx(); + if module_idx > u32::from(u16::max_value()) { + bail_err!(DecodingError::ModuleIndexTooBig(module_idx)); + } + Ok(VMVerificationStatus::Module(module_idx as u16, err)) + } + } + } +} + +impl IntoProto for VMInvariantViolationError { + type ProtoType = crate::proto::vm_errors::VMInvariantViolationError; + + fn into_proto(self) -> Self::ProtoType { + use crate::proto::vm_errors::VMInvariantViolationError as ProtoStatus; + match self { + VMInvariantViolationError::OutOfBoundsIndex => ProtoStatus::OutOfBoundsIndex, + VMInvariantViolationError::OutOfBoundsRange => ProtoStatus::OutOfBoundsRange, + VMInvariantViolationError::EmptyValueStack => ProtoStatus::EmptyValueStack, + VMInvariantViolationError::EmptyCallStack => ProtoStatus::EmptyCallStack, + VMInvariantViolationError::PCOverflow => ProtoStatus::PCOverflow, + VMInvariantViolationError::LinkerError => ProtoStatus::LinkerError, + VMInvariantViolationError::LocalReferenceError => ProtoStatus::LocalReferenceError, + VMInvariantViolationError::StorageError => ProtoStatus::StorageError, + } + } +} + +impl FromProto for VMInvariantViolationError { + type ProtoType = crate::proto::vm_errors::VMInvariantViolationError; + + fn from_proto(proto_invariant_violation: Self::ProtoType) -> Result { + use crate::proto::vm_errors::VMInvariantViolationError as ProtoError; + match proto_invariant_violation { + ProtoError::OutOfBoundsIndex => Ok(VMInvariantViolationError::OutOfBoundsIndex), + ProtoError::OutOfBoundsRange => Ok(VMInvariantViolationError::OutOfBoundsRange), + ProtoError::EmptyValueStack => Ok(VMInvariantViolationError::EmptyValueStack), + ProtoError::EmptyCallStack => Ok(VMInvariantViolationError::EmptyCallStack), + ProtoError::PCOverflow => Ok(VMInvariantViolationError::PCOverflow), + ProtoError::LinkerError => Ok(VMInvariantViolationError::LinkerError), + ProtoError::LocalReferenceError => Ok(VMInvariantViolationError::LocalReferenceError), + ProtoError::StorageError => Ok(VMInvariantViolationError::StorageError), + ProtoError::UnknownInvariantViolationError => { + bail_err!(DecodingError::UnknownInvariantViolationErrorEncountered) + } + } + } +} + +impl IntoProto for BinaryError { + type ProtoType = crate::proto::vm_errors::BinaryError; + + fn into_proto(self) -> Self::ProtoType { + use crate::proto::vm_errors::BinaryError as ProtoStatus; + match self { + BinaryError::Malformed => ProtoStatus::Malformed, + BinaryError::BadMagic => ProtoStatus::BadMagic, + BinaryError::UnknownVersion => ProtoStatus::UnknownVersion, + BinaryError::UnknownTableType => ProtoStatus::UnknownTableType, + BinaryError::UnknownSignatureType => ProtoStatus::UnknownSignatureType, + BinaryError::UnknownSerializedType => ProtoStatus::UnknownSerializedType, + BinaryError::UnknownOpcode => ProtoStatus::UnknownOpcode, + BinaryError::BadHeaderTable => ProtoStatus::BadHeaderTable, + BinaryError::UnexpectedSignatureType => ProtoStatus::UnexpectedSignatureType, + BinaryError::DuplicateTable => ProtoStatus::DuplicateTable, + } + } +} + +impl FromProto for BinaryError { + type ProtoType = crate::proto::vm_errors::BinaryError; + + fn from_proto(proto_binary_error: Self::ProtoType) -> Result { + use crate::proto::vm_errors::BinaryError as ProtoError; + match proto_binary_error { + ProtoError::Malformed => Ok(BinaryError::Malformed), + ProtoError::BadMagic => Ok(BinaryError::BadMagic), + ProtoError::UnknownVersion => Ok(BinaryError::UnknownVersion), + ProtoError::UnknownTableType => Ok(BinaryError::UnknownTableType), + ProtoError::UnknownSignatureType => Ok(BinaryError::UnknownSignatureType), + ProtoError::UnknownSerializedType => Ok(BinaryError::UnknownSerializedType), + ProtoError::UnknownOpcode => Ok(BinaryError::UnknownOpcode), + ProtoError::BadHeaderTable => Ok(BinaryError::BadHeaderTable), + ProtoError::UnexpectedSignatureType => Ok(BinaryError::UnexpectedSignatureType), + ProtoError::DuplicateTable => Ok(BinaryError::DuplicateTable), + ProtoError::UnknownBinaryError => { + bail_err!(DecodingError::UnknownBinaryErrorEncountered) + } + } + } +} + +impl IntoProto for DynamicReferenceErrorType { + type ProtoType = crate::proto::vm_errors::DynamicReferenceError_DynamicReferenceErrorType; + + fn into_proto(self) -> Self::ProtoType { + use crate::proto::vm_errors::DynamicReferenceError_DynamicReferenceErrorType as ProtoError; + match self { + DynamicReferenceErrorType::MoveOfBorrowedResource => ProtoError::MoveOfBorrowedResource, + DynamicReferenceErrorType::GlobalRefAlreadyReleased => { + ProtoError::GlobalRefAlreadyReleased + } + DynamicReferenceErrorType::MissingReleaseRef => ProtoError::MissingReleaseRef, + DynamicReferenceErrorType::GlobalAlreadyBorrowed => ProtoError::GlobalAlreadyBorrowed, + } + } +} + +impl FromProto for DynamicReferenceErrorType { + type ProtoType = crate::proto::vm_errors::DynamicReferenceError_DynamicReferenceErrorType; + + fn from_proto(proto_ref_err_type: Self::ProtoType) -> Result { + use crate::proto::vm_errors::DynamicReferenceError_DynamicReferenceErrorType as ProtoError; + match proto_ref_err_type { + ProtoError::MoveOfBorrowedResource => { + Ok(DynamicReferenceErrorType::MoveOfBorrowedResource) + } + ProtoError::GlobalRefAlreadyReleased => { + Ok(DynamicReferenceErrorType::GlobalRefAlreadyReleased) + } + ProtoError::MissingReleaseRef => Ok(DynamicReferenceErrorType::MissingReleaseRef), + ProtoError::GlobalAlreadyBorrowed => { + Ok(DynamicReferenceErrorType::GlobalAlreadyBorrowed) + } + ProtoError::UnknownDynamicReferenceError => { + bail_err!(DecodingError::UnknownDynamicReferenceErrorTypeEncountered) + } + } + } +} + +impl IntoProto for ArithmeticErrorType { + type ProtoType = crate::proto::vm_errors::ArithmeticError_ArithmeticErrorType; + + fn into_proto(self) -> Self::ProtoType { + use crate::proto::vm_errors::ArithmeticError_ArithmeticErrorType as ProtoError; + match self { + ArithmeticErrorType::Underflow => ProtoError::Underflow, + ArithmeticErrorType::Overflow => ProtoError::Overflow, + ArithmeticErrorType::DivisionByZero => ProtoError::DivisionByZero, + } + } +} + +impl FromProto for ArithmeticErrorType { + type ProtoType = crate::proto::vm_errors::ArithmeticError_ArithmeticErrorType; + + fn from_proto(proto_ref_err_type: Self::ProtoType) -> Result { + use crate::proto::vm_errors::ArithmeticError_ArithmeticErrorType as ProtoError; + match proto_ref_err_type { + ProtoError::Underflow => Ok(ArithmeticErrorType::Underflow), + ProtoError::Overflow => Ok(ArithmeticErrorType::Overflow), + ProtoError::DivisionByZero => Ok(ArithmeticErrorType::DivisionByZero), + ProtoError::UnknownArithmeticError => { + bail_err!(DecodingError::UnknownArithmeticErrorTypeEncountered) + } + } + } +} + +impl IntoProto for ExecutionStatus { + type ProtoType = crate::proto::vm_errors::ExecutionStatus; + + fn into_proto(self) -> Self::ProtoType { + use crate::proto::vm_errors::{ + ArithmeticError, AssertionFailure as AssertStatus, DynamicReferenceError, + ExecutionStatus as ExecuteStatus, RuntimeStatus, + }; + let mut exec_status = ExecuteStatus::new(); + match self { + ExecutionStatus::Executed => exec_status.set_runtime_status(RuntimeStatus::Executed), + ExecutionStatus::OutOfGas => exec_status.set_runtime_status(RuntimeStatus::OutOfGas), + ExecutionStatus::ResourceDoesNotExist => { + exec_status.set_runtime_status(RuntimeStatus::ResourceDoesNotExist) + } + ExecutionStatus::ResourceAlreadyExists => { + exec_status.set_runtime_status(RuntimeStatus::ResourceAlreadyExists) + } + ExecutionStatus::EvictedAccountAccess => { + exec_status.set_runtime_status(RuntimeStatus::EvictedAccountAccess) + } + ExecutionStatus::AccountAddressAlreadyExists => { + exec_status.set_runtime_status(RuntimeStatus::AccountAddressAlreadyExists) + } + ExecutionStatus::TypeError => exec_status.set_runtime_status(RuntimeStatus::TypeError), + ExecutionStatus::MissingData => { + exec_status.set_runtime_status(RuntimeStatus::MissingData) + } + ExecutionStatus::DataFormatError => { + exec_status.set_runtime_status(RuntimeStatus::DataFormatError) + } + ExecutionStatus::InvalidData => { + exec_status.set_runtime_status(RuntimeStatus::InvalidData) + } + ExecutionStatus::RemoteDataError => { + exec_status.set_runtime_status(RuntimeStatus::RemoteDataError) + } + ExecutionStatus::CannotWriteExistingResource => { + exec_status.set_runtime_status(RuntimeStatus::CannotWriteExistingResource) + } + ExecutionStatus::ValueSerializationError => { + exec_status.set_runtime_status(RuntimeStatus::ValueSerializationError) + } + ExecutionStatus::ValueDeserializationError => { + exec_status.set_runtime_status(RuntimeStatus::ValueDeserializationError) + } + ExecutionStatus::DuplicateModuleName => { + exec_status.set_runtime_status(RuntimeStatus::DuplicateModuleName) + } + ExecutionStatus::DynamicReferenceError(err_type) => { + let mut ref_err = DynamicReferenceError::new(); + let err_code = DynamicReferenceErrorType::into_proto(err_type); + ref_err.set_error_code(err_code); + exec_status.set_reference_error(ref_err) + } + ExecutionStatus::ArithmeticError(err_type) => { + let mut arith_err = ArithmeticError::new(); + let err_code = ArithmeticErrorType::into_proto(err_type); + arith_err.set_error_code(err_code); + exec_status.set_arithmetic_error(arith_err) + } + ExecutionStatus::AssertionFailure(err_code) => { + let mut assert_error = AssertStatus::new(); + assert_error.set_assertion_error_code(err_code); + exec_status.set_assertion_failure(assert_error) + } + }; + exec_status + } +} + +impl FromProto for ExecutionStatus { + type ProtoType = crate::proto::vm_errors::ExecutionStatus; + + fn from_proto(mut proto_execution_status: Self::ProtoType) -> Result { + use crate::proto::vm_errors::RuntimeStatus as ProtoRuntimeStatus; + if proto_execution_status.has_runtime_status() { + match proto_execution_status.get_runtime_status() { + ProtoRuntimeStatus::Executed => Ok(ExecutionStatus::Executed), + ProtoRuntimeStatus::OutOfGas => Ok(ExecutionStatus::OutOfGas), + ProtoRuntimeStatus::ResourceDoesNotExist => { + Ok(ExecutionStatus::ResourceDoesNotExist) + } + ProtoRuntimeStatus::ResourceAlreadyExists => { + Ok(ExecutionStatus::ResourceAlreadyExists) + } + ProtoRuntimeStatus::EvictedAccountAccess => { + Ok(ExecutionStatus::EvictedAccountAccess) + } + ProtoRuntimeStatus::AccountAddressAlreadyExists => { + Ok(ExecutionStatus::AccountAddressAlreadyExists) + } + ProtoRuntimeStatus::TypeError => Ok(ExecutionStatus::TypeError), + ProtoRuntimeStatus::MissingData => Ok(ExecutionStatus::MissingData), + ProtoRuntimeStatus::DataFormatError => Ok(ExecutionStatus::DataFormatError), + ProtoRuntimeStatus::InvalidData => Ok(ExecutionStatus::InvalidData), + ProtoRuntimeStatus::RemoteDataError => Ok(ExecutionStatus::RemoteDataError), + ProtoRuntimeStatus::CannotWriteExistingResource => { + Ok(ExecutionStatus::CannotWriteExistingResource) + } + ProtoRuntimeStatus::ValueSerializationError => { + Ok(ExecutionStatus::ValueSerializationError) + } + ProtoRuntimeStatus::ValueDeserializationError => { + Ok(ExecutionStatus::ValueDeserializationError) + } + ProtoRuntimeStatus::DuplicateModuleName => Ok(ExecutionStatus::DuplicateModuleName), + ProtoRuntimeStatus::UnknownRuntimeStatus => { + bail_err!(DecodingError::UnknownRuntimeStatusEncountered) + } + } + } else if proto_execution_status.has_arithmetic_error() { + let err = proto_execution_status + .take_arithmetic_error() + .get_error_code(); + let from_proto = ArithmeticErrorType::from_proto(err)?; + Ok(ExecutionStatus::ArithmeticError(from_proto)) + } else if proto_execution_status.has_reference_error() { + let err = proto_execution_status + .take_reference_error() + .get_error_code(); + let from_proto = DynamicReferenceErrorType::from_proto(err)?; + Ok(ExecutionStatus::DynamicReferenceError(from_proto)) + } else { + // else it's an assertion error + let err_code = proto_execution_status.get_assertion_failure(); + Ok(ExecutionStatus::AssertionFailure( + err_code.assertion_error_code, + )) + } + } +} + +impl IntoProto for VMStatus { + type ProtoType = crate::proto::vm_errors::VMStatus; + + fn into_proto(self) -> Self::ProtoType { + let mut vm_status = Self::ProtoType::new(); + match self { + VMStatus::Validation(status) => vm_status.set_validation(status.into_proto()), + VMStatus::Verification(status_list) => { + use crate::proto::vm_errors::VMVerificationStatusList as ProtoStatusList; + + let proto_vec: Vec<_> = status_list + .into_iter() + .map(VMVerificationStatus::into_proto) + .collect(); + let mut proto_status_list = ProtoStatusList::new(); + proto_status_list.set_status_list(proto_vec.into()); + vm_status.set_verification(proto_status_list); + } + VMStatus::InvariantViolation(err) => { + vm_status.set_invariant_violation(err.into_proto()) + } + VMStatus::Deserialization(err) => vm_status.set_deserialization(err.into_proto()), + VMStatus::Execution(exec_status) => vm_status.set_execution(exec_status.into_proto()), + }; + vm_status + } +} + +impl FromProto for VMStatus { + type ProtoType = crate::proto::vm_errors::VMStatus; + + fn from_proto(mut vm_status: Self::ProtoType) -> Result { + if vm_status.has_validation() { + let from_proto = VMValidationStatus::from_proto(vm_status.take_validation())?; + Ok(VMStatus::Validation(from_proto)) + } else if vm_status.has_verification() { + let mut proto_status_list = vm_status.take_verification(); + let proto_repeated = proto_status_list.take_status_list(); + let status_list = proto_repeated + .into_iter() + .map(VMVerificationStatus::from_proto) + .collect::>()?; + Ok(VMStatus::Verification(status_list)) + } else if vm_status.has_invariant_violation() { + let from_proto = + VMInvariantViolationError::from_proto(vm_status.get_invariant_violation())?; + Ok(VMStatus::InvariantViolation(from_proto)) + } else if vm_status.has_deserialization() { + let from_proto = BinaryError::from_proto(vm_status.get_deserialization())?; + Ok(VMStatus::Deserialization(from_proto)) + } else if vm_status.has_execution() { + let from_proto = ExecutionStatus::from_proto(vm_status.take_execution())?; + Ok(VMStatus::Execution(from_proto)) + } else { + bail_err!(DecodingError::InvalidVMStatusEncountered) + } + } +} diff --git a/types/src/write_set.rs b/types/src/write_set.rs new file mode 100644 index 0000000000000..db9437e9639b1 --- /dev/null +++ b/types/src/write_set.rs @@ -0,0 +1,199 @@ +// Copyright (c) The Libra Core Contributors +// SPDX-License-Identifier: Apache-2.0 + +//! For each transaction the VM executes, the VM will output a `WriteSet` that contains each access +//! path it updates. For each access path, the VM can either give its new value or delete it. + +use crate::access_path::AccessPath; +use failure::prelude::*; +use proto_conv::{FromProto, IntoProto}; +use serde::{Deserialize, Serialize}; + +#[derive(Clone, Eq, Hash, PartialEq, Serialize, Deserialize)] +pub enum WriteOp { + Value(Vec), + Deletion, +} + +impl WriteOp { + #[inline] + pub fn is_value(&self) -> bool { + match self { + WriteOp::Value(_) => true, + WriteOp::Deletion => false, + } + } + + #[inline] + pub fn is_deletion(&self) -> bool { + match self { + WriteOp::Deletion => true, + WriteOp::Value(_) => false, + } + } +} + +impl std::fmt::Debug for WriteOp { + fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + match self { + WriteOp::Value(value) => write!(f, "Value({})", String::from_utf8_lossy(value)), + WriteOp::Deletion => write!(f, "Deletion"), + } + } +} + +/// `WriteSet` contains all access paths that one transaction modifies. Each of them is a `WriteOp` +/// where `Value(val)` means that serialized representation should be updated to `val`, and +/// `Deletion` means that we are going to delete this access path. +#[derive(Clone, Debug, Default, Eq, Hash, PartialEq, Serialize, Deserialize)] +pub struct WriteSet(WriteSetMut); + +impl WriteSet { + #[inline] + pub fn len(&self) -> usize { + self.0.len() + } + + #[inline] + pub fn is_empty(&self) -> bool { + self.0.is_empty() + } + + #[inline] + pub fn iter<'a>(&'a self) -> ::std::slice::Iter<'a, (AccessPath, WriteOp)> { + self.into_iter() + } + + #[inline] + pub fn into_mut(self) -> WriteSetMut { + self.0 + } +} + +/// A mutable version of `WriteSet`. +/// +/// This is separate because it goes through validation before becoming an immutable `WriteSet`. +#[derive(Clone, Debug, Default, Eq, Hash, PartialEq, Serialize, Deserialize)] +pub struct WriteSetMut { + write_set: Vec<(AccessPath, WriteOp)>, +} + +impl WriteSetMut { + pub fn new(write_set: Vec<(AccessPath, WriteOp)>) -> Self { + Self { write_set } + } + + pub fn push(&mut self, item: (AccessPath, WriteOp)) { + self.write_set.push(item); + } + + #[inline] + pub fn len(&self) -> usize { + self.write_set.len() + } + + #[inline] + pub fn is_empty(&self) -> bool { + self.write_set.is_empty() + } + + pub fn freeze(self) -> Result { + // TODO: add structural validation + Ok(WriteSet(self)) + } +} + +impl FromProto for WriteSet { + type ProtoType = crate::proto::transaction::WriteSet; + + fn from_proto(mut write_set: Self::ProtoType) -> Result { + use crate::proto::transaction::WriteOpType; + + let write_set = write_set + .take_write_set() + .into_iter() + .map(|mut write_op| { + // The protobuf WriteOp is equivalent to (AccessPath, WriteOp) in Rust, so + // From/IntoProto can't be implemented for WriteOp and instead the conversion must + // be done here. + let access_path = AccessPath::from_proto(write_op.take_access_path())?; + let write_op = match write_op.get_field_type() { + WriteOpType::Write => WriteOp::Value(write_op.take_value()), + WriteOpType::Delete => { + ensure!( + write_op.get_value().is_empty(), + "WriteOp with access path {:?} has WriteOpType::Delete with value", + access_path, + ); + WriteOp::Deletion + } + }; + Ok((access_path, write_op)) + }) + .collect::>()?; + let write_set_mut = WriteSetMut::new(write_set); + write_set_mut.freeze() + } +} + +impl IntoProto for WriteSet { + type ProtoType = crate::proto::transaction::WriteSet; + + fn into_proto(self) -> Self::ProtoType { + use crate::proto::transaction::{WriteOp as ProtoWriteOp, WriteOpType}; + + let proto_write_ops = self + .0 + .write_set + .into_iter() + .map(|(access_path, write_op)| { + let mut proto_write_op = ProtoWriteOp::new(); + proto_write_op.set_access_path(access_path.into_proto()); + match write_op { + WriteOp::Value(value) => { + proto_write_op.set_value(value); + proto_write_op.set_field_type(WriteOpType::Write); + } + WriteOp::Deletion => { + // This should be a no-op but this code conveys the intent better. + proto_write_op.set_value(vec![]); + proto_write_op.set_field_type(WriteOpType::Delete); + } + }; + proto_write_op + }) + .collect(); + + let mut proto_write_set = Self::ProtoType::new(); + proto_write_set.set_write_set(proto_write_ops); + proto_write_set + } +} + +impl ::std::iter::FromIterator<(AccessPath, WriteOp)> for WriteSetMut { + fn from_iter>(iter: I) -> Self { + let mut ws = WriteSetMut::default(); + for write in iter { + ws.push((write.0, write.1)); + } + ws + } +} + +impl<'a> IntoIterator for &'a WriteSet { + type Item = &'a (AccessPath, WriteOp); + type IntoIter = ::std::slice::Iter<'a, (AccessPath, WriteOp)>; + + fn into_iter(self) -> Self::IntoIter { + self.0.write_set.iter() + } +} + +impl ::std::iter::IntoIterator for WriteSet { + type Item = (AccessPath, WriteOp); + type IntoIter = ::std::vec::IntoIter<(AccessPath, WriteOp)>; + + fn into_iter(self) -> Self::IntoIter { + self.0.write_set.into_iter() + } +} diff --git a/vm_validator/Cargo.toml b/vm_validator/Cargo.toml new file mode 100644 index 0000000000000..9b9989c842479 --- /dev/null +++ b/vm_validator/Cargo.toml @@ -0,0 +1,31 @@ +[package] +name = "vm_validator" +version = "0.1.0" +authors = ["Libra Association "] +license = "Apache-2.0" +publish = false +edition = "2018" + +[dependencies] +futures = "0.1.25" + +config = { path = "../config" } +crypto = { path = "../crypto/legacy_crypto" } +failure = { path = "../common/failure_ext", package = "failure_ext" } +proto_conv = { path = "../common/proto_conv" } +scratchpad = { path = "../storage/scratchpad" } +state_view = { path = "../storage/state_view" } +storage_client = { path = "../storage/storage_client" } +types = { path = "../types" } +vm_runtime = { path = "../language/vm/vm_runtime" } + +[dev-dependencies] +grpcio = "0.4.4" +assert_matches = "1.3.0" + +execution_proto = { path = "../execution/execution_proto" } +execution_service = { path = "../execution/execution_service" } +grpc_helpers = { path = "../common/grpc_helpers" } +storage_service = { path = "../storage/storage_service" } +vm_genesis = { path = "../language/vm/vm_genesis" } +config_builder = { path = "../config/config_builder" } diff --git a/vm_validator/src/lib.rs b/vm_validator/src/lib.rs new file mode 100644 index 0000000000000..549fc0e91f817 --- /dev/null +++ b/vm_validator/src/lib.rs @@ -0,0 +1,6 @@ +// Copyright (c) The Libra Core Contributors +// SPDX-License-Identifier: Apache-2.0 + +#![feature(async_await)] +pub mod mocks; +pub mod vm_validator; diff --git a/vm_validator/src/mocks/mock_vm_validator.rs b/vm_validator/src/mocks/mock_vm_validator.rs new file mode 100644 index 0000000000000..6e764da63e132 --- /dev/null +++ b/vm_validator/src/mocks/mock_vm_validator.rs @@ -0,0 +1,74 @@ +// Copyright (c) The Libra Core Contributors +// SPDX-License-Identifier: Apache-2.0 + +use crate::vm_validator::TransactionValidation; +use futures::future::{ok, Future}; +use state_view::StateView; +use std::convert::TryFrom; +use types::{ + account_address::{AccountAddress, ADDRESS_LENGTH}, + transaction::SignedTransaction, + vm_error::{VMStatus, VMValidationStatus}, +}; +use vm_runtime::VMVerifier; + +#[derive(Clone)] +pub struct MockVMValidator; + +impl VMVerifier for MockVMValidator { + fn validate_transaction( + &self, + _transaction: SignedTransaction, + _state_view: &dyn StateView, + ) -> Option { + None + } +} + +impl TransactionValidation for MockVMValidator { + type ValidationInstance = MockVMValidator; + fn validate_transaction( + &self, + txn: SignedTransaction, + ) -> Box, Error = failure::Error> + Send> { + let sender = txn.sender(); + let account_dne_test_add = AccountAddress::try_from(&[0 as u8; ADDRESS_LENGTH]).unwrap(); + let invalid_sig_test_add = AccountAddress::try_from(&[1 as u8; ADDRESS_LENGTH]).unwrap(); + let insufficient_balance_test_add = + AccountAddress::try_from(&[2 as u8; ADDRESS_LENGTH]).unwrap(); + let seq_number_too_new_test_add = + AccountAddress::try_from(&[3 as u8; ADDRESS_LENGTH]).unwrap(); + let seq_number_too_old_test_add = + AccountAddress::try_from(&[4 as u8; ADDRESS_LENGTH]).unwrap(); + let txn_expiration_time_test_add = + AccountAddress::try_from(&[5 as u8; ADDRESS_LENGTH]).unwrap(); + let invalid_auth_key_test_add = + AccountAddress::try_from(&[6 as u8; ADDRESS_LENGTH]).unwrap(); + let ret = if sender == account_dne_test_add { + Some(VMStatus::Validation( + VMValidationStatus::SendingAccountDoesNotExist("TEST".to_string()), + )) + } else if sender == invalid_sig_test_add { + Some(VMStatus::Validation(VMValidationStatus::InvalidSignature)) + } else if sender == insufficient_balance_test_add { + Some(VMStatus::Validation( + VMValidationStatus::InsufficientBalanceForTransactionFee, + )) + } else if sender == seq_number_too_new_test_add { + Some(VMStatus::Validation( + VMValidationStatus::SequenceNumberTooNew, + )) + } else if sender == seq_number_too_old_test_add { + Some(VMStatus::Validation( + VMValidationStatus::SequenceNumberTooOld, + )) + } else if sender == txn_expiration_time_test_add { + Some(VMStatus::Validation(VMValidationStatus::TransactionExpired)) + } else if sender == invalid_auth_key_test_add { + Some(VMStatus::Validation(VMValidationStatus::InvalidAuthKey)) + } else { + None + }; + Box::new(ok(ret)) + } +} diff --git a/vm_validator/src/mocks/mod.rs b/vm_validator/src/mocks/mod.rs new file mode 100644 index 0000000000000..25e56563c5c47 --- /dev/null +++ b/vm_validator/src/mocks/mod.rs @@ -0,0 +1,4 @@ +// Copyright (c) The Libra Core Contributors +// SPDX-License-Identifier: Apache-2.0 + +pub mod mock_vm_validator; diff --git a/vm_validator/src/unit_tests/vm_validator_test.rs b/vm_validator/src/unit_tests/vm_validator_test.rs new file mode 100644 index 0000000000000..81ac010308e0a --- /dev/null +++ b/vm_validator/src/unit_tests/vm_validator_test.rs @@ -0,0 +1,489 @@ +// Copyright (c) The Libra Core Contributors +// SPDX-License-Identifier: Apache-2.0 + +use crate::vm_validator::{TransactionValidation, VMValidator}; +use assert_matches::assert_matches; +use config::config::NodeConfig; +use config_builder::util::get_test_config; +use crypto::signing::KeyPair; +use execution_proto::proto::execution_grpc; +use execution_service::ExecutionService; +use futures::future::Future; +use grpc_helpers::ServerHandle; +use grpcio::EnvBuilder; +use proto_conv::FromProto; +use std::{sync::Arc, u64}; +use storage_client::{StorageRead, StorageReadServiceClient, StorageWriteServiceClient}; +use storage_service::start_storage_service; +use types::{ + account_address, account_config, + test_helpers::transaction_test_helpers, + transaction::{Program, SignedTransaction, TransactionArgument, MAX_TRANSACTION_SIZE_IN_BYTES}, + vm_error::{VMStatus, VMValidationStatus, VMVerificationError, VMVerificationStatus}, +}; +use vm_genesis::encode_transfer_program; + +struct TestValidator { + _storage: ServerHandle, + _execution: grpcio::Server, + vm_validator: VMValidator, +} + +impl TestValidator { + fn new(config: &NodeConfig) -> Self { + let storage = start_storage_service(&config); + + // setup execution + let client_env = Arc::new(EnvBuilder::new().build()); + let storage_read_client: Arc = Arc::new(StorageReadServiceClient::new( + Arc::clone(&client_env), + &config.storage.address, + config.storage.port, + )); + let storage_write_client = Arc::new(StorageWriteServiceClient::new( + Arc::clone(&client_env), + &config.storage.address, + config.storage.port, + )); + + let handle = ExecutionService::new( + Arc::clone(&storage_read_client), + storage_write_client, + config, + ); + let service = execution_grpc::create_execution(handle); + let execution = ::grpcio::ServerBuilder::new(Arc::new(EnvBuilder::new().build())) + .register_service(service) + .bind(config.execution.address.clone(), config.execution.port) + .build() + .expect("Unable to create grpc server"); + + let vm_validator = VMValidator::new(config, storage_read_client); + + TestValidator { + _storage: storage, + _execution: execution, + vm_validator, + } + } +} + +impl std::ops::Deref for TestValidator { + type Target = VMValidator; + + fn deref(&self) -> &Self::Target { + &self.vm_validator + } +} + +// These tests are meant to test all high-level code paths that lead to a validation error in the +// verification of a transaction in the VM. However, there are a couple notable exceptions that we +// do _not_ test here -- this is due to limitations around execution and semantics. The following +// errors are not exercised: +// * Sequence number too old -- We can't test sequence number too old here without running execution +// first in order to bump the account's sequence number. This needs to (and is) tested in the +// vm_runtime_tests in: libra/language/vm/vm_runtime/vm_runtime_tests/src/tests/verify_txn.rs -> +// verify_simple_payment. +// * Errors arising from deserializing the code -- these are tested in +// - libra/language/vm/src/unit_tests/deserializer_tests.rs +// - libra/language/vm/tests/serializer_tests.rs +// * Errors arising from calls to `static_verify_program` -- this is tested separately in tests for +// the bytecode verifier. +// * Testing for invalid genesis write sets -- this is tested in +// libra/language/vm/vm_runtime/vm_runtime_tests/src/tests/genesis.rs + +#[test] +fn test_validate_transaction() { + let (config, keypair) = get_test_config(); + let vm_validator = TestValidator::new(&config); + + let address = account_config::association_address(); + let program = encode_transfer_program(&address, 100); + let signed_txn = transaction_test_helpers::get_test_signed_txn( + address, + 0, + keypair.private_key().clone(), + keypair.public_key(), + Some(program), + ); + let ret = vm_validator + .validate_transaction(SignedTransaction::from_proto(signed_txn).unwrap()) + .wait() + .unwrap(); + assert_eq!(ret, None); +} + +#[test] +fn test_validate_invalid_signature() { + let (config, keypair) = get_test_config(); + let vm_validator = TestValidator::new(&config); + + let (other_private_key, _) = ::crypto::signing::generate_keypair(); + // Submit with an account wusing an different private/public keypair + let other_keypair = KeyPair::new(other_private_key); + + let address = account_config::association_address(); + let program = encode_transfer_program(&address, 100); + let signed_txn = transaction_test_helpers::get_unverified_test_signed_txn( + address, + 0, + other_keypair.private_key().clone(), + keypair.public_key(), + Some(program), + ); + let ret = vm_validator + .validate_transaction(signed_txn) + .wait() + .unwrap(); + assert_eq!( + ret, + Some(VMStatus::Validation(VMValidationStatus::InvalidSignature)) + ); +} + +#[test] +fn test_validate_known_script_too_large_args() { + let (config, keypair) = get_test_config(); + let vm_validator = TestValidator::new(&config); + + let address = account_config::association_address(); + let txn = transaction_test_helpers::get_test_signed_transaction( + address, + 0, + keypair.private_key().clone(), + keypair.public_key(), + Some(Program::new( + vec![42; MAX_TRANSACTION_SIZE_IN_BYTES], + vec![], + vec![], + )), /* generate a program with args longer than the max size */ + 0, + 0, /* max gas price */ + None, + ); + let txn = SignedTransaction::from_proto(txn).unwrap(); + let ret = vm_validator.validate_transaction(txn).wait().unwrap(); + assert_matches!( + ret, + Some(VMStatus::Validation(VMValidationStatus::ExceededMaxTransactionSize(_))) + ); +} + +#[test] +fn test_validate_max_gas_units_above_max() { + let (config, keypair) = get_test_config(); + let vm_validator = TestValidator::new(&config); + + let address = account_config::association_address(); + let txn = transaction_test_helpers::get_test_signed_transaction( + address, + 0, + keypair.private_key().clone(), + keypair.public_key(), + None, + 0, + 0, /* max gas price */ + Some(u64::MAX), // Max gas units + ); + let ret = vm_validator + .validate_transaction(SignedTransaction::from_proto(txn).unwrap()) + .wait() + .unwrap(); + assert_matches!( + ret, + Some(VMStatus::Validation(VMValidationStatus::MaxGasUnitsExceedsMaxGasUnitsBound(_))) + ); +} + +#[test] +fn test_validate_max_gas_units_below_min() { + let (config, keypair) = get_test_config(); + let vm_validator = TestValidator::new(&config); + + let address = account_config::association_address(); + let txn = transaction_test_helpers::get_test_signed_transaction( + address, + 0, + keypair.private_key().clone(), + keypair.public_key(), + None, + 0, + 0, /* max gas price */ + Some(1), // Max gas units + ); + let ret = vm_validator + .validate_transaction(SignedTransaction::from_proto(txn).unwrap()) + .wait() + .unwrap(); + assert_matches!( + ret, + Some(VMStatus::Validation(VMValidationStatus::MaxGasUnitsBelowMinTransactionGasUnits(_))) + ); +} + +#[test] +fn test_validate_max_gas_price_above_bounds() { + let (config, keypair) = get_test_config(); + let vm_validator = TestValidator::new(&config); + + let address = account_config::association_address(); + let txn = transaction_test_helpers::get_test_signed_transaction( + address, + 0, + keypair.private_key().clone(), + keypair.public_key(), + None, + 0, + u64::MAX, /* max gas price */ + None, + ); + let ret = vm_validator + .validate_transaction(SignedTransaction::from_proto(txn).unwrap()) + .wait() + .unwrap(); + assert_matches!( + ret, + Some(VMStatus::Validation(VMValidationStatus::GasUnitPriceAboveMaxBound(_))) + ); +} + +// NB: This test is designed to fail if/when we bump the minimum gas price to be non-zero. You will +// then need to update this price here in order to make the test pass -- uncomment the commented +// out assertion and remove the current failing assertion in this case. +#[test] +fn test_validate_max_gas_price_below_bounds() { + let (config, keypair) = get_test_config(); + let vm_validator = TestValidator::new(&config); + + let address = account_config::association_address(); + let program = encode_transfer_program(&address, 100); + let txn = transaction_test_helpers::get_test_signed_transaction( + address, + 0, + keypair.private_key().clone(), + keypair.public_key(), + Some(program), + 0, + 0, /* max gas price */ + None, + ); + let ret = vm_validator + .validate_transaction(SignedTransaction::from_proto(txn).unwrap()) + .wait() + .unwrap(); + assert_eq!(ret, None); + //assert_eq!( + // ret.unwrap(), + // VMStatus::ValidationStatus(VMValidationStatus::GasUnitPriceBelowMinBound) + //); +} + +#[cfg(not(feature = "allow_custom_transaction_scripts"))] +#[test] +fn test_validate_unknown_script() { + let (config, keypair) = get_test_config(); + let vm_validator = TestValidator::new(&config); + + let address = account_config::association_address(); + let signed_txn = transaction_test_helpers::get_test_signed_txn( + address, + 1, + keypair.private_key().clone(), + keypair.public_key(), + None, + ); + let ret = vm_validator + .validate_transaction(SignedTransaction::from_proto(signed_txn).unwrap()) + .wait() + .unwrap(); + assert_eq!( + ret, + Some(VMStatus::Validation(VMValidationStatus::UnknownScript)) + ); +} + +// Make sure that we can't publish non-whitelisted modules +#[cfg(not(feature = "allow_custom_transaction_scripts"))] +#[cfg(not(feature = "custom_modules"))] +#[test] +fn test_validate_module_publishing() { + let (config, keypair) = get_test_config(); + let vm_validator = TestValidator::new(&config); + + let address = account_config::association_address(); + let (program_script, args, _) = encode_transfer_program(&address, 100).into_inner(); + let program = Program::new(program_script, vec![vec![]], args); + let signed_txn = transaction_test_helpers::get_test_signed_txn( + address, + 1, + keypair.private_key().clone(), + keypair.public_key(), + Some(program), + ); + let ret = vm_validator + .validate_transaction(SignedTransaction::from_proto(signed_txn).unwrap()) + .wait() + .unwrap(); + assert_eq!( + ret, + Some(VMStatus::Validation(VMValidationStatus::UnknownModule)) + ); +} + +#[test] +fn test_validate_invalid_auth_key() { + let (config, _) = get_test_config(); + let vm_validator = TestValidator::new(&config); + + let (other_private_key, _) = ::crypto::signing::generate_keypair(); + // Submit with an account wusing an different private/public keypair + let other_keypair = KeyPair::new(other_private_key); + + let address = account_config::association_address(); + let program = encode_transfer_program(&address, 100); + let signed_txn = transaction_test_helpers::get_test_signed_txn( + address, + 0, + other_keypair.private_key().clone(), + other_keypair.public_key(), + Some(program), + ); + let ret = vm_validator + .validate_transaction(SignedTransaction::from_proto(signed_txn).unwrap()) + .wait() + .unwrap(); + assert_eq!( + ret, + Some(VMStatus::Validation(VMValidationStatus::InvalidAuthKey)) + ); +} + +#[test] +fn test_validate_balance_below_gas_fee() { + let (config, keypair) = get_test_config(); + let vm_validator = TestValidator::new(&config); + + let address = account_config::association_address(); + let program = encode_transfer_program(&address, 100); + let signed_txn = transaction_test_helpers::get_test_signed_transaction( + address, + 0, + keypair.private_key().clone(), + keypair.public_key(), + Some(program), + 0, + // Note that this will be dependent upon the max gas price and gas amounts that are set. So + // changing those may cause this test to fail. + 10_000, /* max gas price */ + Some(1_000_000), + ); + let ret = vm_validator + .validate_transaction(SignedTransaction::from_proto(signed_txn).unwrap()) + .wait() + .unwrap(); + assert_eq!( + ret, + Some(VMStatus::Validation( + VMValidationStatus::InsufficientBalanceForTransactionFee + )) + ); +} + +#[test] +fn test_validate_account_doesnt_exist() { + let (config, keypair) = get_test_config(); + let vm_validator = TestValidator::new(&config); + + let address = account_config::association_address(); + let random_account_addr = account_address::AccountAddress::random(); + let program = encode_transfer_program(&address, 100); + let signed_txn = transaction_test_helpers::get_test_signed_transaction( + random_account_addr, + 0, + keypair.private_key().clone(), + keypair.public_key(), + Some(program), + 0, + 1, /* max gas price */ + None, + ); + let ret = vm_validator + .validate_transaction(SignedTransaction::from_proto(signed_txn).unwrap()) + .wait() + .unwrap(); + assert_matches!( + ret.unwrap(), + VMStatus::Validation(VMValidationStatus::SendingAccountDoesNotExist(_)) + ); +} + +#[test] +fn test_validate_sequence_number_too_new() { + let (config, keypair) = get_test_config(); + let vm_validator = TestValidator::new(&config); + + let address = account_config::association_address(); + let program = encode_transfer_program(&address, 100); + let signed_txn = transaction_test_helpers::get_test_signed_txn( + address, + 1, + keypair.private_key().clone(), + keypair.public_key(), + Some(program), + ); + let ret = vm_validator + .validate_transaction(SignedTransaction::from_proto(signed_txn).unwrap()) + .wait() + .unwrap(); + assert_eq!(ret, None); +} + +#[test] +fn test_validate_invalid_arguments() { + let (config, keypair) = get_test_config(); + let vm_validator = TestValidator::new(&config); + + let address = account_config::association_address(); + let (program_script, _, _) = encode_transfer_program(&address, 100).into_inner(); + let program = Program::new(program_script, vec![], vec![TransactionArgument::U64(42)]); + let signed_txn = transaction_test_helpers::get_test_signed_txn( + address, + 0, + keypair.private_key().clone(), + keypair.public_key(), + Some(program), + ); + let ret = vm_validator + .validate_transaction(SignedTransaction::from_proto(signed_txn).unwrap()) + .wait() + .unwrap(); + assert_eq!( + ret, + Some(VMStatus::Verification(vec![VMVerificationStatus::Script( + VMVerificationError::TypeMismatch("Actual Type Mismatch".to_string()) + )])) + ); +} + +#[test] +fn test_validate_non_genesis_write_set() { + let (config, keypair) = get_test_config(); + let vm_validator = TestValidator::new(&config); + + let address = account_config::association_address(); + let signed_txn = transaction_test_helpers::get_write_set_txn( + address, + 0, + keypair.private_key().clone(), + keypair.public_key(), + None, + ); + let ret = vm_validator + .validate_transaction(signed_txn) + .wait() + .unwrap(); + assert_eq!( + ret, + Some(VMStatus::Validation(VMValidationStatus::RejectedWriteSet)) + ); +} diff --git a/vm_validator/src/vm_validator.rs b/vm_validator/src/vm_validator.rs new file mode 100644 index 0000000000000..9808a3ed65d3c --- /dev/null +++ b/vm_validator/src/vm_validator.rs @@ -0,0 +1,125 @@ +// Copyright (c) The Libra Core Contributors +// SPDX-License-Identifier: Apache-2.0 + +use config::config::NodeConfig; +use failure::prelude::*; +use futures::future::{err, ok, Future}; +use scratchpad::SparseMerkleTree; +use std::sync::Arc; +use storage_client::{StorageRead, VerifiedStateView}; +use types::{ + account_address::{AccountAddress, ADDRESS_LENGTH}, + account_config::get_account_resource_or_default, + get_with_proof::{RequestItem, ResponseItem}, + transaction::SignedTransaction, + vm_error::VMStatus, +}; +use vm_runtime::{MoveVM, VMVerifier}; + +#[cfg(test)] +#[path = "unit_tests/vm_validator_test.rs"] +mod vm_validator_test; + +pub trait TransactionValidation: Send + Sync { + type ValidationInstance: VMVerifier; + /// Validate a txn from client + fn validate_transaction( + &self, + _txn: SignedTransaction, + ) -> Box, Error = failure::Error> + Send>; +} + +#[derive(Clone)] +pub struct VMValidator { + storage_read_client: Arc, + vm: MoveVM, +} + +impl VMValidator { + pub fn new(config: &NodeConfig, storage_read_client: Arc) -> Self { + VMValidator { + storage_read_client, + vm: MoveVM::new(&config.vm_config), + } + } +} + +impl TransactionValidation for VMValidator { + type ValidationInstance = MoveVM; + + fn validate_transaction( + &self, + txn: SignedTransaction, + ) -> Box, Error = failure::Error> + Send> { + // TODO: For transaction validation, there are two options to go: + // 1. Trust storage: there is no need to get root hash from storage here. We will + // create another struct similar to `VerifiedStateView` that implements `StateView` + // but does not do verification. + // 2. Don't trust storage. This requires more work: + // 1) AC must have validator set information + // 2) Get state_root from transaction info which can be verified with signatures of + // validator set. + // 3) Create VerifiedStateView with verified state + // root. + + // Just ask something from storage. It doesn't matter what it is -- we just need the + // transaction info object in account state proof which contains the state root hash. + let address = AccountAddress::new([0xff; ADDRESS_LENGTH]); + let item = RequestItem::GetAccountState { address }; + + match self + .storage_read_client + .update_to_latest_ledger(/* client_known_version = */ 0, vec![item]) + { + Ok((mut items, _, _)) => { + if items.len() != 1 { + return Box::new(err(format_err!( + "Unexpected number of items ({}).", + items.len() + ) + .into())); + } + + match items.remove(0) { + ResponseItem::GetAccountState { + account_state_with_proof, + } => { + let transaction_info = account_state_with_proof.proof.transaction_info(); + let state_root = transaction_info.state_root_hash(); + let smt = SparseMerkleTree::new(state_root); + let state_view = VerifiedStateView::new( + Arc::clone(&self.storage_read_client), + state_root, + &smt, + ); + Box::new(ok(self.vm.validate_transaction(txn, &state_view))) + } + _ => panic!("Unexpected item in response."), + } + } + Err(e) => Box::new(err(e.into())), + } + } +} + +/// read account state +/// returns account's current sequence number and balance +pub async fn get_account_state( + storage_read_client: Arc, + address: AccountAddress, +) -> Result<(u64, u64)> { + let req_item = RequestItem::GetAccountState { address }; + let (response_items, _, _) = storage_read_client + .update_to_latest_ledger_async(0 /* client_known_version */, vec![req_item]) + .await?; + let account_state = match &response_items[0] { + ResponseItem::GetAccountState { + account_state_with_proof, + } => &account_state_with_proof.blob, + _ => bail!("Not account state response."), + }; + let account_resource = get_account_resource_or_default(account_state)?; + let sequence_number = account_resource.sequence_number(); + let balance = account_resource.balance(); + Ok((sequence_number, balance)) +}

{ + vec![self.get_proposer(round)] + } +} + +impl EventBasedActor for RotatingProposer { + type InputEvent = ProposalInfo; + type OutputEvent = ProposalInfo; + + fn init( + &mut self, + _: mpsc::Sender, + output_stream_sender: mpsc::Sender, + ) { + self.winning_proposals = Some(output_stream_sender); + } + + fn process_event(&self, event: Self::InputEvent) -> Pin + Send>> { + let proposer = self.get_proposer(event.proposal.round()); + let mut sender = self.winning_proposals.as_ref().unwrap().clone(); + async move { + if proposer.get_author() == event.proposer_info.get_author() { + if let Err(e) = sender.send(event).await { + debug!("Error in sending the winning proposal: {:?}", e); + } + } + } + .boxed() + } +} diff --git a/consensus/src/chained_bft/liveness/rotating_proposer_test.rs b/consensus/src/chained_bft/liveness/rotating_proposer_test.rs new file mode 100644 index 0000000000000..eb108f76ff8df --- /dev/null +++ b/consensus/src/chained_bft/liveness/rotating_proposer_test.rs @@ -0,0 +1,245 @@ +// Copyright (c) The Libra Core Contributors +// SPDX-License-Identifier: Apache-2.0 + +use crate::{ + chained_bft::{ + common::Author, + consensus_types::{block::Block, quorum_cert::QuorumCert}, + liveness::{ + proposer_election::{ProposalInfo, ProposerElection}, + rotating_proposer_election::RotatingProposer, + }, + test_utils::consensus_runtime, + }, + stream_utils::start_event_processing_loop, +}; +use futures::{executor::block_on, SinkExt, StreamExt}; +use std::sync::Arc; +use types::validator_signer::ValidatorSigner; + +#[test] +fn test_rotating_proposer() { + let runtime = consensus_runtime(); + + let chosen_validator_signer = ValidatorSigner::random(); + let chosen_author = chosen_validator_signer.author(); + let another_validator_signer = ValidatorSigner::random(); + let another_author = another_validator_signer.author(); + let proposers = vec![chosen_author, another_author]; + let mut pe = Arc::new(RotatingProposer::::new(proposers, 1)); + let (mut tx, rx) = start_event_processing_loop(&mut pe, runtime.executor()); + + // Send a proposal from both chosen author and another author, the only winning proposals + // follow the round-robin rotation. + + // Test genesis and the next block + let genesis_block = Block::make_genesis_block(); + let quorum_cert = QuorumCert::certificate_for_genesis(); + + let good_proposal = ProposalInfo { + proposal: Block::make_block( + &genesis_block, + 1, + 1, + 1, + quorum_cert.clone(), + &another_validator_signer, + ), + proposer_info: another_author, + timeout_certificate: None, + highest_ledger_info: quorum_cert.clone(), + }; + let bad_proposal = ProposalInfo { + proposal: Block::make_block( + &genesis_block, + 2, + 1, + 2, + quorum_cert.clone(), + &chosen_validator_signer, + ), + proposer_info: chosen_author, + timeout_certificate: None, + highest_ledger_info: quorum_cert.clone(), + }; + let next_good_proposal = ProposalInfo { + proposal: Block::make_block( + &genesis_block, + 3, + 2, + 3, + quorum_cert.clone(), + &chosen_validator_signer, + ), + proposer_info: chosen_author, + timeout_certificate: None, + highest_ledger_info: quorum_cert.clone(), + }; + block_on(async move { + tx.send(good_proposal.clone()).await.unwrap(); + tx.send(bad_proposal.clone()).await.unwrap(); + tx.send(next_good_proposal.clone()).await.unwrap(); + + assert_eq!( + rx.take(2).collect::>().await, + vec![good_proposal, next_good_proposal], + ); + assert_eq!(pe.is_valid_proposer(chosen_author, 1), None); + assert_eq!( + pe.is_valid_proposer(another_author, 1), + Some(another_author) + ); + assert_eq!(pe.is_valid_proposer(chosen_author, 2), Some(chosen_author)); + assert_eq!(pe.is_valid_proposer(another_author, 2), None); + assert_eq!(pe.get_valid_proposers(1), vec![another_author]); + assert_eq!(pe.get_valid_proposers(2), vec![chosen_author]); + }); +} + +#[test] +fn test_rotating_proposer_with_three_contiguous_rounds() { + let runtime = consensus_runtime(); + + let chosen_validator_signer = ValidatorSigner::random(); + let chosen_author = chosen_validator_signer.author(); + let another_validator_signer = ValidatorSigner::random(); + let another_author = another_validator_signer.author(); + let proposers = vec![chosen_author, another_author]; + let mut pe = Arc::new(RotatingProposer::::new(proposers, 3)); + let (mut tx, rx) = start_event_processing_loop(&mut pe, runtime.executor()); + + // Send a proposal from both chosen author and another author, the only winning proposals + // follow the round-robin rotation with 3 contiguous rounds. + + // Test genesis and the next block + let genesis_block = Block::make_genesis_block(); + let quorum_cert = QuorumCert::certificate_for_genesis(); + + let good_proposal = ProposalInfo { + proposal: Block::make_block( + &genesis_block, + 1, + 1, + 1, + quorum_cert.clone(), + &chosen_validator_signer, + ), + proposer_info: chosen_author, + timeout_certificate: None, + highest_ledger_info: quorum_cert.clone(), + }; + let bad_proposal = ProposalInfo { + proposal: Block::make_block( + &genesis_block, + 2, + 1, + 2, + quorum_cert.clone(), + &another_validator_signer, + ), + proposer_info: another_author, + timeout_certificate: None, + highest_ledger_info: quorum_cert.clone(), + }; + let next_good_proposal = ProposalInfo { + proposal: Block::make_block( + &genesis_block, + 3, + 2, + 3, + quorum_cert.clone(), + &chosen_validator_signer, + ), + proposer_info: chosen_author, + timeout_certificate: None, + highest_ledger_info: quorum_cert.clone(), + }; + block_on(async move { + tx.send(good_proposal.clone()).await.unwrap(); + tx.send(bad_proposal.clone()).await.unwrap(); + tx.send(next_good_proposal.clone()).await.unwrap(); + + assert_eq!( + rx.take(2).collect::>().await, + vec![good_proposal, next_good_proposal], + ); + assert_eq!(pe.is_valid_proposer(another_author, 1), None); + assert_eq!(pe.is_valid_proposer(chosen_author, 1), Some(chosen_author)); + assert_eq!(pe.is_valid_proposer(chosen_author, 2), Some(chosen_author)); + assert_eq!(pe.is_valid_proposer(another_author, 2), None); + assert_eq!(pe.get_valid_proposers(1), vec![chosen_author]); + assert_eq!(pe.get_valid_proposers(2), vec![chosen_author]); + }); +} + +#[test] +fn test_fixed_proposer() { + let runtime = consensus_runtime(); + + let chosen_validator_signer = ValidatorSigner::random(); + let chosen_author = chosen_validator_signer.author(); + let another_validator_signer = ValidatorSigner::random(); + let another_author = another_validator_signer.author(); + let mut pe = Arc::new(RotatingProposer::::new(vec![chosen_author], 1)); + let (mut tx, rx) = start_event_processing_loop(&mut pe, runtime.executor()); + + // Send a proposal from both chosen author and another author, the only winning proposal is + // from the chosen author. + + // Test genesis and the next block + let genesis_block = Block::make_genesis_block(); + let quorum_cert = QuorumCert::certificate_for_genesis(); + + let good_proposal = ProposalInfo { + proposal: Block::make_block( + &genesis_block, + 1, + 1, + 1, + quorum_cert.clone(), + &chosen_validator_signer, + ), + proposer_info: chosen_author, + timeout_certificate: None, + highest_ledger_info: quorum_cert.clone(), + }; + let bad_proposal = ProposalInfo { + proposal: Block::make_block( + &genesis_block, + 2, + 1, + 2, + quorum_cert.clone(), + &another_validator_signer, + ), + proposer_info: another_author, + timeout_certificate: None, + highest_ledger_info: quorum_cert.clone(), + }; + let next_good_proposal = ProposalInfo { + proposal: Block::make_block( + &genesis_block, + 2, + 2, + 3, + quorum_cert.clone(), + &chosen_validator_signer, + ), + proposer_info: chosen_author, + timeout_certificate: None, + highest_ledger_info: quorum_cert.clone(), + }; + block_on(async move { + tx.send(good_proposal.clone()).await.unwrap(); + tx.send(bad_proposal.clone()).await.unwrap(); + tx.send(next_good_proposal.clone()).await.unwrap(); + + assert_eq!( + rx.take(2).collect::>().await, + vec![good_proposal, next_good_proposal], + ); + assert_eq!(pe.is_valid_proposer(chosen_author, 1), Some(chosen_author)); + assert_eq!(pe.is_valid_proposer(another_author, 1), None); + assert_eq!(pe.get_valid_proposers(1), vec![chosen_author]); + }); +} diff --git a/consensus/src/chained_bft/mod.rs b/consensus/src/chained_bft/mod.rs new file mode 100644 index 0000000000000..180d9e91cb23b --- /dev/null +++ b/consensus/src/chained_bft/mod.rs @@ -0,0 +1,29 @@ +// Copyright (c) The Libra Core Contributors +// SPDX-License-Identifier: Apache-2.0 + +mod common; +mod consensus_types; +mod consensusdb; +mod liveness; +mod safety; + +mod block_storage; +pub mod chained_bft_consensus_provider; +pub use consensus_types::quorum_cert::QuorumCert; +mod chained_bft_smr; +mod event_processor; +mod network; + +pub mod persistent_storage; +mod sync_manager; + +#[cfg(test)] +mod chained_bft_smr_test; +#[cfg(test)] +mod event_processor_test; +#[cfg(test)] +mod network_tests; +#[cfg(test)] +mod proto_test; +#[cfg(test)] +pub mod test_utils; diff --git a/consensus/src/chained_bft/network.rs b/consensus/src/chained_bft/network.rs new file mode 100644 index 0000000000000..ad64a34572d83 --- /dev/null +++ b/consensus/src/chained_bft/network.rs @@ -0,0 +1,485 @@ +// Copyright (c) The Libra Core Contributors +// SPDX-License-Identifier: Apache-2.0 + +use crate::{ + chained_bft::{ + block_storage::BlockRetrievalFailure, + common::{Author, Payload}, + consensus_types::{block::Block, quorum_cert::QuorumCert}, + liveness::{ + new_round_msg::NewRoundMsg, + proposer_election::{ProposalInfo, ProposerInfo}, + }, + safety::vote_msg::VoteMsg, + }, + counters, +}; +use bytes::Bytes; +use channel; +use crypto::HashValue; +use failure; +use futures::{ + channel::oneshot, stream::select, FutureExt, SinkExt, Stream, StreamExt, TryFutureExt, + TryStreamExt, +}; +use logger::prelude::*; +use network::{ + proto::{BlockRetrievalStatus, ConsensusMsg, RequestBlock, RespondBlock, RespondChunk}, + validator_network::{ConsensusNetworkEvents, ConsensusNetworkSender, Event, RpcError}, +}; +use proto_conv::{FromProto, IntoProto}; +use protobuf::Message; +use std::{ + sync::Arc, + time::{Duration, Instant}, +}; +use tokio::runtime::TaskExecutor; +use types::{transaction::TransactionListWithProof, validator_verifier::ValidatorVerifier}; + +/// The response sent back from event_processor for the BlockRetrievalRequest. +#[derive(Debug)] +pub struct BlockRetrievalResponse { + pub status: BlockRetrievalStatus, + pub blocks: Vec>, +} + +impl BlockRetrievalResponse { + pub fn verify(&self, mut block_id: HashValue, num_blocks: u64) -> Result<(), failure::Error> { + if self.status == BlockRetrievalStatus::SUCCEEDED && self.blocks.len() as u64 != num_blocks + { + return Err(format_err!( + "not enough blocks returned, expect {}, get {}", + num_blocks, + self.blocks.len(), + )); + } + for block in self.blocks.iter() { + if block.id() != block_id { + return Err(format_err!( + "blocks doesn't form a chain: expect {}, get {}", + block.id(), + block_id + )); + } + block_id = block.parent_id(); + } + Ok(()) + } +} + +/// BlockRetrievalRequest carries a block id for the requested block as well as the +/// oneshot sender to deliver the response. +pub struct BlockRetrievalRequest { + pub block_id: HashValue, + pub num_blocks: u64, + pub response_sender: oneshot::Sender>, +} + +/// Represents a request to get up to batch_size transactions starting from start_version +/// with the oneshot sender to deliver the response. +pub struct ChunkRetrievalRequest { + pub start_version: u64, + pub target: QuorumCert, + pub batch_size: u64, + pub response_sender: oneshot::Sender>, +} + +/// Just a convenience struct to keep all the network proxy receiving queues in one place. +/// 1. proposals +/// 2. votes +/// 3. block retrieval requests (the request carries a oneshot sender for returning the Block) +/// 4. pacemaker timeouts +/// Will be returned by the networking trait upon startup. +pub struct NetworkReceivers { + pub proposals: channel::Receiver>, + pub votes: channel::Receiver, + pub block_retrieval: channel::Receiver>, + pub new_rounds: channel::Receiver, + pub chunk_retrieval: channel::Receiver, +} + +/// Implements the actual networking support for all consensus messaging. +pub struct ConsensusNetworkImpl { + author: Author, + network_sender: ConsensusNetworkSender, + network_events: Option, + // Self sender and self receivers provide a shortcut for sending the messages to itself. + // (self sending is not supported by the networking API). + // Note that we do not support self rpc requests as it might cause infinite recursive calls. + self_sender: channel::Sender, failure::Error>>, + self_receiver: Option, failure::Error>>>, + peers: Arc>, + validator: Arc, +} + +impl Clone for ConsensusNetworkImpl { + fn clone(&self) -> Self { + Self { + author: self.author, + network_sender: self.network_sender.clone(), + network_events: None, + self_sender: self.self_sender.clone(), + self_receiver: None, + peers: self.peers.clone(), + validator: Arc::clone(&self.validator), + } + } +} + +impl ConsensusNetworkImpl { + pub fn new( + author: Author, + network_sender: ConsensusNetworkSender, + network_events: ConsensusNetworkEvents, + peers: Arc>, + validator: Arc, + ) -> Self { + let (self_sender, self_receiver) = channel::new(1_024, &counters::PENDING_SELF_MESSAGES); + ConsensusNetworkImpl { + author, + network_sender, + network_events: Some(network_events), + self_sender, + self_receiver: Some(self_receiver), + peers, + validator, + } + } + + /// Establishes the initial connections with the peers and returns the receivers. + pub fn start( + &mut self, + executor: &TaskExecutor, + ) -> NetworkReceivers { + let (proposal_tx, proposal_rx) = channel::new(1_024, &counters::PENDING_PROPOSAL); + let (vote_tx, vote_rx) = channel::new(1_024, &counters::PENDING_VOTES); + let (block_request_tx, block_request_rx) = + channel::new(1_024, &counters::PENDING_BLOCK_REQUESTS); + let (chunk_request_tx, chunk_request_rx) = + channel::new(1_024, &counters::PENDING_CHUNK_REQUESTS); + let (new_round_tx, new_round_rx) = + channel::new(1_024, &counters::PENDING_NEW_ROUND_MESSAGES); + let network_events = self + .network_events + .take() + .expect("[consensus] Failed to start; network_events stream is already taken") + .map_err(Into::::into); + let own_msgs = self + .self_receiver + .take() + .expect("[consensus]: self receiver is already taken"); + let all_events = select(network_events, own_msgs); + let validator = Arc::clone(&self.validator); + executor.spawn( + NetworkTask { + proposal_tx, + vote_tx, + block_request_tx, + chunk_request_tx, + new_round_tx, + all_events, + validator, + } + .run() + .boxed() + .unit_error() + .compat(), + ); + NetworkReceivers { + proposals: proposal_rx, + votes: vote_rx, + block_retrieval: block_request_rx, + new_rounds: new_round_rx, + chunk_retrieval: chunk_request_rx, + } + } + + /// Tries to retrieve num of blocks backwards starting from id from the given peer: the function + /// returns a future that is either fulfilled with BlockRetrievalResponse, or with a + /// BlockRetrievalFailure. + pub async fn request_block( + &mut self, + block_id: HashValue, + num_blocks: u64, + from: Author, + timeout: Duration, + ) -> Result, BlockRetrievalFailure> { + if from == self.author { + return Err(BlockRetrievalFailure::SelfRetrieval); + } + let mut req_msg = RequestBlock::new(); + req_msg.set_block_id(block_id.into()); + req_msg.set_num_blocks(num_blocks); + counters::BLOCK_RETRIEVAL_COUNT.inc_by(num_blocks as i64); + let pre_retrieval_instant = Instant::now(); + + let mut res_block = self + .network_sender + .request_block(from, req_msg, timeout) + .await?; + let mut blocks = vec![]; + for block in res_block.take_blocks().into_iter() { + if let Ok(block) = Block::from_proto(block) { + if block.verify(self.validator.as_ref()).is_err() { + return Err(BlockRetrievalFailure::InvalidSignature); + } + blocks.push(block); + } else { + return Err(BlockRetrievalFailure::InvalidResponse); + } + } + counters::BLOCK_RETRIEVAL_DURATION_MS + .observe(pre_retrieval_instant.elapsed().as_millis() as f64); + let response = BlockRetrievalResponse { + status: res_block.get_status(), + blocks, + }; + if response.verify(block_id, num_blocks).is_err() { + return Err(BlockRetrievalFailure::InvalidResponse); + } + Ok(response) + } + + /// Tries to send the given proposal (block and proposer metadata) to all the participants. + /// A validator on the receiving end is going to be notified about a new proposal in the + /// proposal queue. + /// + /// The future is fulfilled as soon as the message put into the mpsc channel to network + /// internal(to provide back pressure), it does not indicate the message is delivered or sent + /// out. It does not give indication about when the message is delivered to the recipients, + /// as well as there is no indication about the network failures. + pub async fn broadcast_proposal( + &mut self, + proposal: ProposalInfo, + ) { + let mut msg = ConsensusMsg::new(); + msg.set_proposal(proposal.into_proto()); + self.broadcast(msg).await + } + + async fn broadcast(&mut self, msg: ConsensusMsg) { + for peer in self.peers.iter() { + if self.author == *peer { + let self_msg = Event::Message((self.author, msg.clone())); + if let Err(err) = self.self_sender.send(Ok(self_msg)).await { + error!("Error delivering a self proposal: {:?}", err); + } + continue; + } + if let Err(err) = self.network_sender.send_to(*peer, msg.clone()).await { + error!( + "Error broadcasting proposal to peer: {:?}, error: {:?}, msg: {:?}", + peer, err, msg + ); + } + } + } + + /// Sends the vote to the chosen recipients (typically that would be the recipients that + /// we believe could serve as proposers in the next round). The recipients on the receiving + /// end are going to be notified about a new vote in the vote queue. + /// + /// The future is fulfilled as soon as the message put into the mpsc channel to network + /// internal(to provide back pressure), it does not indicate the message is delivered or sent + /// out. It does not give indication about when the message is delivered to the recipients, + /// as well as there is no indication about the network failures. + pub async fn send_vote(&self, vote_msg: VoteMsg, recipients: Vec) { + let mut network_sender = self.network_sender.clone(); + let mut self_sender = self.self_sender.clone(); + let mut msg = ConsensusMsg::new(); + msg.set_vote(vote_msg.into_proto()); + for peer in recipients { + if self.author == peer { + let self_msg = Event::Message((self.author, msg.clone())); + if let Err(err) = self_sender.send(Ok(self_msg)).await { + error!("Error delivering a self vote: {:?}", err); + } + continue; + } + if let Err(e) = network_sender.send_to(peer, msg.clone()).await { + error!("Failed to send a vote to peer {:?}: {:?}", peer, e); + } + } + } + + /// Broadcasts new round (including timeout) messages to all validators + pub async fn broadcast_new_round(&mut self, new_round_msg: NewRoundMsg) { + let mut msg = ConsensusMsg::new(); + msg.set_new_round(new_round_msg.into_proto()); + self.broadcast(msg).await + } +} + +struct NetworkTask { + proposal_tx: channel::Sender>, + vote_tx: channel::Sender, + block_request_tx: channel::Sender>, + chunk_request_tx: channel::Sender, + new_round_tx: channel::Sender, + all_events: S, + validator: Arc, +} + +impl NetworkTask +where + S: Stream, failure::Error>> + Unpin, + T: Payload, + P: ProposerInfo, +{ + pub async fn run(mut self) { + while let Some(Ok(message)) = self.all_events.next().await { + match message { + Event::Message((peer_id, mut msg)) => { + let r = if msg.has_proposal() { + self.process_proposal(&mut msg).await + } else if msg.has_vote() { + self.process_vote(&mut msg).await + } else if msg.has_new_round() { + self.process_new_round(&mut msg).await + } else { + warn!("Unexpected msg from {}: {:?}", peer_id, msg); + continue; + }; + if let Err(e) = r { + warn!("Failed to process msg {:?}: {:?}", msg, e) + } + } + Event::RpcRequest((peer_id, mut msg, callback)) => { + let r = if msg.has_request_block() { + self.process_request_block(&mut msg, callback).await + } else if msg.has_request_chunk() { + self.process_request_chunk(&mut msg, callback).await + } else { + warn!("Unexpected RPC from {}: {:?}", peer_id, msg); + continue; + }; + if let Err(e) = r { + warn!("Failed to process RPC {:?}: {:?}", msg, e) + } + } + Event::NewPeer(peer_id) => { + debug!("Peer {} connected", peer_id); + } + Event::LostPeer(peer_id) => { + debug!("Peer {} disconnected", peer_id); + } + } + } + } + + async fn process_proposal<'a>(&'a mut self, msg: &'a mut ConsensusMsg) -> failure::Result<()> { + let proposal = ProposalInfo::::from_proto(msg.take_proposal())?; + proposal.verify(self.validator.as_ref()).map_err(|e| { + security_log(SecurityEvent::InvalidConsensusProposal) + .error(&e) + .data(&proposal) + .log(); + e + })?; + debug!("Received proposal {}", proposal); + self.proposal_tx.send(proposal).await?; + Ok(()) + } + + async fn process_vote<'a>(&'a mut self, msg: &'a mut ConsensusMsg) -> failure::Result<()> { + let vote = VoteMsg::from_proto(msg.take_vote())?; + debug!("Received {}", vote); + vote.verify(self.validator.as_ref()).map_err(|e| { + security_log(SecurityEvent::InvalidConsensusVote) + .error(&e) + .data(&vote) + .log(); + e + })?; + self.vote_tx.send(vote).await?; + Ok(()) + } + + async fn process_new_round<'a>(&'a mut self, msg: &'a mut ConsensusMsg) -> failure::Result<()> { + let new_round = NewRoundMsg::from_proto(msg.take_new_round())?; + new_round.verify(self.validator.as_ref()).map_err(|e| { + security_log(SecurityEvent::InvalidConsensusRound) + .error(&e) + .data(&new_round) + .log(); + e + })?; + self.new_round_tx.send(new_round).await?; + Ok(()) + } + + async fn process_request_chunk<'a>( + &'a mut self, + msg: &'a mut ConsensusMsg, + callback: oneshot::Sender>, + ) -> failure::Result<()> { + let mut req = msg.take_request_chunk(); + debug!( + "Received request_chunk RPC for start version: {} target: {:?} batch_size: {}", + req.start_version, + req.get_target(), + req.batch_size + ); + let (tx, rx) = oneshot::channel(); + let target = QuorumCert::from_proto(req.take_target())?; + target.verify(self.validator.as_ref())?; + let request = ChunkRetrievalRequest { + start_version: req.start_version, + target, + batch_size: req.batch_size, + response_sender: tx, + }; + self.chunk_request_tx.send(request).await?; + callback + .send(match rx.await? { + Ok(txn_list_with_proof) => { + let mut response_msg = ConsensusMsg::new(); + let mut response = RespondChunk::new(); + response.set_txn_list_with_proof(txn_list_with_proof.into_proto()); + response_msg.set_respond_chunk(response); + let response_data = Bytes::from( + response_msg + .write_to_bytes() + .expect("fail to serialize proto"), + ); + Ok(response_data) + } + Err(err) => Err(RpcError::ApplicationError(err)), + }) + .map_err(|_| format_err!("handling inbound rpc call timed out")) + } + + async fn process_request_block<'a>( + &'a mut self, + msg: &'a mut ConsensusMsg, + callback: oneshot::Sender>, + ) -> failure::Result<()> { + let block_id = HashValue::from_slice(msg.get_request_block().get_block_id())?; + let num_blocks = msg.get_request_block().get_num_blocks(); + debug!( + "Received request_block RPC for {} blocks from {:?}", + num_blocks, block_id + ); + let (tx, rx) = oneshot::channel(); + let request = BlockRetrievalRequest { + block_id, + num_blocks, + response_sender: tx, + }; + self.block_request_tx.send(request).await?; + let BlockRetrievalResponse { status, blocks } = rx.await?; + let mut response_msg = ConsensusMsg::new(); + let mut response = RespondBlock::new(); + response.set_status(status); + response.set_blocks(blocks.into_iter().map(IntoProto::into_proto).collect()); + response_msg.set_respond_block(response); + let response_data = Bytes::from( + response_msg + .write_to_bytes() + .expect("fail to serialize proto"), + ); + callback + .send(Ok(response_data)) + .map_err(|_| format_err!("handling inbound rpc call timed out")) + } +} diff --git a/consensus/src/chained_bft/network_tests.rs b/consensus/src/chained_bft/network_tests.rs new file mode 100644 index 0000000000000..302ebf5fa70f8 --- /dev/null +++ b/consensus/src/chained_bft/network_tests.rs @@ -0,0 +1,522 @@ +// Copyright (c) The Libra Core Contributors +// SPDX-License-Identifier: Apache-2.0 + +use crate::{ + chained_bft::{ + common::Author, + consensus_types::{block::Block, quorum_cert::QuorumCert}, + liveness::proposer_election::ProposalInfo, + network::{BlockRetrievalResponse, ConsensusNetworkImpl, NetworkReceivers}, + safety::vote_msg::VoteMsg, + test_utils::{consensus_runtime, placeholder_ledger_info}, + }, + state_replication::ExecutedState, +}; +use channel; +use crypto::{signing::generate_keypair, HashValue}; +use futures::{channel::mpsc, executor::block_on, FutureExt, SinkExt, StreamExt, TryFutureExt}; +use network::{ + interface::{NetworkNotification, NetworkRequest}, + proto::{BlockRetrievalStatus, ConsensusMsg, QuorumCert as ProtoQuorumCert, RequestChunk}, + protocols::rpc::InboundRpcRequest, + validator_network::{ConsensusNetworkEvents, ConsensusNetworkSender}, +}; +use proto_conv::FromProto; +use std::{ + collections::{HashMap, HashSet}, + sync::{Arc, Mutex, RwLock}, + time::Duration, +}; +use tokio::runtime::TaskExecutor; +use types::{ + account_address::AccountAddress, + proto::ledger_info::{LedgerInfo, LedgerInfoWithSignatures}, + test_helpers::transaction_test_helpers::get_test_signed_txn, + transaction::{SignedTransaction, TransactionInfo, TransactionListWithProof}, + validator_signer::ValidatorSigner, + validator_verifier::ValidatorVerifier, +}; + +/// `NetworkPlayground` mocks the network implementation and provides convenience +/// methods for testing. Test clients can use `wait_for_messages` or +/// `deliver_messages` to inspect the direct-send messages sent between peers. +/// They can also configure network messages to be dropped between specific peers. +/// +/// Currently, RPC messages are delivered immediately and are not controlled by +/// `wait_for_messages` or `deliver_messages` for delivery. They are also not +/// currently dropped according to the `NetworkPlayground`'s drop config. +pub struct NetworkPlayground { + /// Maps each Author to a Sender of their inbound network notifications. + /// These events will usually be handled by the event loop spawned in + /// `ConsensusNetworkImpl`. + node_consensus_txs: Arc>>>, + /// Nodes' outbound handlers forward their outbound non-rpc messages to this + /// queue. + outbound_msgs_tx: mpsc::Sender<(Author, NetworkRequest)>, + /// NetworkPlayground reads all nodes' outbound messages through this queue. + outbound_msgs_rx: mpsc::Receiver<(Author, NetworkRequest)>, + /// Allow test code to drop direct-send messages between peers. + drop_config: Arc>, + /// An executor for spawning node outbound network event handlers + executor: TaskExecutor, +} + +impl NetworkPlayground { + pub fn new(executor: TaskExecutor) -> Self { + let (outbound_msgs_tx, outbound_msgs_rx) = mpsc::channel(1_024); + + NetworkPlayground { + node_consensus_txs: Arc::new(Mutex::new(HashMap::new())), + outbound_msgs_tx, + outbound_msgs_rx, + drop_config: Arc::new(RwLock::new(DropConfig(HashMap::new()))), + executor, + } + } + + /// Create a new async task that handles outbound messages sent by a node. + /// + /// All non-rpc messages are forwarded to the NetworkPlayground's + /// `outbound_msgs_rx` queue, which controls delivery through the + /// `deliver_messages` and `wait_for_messages` API's. + /// + /// Rpc messages are immediately sent to the destination for handling, so + /// they don't block. + async fn start_node_outbound_handler( + drop_config: Arc>, + src: Author, + mut network_reqs_rx: channel::Receiver, + mut outbound_msgs_tx: mpsc::Sender<(Author, NetworkRequest)>, + node_consensus_txs: Arc>>>, + ) { + while let Some(net_req) = network_reqs_rx.next().await { + let drop_rpc = drop_config + .read() + .unwrap() + .is_message_dropped(&src, &net_req); + match net_req { + // Immediately forward rpc requests for handling. Unfortunately, + // we can't handle rpc requests in `deliver_messages` due to + // blocking issues, e.g., I want to write: + // ``` + // let block = sender.request_block(peer_id, block_id).await.unwrap(); + // playground.wait_for_messages(1).await; + // ``` + // but because the rpc call blocks and depends on the message + // delivery, we'd have to spawn the sending behaviour on a + // separate task, which is inconvenient. + NetworkRequest::SendRpc(dst, outbound_req) => { + if drop_rpc { + continue; + } + let mut node_consensus_tx = node_consensus_txs + .lock() + .unwrap() + .get(&dst.into()) + .unwrap() + .clone(); + + let inbound_req = InboundRpcRequest { + protocol: outbound_req.protocol, + data: outbound_req.data, + res_tx: outbound_req.res_tx, + }; + + node_consensus_tx + .send(NetworkNotification::RecvRpc(src.into(), inbound_req)) + .await + .unwrap(); + } + // Other NetworkRequest get buffered for `deliver_messages` to + // synchronously drain. + net_req => { + let _ = outbound_msgs_tx.send((src, net_req)).await; + } + } + } + } + + /// Add a new node to the NetworkPlayground. + pub fn add_node( + &mut self, + author: Author, + // The `Sender` of inbound network events. The `Receiver` end of this + // queue is usually wrapped in a `ConsensusNetworkEvents` adapter. + consensus_tx: channel::Sender, + // The `Receiver` of outbound network events this node sends. The + // `Sender` side of this queue is usually wrapped in a + // `ConsensusNetworkSender` adapter. + network_reqs_rx: channel::Receiver, + ) { + self.node_consensus_txs + .lock() + .unwrap() + .insert(author, consensus_tx); + self.drop_config.write().unwrap().add_node(author); + + let fut = NetworkPlayground::start_node_outbound_handler( + Arc::clone(&self.drop_config), + author, + network_reqs_rx, + self.outbound_msgs_tx.clone(), + self.node_consensus_txs.clone(), + ); + self.executor.spawn(fut.boxed().unit_error().compat()); + } + + /// Deliver a `NetworkRequest` from peer `src` to the destination peer. + /// Returns a copy of the delivered message and the sending peer id. + async fn deliver_message( + &mut self, + src: Author, + msg: NetworkRequest, + ) -> (Author, ConsensusMsg) { + // extract destination peer + let dst = match &msg { + NetworkRequest::SendMessage(dst, _) => *dst, + msg => panic!("[network playground] Unexpected NetworkRequest: {:?}", msg), + }; + + // get his sender + let mut node_consensus_tx = self + .node_consensus_txs + .lock() + .unwrap() + .get(&dst.into()) + .unwrap() + .clone(); + + // convert NetworkRequest to corresponding NetworkNotification + let msg_notif = match msg { + NetworkRequest::SendMessage(_dst, msg) => { + NetworkNotification::RecvMessage(src.into(), msg) + } + msg => panic!("[network playground] Unexpected NetworkRequest: {:?}", msg), + }; + + // copy message data + let msg_copy = match &msg_notif { + NetworkNotification::RecvMessage(src, msg) => { + let msg: ConsensusMsg = ::protobuf::parse_from_bytes(msg.mdata.as_ref()).unwrap(); + ((*src).into(), msg) + } + msg_notif => panic!( + "[network playground] Unexpected NetworkNotification: {:?}", + msg_notif + ), + }; + + node_consensus_tx.send(msg_notif).await.unwrap(); + msg_copy + } + + /// Wait for exactly `num_messages` to be enqueued and delivered. Return a + /// copy of all messages for verification. + /// While all the sent messages are delivered, only the messages that satisfy the given + /// msg inspector are counted. + pub async fn wait_for_messages( + &mut self, + num_messages: usize, + msg_inspector: F, + ) -> Vec<(Author, ConsensusMsg)> + where + F: Fn(&(Author, ConsensusMsg)) -> bool, + { + let mut msg_copies = vec![]; + while msg_copies.len() < num_messages { + // Take the next queued message + let (src, net_req) = self.outbound_msgs_rx.next().await + .expect("[network playground] waiting for messages, but message queue has shutdown unexpectedly"); + + // Deliver and copy message it if it's not dropped + if !self.is_message_dropped(&src, &net_req) { + let msg_copy = self.deliver_message(src, net_req).await; + if msg_inspector(&msg_copy) { + msg_copies.push(msg_copy); + } + } + } + assert_eq!(msg_copies.len(), num_messages); + msg_copies + } + + /// Returns true for any message + pub fn take_all(_msg_copy: &(Author, ConsensusMsg)) -> bool { + true + } + + /// Returns true for any message other than new round + pub fn exclude_new_round(msg_copy: &(Author, ConsensusMsg)) -> bool { + !msg_copy.1.has_new_round() + } + + /// Returns true for proposal messages only. + pub fn proposals_only(msg_copy: &(Author, ConsensusMsg)) -> bool { + msg_copy.1.has_proposal() + } + + /// Returns true for vote messages only. + pub fn votes_only(msg_copy: &(Author, ConsensusMsg)) -> bool { + msg_copy.1.has_vote() + } + + /// Returns true for new round messages only. + pub fn new_round_only(msg_copy: &(Author, ConsensusMsg)) -> bool { + msg_copy.1.has_new_round() + } + + fn is_message_dropped(&self, src: &Author, net_req: &NetworkRequest) -> bool { + self.drop_config + .read() + .unwrap() + .is_message_dropped(src, net_req) + } + + pub fn drop_message_for(&mut self, src: &Author, dst: Author) -> bool { + self.drop_config.write().unwrap().drop_message_for(src, dst) + } + + pub fn stop_drop_message_for(&mut self, src: &Author, dst: &Author) -> bool { + self.drop_config + .write() + .unwrap() + .stop_drop_message_for(src, dst) + } +} + +struct DropConfig(HashMap>); + +impl DropConfig { + pub fn is_message_dropped(&self, src: &Author, net_req: &NetworkRequest) -> bool { + match net_req { + NetworkRequest::SendMessage(dst, _) => self + .0 + .get(src.into()) + .unwrap() + .contains(&Author::from(*dst)), + NetworkRequest::SendRpc(dst, _) => self + .0 + .get(src.into()) + .unwrap() + .contains(&Author::from(*dst)), + _ => true, + } + } + + pub fn drop_message_for(&mut self, src: &Author, dst: Author) -> bool { + self.0.get_mut(src).unwrap().insert(dst) + } + + pub fn stop_drop_message_for(&mut self, src: &Author, dst: &Author) -> bool { + self.0.get_mut(src).unwrap().remove(dst) + } + + fn add_node(&mut self, src: Author) { + self.0.insert(src, HashSet::new()); + } +} + +#[test] +fn test_network_api() { + let runtime = consensus_runtime(); + let num_nodes = 5; + let mut peers = Vec::new(); + let mut receivers: Vec> = Vec::new(); + let mut playground = NetworkPlayground::new(runtime.executor()); + let mut nodes = Vec::new(); + let mut author_to_public_keys = HashMap::new(); + let mut signers = Vec::new(); + for _ in 0..num_nodes { + let random_validator_signer = ValidatorSigner::random(); + author_to_public_keys.insert( + random_validator_signer.author(), + random_validator_signer.public_key(), + ); + peers.push(random_validator_signer.author()); + signers.push(random_validator_signer); + } + let validator = Arc::new(ValidatorVerifier::new( + author_to_public_keys, + peers.len() * 2 / 3 + 1, + )); + for i in 0..num_nodes { + let (network_reqs_tx, network_reqs_rx) = channel::new_test(8); + let (consensus_tx, consensus_rx) = channel::new_test(8); + let network_sender = ConsensusNetworkSender::new(network_reqs_tx); + let network_events = ConsensusNetworkEvents::new(consensus_rx); + + playground.add_node(peers[i], consensus_tx, network_reqs_rx); + let mut node = ConsensusNetworkImpl::new( + peers[i], + network_sender, + network_events, + Arc::new(peers.clone()), + Arc::clone(&validator), + ); + receivers.push(node.start(&runtime.executor())); + nodes.push(node); + } + let vote = VoteMsg::new( + HashValue::random(), + ExecutedState::state_for_genesis(), + 1, + peers[0], + placeholder_ledger_info(), + &signers[0], + ); + let proposal = ProposalInfo { + proposal: Block::make_genesis_block(), + proposer_info: ValidatorSigner::genesis().author(), + timeout_certificate: None, + highest_ledger_info: QuorumCert::certificate_for_genesis(), + }; + block_on(async move { + nodes[0].send_vote(vote.clone(), peers[2..5].to_vec()).await; + playground + .wait_for_messages(3, NetworkPlayground::take_all) + .await; + for r in receivers.iter_mut().take(5).skip(2) { + let v = r.votes.next().await.unwrap(); + assert_eq!(v, vote); + } + nodes[4].broadcast_proposal(proposal.clone()).await; + playground + .wait_for_messages(4, NetworkPlayground::take_all) + .await; + for r in receivers.iter_mut().take(num_nodes - 1) { + let p = r.proposals.next().await.unwrap(); + assert_eq!(p, proposal); + } + }); +} + +#[test] +fn test_rpc() { + let runtime = consensus_runtime(); + let num_nodes = 2; + let mut peers = Arc::new(Vec::new()); + let mut senders = Vec::new(); + let mut receivers: Vec> = Vec::new(); + let mut playground = NetworkPlayground::new(runtime.executor()); + let mut nodes = Vec::new(); + let mut author_to_public_keys = HashMap::new(); + for _ in 0..num_nodes { + let random_validator_signer = ValidatorSigner::random(); + author_to_public_keys.insert( + random_validator_signer.author(), + random_validator_signer.public_key(), + ); + Arc::get_mut(&mut peers) + .unwrap() + .push(random_validator_signer.author()); + } + let validator = Arc::new(ValidatorVerifier::new( + author_to_public_keys, + peers.len() * 2 / 3 + 1, + )); + for i in 0..num_nodes { + let (network_reqs_tx, network_reqs_rx) = channel::new_test(8); + let (consensus_tx, consensus_rx) = channel::new_test(8); + let network_sender = ConsensusNetworkSender::new(network_reqs_tx); + let network_events = ConsensusNetworkEvents::new(consensus_rx); + + playground.add_node(peers[i], consensus_tx, network_reqs_rx); + let mut node = ConsensusNetworkImpl::new( + peers[i], + network_sender.clone(), + network_events, + Arc::clone(&peers), + Arc::clone(&validator), + ); + senders.push(network_sender); + receivers.push(node.start(&runtime.executor())); + nodes.push(node); + } + let receiver_1 = receivers.remove(1); + let genesis = Arc::new(Block::::make_genesis_block()); + let genesis_clone = Arc::clone(&genesis); + + // verify request block rpc + let mut block_retrieval = receiver_1.block_retrieval; + let on_request_block = async move { + while let Some(request) = block_retrieval.next().await { + request + .response_sender + .send(BlockRetrievalResponse { + status: BlockRetrievalStatus::SUCCEEDED, + blocks: vec![Block::clone(genesis_clone.as_ref())], + }) + .unwrap(); + } + }; + runtime + .executor() + .spawn(on_request_block.boxed().unit_error().compat()); + let peer = peers[1]; + block_on(async move { + let response = nodes[0] + .request_block(genesis.id(), 1, peer, Duration::from_secs(5)) + .await + .unwrap(); + assert_eq!(response.blocks[0], *genesis); + }); + + // verify request chunk rpc + let mut chunk_retrieval = receiver_1.chunk_retrieval; + let on_request_chunk = async move { + while let Some(request) = chunk_retrieval.next().await { + let keypair = generate_keypair(); + let proto_txn = + get_test_signed_txn(AccountAddress::random(), 0, keypair.0, keypair.1, None); + let txn = SignedTransaction::from_proto(proto_txn).unwrap(); + let info = + TransactionInfo::new(HashValue::zero(), HashValue::zero(), HashValue::zero(), 0); + request + .response_sender + .send(Ok(TransactionListWithProof::new( + vec![(txn, info)], + None, + None, + None, + None, + ))) + .unwrap(); + } + }; + runtime + .executor() + .spawn(on_request_chunk.boxed().unit_error().compat()); + + block_on(async move { + let mut ledger_info = LedgerInfo::new(); + ledger_info.set_transaction_accumulator_hash(HashValue::zero().to_vec()); + ledger_info.set_consensus_block_id(HashValue::zero().to_vec()); + ledger_info.set_consensus_data_hash( + VoteMsg::vote_digest( + HashValue::zero(), + ExecutedState { + state_id: HashValue::zero(), + version: 0, + }, + 0, + ) + .to_vec(), + ); + let mut ledger_info_with_sigs = LedgerInfoWithSignatures::new(); + ledger_info_with_sigs.set_ledger_info(ledger_info); + let mut target = ProtoQuorumCert::new(); + target.set_block_id(HashValue::zero().into()); + target.set_state_id(HashValue::zero().into()); + target.set_round(0); + target.set_signed_ledger_info(ledger_info_with_sigs); + let mut req = RequestChunk::new(); + req.set_start_version(0); + req.set_batch_size(1); + req.set_target(target); + let chunk = senders[0] + .request_chunk(peers[1], req, Duration::from_secs(5)) + .await + .unwrap(); + assert_eq!(chunk.get_txn_list_with_proof().get_transactions().len(), 1); + }); +} diff --git a/consensus/src/chained_bft/persistent_storage.rs b/consensus/src/chained_bft/persistent_storage.rs new file mode 100644 index 0000000000000..0bd2dcfc8ad86 --- /dev/null +++ b/consensus/src/chained_bft/persistent_storage.rs @@ -0,0 +1,380 @@ +// Copyright (c) The Libra Core Contributors +// SPDX-License-Identifier: Apache-2.0 + +use crate::{ + chained_bft::{ + common::Payload, + consensus_types::{block::Block, quorum_cert::QuorumCert}, + consensusdb::ConsensusDB, + liveness::pacemaker_timeout_manager::HighestTimeoutCertificates, + safety::safety_rules::ConsensusState, + }, + consensus_provider::create_storage_read_client, +}; +use config::config::NodeConfig; +use crypto::HashValue; +use failure::Result; +use logger::prelude::*; +use rmp_serde::{from_slice, to_vec_named}; +use std::{ + collections::{HashMap, HashSet}, + sync::Arc, +}; + +/// Persistent storage for liveness data +pub trait PersistentLivenessStorage: Send + Sync { + /// Persist the highest timeout certificate for improved liveness - proof for other replicas + /// to jump to this round + fn save_highest_timeout_cert( + &self, + highest_timeout_certs: HighestTimeoutCertificates, + ) -> Result<()>; +} + +/// Persistent storage is essential for maintaining safety when a node crashes. Specifically, +/// upon a restart, a correct node will not equivocate. Even if all nodes crash, safety is +/// guaranteed. This trait also also supports liveness aspects (i.e. highest timeout certificate) +/// and supports clean up (i.e. tree pruning). +/// Blocks persisted are proposed but not yet committed. The committed state is persisted +/// via StateComputer. +pub trait PersistentStorage: PersistentLivenessStorage + Send + Sync { + /// Get an Arc to an instance of PersistentLivenessStorage + /// (workaround for trait downcasting + fn persistent_liveness_storage(&self) -> Box; + + /// Persist the blocks and quorum certs into storage atomically. + fn save_tree(&self, blocks: Vec>, quorum_certs: Vec) -> Result<()>; + + /// Delete the corresponding blocks and quorum certs atomically. + fn prune_tree(&self, block_ids: Vec) -> Result<()>; + + /// Persist the consensus state. + fn save_consensus_state(&self, state: ConsensusState) -> Result<()>; + + /// When the node restart, construct the instance and returned the data read from db. + /// This could guarantee we only read once during start, and we would panic if the + /// read fails. + /// It makes sense to be synchronous since we can't do anything else until this finishes. + fn start(config: &NodeConfig) -> (Arc, RecoveryData) + where + Self: Sized; +} + +/// The recovery data constructed from raw consensusdb data, it'll find the root value and +/// blocks that need cleanup or return error if the input data is inconsistent. +#[derive(Debug)] +pub struct RecoveryData { + // Safety data + state: ConsensusState, + root: (Block, QuorumCert, QuorumCert), + // 1. the blocks guarantee the topological ordering - parent <- child. + // 2. all blocks are children of the root. + blocks: Vec>, + quorum_certs: Vec, + blocks_to_prune: Option>, + + // Liveness data + highest_timeout_certificates: HighestTimeoutCertificates, + + // whether root is consistent with StateComputer, if not we need to do the state sync before + // starting + need_sync: bool, +} + +impl RecoveryData { + pub fn new( + state: ConsensusState, + mut blocks: Vec>, + mut quorum_certs: Vec, + root_from_storage: HashValue, + highest_timeout_certificates: HighestTimeoutCertificates, + ) -> Result { + let root = + Self::find_root(&mut blocks, &mut quorum_certs, root_from_storage).map_err(|e| { + format_err!( + "Blocks in db: {}\nQuorum Certs in db: {}, error: {}", + blocks + .iter() + .map(|b| format!("\n\t{}", b)) + .collect::>() + .concat(), + quorum_certs + .iter() + .map(|qc| format!("\n\t{}", qc)) + .collect::>() + .concat(), + e, + ) + })?; + let blocks_to_prune = Some(Self::find_blocks_to_prune( + root.0.id(), + &mut blocks, + &mut quorum_certs, + )); + // if the root is different than the LI(S).block, we need to sync before start + let need_sync = root_from_storage != root.0.id(); + Ok(RecoveryData { + state, + root, + blocks, + quorum_certs, + blocks_to_prune, + highest_timeout_certificates, + need_sync, + }) + } + + pub fn state(&self) -> ConsensusState { + self.state.clone() + } + + pub fn take( + self, + ) -> ( + (Block, QuorumCert, QuorumCert), + Vec>, + Vec, + ) { + (self.root, self.blocks, self.quorum_certs) + } + + pub fn take_blocks_to_prune(&mut self) -> Vec { + self.blocks_to_prune + .take() + .expect("blocks_to_prune already taken") + } + + pub fn highest_timeout_certificates(&self) -> &HighestTimeoutCertificates { + &self.highest_timeout_certificates + } + + pub fn root_ledger_info(&self) -> QuorumCert { + self.root.2.clone() + } + + pub fn need_sync(&self) -> bool { + self.need_sync + } + + /// Finds the root (last committed block) and returns the root block, the QC to the root block + /// and the ledger info for the root block, return an error if it can not be found. + /// + /// LI(S) is the highest known ledger info determined by storage. + /// LI(C) is determined by ConsensusDB: it's the highest block id that is certified as committed + /// by one of the QC's ledger infos. + /// + /// We guarantee a few invariants: + /// 1. LI(C) must exist in blocks + /// 2. LI(S).block.round <= LI(C).block.round + /// + /// We use the following condition to decide the root: + /// 1. LI(S) exist && LI(S) is ancestor of LI(C) according to blocks, root = LI(S) + /// 2. else root = LI(C) + /// + /// In a typical case, the QC certifying a commit of a block is persisted to ConsensusDB before + /// this block is committed to the storage. Hence, ConsensusDB contains the + /// block corresponding to LI(S) id, which is going to become the root. + /// An additional complication is added in this code in order to tolerate a potential failure + /// during state synchronization. In this case LI(S) might not be found in the blocks of + /// ConsensusDB: we're going to start with LI(C) and invoke state synchronizer in order to + /// resume the synchronization. + fn find_root( + blocks: &mut Vec>, + quorum_certs: &mut Vec, + root_from_storage: HashValue, + ) -> Result<(Block, QuorumCert, QuorumCert)> { + // sort by round to guarantee the topological order of parent <- child + blocks.sort_by_key(Block::round); + let root_from_consensus = { + let id_to_round: HashMap<_, _> = blocks + .iter() + .map(|block| (block.id(), block.round())) + .collect(); + let mut round_and_id = None; + for qc in quorum_certs.iter() { + if let Some(committed_block_id) = qc.committed_block_id() { + if let Some(round) = id_to_round.get(&committed_block_id) { + match round_and_id { + Some((r, _)) if r > round => (), + _ => round_and_id = Some((round, committed_block_id)), + } + } + } + } + match round_and_id { + Some((_, id)) => id, + None => return Err(format_err!("No LI found in quorum certs.")), + } + }; + let root_id = { + let mut tree = HashSet::new(); + tree.insert(root_from_storage); + blocks.iter().for_each(|block| { + if tree.contains(&block.parent_id()) { + tree.insert(block.id()); + } + }); + if !tree.contains(&root_from_consensus) { + root_from_consensus + } else { + root_from_storage + } + }; + + let root_idx = blocks + .iter() + .position(|block| block.id() == root_id) + .ok_or_else(|| format_err!("unable to find root: {}", root_id))?; + let root_block = blocks.remove(root_idx); + let root_quorum_cert = quorum_certs + .iter() + .find(|qc| qc.certified_block_id() == root_block.id()) + .ok_or_else(|| format_err!("No QC found for root: {}", root_id))? + .clone(); + let root_ledger_info = quorum_certs + .iter() + .find(|qc| qc.committed_block_id() == Some(root_block.id())) + .ok_or_else(|| format_err!("No LI found for root: {}", root_id))? + .clone(); + Ok((root_block, root_quorum_cert, root_ledger_info)) + } + + fn find_blocks_to_prune( + root_id: HashValue, + blocks: &mut Vec>, + quorum_certs: &mut Vec, + ) -> Vec { + // prune all the blocks that don't have root as ancestor + let mut tree = HashSet::new(); + let mut to_remove = vec![]; + tree.insert(root_id); + // assume blocks are sorted by round already + blocks.retain(|block| { + if tree.contains(&block.parent_id()) { + tree.insert(block.id()); + true + } else { + to_remove.push(block.id()); + false + } + }); + quorum_certs.retain(|qc| tree.contains(&qc.certified_block_id())); + to_remove + } +} + +/// The proxy we use to persist data in libra db storage service via grpc. +pub struct StorageWriteProxy { + db: Arc, +} + +impl StorageWriteProxy { + pub fn new(db: Arc) -> Self { + StorageWriteProxy { db } + } +} + +impl PersistentLivenessStorage for StorageWriteProxy { + fn save_highest_timeout_cert( + &self, + highest_timeout_certs: HighestTimeoutCertificates, + ) -> Result<()> { + self.db + .save_highest_timeout_certificates(to_vec_named(&highest_timeout_certs)?) + } +} + +impl PersistentStorage for StorageWriteProxy { + fn persistent_liveness_storage(&self) -> Box { + Box::new(StorageWriteProxy::new(Arc::clone(&self.db))) + } + + fn save_tree(&self, blocks: Vec>, quorum_certs: Vec) -> Result<()> { + self.db + .save_blocks_and_quorum_certificates(blocks, quorum_certs) + } + + fn prune_tree(&self, block_ids: Vec) -> Result<()> { + if !block_ids.is_empty() { + // quorum certs that certified the block_ids will get removed + self.db + .delete_blocks_and_quorum_certificates::(block_ids)?; + } + Ok(()) + } + + fn save_consensus_state(&self, state: ConsensusState) -> Result<()> { + self.db.save_state(to_vec_named(&state)?) + } + + fn start(config: &NodeConfig) -> (Arc, RecoveryData) { + info!("Start consensus recovery."); + let read_client = create_storage_read_client(config); + let db = Arc::new(ConsensusDB::new(config.storage.dir.clone())); + let proxy = Arc::new(Self::new(Arc::clone(&db))); + let initial_data = db.get_data().expect("unable to recover consensus data"); + let consensus_state = initial_data.0.map_or_else(ConsensusState::default, |s| { + from_slice(&s[..]).expect("unable to deserialize consensus state") + }); + debug!("Recovered consensus state: {}", consensus_state); + let highest_timeout_certificates = initial_data + .1 + .map_or_else(HighestTimeoutCertificates::default, |s| { + from_slice(&s[..]).expect("unable to deserialize highest timeout certificates") + }); + let mut blocks = initial_data.2; + let mut quorum_certs: Vec<_> = initial_data.3; + // bootstrap the empty store with genesis block and qc. + if blocks.is_empty() && quorum_certs.is_empty() { + blocks.push(Block::make_genesis_block()); + quorum_certs.push(QuorumCert::certificate_for_genesis()); + proxy + .save_tree(vec![blocks[0].clone()], vec![quorum_certs[0].clone()]) + .expect("unable to bootstrap the storage with genesis block"); + } + let blocks_repr: Vec = blocks.iter().map(|b| format!("\n\t{}", b)).collect(); + debug!( + "The following blocks were restored from ConsensusDB : {}", + blocks_repr.concat() + ); + let qc_repr: Vec = quorum_certs + .iter() + .map(|qc| format!("\n\t{}", qc)) + .collect(); + debug!( + "The following blocks were restored from ConsensusDB: {}", + qc_repr.concat() + ); + + // find the block corresponding to storage latest ledger info + let (_, ledger_info, _) = read_client + .update_to_latest_ledger(0, vec![]) + .expect("unable to read ledger info from storage"); + let root_from_storage = ledger_info.ledger_info().consensus_block_id(); + debug!( + "The last committed block id as recorded in storage: {}", + root_from_storage + ); + + let mut initial_data = RecoveryData::new( + consensus_state, + blocks, + quorum_certs, + root_from_storage, + highest_timeout_certificates, + ) + .unwrap_or_else(|e| panic!("Can not construct recovery data due to {}", e)); + + >::prune_tree(proxy.as_ref(), initial_data.take_blocks_to_prune()) + .expect("unable to prune dangling blocks during restart"); + + debug!("Consensus root to start with: {}", initial_data.root.0); + + if initial_data.need_sync { + info!("Consensus recovery done but additional state synchronization is required."); + } else { + info!("Consensus recovery completed.") + } + (proxy, initial_data) + } +} diff --git a/consensus/src/chained_bft/proto_test.rs b/consensus/src/chained_bft/proto_test.rs new file mode 100644 index 0000000000000..8b33fce19d582 --- /dev/null +++ b/consensus/src/chained_bft/proto_test.rs @@ -0,0 +1,47 @@ +// Copyright (c) The Libra Core Contributors +// SPDX-License-Identifier: Apache-2.0 + +use crate::{ + chained_bft::{ + consensus_types::{block::Block, quorum_cert::QuorumCert}, + liveness::proposer_election::ProposalInfo, + safety::vote_msg::VoteMsg, + test_utils::placeholder_ledger_info, + }, + state_replication::ExecutedState, +}; +use crypto::HashValue; +use proto_conv::test_helper::assert_protobuf_encode_decode; +use types::validator_signer::ValidatorSigner; + +#[test] +fn test_proto_convert_block() { + let block: Block = Block::make_genesis_block(); + assert_protobuf_encode_decode(&block); +} + +#[test] +fn test_proto_convert_proposal() { + let author = ValidatorSigner::random().author(); + let proposal = ProposalInfo { + proposal: Block::::make_genesis_block(), + proposer_info: author, + timeout_certificate: None, + highest_ledger_info: QuorumCert::certificate_for_genesis(), + }; + assert_protobuf_encode_decode(&proposal); +} + +#[test] +fn test_proto_convert_vote() { + let signer = ValidatorSigner::random(); + let vote = VoteMsg::new( + HashValue::random(), + ExecutedState::state_for_genesis(), + 1, + signer.author(), + placeholder_ledger_info(), + &signer, + ); + assert_protobuf_encode_decode(&vote); +} diff --git a/consensus/src/chained_bft/safety/mod.rs b/consensus/src/chained_bft/safety/mod.rs new file mode 100644 index 0000000000000..b48be35b132f5 --- /dev/null +++ b/consensus/src/chained_bft/safety/mod.rs @@ -0,0 +1,5 @@ +// Copyright (c) The Libra Core Contributors +// SPDX-License-Identifier: Apache-2.0 + +pub(crate) mod safety_rules; +pub(crate) mod vote_msg; diff --git a/consensus/src/chained_bft/safety/safety_rules.rs b/consensus/src/chained_bft/safety/safety_rules.rs new file mode 100644 index 0000000000000..f979bdacbc6e6 --- /dev/null +++ b/consensus/src/chained_bft/safety/safety_rules.rs @@ -0,0 +1,313 @@ +// Copyright (c) The Libra Core Contributors +// SPDX-License-Identifier: Apache-2.0 + +use crate::{ + chained_bft::{ + block_storage::BlockReader, + common::{Payload, Round}, + consensus_types::{block::Block, quorum_cert::QuorumCert}, + }, + counters, +}; + +use crypto::HashValue; +use serde::{Deserialize, Serialize}; +use std::{ + fmt::{Display, Formatter}, + sync::Arc, +}; +use types::ledger_info::LedgerInfoWithSignatures; + +#[cfg(test)] +#[path = "safety_rules_test.rs"] +mod safety_rules_test; + +/// Vote information is returned if a proposal passes the voting rules. +/// The caller might need to persist some of the consensus state before sending out the actual +/// vote message. +/// Vote info also includes the block id that is going to be committed in case this vote gathers +/// QC. +#[derive(Debug, Eq, PartialEq)] +pub struct VoteInfo { + /// Block id of the proposed block. + proposal_id: HashValue, + /// Round of the proposed block. + proposal_round: Round, + /// Consensus state after the voting (e.g., with the updated vote round) + consensus_state: ConsensusState, + /// The block that should be committed in case this vote gathers QC. + /// If no block is committed in case the vote gathers QC, return None. + potential_commit_id: Option, +} + +impl VoteInfo { + pub fn proposal_id(&self) -> HashValue { + self.proposal_id + } + + pub fn consensus_state(&self) -> &ConsensusState { + &self.consensus_state + } + + pub fn potential_commit_id(&self) -> Option { + self.potential_commit_id + } +} + +#[derive(Debug, Fail, Eq, PartialEq)] +/// Different reasons for proposal rejection +pub enum ProposalReject { + /// This proposal's round is less than round of preferred block. + /// Returns the id of the preferred block. + #[fail( + display = "Proposal's round is lower than round of preferred block at round {:?}", + preferred_block_round + )] + ProposalRoundLowerThenPreferredBlock { preferred_block_round: Round }, + + /// This proposal is too old - return last_vote_round + #[fail( + display = "Proposal at round {:?} is not newer than the last vote round {:?}", + proposal_round, last_vote_round + )] + OldProposal { + last_vote_round: Round, + proposal_round: Round, + }, +} + +/// The state required to guarantee safety of the protocol. +/// We need to specify the specific state to be persisted for the recovery of the protocol. +/// (e.g., last vote round and preferred block round). +#[derive(Serialize, Default, Deserialize, Debug, Eq, PartialEq, Clone)] +pub struct ConsensusState { + last_vote_round: Round, + last_committed_round: Round, + + // A "preferred block" is the two-chain head with the highest block round. + // We're using the `head` / `tail` terminology for describing the chains of QCs for describing + // `head` <-- * <-- `tail` chains. + + // A new proposal is voted for only if it's previous block's round is higher or equal to + // the preferred_block_round. + // 1) QC chains follow direct parenthood relations because a node must carry a QC to its + // parent. 2) The "max round" rule applies to the HEAD of the chain and not its TAIL (one + // does not necessarily apply the other). + preferred_block_round: Round, +} + +impl Display for ConsensusState { + fn fmt(&self, f: &mut Formatter) -> std::fmt::Result { + write!( + f, + "ConsensusState: [\n\ + \tlast_vote_round = {},\n\ + \tlast_committed_round = {},\n\ + \tpreferred_block_round = {}\n\ + ]", + self.last_vote_round, self.last_committed_round, self.preferred_block_round + ) + } +} + +impl ConsensusState { + #[cfg(test)] + pub fn new( + last_vote_round: Round, + last_committed_round: Round, + preferred_block_round: Round, + ) -> Self { + Self { + last_vote_round, + last_committed_round, + preferred_block_round, + } + } + + /// Returns the last round that was voted on + pub fn last_vote_round(&self) -> Round { + self.last_vote_round + } + + /// Returns the last committed round + #[cfg(test)] + pub fn last_committed_round(&self) -> Round { + self.last_committed_round + } + + /// Returns the preferred block round + pub fn preferred_block_round(&self) -> Round { + self.preferred_block_round + } + + /// Set the last vote round that ensures safety. If the last vote round increases, return + /// the new consensus state based with the updated last vote round. Otherwise, return None. + fn set_last_vote_round(&mut self, last_vote_round: Round) -> Option { + if last_vote_round <= self.last_vote_round { + None + } else { + self.last_vote_round = last_vote_round; + counters::LAST_VOTE_ROUND.set(last_vote_round as i64); + Some(self.clone()) + } + } + + /// Set the preferred block round + fn set_preferred_block_round(&mut self, preferred_block_round: Round) { + self.preferred_block_round = preferred_block_round; + counters::PREFERRED_BLOCK_ROUND.set(preferred_block_round as i64); + } +} + +/// SafetyRules is responsible for two things that are critical for the safety of the consensus: +/// 1) voting rules, +/// 2) commit rules. +/// The only dependency is a block tree, which is queried for ancestry relationships between +/// the blocks and their QCs. +/// SafetyRules is NOT THREAD SAFE (should be protected outside via e.g., RwLock). +/// The commit decisions are returned to the caller as result of learning about a new QuorumCert. +pub struct SafetyRules { + // To query about the relationships between blocks and QCs. + block_tree: Arc>, + // Keeps the state. + state: ConsensusState, +} + +impl SafetyRules { + /// Constructs a new instance of SafetyRules given the BlockTree and ConsensusState. + pub fn new(block_tree: Arc>, state: ConsensusState) -> Self { + Self { block_tree, state } + } + + /// Learn about a new quorum certificate. Several things can happen as a result of that: + /// 1) update the preferred block to a higher value. + /// 2) commit some blocks. + /// In case of commits the last committed block is returned. + /// Requires that all the ancestors of the block are available for at least up to the last + /// committed block, might panic otherwise. + /// The update function is invoked whenever a system learns about a potentially high QC. + pub fn update(&mut self, qc: &QuorumCert) -> Option>> { + // Preferred block rule: choose the highest 2-chain head. + if let Some(one_chain_head) = self.block_tree.get_block(qc.certified_block_id()) { + if let Some(two_chain_head) = self.block_tree.get_block(one_chain_head.parent_id()) { + if two_chain_head.round() >= self.state.preferred_block_round() { + self.state.set_preferred_block_round(two_chain_head.round()); + } + } + } + self.process_ledger_info(qc.ledger_info()) + } + + /// Check to see if a processing a new LedgerInfoWithSignatures leads to a commit. Return a + /// new committed block if there is one. + pub fn process_ledger_info( + &mut self, + ledger_info: &LedgerInfoWithSignatures, + ) -> Option>> { + // While voting for a block the validators have already calculated the potential commits. + // In case there are no commits enabled by this ledger info, the committed block id is going + // to carry some placeholder value (e.g., zero). + let committed_block_id = ledger_info.ledger_info().consensus_block_id(); + if let Some(committed_block) = self.block_tree.get_block(committed_block_id) { + // We check against the root of the tree instead of last committed round to avoid + // double commit. + // Because we only persist the ConsensusState before sending out the vote, it could + // be lagged behind the reality if we crash between committing and sending the vote. + if committed_block.round() > self.block_tree.root().round() { + self.state.last_committed_round = committed_block.round(); + return Some(committed_block); + } + } + None + } + + /// Check if a one-chain at round r+2 causes a commit at round r and return the committed + /// block at round r if possible + pub fn commit_rule_for_certified_block( + &self, + one_chain_head: Arc>, + ) -> Option>> { + if let Some(two_chain_head) = self.block_tree.get_block(one_chain_head.parent_id()) { + if let Some(three_chain_head) = self.block_tree.get_block(two_chain_head.parent_id()) { + // We're using a so-called 3-chain commit rule: B0 (as well as its prefix) + // can be committed if there exist certified blocks B1 and B2 that satisfy: + // 1) B0 <- B1 <- B2 <-- + // 2) round(B0) + 1 = round(B1), and + // 3) round(B1) + 1 = round(B2). + if three_chain_head.round() + 1 == two_chain_head.round() + && two_chain_head.round() + 1 == one_chain_head.round() + { + return Some(three_chain_head); + } + } + } + None + } + + /// Return the highest known committed round + pub fn last_committed_round(&self) -> Round { + self.state.last_committed_round + } + + /// Return the new state if the voting round was increased, otherwise ignore. Increasing the + /// last vote round is always safe, but can affect liveness and must be increasing + /// to protect safety. + pub fn increase_last_vote_round(&mut self, round: Round) -> Option { + self.state.set_last_vote_round(round) + } + + /// Clones the up-to-date state of consensus (for monitoring / debugging purposes) + #[allow(dead_code)] + pub fn consensus_state(&self) -> ConsensusState { + self.state.clone() + } + + /// Attempts to vote for a given proposal following the voting rules. + /// The returned value is then going to be used for either sending the vote or doing nothing. + /// In case of a vote a cloned consensus state is returned (to be persisted before the vote is + /// sent). + /// Requires that all the ancestors of the block are available for at least up to the last + /// committed block, might panic otherwise. + pub fn voting_rule( + &mut self, + proposed_block: Arc>, + ) -> Result { + if proposed_block.round() <= self.state.last_vote_round() { + return Err(ProposalReject::OldProposal { + proposal_round: proposed_block.round(), + last_vote_round: self.state.last_vote_round(), + }); + } + + let parent_block = self + .block_tree + .get_block(proposed_block.parent_id()) + .expect("Parent block not found"); + let parent_block_round = parent_block.round(); + let respects_preferred_block = parent_block_round >= self.state.preferred_block_round(); + if respects_preferred_block { + self.state.set_last_vote_round(proposed_block.round()); + + // If the vote for the given proposal is gathered into QC, then this QC might eventually + // commit another block following the rules defined in + // `commit_rule_for_certified_block()` function. + let potential_commit = + self.commit_rule_for_certified_block(Arc::clone(&proposed_block)); + let potential_commit_id = match potential_commit { + None => None, + Some(commit_block) => Some(commit_block.id()), + }; + + Ok(VoteInfo { + proposal_id: proposed_block.id(), + proposal_round: proposed_block.round(), + consensus_state: self.state.clone(), + potential_commit_id, + }) + } else { + Err(ProposalReject::ProposalRoundLowerThenPreferredBlock { + preferred_block_round: self.state.preferred_block_round(), + }) + } + } +} diff --git a/consensus/src/chained_bft/safety/safety_rules_test.rs b/consensus/src/chained_bft/safety/safety_rules_test.rs new file mode 100644 index 0000000000000..494d5a77f2c1d --- /dev/null +++ b/consensus/src/chained_bft/safety/safety_rules_test.rs @@ -0,0 +1,406 @@ +// Copyright (c) The Libra Core Contributors +// SPDX-License-Identifier: Apache-2.0 + +use crate::chained_bft::{ + block_storage::BlockReader, + common::Round, + consensus_types::block::{block_test, Block}, + safety::safety_rules::{ConsensusState, ProposalReject, SafetyRules}, + test_utils::{build_empty_tree, build_empty_tree_with_custom_signing, TreeInserter}, +}; +use cached::{cached_key, SizedCache}; +use crypto::HashValue; +use proptest::prelude::*; +use std::{ + collections::{hash_map::DefaultHasher, BTreeMap}, + hash::{Hash, Hasher}, + sync::Arc, +}; +use types::{account_address::AccountAddress, validator_signer::ValidatorSigner}; + +fn calculate_hash(t: &T) -> u64 { + let mut s = DefaultHasher::new(); + t.hash(&mut s); + s.finish() +} + +cached_key! { + // We memoize the dfs of max_chain_depth. The size limits reflects that if we call more + // times than the forest size or so, we probably have changed the + // map's state. + LENGTH: SizedCache)> = SizedCache::with_size(50); + Key = { format!("{}{:?}{:?}", calculate_hash(children_table), query, initial_contiguous_links ) }; + // This returns the length of the maximal chain constructible from the + // (block_id, block_round) node, along with an example of such a chain + // (they are not unique) + fn max_chain_depth(children_table: &BTreeMap>, query: (HashValue, Round), initial_contiguous_links: i64) -> (usize, Vec) = { + if let Some(children) = children_table.get(&query.0) { + if let Some((depth, mut subchain)) = children.iter().cloned().flat_map(|(child_id, child_round)| + if initial_contiguous_links > 0 && child_round > query.1 + 1 { + // we're asked for a contiguous chain link and can't deliver on this child + None + } else { + Some(max_chain_depth(children_table, (child_id, child_round), std::cmp::max(initial_contiguous_links - 1, 0))) + }).max_by(|x, y| x.0.cmp(&y.0)) { + subchain.push(query.0); + (depth + 1, subchain) + } else { + (0, vec![query.0]) + } + } else { + (0, vec![query.0]) + } + } +} + +proptest! { + #[test] + fn test_blocks_commits_safety_rules( + (keypairs, blocks) in block_test::block_forest_and_its_keys( + // quorum size + 10, + // recursion depth + 50) + ) { + let (priv_key, pub_key) = keypairs.first().expect("several keypairs generated"); + let signer = ValidatorSigner::new(AccountAddress::from(*pub_key), *pub_key, priv_key.clone()); + + let mut qc_signers = vec![]; + for (priv_key, pub_key) in keypairs { + qc_signers.push(ValidatorSigner::new(AccountAddress::from(pub_key), pub_key, priv_key.clone())); + } + let block_tree = build_empty_tree_with_custom_signing(signer.clone()); + let mut inserter = TreeInserter::new(block_tree.clone()); + let mut safety_rules = SafetyRules::new(block_tree.clone(), ConsensusState::default()); + + // This commit_candidate tracks the commit that would get + // committed if the current block would get a QC + let mut commit_candidate = block_tree.root().id(); + + // children_table contains a map from parent block id to + // [(block_id, block_round), ...] of its children + let mut children_table = BTreeMap::new(); + + // inserted contains the blocks newly inserted in the tree + let mut inserted = Vec::new(); + + for block in blocks { + let known_parent = block_tree.block_exists(block.parent_id()); + if !known_parent { + continue; + } + + let insert_res = inserter.insert_pre_made_block(block.clone(), &signer, qc_signers.clone()); + let id_and_qc = |ref block: Arc>>| { (block.id(), block.quorum_cert().clone()) }; + let (inserted_id, inserted_qc) = id_and_qc(insert_res.clone()); + safety_rules.update(&inserted_qc); + + let siblings = children_table.entry(block.parent_id()).or_insert_with(|| vec![]); + siblings.push((inserted_id, block.round())); + + inserted.push((inserted_id, block.round())); + + let long_chained_blocks: Vec<(&HashValue, &Round, usize, Vec)> = inserted.iter().map(|(b_id, b_round)| { + let (chain_depth, chain) = max_chain_depth(&children_table, (*b_id, *b_round), 0); + (b_id, b_round, chain_depth, chain) + }).collect(); + + // The preferred_block is the latest (highest round-wise) 2-chain + let preferred_block_round = safety_rules.consensus_state().preferred_block_round; + let highest_two_chain = long_chained_blocks.clone().iter() + .filter(|(_bid, _bround, chain_depth, _chain)| *chain_depth >= 2) + // highest = max by round + .max_by(|b1, b2| (*b1.1).cmp(b2.1)) + .map_or((block_tree.root().id(), 0, vec![]), |(bid, bround, _, chain)| (**bid, **bround, chain.to_vec())); + prop_assert_eq!(highest_two_chain.1, preferred_block_round, + "Preferred block mismatch, expected {:?} because of chain {:#?}\n", highest_two_chain.0, highest_two_chain.2); + + let long_contiguous_chained_blocks: Vec<(&HashValue, &Round, usize, Vec)> = inserted.iter().map(|(b_id, b_round)| { + // We ask for 2 contiguous initial rounds this time + let (chain_depth, chain) = max_chain_depth(&children_table, (*b_id, *b_round), 2); + (b_id, b_round, chain_depth, chain) + }).collect(); + + let highest_contiguous_3_chain_prefix = long_contiguous_chained_blocks.iter() + // We have a chain of 3 blocks (two links) which are contiguous + .filter(|(_bid, _bround, chain_depth, _chain)| *chain_depth == 2) + // max by round + .max_by(|b1, b2| (*b1.1).cmp(b2.1)) + .map_or((block_tree.root().id(), 0, vec![]), |(bid, bround, _, chain)| (**bid, **bround, chain.to_vec())); + + if highest_contiguous_3_chain_prefix.0 != commit_candidate { + // We have a potential change of commit candidate -> + // the current block can be voted on and if gathered a + // QC, would trigger a different commit + let block_arc = block_tree.get_block(inserted_id).expect("we just inserted this"); + let vote_info = safety_rules.voting_rule(block_arc).and_then(|x| Ok(x.potential_commit_id())); + prop_assert_eq!(vote_info, Ok(Some(highest_contiguous_3_chain_prefix.0)), + "Commit mismatch: expected committing {:?} upon hearing about {:?} with preferred block {:?} because of chain {:#?}\n", highest_contiguous_3_chain_prefix.0, block.id(), highest_two_chain.0, highest_contiguous_3_chain_prefix.2 + ); + commit_candidate = highest_contiguous_3_chain_prefix.0; + } + + + } + + } +} + +#[test] +fn test_initial_state() { + // Start from scratch, verify the state + let block_tree = build_empty_tree(); + + let safety_rules = SafetyRules::new(block_tree.clone(), ConsensusState::default()); + let state = safety_rules.consensus_state(); + assert_eq!(state.last_vote_round(), 0); + assert_eq!(state.last_committed_round(), block_tree.root().round()); + assert_eq!(state.preferred_block_round(), block_tree.root().round()); +} + +#[test] +fn test_preferred_block_rule() { + // Preferred block is the highest 2-chain head. + let block_tree = build_empty_tree(); + let mut inserter = TreeInserter::new(block_tree.clone()); + let mut safety_rules = SafetyRules::new(block_tree.clone(), ConsensusState::default()); + + // build a tree of the following form: + // _____ _____ + // / \ / \ + // genesis---a1 b1 b2 a2 b3 a3---a4 + // \_____/ \_____/ \_____/ + // + // PB should change from genesis to b1, and then to a2. + let genesis = block_tree.root(); + let a1 = inserter.insert_block(genesis.as_ref(), 1); + let b1 = inserter.insert_block(genesis.as_ref(), 2); + let b2 = inserter.insert_block(a1.as_ref(), 3); + let a2 = inserter.insert_block(b1.as_ref(), 4); + let b3 = inserter.insert_block(b2.as_ref(), 5); + let a3 = inserter.insert_block(a2.as_ref(), 6); + let a4 = inserter.insert_block(a3.as_ref(), 7); + + safety_rules.update(a1.quorum_cert()); + assert_eq!( + safety_rules.consensus_state().preferred_block_round(), + genesis.round() + ); + + safety_rules.update(b1.quorum_cert()); + assert_eq!( + safety_rules.consensus_state().preferred_block_round(), + genesis.round() + ); + + safety_rules.update(a2.quorum_cert()); + assert_eq!( + safety_rules.consensus_state().preferred_block_round(), + genesis.round() + ); + + safety_rules.update(b2.quorum_cert()); + assert_eq!( + safety_rules.consensus_state().preferred_block_round(), + genesis.round() + ); + + safety_rules.update(a3.quorum_cert()); + assert_eq!( + safety_rules.consensus_state().preferred_block_round(), + b1.round() + ); + + safety_rules.update(b3.quorum_cert()); + assert_eq!( + safety_rules.consensus_state().preferred_block_round(), + b1.round() + ); + + safety_rules.update(a4.quorum_cert()); + assert_eq!( + safety_rules.consensus_state().preferred_block_round(), + a2.round() + ); +} + +#[test] +fn test_voting() { + let block_tree = build_empty_tree(); + let mut inserter = TreeInserter::new(block_tree.clone()); + let mut safety_rules = SafetyRules::new(block_tree.clone(), ConsensusState::default()); + + // build a tree of the following form: + // _____ __________ + // / \ / \ + // genesis---a1 b1 b2 a2---a3 b3 a4 b4 + // \_____/ \_____/ \______/ / + // \__________________/ + // + // + // We'll introduce the votes in the following order: + // a1 (ok), potential_commit is None + // b1 (ok), potential commit is None + // a2 (ok), potential_commit is None + // b2 (old proposal) + // a3 (ok), potential commit is None + // b3 (ok), potential commit is None + // a4 (ok), potential commit is None + // a4 (old proposal) + // b4 (round lower then round of pb. PB: a2, parent(b4)=b2) + let genesis = block_tree.root(); + let a1 = inserter.insert_block(genesis.as_ref(), 1); + let b1 = inserter.insert_block(genesis.as_ref(), 2); + let b2 = inserter.insert_block(a1.as_ref(), 3); + let a2 = inserter.insert_block(b1.as_ref(), 4); + let a3 = inserter.insert_block(a2.as_ref(), 5); + let b3 = inserter.insert_block(b2.as_ref(), 6); + let a4 = inserter.insert_block(a3.as_ref(), 7); + let b4 = inserter.insert_block(b2.as_ref(), 8); + + safety_rules.update(a1.quorum_cert()); + let mut voting_info = safety_rules.voting_rule(a1.clone()).unwrap(); + assert_eq!(voting_info.potential_commit_id, None); + + safety_rules.update(b1.quorum_cert()); + voting_info = safety_rules.voting_rule(b1.clone()).unwrap(); + assert_eq!(voting_info.potential_commit_id, None); + + safety_rules.update(a2.quorum_cert()); + voting_info = safety_rules.voting_rule(a2.clone()).unwrap(); + assert_eq!(voting_info.potential_commit_id, None); + + safety_rules.update(b2.quorum_cert()); + assert_eq!( + safety_rules.voting_rule(b2.clone()), + Err(ProposalReject::OldProposal { + last_vote_round: 4, + proposal_round: 3, + }) + ); + + safety_rules.update(a3.quorum_cert()); + voting_info = safety_rules.voting_rule(a3.clone()).unwrap(); + assert_eq!(voting_info.potential_commit_id, None); + + safety_rules.update(b3.quorum_cert()); + voting_info = safety_rules.voting_rule(b3.clone()).unwrap(); + assert_eq!(voting_info.potential_commit_id, None); + + safety_rules.update(a4.quorum_cert()); + voting_info = safety_rules.voting_rule(a4.clone()).unwrap(); + assert_eq!(voting_info.potential_commit_id, None); + + safety_rules.update(a4.quorum_cert()); + assert_eq!( + safety_rules.voting_rule(a4.clone()), + Err(ProposalReject::OldProposal { + last_vote_round: 7, + proposal_round: 7, + }) + ); + safety_rules.update(b4.quorum_cert()); + assert_eq!( + safety_rules.voting_rule(b4.clone()), + Err(ProposalReject::ProposalRoundLowerThenPreferredBlock { + preferred_block_round: 4, + }) + ); +} + +#[test] +/// Test the potential ledger info that we're going to use in case of voting +fn test_voting_potential_commit_id() { + let block_tree = build_empty_tree(); + let mut inserter = TreeInserter::new(block_tree.clone()); + let mut safety_rules = SafetyRules::new(block_tree.clone(), ConsensusState::default()); + + // build a tree of the following form: + // _____ + // / \ + // genesis--a1 b1 a2--a3--a4--a5 + // \_____/ + // + // All the votes before a4 cannot produce any potential commits. + // A potential commit for proposal a4 is a2, a potential commit for proposal a5 is a3. + + let genesis = block_tree.root(); + let a1 = inserter.insert_block(genesis.as_ref(), 1); + let b1 = inserter.insert_block(genesis.as_ref(), 2); + let a2 = inserter.insert_block(a1.as_ref(), 3); + let a3 = inserter.insert_block(a2.as_ref(), 4); + let a4 = inserter.insert_block(a3.as_ref(), 5); + let a5 = inserter.insert_block(a4.as_ref(), 6); + + let vec_with_no_potential_commits = vec![a1.clone(), b1.clone(), a2.clone(), a3.clone()]; + for b in vec_with_no_potential_commits { + safety_rules.update(b.quorum_cert()); + let voting_info = safety_rules.voting_rule(b.clone()).unwrap(); + assert_eq!(voting_info.potential_commit_id, None); + } + safety_rules.update(a4.quorum_cert()); + assert_eq!( + safety_rules + .voting_rule(a4.clone()) + .unwrap() + .potential_commit_id, + Some(a2.id()) + ); + safety_rules.update(a5.quorum_cert()); + assert_eq!( + safety_rules + .voting_rule(a5.clone()) + .unwrap() + .potential_commit_id, + Some(a3.id()) + ); +} + +#[test] +fn test_commit_rule_consecutive_rounds() { + let block_tree = build_empty_tree(); + let mut inserter = TreeInserter::new(block_tree.clone()); + let safety_rules = SafetyRules::new(block_tree.clone(), ConsensusState::default()); + + // build a tree of the following form: + // ___________ + // / \ + // genesis---a1 b1---b2 a2---a3---a4 + // \_____/ + // + // a1 cannot be committed after a3 gathers QC because a1 and a2 are not consecutive + // a2 can be committed after a4 gathers QC + + let genesis = block_tree.root(); + let a1 = inserter.insert_block(genesis.as_ref(), 1); + let b1 = inserter.insert_block(genesis.as_ref(), 2); + let b2 = inserter.insert_block(b1.as_ref(), 3); + let a2 = inserter.insert_block(a1.as_ref(), 4); + let a3 = inserter.insert_block(a2.as_ref(), 5); + let a4 = inserter.insert_block(a3.as_ref(), 6); + + assert_eq!( + safety_rules.commit_rule_for_certified_block(a1.clone()), + None + ); + assert_eq!( + safety_rules.commit_rule_for_certified_block(b1.clone()), + None + ); + assert_eq!( + safety_rules.commit_rule_for_certified_block(b2.clone()), + None + ); + assert_eq!( + safety_rules.commit_rule_for_certified_block(a2.clone()), + None + ); + assert_eq!( + safety_rules.commit_rule_for_certified_block(a3.clone()), + None + ); + assert_eq!( + safety_rules.commit_rule_for_certified_block(a4.clone()), + Some(a2) + ); +} diff --git a/consensus/src/chained_bft/safety/vote_msg.rs b/consensus/src/chained_bft/safety/vote_msg.rs new file mode 100644 index 0000000000000..610e89c08d375 --- /dev/null +++ b/consensus/src/chained_bft/safety/vote_msg.rs @@ -0,0 +1,226 @@ +// Copyright (c) The Libra Core Contributors +// SPDX-License-Identifier: Apache-2.0 + +use crate::{ + chained_bft::common::{Author, Round}, + state_replication::ExecutedState, +}; +use canonical_serialization::{CanonicalSerialize, CanonicalSerializer, SimpleSerializer}; +use crypto::{ + hash::{CryptoHash, CryptoHasher, VoteMsgHasher}, + HashValue, Signature, +}; +use failure::Result as ProtoResult; +use network::proto::Vote as ProtoVote; +use proto_conv::{FromProto, IntoProto}; +use serde::{Deserialize, Serialize}; +use std::{ + convert::TryFrom, + fmt::{Display, Formatter}, +}; +use types::{ + ledger_info::LedgerInfo, + validator_signer::ValidatorSigner, + validator_verifier::{ValidatorVerifier, VerifyError}, +}; + +/// VoteMsg verification errors. +#[derive(Debug, Fail, PartialEq)] +pub enum VoteMsgVerificationError { + /// The internal consensus data of LedgerInfo doesn't match the vote info. + #[fail(display = "ConsensusDataMismatch")] + ConsensusDataMismatch, + /// The signature doesn't pass verification + #[fail(display = "SigVerifyError: {}", _0)] + SigVerifyError(VerifyError), +} + +// Internal use only. Contains all the fields in VoteMsgSerializer that contributes to the +// computation of its hash. +struct VoteMsgSerializer { + proposed_block_id: HashValue, + executed_state: ExecutedState, + round: Round, +} + +impl CanonicalSerialize for VoteMsgSerializer { + fn serialize(&self, serializer: &mut impl CanonicalSerializer) -> failure::Result<()> { + serializer.encode_raw_bytes(self.proposed_block_id.as_ref())?; + serializer.encode_struct(&self.executed_state)?; + serializer.encode_u64(self.round)?; + Ok(()) + } +} + +impl CryptoHash for VoteMsgSerializer { + type Hasher = VoteMsgHasher; + + fn hash(&self) -> HashValue { + let mut state = Self::Hasher::default(); + state.write( + SimpleSerializer::>::serialize(self) + .expect("Should serialize.") + .as_ref(), + ); + state.finish() + } +} + +/// VoteMsg is the struct that is ultimately sent by the voter in response for +/// receiving a proposal. +/// VoteMsg carries the `LedgerInfo` of a block that is going to be committed in case this vote +/// is gathers QuorumCertificate (see the detailed explanation in the comments of `LedgerInfo`). +#[derive(Deserialize, Serialize, Clone, Debug, PartialEq, Eq)] +pub struct VoteMsg { + /// The id of the proposed block. + proposed_block_id: HashValue, + /// The id of the state generated by the StateExecutor after executing the proposed block. + executed_state: ExecutedState, + /// The round of the block. + round: Round, + /// The identity of the voter. + author: Author, + /// LedgerInfo of a block that is going to be committed in case this vote gathers QC. + ledger_info: LedgerInfo, + /// Signature of the LedgerInfo + signature: Signature, +} + +impl Display for VoteMsg { + fn fmt(&self, f: &mut Formatter) -> std::fmt::Result { + write!( + f, + "Vote: [block id: {}, round: {:02}, author: {}, {}]", + self.proposed_block_id, + self.round, + self.author.short_str(), + self.ledger_info + ) + } +} + +impl VoteMsg { + pub fn new( + proposed_block_id: HashValue, + executed_state: ExecutedState, + round: Round, + author: Author, + mut ledger_info_placeholder: LedgerInfo, + validator_signer: &ValidatorSigner, + ) -> Self { + ledger_info_placeholder.set_consensus_data_hash(Self::vote_digest( + proposed_block_id, + executed_state, + round, + )); + let li_sig = validator_signer + .sign_message(ledger_info_placeholder.hash()) + .expect("Failed to sign LedgerInfo"); + Self { + proposed_block_id, + executed_state, + round, + author, + ledger_info: ledger_info_placeholder, + signature: li_sig, + } + } + + /// Return the proposed block id + pub fn proposed_block_id(&self) -> HashValue { + self.proposed_block_id + } + + /// Return the executed state of the proposed block + pub fn executed_state(&self) -> ExecutedState { + self.executed_state + } + + /// Return the round of the block + pub fn round(&self) -> Round { + self.round + } + + /// Return the author of the vote + pub fn author(&self) -> Author { + self.author + } + + /// Return the LedgerInfo associated with this vote + pub fn ledger_info(&self) -> &LedgerInfo { + &self.ledger_info + } + + /// Return the signature of the vote + pub fn signature(&self) -> &Signature { + &self.signature + } + + /// Verifies that the consensus data hash of LedgerInfo corresponds to the vote info, + /// and then verifies the signature. + pub fn verify(&self, validator: &ValidatorVerifier) -> Result<(), VoteMsgVerificationError> { + if self.ledger_info.consensus_data_hash() != self.vote_hash() { + return Err(VoteMsgVerificationError::ConsensusDataMismatch); + } + validator + .verify_signature(self.author(), self.ledger_info.hash(), self.signature()) + .map_err(VoteMsgVerificationError::SigVerifyError) + } + + /// Return the hash of this struct + pub fn vote_hash(&self) -> HashValue { + Self::vote_digest(self.proposed_block_id, self.executed_state, self.round) + } + + /// Return a digest of the vote + pub fn vote_digest( + proposed_block_id: HashValue, + executed_state: ExecutedState, + round: Round, + ) -> HashValue { + VoteMsgSerializer { + proposed_block_id, + executed_state, + round, + } + .hash() + } +} + +impl IntoProto for VoteMsg { + type ProtoType = ProtoVote; + + fn into_proto(self) -> Self::ProtoType { + let mut proto = Self::ProtoType::new(); + proto.set_proposed_block_id(self.proposed_block_id().into()); + proto.set_executed_state_id(self.executed_state().state_id.into()); + proto.set_version(self.executed_state().version); + proto.set_round(self.round); + proto.set_author(self.author.into()); + proto.set_ledger_info(self.ledger_info.into_proto()); + proto.set_signature(self.signature.to_compact().as_ref().into()); + proto + } +} + +impl FromProto for VoteMsg { + type ProtoType = ProtoVote; + + fn from_proto(mut object: Self::ProtoType) -> ProtoResult { + let proposed_block_id = HashValue::from_slice(object.get_proposed_block_id())?; + let state_id = HashValue::from_slice(object.get_executed_state_id())?; + let version = object.get_version(); + let round = object.get_round(); + let author = Author::try_from(object.take_author())?; + let ledger_info = LedgerInfo::from_proto(object.take_ledger_info())?; + let signature = Signature::from_compact(object.get_signature())?; + Ok(VoteMsg { + proposed_block_id, + executed_state: ExecutedState { state_id, version }, + round, + author, + ledger_info, + signature, + }) + } +} diff --git a/consensus/src/chained_bft/sync_manager.rs b/consensus/src/chained_bft/sync_manager.rs new file mode 100644 index 0000000000000..f8da64961d0c3 --- /dev/null +++ b/consensus/src/chained_bft/sync_manager.rs @@ -0,0 +1,353 @@ +// Copyright (c) The Libra Core Contributors +// SPDX-License-Identifier: Apache-2.0 + +use crate::{ + chained_bft::{ + block_storage::{BlockReader, BlockStore, InsertError}, + common::{Author, Payload}, + consensus_types::{block::Block, quorum_cert::QuorumCert}, + network::ConsensusNetworkImpl, + persistent_storage::PersistentStorage, + }, + counters, + state_replication::StateComputer, + state_synchronizer::SyncStatus, +}; +use failure::{Fail, Result}; +use logger::prelude::*; +use network::proto::BlockRetrievalStatus; +use rand::{prelude::*, Rng}; +use std::{ + clone::Clone, + sync::Arc, + time::{Duration, Instant}, +}; +use termion::color::*; +use types::{account_address::AccountAddress, transaction::TransactionListWithProof}; + +/// SyncManager is responsible for fetching dependencies and 'catching up' for given qc/ledger info +pub struct SyncManager { + block_store: Arc>, + storage: Arc>, + network: ConsensusNetworkImpl, + state_computer: Arc>, +} + +/// This struct describes where do we sync to +pub struct SyncInfo { + /// Highest ledger info to invoke state sync for + /// This is optional for now, because vote does not have it + pub highest_ledger_info: QuorumCert, + /// Quorum certificate to be inserted into block tree + pub highest_quorum_cert: QuorumCert, + /// Author of messages that triggered this sync. + /// For now we sync from this peer. In future we going to use peers from quorum certs, + /// and this field going to be mostly informational + pub peer: Author, +} + +impl SyncManager +where + T: Payload, +{ + pub fn new( + block_store: Arc>, + storage: Arc>, + network: ConsensusNetworkImpl, + state_computer: Arc>, + ) -> SyncManager { + // Our counters are initialized via lazy_static, so they're not going to appear in + // Prometheus if some conditions never happen. Invoking get() function enforces creation. + counters::BLOCK_RETRIEVAL_COUNT.get(); + counters::STATE_SYNC_COUNT.get(); + counters::STATE_SYNC_TXN_REPLAYED.get(); + SyncManager { + block_store, + storage, + network, + state_computer, + } + } + + /// Fetches dependencies for given sync_info.quorum_cert + /// If gap is large, performs state sync using process_highest_ledger_info + /// Inserts sync_info.quorum_cert into block store as the last step + pub async fn sync_to(&mut self, deadline: Instant, sync_info: SyncInfo) -> Result<()> { + let highest_ledger_info = sync_info.highest_ledger_info.clone(); + + self.process_highest_ledger_info(highest_ledger_info, sync_info.peer, deadline) + .await?; + + self.fetch_quorum_cert( + sync_info.highest_quorum_cert.clone(), + sync_info.peer, + deadline, + ) + .await?; + Ok(()) + } + + /// Get a chunk of transactions as a batch + pub async fn get_chunk( + &self, + start_version: u64, + target_version: u64, + batch_size: u64, + ) -> Result { + self.state_computer + .get_chunk(start_version, target_version, batch_size) + .await + } + + /// Insert the quorum certificate separately from the block, used to split the processing of + /// updating the consensus state(with qc) and deciding whether to vote(with block) + /// The missing ancestors are going to be retrieved from the given peer. If a given peer + /// fails to provide the missing ancestors, the qc is not going to be added. + pub async fn fetch_quorum_cert( + &self, + qc: QuorumCert, + preferred_peer: Author, + deadline: Instant, + ) -> std::result::Result<(), InsertError> { + let mut pending = vec![]; + let network = self.network.clone(); + let mut retriever = BlockRetriever { + network, + deadline, + preferred_peer, + }; + let mut retrieve_qc = qc.clone(); + while !self + .block_store + .block_exists(retrieve_qc.certified_block_id()) + { + let mut blocks = retriever.retrieve_block_for_qc(&retrieve_qc, 1).await?; + // retrieve_block_for_qc guarantees that blocks has exactly 1 element + let block = blocks.remove(0); + retrieve_qc = block.quorum_cert().clone(); + pending.push(block); + } + // insert the qc <- block pair + while let Some(block) = pending.pop() { + let block_qc = block.quorum_cert().clone(); + self.block_store.insert_single_quorum_cert(block_qc).await?; + self.block_store.execute_and_insert_block(block).await?; + } + self.block_store.insert_single_quorum_cert(qc).await + } + + /// Check the highest ledger info sent by peer to see if we're behind and start a fast + /// forward sync if the committed block doesn't exist in our tree. + /// It works as follows: + /// 1. request the committed 3-chain from the peer, if C2 is the highest_ledger_info + /// we request for B0 <- C0 <- B1 <- C1 <- B2 (<- C2) + /// 2. We persist the 3-chain to storage before start sync to ensure we could restart if we + /// crash in the middle of the sync. + /// 3. We prune the old tree and replace with a new tree built with the 3-chain. + async fn process_highest_ledger_info( + &self, + highest_ledger_info: QuorumCert, + peer: Author, + deadline: Instant, + ) -> Result<()> { + let committed_block_id = highest_ledger_info + .committed_block_id() + .ok_or_else(|| format_err!("highest ledger info has no committed block"))?; + if !self + .block_store + .need_sync_for_quorum_cert(committed_block_id, &highest_ledger_info) + { + return Ok(()); + } + debug!( + "Start state sync with peer: {}, to block: {}, round: {} from {}", + peer.short_str(), + committed_block_id, + highest_ledger_info.certified_block_round() - 2, + self.block_store.root() + ); + let network = self.network.clone(); + let mut retriever = BlockRetriever { + network, + deadline, + preferred_peer: peer, + }; + let mut blocks = retriever + .retrieve_block_for_qc(&highest_ledger_info, 3) + .await?; + assert_eq!( + blocks.last().expect("should have 3-chain").id(), + committed_block_id + ); + let mut quorum_certs = vec![]; + quorum_certs.push(highest_ledger_info.clone()); + quorum_certs.push(blocks[0].quorum_cert().clone()); + quorum_certs.push(blocks[1].quorum_cert().clone()); + // If a node restarts in the middle of state synchronization, it is going to try to catch up + // to the stored quorum certs as the new root. + self.storage + .save_tree(blocks.clone(), quorum_certs.clone())?; + let pre_sync_instance = Instant::now(); + match self + .state_computer + .sync_to(highest_ledger_info.clone()) + .await + { + Ok(SyncStatus::Finished) => (), + Ok(e) => panic!( + "state synchronizer failure: {:?}, this validator will be killed as it can not \ + recover from this error. After the validator is restarted, synchronization will \ + be retried.", + e + ), + Err(e) => panic!( + "state synchronizer failure: {:?}, this validator will be killed as it can not \ + recover from this error. After the validator is restarted, synchronization will \ + be retried.", + e + ), + }; + counters::STATE_SYNC_DURATION_MS.observe(pre_sync_instance.elapsed().as_millis() as f64); + let root = ( + blocks.pop().expect("should have 3-chain"), + quorum_certs.last().expect("should have 3-chain").clone(), + highest_ledger_info.clone(), + ); + debug!("{}Sync to{} {}", Fg(Blue), Fg(Reset), root.0); + // ensure it's [b1, b2] + blocks.reverse(); + self.block_store.rebuild(root, blocks, quorum_certs).await; + Ok(()) + } +} + +/// BlockRetriever is used internally to retrieve blocks +struct BlockRetriever { + network: ConsensusNetworkImpl, + deadline: Instant, + preferred_peer: Author, +} + +#[derive(Debug, Fail)] +enum BlockRetrieverError { + #[fail(display = "All peers failed")] + AllPeersFailed, + #[fail(display = "Round deadline reached")] + RoundDeadlineReached, +} + +impl From for InsertError { + fn from(_error: BlockRetrieverError) -> Self { + InsertError::AncestorRetrievalError + } +} + +impl BlockRetriever { + /// Retrieve chain of n blocks for given QC + /// + /// Returns Result with Vec that has a guaranteed size of num_blocks + /// This guarantee is based on BlockRetrievalResponse::verify that ensures that number of + /// blocks in response is equal to number of blocks requested. This method will + /// continue until either the round deadline is reached or the quorum certificate members all + /// fail to return the missing chain. + /// + /// The first attempt of block retrieval will always be sent to preferred_peer to allow the + /// leader to drive quorum certificate creation The other peers from the quorum certificate + /// will be randomly tried next. If all members of the quorum certificate are exhausted, an + /// error is returned + pub async fn retrieve_block_for_qc<'a, T>( + &'a mut self, + qc: &'a QuorumCert, + num_blocks: u64, + ) -> std::result::Result>, BlockRetrieverError> + where + T: Payload, + { + let block_id = qc.certified_block_id(); + let mut peers: Vec<&AccountAddress> = qc.ledger_info().signatures().keys().collect(); + let mut attempt = 0_u32; + loop { + if peers.is_empty() { + warn!( + "Failed to fetch block {} in {} attempts: no more peers available", + block_id, attempt + ); + return Err(BlockRetrieverError::AllPeersFailed); + } + let peer = self.pick_peer(attempt, &mut peers); + attempt += 1; + + let timeout = retrieval_timeout(&self.deadline, attempt); + let timeout = if let Some(timeout) = timeout { + timeout + } else { + warn!("Failed to fetch block {} from {}, attempt {}: round deadline was reached, won't make more attempts", block_id, peer, attempt); + return Err(BlockRetrieverError::RoundDeadlineReached); + }; + debug!( + "Fetching {} from {}, attempt {}", + block_id, + peer.short_str(), + attempt + ); + let response = self + .network + .request_block(block_id, num_blocks, peer, timeout) + .await; + let response = match response { + Err(e) => { + warn!( + "Failed to fetch block {} from {}: {:?}, trying another peer", + block_id, peer, e + ); + continue; + } + Ok(response) => response, + }; + if response.status != BlockRetrievalStatus::SUCCEEDED { + warn!( + "Failed to fetch block {} from {}: {:?}, trying another peer", + block_id, peer, response.status + ); + continue; + } + return Ok(response.blocks); + } + } + + fn pick_peer(&self, attempt: u32, peers: &mut Vec<&AccountAddress>) -> AccountAddress { + assert!(!peers.is_empty(), "pick_peer on empty peer list"); + + if attempt == 0 { + // remove preferred_peer if its in list of peers + // (strictly speaking it is not required to be there) + for i in 0..peers.len() { + if *peers[i] == self.preferred_peer { + peers.remove(i); + break; + } + } + return self.preferred_peer; + } + let peer_idx = thread_rng().gen_range(0, peers.len()); + *peers.remove(peer_idx) + } +} + +// Max timeout is 16s=RETRIEVAL_INITIAL_TIMEOUT*(2^RETRIEVAL_MAX_EXP) +const RETRIEVAL_INITIAL_TIMEOUT: Duration = Duration::from_secs(1); +const RETRIEVAL_MAX_EXP: u32 = 4; + +/// Returns exponentially increasing timeout with +/// limit of RETRIEVAL_INITIAL_TIMEOUT*(2^RETRIEVAL_MAX_EXP) +fn retrieval_timeout(deadline: &Instant, attempt: u32) -> Option { + assert!(attempt > 0, "retrieval_timeout attempt can't be 0"); + let exp = RETRIEVAL_MAX_EXP.min(attempt - 1); // [0..RETRIEVAL_MAX_EXP] + let request_timeout = RETRIEVAL_INITIAL_TIMEOUT * 2_u32.pow(exp); + let deadline_timeout = deadline.checked_duration_since(Instant::now()); + if let Some(deadline_timeout) = deadline_timeout { + Some(request_timeout.min(deadline_timeout)) + } else { + None + } +} diff --git a/consensus/src/chained_bft/test_utils/mock_state_computer.rs b/consensus/src/chained_bft/test_utils/mock_state_computer.rs new file mode 100644 index 0000000000000..434ed84e48df0 --- /dev/null +++ b/consensus/src/chained_bft/test_utils/mock_state_computer.rs @@ -0,0 +1,80 @@ +// Copyright (c) The Libra Core Contributors +// SPDX-License-Identifier: Apache-2.0 + +use crate::{ + chained_bft::consensus_types::quorum_cert::QuorumCert, + state_replication::{StateComputeResult, StateComputer}, + state_synchronizer::SyncStatus, +}; +use crypto::{hash::ACCUMULATOR_PLACEHOLDER_HASH, HashValue}; +use failure::Result; +use futures::{channel::mpsc, Future, FutureExt}; +use logger::prelude::*; +use std::pin::Pin; +use termion::color::*; +use types::{ledger_info::LedgerInfoWithSignatures, transaction::TransactionListWithProof}; + +pub struct MockStateComputer { + commit_callback: mpsc::UnboundedSender, +} + +impl MockStateComputer { + pub fn new(commit_callback: mpsc::UnboundedSender) -> Self { + MockStateComputer { commit_callback } + } +} + +impl StateComputer for MockStateComputer { + type Payload = Vec; + fn compute( + &self, + _parent_id: HashValue, + _block_id: HashValue, + _transactions: &Self::Payload, + ) -> Pin> + Send>> { + async move { + Ok(StateComputeResult { + new_state_id: *ACCUMULATOR_PLACEHOLDER_HASH, + compute_status: vec![], + num_successful_txns: 0, + validators: None, + }) + } + .boxed() + } + + fn commit( + &self, + commit: LedgerInfoWithSignatures, + ) -> Pin> + Send>> { + self.commit_callback + .unbounded_send(commit) + .expect("Fail to notify about commit."); + async { Ok(()) }.boxed() + } + + fn sync_to( + &self, + commit: QuorumCert, + ) -> Pin> + Send>> { + debug!( + "{}Fake sync{} to block id {}", + Fg(Blue), + Fg(Reset), + commit.ledger_info().ledger_info().consensus_block_id() + ); + self.commit_callback + .unbounded_send(commit.ledger_info().clone()) + .expect("Fail to notify about sync"); + async { Ok(SyncStatus::Finished) }.boxed() + } + + fn get_chunk( + &self, + _: u64, + _: u64, + _: u64, + ) -> Pin> + Send>> { + async move { Err(format_err!("not implemented")) }.boxed() + } +} diff --git a/consensus/src/chained_bft/test_utils/mock_storage.rs b/consensus/src/chained_bft/test_utils/mock_storage.rs new file mode 100644 index 0000000000000..919074b3f095d --- /dev/null +++ b/consensus/src/chained_bft/test_utils/mock_storage.rs @@ -0,0 +1,217 @@ +// Copyright (c) The Libra Core Contributors +// SPDX-License-Identifier: Apache-2.0 + +use crate::chained_bft::{ + common::Payload, + consensus_types::{block::Block, quorum_cert::QuorumCert}, + liveness::pacemaker_timeout_manager::HighestTimeoutCertificates, + persistent_storage::{PersistentLivenessStorage, PersistentStorage, RecoveryData}, + safety::safety_rules::ConsensusState, +}; +use config::config::{NodeConfig, NodeConfigHelpers}; +use crypto::HashValue; +use failure::Result; +use std::{ + collections::HashMap, + sync::{Arc, Mutex}, +}; + +pub struct MockSharedStorage { + // Safety state + pub block: Mutex>>, + pub qc: Mutex>, + pub state: Mutex, + + // Liveness state + pub highest_timeout_certificates: Mutex, +} + +/// A storage that simulates the operations in-memory, used in the tests that cares about storage +/// consistency. +pub struct MockStorage { + pub shared_storage: Arc>, +} + +impl MockStorage { + pub fn new(shared_storage: Arc>) -> Self { + MockStorage { shared_storage } + } + + pub fn get_recovery_data(&self) -> Result> { + let mut blocks: Vec<_> = self + .shared_storage + .block + .lock() + .unwrap() + .clone() + .into_iter() + .map(|(_, v)| v) + .collect(); + let quorum_certs = self + .shared_storage + .qc + .lock() + .unwrap() + .clone() + .into_iter() + .map(|(_, v)| v) + .collect(); + // There is no root_from_storage in MockStorage(unit tests), hence we use the consensus + // root value; + blocks.sort_by_key(Block::round); + let root_from_storage = blocks[0].id(); + RecoveryData::new( + self.shared_storage.state.lock().unwrap().clone(), + blocks, + quorum_certs, + root_from_storage, + self.shared_storage + .highest_timeout_certificates + .lock() + .unwrap() + .clone(), + ) + } + + pub fn verify_consistency(&self) -> Result<()> { + self.get_recovery_data().map(|_| ()) + } + + pub fn start_for_testing() -> (Arc, RecoveryData) { + Self::start(&NodeConfigHelpers::get_single_node_test_config(false)) + } +} + +impl PersistentLivenessStorage for MockStorage { + fn save_highest_timeout_cert( + &self, + highest_timeout_certificates: HighestTimeoutCertificates, + ) -> Result<()> { + *self + .shared_storage + .highest_timeout_certificates + .lock() + .unwrap() = highest_timeout_certificates; + Ok(()) + } +} + +// A impl that always start from genesis. +impl PersistentStorage for MockStorage { + fn persistent_liveness_storage(&self) -> Box { + Box::new(MockStorage { + shared_storage: Arc::clone(&self.shared_storage), + }) + } + + fn save_tree(&self, blocks: Vec>, quorum_certs: Vec) -> Result<()> { + for block in blocks { + self.shared_storage + .block + .lock() + .unwrap() + .insert(block.id(), block); + } + for qc in quorum_certs { + self.shared_storage + .qc + .lock() + .unwrap() + .insert(qc.certified_block_id(), qc); + } + if let Err(e) = self.verify_consistency() { + panic!("invalid db after save tree: {}", e); + } + Ok(()) + } + + fn prune_tree(&self, block_id: Vec) -> Result<()> { + for id in block_id { + self.shared_storage.block.lock().unwrap().remove(&id); + self.shared_storage.qc.lock().unwrap().remove(&id); + } + if let Err(e) = self.verify_consistency() { + panic!("invalid db after prune tree: {}", e); + } + Ok(()) + } + + fn save_consensus_state(&self, state: ConsensusState) -> Result<()> { + *self.shared_storage.state.lock().unwrap() = state; + Ok(()) + } + + fn start(_config: &NodeConfig) -> (Arc, RecoveryData) { + let shared_storage = Arc::new(MockSharedStorage { + block: Mutex::new(HashMap::new()), + qc: Mutex::new(HashMap::new()), + state: Mutex::new(ConsensusState::default()), + highest_timeout_certificates: Mutex::new(HighestTimeoutCertificates::new(None, None)), + }); + let storage = MockStorage { + shared_storage: Arc::clone(&shared_storage), + }; + + // The current assumption is that the genesis block version is 0. + storage + .save_tree( + vec![Block::make_genesis_block()], + vec![QuorumCert::certificate_for_genesis()], + ) + .unwrap(); + ( + Arc::new(Self::new(shared_storage)), + storage.get_recovery_data().unwrap(), + ) + } +} + +/// A storage that ignores any requests, used in the tests that don't care about the storage. +pub struct EmptyStorage; + +impl EmptyStorage { + pub fn start_for_testing() -> (Arc, RecoveryData) { + Self::start(&NodeConfigHelpers::get_single_node_test_config(false)) + } +} + +impl PersistentLivenessStorage for EmptyStorage { + fn save_highest_timeout_cert(&self, _: HighestTimeoutCertificates) -> Result<()> { + Ok(()) + } +} + +impl PersistentStorage for EmptyStorage { + fn persistent_liveness_storage(&self) -> Box { + Box::new(EmptyStorage) + } + + fn save_tree(&self, _: Vec>, _: Vec) -> Result<()> { + Ok(()) + } + + fn prune_tree(&self, _: Vec) -> Result<()> { + Ok(()) + } + + fn save_consensus_state(&self, _: ConsensusState) -> Result<()> { + Ok(()) + } + + fn start(_: &NodeConfig) -> (Arc, RecoveryData) { + let genesis = Block::make_genesis_block(); + let genesis_qc = QuorumCert::certificate_for_genesis(); + let htc = HighestTimeoutCertificates::new(None, None); + ( + Arc::new(EmptyStorage), + RecoveryData::new( + ConsensusState::default(), + vec![genesis], + vec![genesis_qc], + HashValue::random(), + htc, + ) + .unwrap(), + ) + } +} diff --git a/consensus/src/chained_bft/test_utils/mock_txn_manager.rs b/consensus/src/chained_bft/test_utils/mock_txn_manager.rs new file mode 100644 index 0000000000000..0827c7d380f63 --- /dev/null +++ b/consensus/src/chained_bft/test_utils/mock_txn_manager.rs @@ -0,0 +1,85 @@ +// Copyright (c) The Libra Core Contributors +// SPDX-License-Identifier: Apache-2.0 + +use crate::state_replication::{StateComputeResult, TxnManager}; +use failure::Result; +use futures::{channel::mpsc, Future, FutureExt, SinkExt}; +use std::{ + pin::Pin, + sync::{ + atomic::{AtomicUsize, Ordering}, + Arc, RwLock, + }, +}; + +pub type MockTransaction = usize; + +/// Trivial mock: generates MockTransactions on the fly. Each next transaction is the next value. +pub struct MockTransactionManager { + next_val: AtomicUsize, + committed_txns: Arc>>, + commit_receiver: Option>, + commit_sender: mpsc::Sender, +} + +impl MockTransactionManager { + pub fn new() -> Self { + let (commit_sender, commit_receiver) = mpsc::channel(1024); + Self { + next_val: AtomicUsize::new(0), + committed_txns: Arc::new(RwLock::new(vec![])), + commit_receiver: Some(commit_receiver), + commit_sender, + } + } + + pub fn get_committed_txns(&self) -> Vec { + self.committed_txns.read().unwrap().clone() + } + + /// Pulls the receiver out of the manager to let the clients receive notifications about the + /// commits. + pub fn take_commit_receiver(&mut self) -> mpsc::Receiver { + self.commit_receiver + .take() + .expect("The receiver has been already pulled out.") + } +} + +impl TxnManager for MockTransactionManager { + type Payload = Vec; + + /// The returned future is fulfilled with the vector of SignedTransactions + fn pull_txns( + &self, + max_size: u64, + _exclude_txns: Vec<&Self::Payload>, + ) -> Pin> + Send>> { + let next_value = self.next_val.load(Ordering::SeqCst); + let upper_bound = next_value + max_size as usize; + let res = (next_value..upper_bound).collect(); + self.next_val.store(upper_bound, Ordering::SeqCst); + async move { Ok(res) }.boxed() + } + + fn commit_txns<'a>( + &'a self, + txns: &Self::Payload, + _compute_result: &StateComputeResult, + _timestamp_usecs: u64, + ) -> Pin> + Send + 'a>> { + let committed_tns = txns.clone(); + let mut commit_sender = self.commit_sender.clone(); + async move { + for txn in committed_tns { + self.committed_txns.write().unwrap().push(txn); + } + commit_sender + .send(self.committed_txns.read().unwrap().len()) + .await + .expect("Failed to notify about mempool commit"); + Ok(()) + } + .boxed() + } +} diff --git a/consensus/src/chained_bft/test_utils/mod.rs b/consensus/src/chained_bft/test_utils/mod.rs new file mode 100644 index 0000000000000..0b57c9d08b20b --- /dev/null +++ b/consensus/src/chained_bft/test_utils/mod.rs @@ -0,0 +1,190 @@ +// Copyright (c) The Libra Core Contributors +// SPDX-License-Identifier: Apache-2.0 + +use crate::{ + chained_bft::{ + block_storage::BlockStore, + common::Round, + consensus_types::{block::Block, quorum_cert::QuorumCert}, + safety::vote_msg::VoteMsg, + }, + state_replication::ExecutedState, +}; +use crypto::{hash::CryptoHash, HashValue}; +use futures::{channel::mpsc, executor::block_on}; +use logger::{set_simple_logger, set_simple_logger_prefix}; +use std::{collections::HashMap, sync::Arc}; +use termion::color::*; +use tokio::runtime; +use tools::output_capture::OutputCapture; +use types::{ + ledger_info::{LedgerInfo, LedgerInfoWithSignatures}, + validator_signer::ValidatorSigner, +}; + +mod mock_state_computer; +mod mock_storage; +mod mock_txn_manager; + +pub use mock_state_computer::MockStateComputer; +pub use mock_storage::{EmptyStorage, MockStorage}; +pub use mock_txn_manager::MockTransactionManager; + +pub type TestPayload = Vec; + +pub fn build_empty_tree() -> Arc>> { + let signer = ValidatorSigner::random(); + build_empty_tree_with_custom_signing(signer.clone()) +} + +pub fn build_empty_tree_with_custom_signing( + my_signer: ValidatorSigner, +) -> Arc>> { + let (commit_cb_sender, _commit_cb_receiver) = mpsc::unbounded::(); + let (storage, initial_data) = EmptyStorage::start_for_testing(); + Arc::new(block_on(BlockStore::new( + storage, + initial_data, + my_signer, + Arc::new(MockStateComputer::new(commit_cb_sender)), + true, + 10, // max pruned blocks in mem + ))) +} + +pub struct TreeInserter { + payload_val: usize, + block_store: Arc>>, +} + +impl TreeInserter { + pub fn new(block_store: Arc>>) -> Self { + Self { + payload_val: 0, + block_store, + } + } + + /// This function is generating a placeholder QC for a block's parent that is signed by a single + /// signer kept by the block store. If more sophisticated QC required, please use + /// `insert_block_with_qc`. + pub fn insert_block( + &mut self, + parent: &Block>, + round: Round, + ) -> Arc>> { + // Node must carry a QC to its parent + let parent_qc = placeholder_certificate_for_block( + vec![self.block_store.signer().clone()], + parent.id(), + parent.round(), + ); + + self.insert_block_with_qc(parent_qc, parent, round) + } + + pub fn insert_block_with_qc( + &mut self, + parent_qc: QuorumCert, + parent: &Block>, + round: Round, + ) -> Arc>> { + self.payload_val += 1; + block_on(self.block_store.insert_block_with_qc(Block::make_block( + parent, + vec![self.payload_val], + round, + parent.timestamp_usecs() + 1, + parent_qc, + self.block_store.signer(), + ))) + .unwrap() + } + + pub fn insert_pre_made_block( + &mut self, + block: Block>, + block_signer: &ValidatorSigner, + qc_signers: Vec, + ) -> Arc>> { + self.payload_val += 1; + let new_round = if block.round() > 0 { + block.round() - 1 + } else { + 0 + }; + let parent_qc = placeholder_certificate_for_block(qc_signers, block.parent_id(), new_round); + let new_block = Block::new_internal( + block.get_payload().clone(), + block.parent_id(), + block.round(), + block.height(), + block.timestamp_usecs(), + parent_qc, + block_signer, + ); + block_on(self.block_store.insert_block_with_qc(new_block)).unwrap() + } +} + +pub fn placeholder_ledger_info() -> LedgerInfo { + LedgerInfo::new( + 0, + HashValue::zero(), + HashValue::zero(), + HashValue::zero(), + 0, + 0, + ) +} + +pub fn placeholder_certificate_for_block( + signers: Vec, + certified_block_id: HashValue, + certified_block_round: u64, +) -> QuorumCert { + // Assuming executed state to be Genesis state. + let certified_block_state = ExecutedState::state_for_genesis(); + let consensus_data_hash = VoteMsg::vote_digest( + certified_block_id, + certified_block_state, + certified_block_round, + ); + + // This ledger info doesn't carry any meaningful information: it is all zeros except for + // the consensus data hash that carries the actual vote. + let mut ledger_info_placeholder = placeholder_ledger_info(); + ledger_info_placeholder.set_consensus_data_hash(consensus_data_hash); + + let mut signatures = HashMap::new(); + for signer in signers { + let li_sig = signer + .sign_message(ledger_info_placeholder.hash()) + .expect("Failed to sign LedgerInfo"); + signatures.insert(signer.author(), li_sig); + } + + QuorumCert::new( + certified_block_id, + certified_block_state, + certified_block_round, + LedgerInfoWithSignatures::new(ledger_info_placeholder, signatures), + ) +} + +pub fn consensus_runtime() -> runtime::Runtime { + set_simple_logger("consensus"); + let capture = OutputCapture::grab(); + runtime::Builder::new() + .after_start(move || capture.apply()) + .build() + .expect("Failed to create Tokio runtime!") +} + +pub fn with_smr_id(id: String) -> impl Fn() { + let capture = OutputCapture::grab(); + move || { + capture.apply(); + set_simple_logger_prefix(format!("{}[{}]{}", Fg(LightBlack), id.clone(), Fg(Reset))) + } +} diff --git a/consensus/src/consensus_provider.rs b/consensus/src/consensus_provider.rs new file mode 100644 index 0000000000000..6c85582da9920 --- /dev/null +++ b/consensus/src/consensus_provider.rs @@ -0,0 +1,71 @@ +// Copyright (c) The Libra Core Contributors +// SPDX-License-Identifier: Apache-2.0 + +use config::config::NodeConfig; +use failure::prelude::*; +use network::validator_network::{ConsensusNetworkEvents, ConsensusNetworkSender}; + +use crate::chained_bft::chained_bft_consensus_provider::ChainedBftProvider; +use execution_proto::proto::execution_grpc::ExecutionClient; +use grpcio::{ChannelBuilder, EnvBuilder}; +use mempool::proto::mempool_grpc::MempoolClient; +use std::sync::Arc; +use storage_client::{StorageRead, StorageReadServiceClient}; + +/// Public interface to a consensus protocol. +pub trait ConsensusProvider { + /// Spawns new threads, starts the consensus operations (retrieve txns, consensus protocol, + /// execute txns, commit txns, update txn status in the mempool, etc). + /// The function returns after consensus has recovered its initial state, + /// and has established the required connections (e.g., to mempool and + /// executor). + fn start(&mut self) -> Result<()>; + + /// Stop the consensus operations. The function returns after graceful shutdown. + fn stop(&mut self); +} + +/// Helper function to create a ConsensusProvider based on configuration +pub fn make_consensus_provider( + node_config: &NodeConfig, + network_sender: ConsensusNetworkSender, + network_receiver: ConsensusNetworkEvents, +) -> Box { + Box::new(ChainedBftProvider::new( + node_config, + network_sender, + network_receiver, + create_mempool_client(node_config), + create_execution_client(node_config), + )) +} +/// Create a mempool client assuming the mempool is running on localhost +fn create_mempool_client(config: &NodeConfig) -> Arc { + let port = config.mempool.mempool_service_port; + let connection_str = format!("localhost:{}", port); + + let env = Arc::new(EnvBuilder::new().name_prefix("grpc-con-mem-").build()); + Arc::new(MempoolClient::new( + ChannelBuilder::new(env).connect(&connection_str), + )) +} + +/// Create an execution client assuming the mempool is running on localhost +fn create_execution_client(config: &NodeConfig) -> Arc { + let connection_str = format!("localhost:{}", config.execution.port); + + let env = Arc::new(EnvBuilder::new().name_prefix("grpc-con-exe-").build()); + Arc::new(ExecutionClient::new( + ChannelBuilder::new(env).connect(&connection_str), + )) +} + +/// Create a storage read client based on the config +pub fn create_storage_read_client(config: &NodeConfig) -> Arc { + let env = Arc::new(EnvBuilder::new().name_prefix("grpc-con-sto-").build()); + Arc::new(StorageReadServiceClient::new( + env, + &config.storage.address, + config.storage.port, + )) +} diff --git a/consensus/src/counters.rs b/consensus/src/counters.rs new file mode 100644 index 0000000000000..f92ab1b2c7c41 --- /dev/null +++ b/consensus/src/counters.rs @@ -0,0 +1,184 @@ +// Copyright (c) The Libra Core Contributors +// SPDX-License-Identifier: Apache-2.0 + +use lazy_static; +use metrics::OpMetrics; +use prometheus::{Histogram, IntCounter, IntGauge}; + +lazy_static::lazy_static! { + pub static ref OP_COUNTERS: OpMetrics = OpMetrics::new_and_registered("consensus"); +} + +lazy_static::lazy_static! { +////////////////////// +// HEALTH COUNTERS +////////////////////// +/// This counter is set to the round of the highest committed block. +pub static ref LAST_COMMITTED_ROUND: IntGauge = OP_COUNTERS.gauge("last_committed_round"); + +/// The counter corresponds to the version of the last committed ledger info. +pub static ref LAST_COMMITTED_VERSION: IntGauge = OP_COUNTERS.gauge("last_committed_version"); + +/// This counter is set to the round of the highest voted block. +pub static ref LAST_VOTE_ROUND: IntGauge = OP_COUNTERS.gauge("last_vote_round"); + +/// This counter is set to the round of the preferred block (highest 2-chain head). +pub static ref PREFERRED_BLOCK_ROUND: IntGauge = OP_COUNTERS.gauge("preferred_block_round"); + +/// This counter is set to the last round reported by the local pacemaker. +pub static ref CURRENT_ROUND: IntGauge = OP_COUNTERS.gauge("current_round"); + +/// Count of the block proposals sent by this validator since last restart. +pub static ref PROPOSALS_COUNT: IntCounter = OP_COUNTERS.counter("proposals_count"); + +/// Count of the committed blocks since last restart. +pub static ref COMMITTED_BLOCKS_COUNT: IntCounter = OP_COUNTERS.counter("committed_blocks_count"); + +/// Count of the committed transactions since last restart. +pub static ref COMMITTED_TXNS_COUNT: IntCounter = OP_COUNTERS.counter("committed_txns_count"); + +/// Count of success txns in the blocks committed by this validator since last restart. +pub static ref SUCCESS_TXNS_COUNT: IntCounter = OP_COUNTERS.counter("success_txns_count"); + +/// Count of failed txns in the committed blocks since last restart. +/// FAILED_TXNS_COUNT + SUCCESS_TXN_COUNT == COMMITTED_TXNS_COUNT +pub static ref FAILED_TXNS_COUNT: IntCounter = OP_COUNTERS.counter("failed_txns_count"); + +////////////////////// +// PACEMAKER COUNTERS +////////////////////// +/// Count of the rounds that gathered QC since last restart. +pub static ref QC_ROUNDS_COUNT: IntCounter = OP_COUNTERS.counter("qc_rounds_count"); + +/// Count of the timeout rounds since last restart (close to 0 in happy path). +pub static ref TIMEOUT_ROUNDS_COUNT: IntCounter = OP_COUNTERS.counter("timeout_rounds_count"); + +/// Count the number of timeouts a node experienced since last restart (close to 0 in happy path). +/// This count is different from `TIMEOUT_ROUNDS_COUNT`, because not every time a node has +/// a timeout there is an ultimate decision to move to the next round (it might take multiple +/// timeouts to get the timeout certificate). +pub static ref TIMEOUT_COUNT: IntCounter = OP_COUNTERS.counter("timeout_count"); + +/// The timeout of the current round. +pub static ref ROUND_TIMEOUT_MS: IntGauge = OP_COUNTERS.gauge("round_timeout_ms"); + +//////////////////////// +// SYNCMANAGER COUNTERS +//////////////////////// +/// Count the number of times we invoked state synchronization since last restart. +pub static ref STATE_SYNC_COUNT: IntCounter = OP_COUNTERS.counter("state_sync_count"); + +/// Count the overall number of transactions state synchronizer has retrieved since last restart. +/// Large values mean that a node has been significantly behind and had to replay a lot of txns. +pub static ref STATE_SYNC_TXN_REPLAYED: IntCounter = OP_COUNTERS.counter("state_sync_txns_replayed"); + +/// Count the number of block retrieval requests issued since last restart. +pub static ref BLOCK_RETRIEVAL_COUNT: IntCounter = OP_COUNTERS.counter("block_retrieval_count"); + +/// Histogram of block retrieval duration. +pub static ref BLOCK_RETRIEVAL_DURATION_MS: Histogram = OP_COUNTERS.histogram("block_retrieval_duration_ms"); + +/// Histogram of state sync duration. +pub static ref STATE_SYNC_DURATION_MS: Histogram = OP_COUNTERS.histogram("state_sync_duration_ms"); + +////////////////////// +// BLOCK STORE COUNTERS +////////////////////// +/// Counter for the number of blocks in the block tree (including the root). +/// In a "happy path" with no collisions and timeouts, should be equal to 3 or 4. +pub static ref NUM_BLOCKS_IN_TREE: IntGauge = OP_COUNTERS.gauge("num_blocks_in_tree"); + +////////////////////// +// PERFORMANCE COUNTERS +////////////////////// +/// Histogram of execution time (ms) of non-empty blocks. +pub static ref BLOCK_EXECUTION_DURATION_MS: Histogram = OP_COUNTERS.histogram("block_execution_duration_ms"); + +/// Histogram of duration of a commit procedure (the time it takes for the execution / storage to +/// commit a block once we decide to do so). +pub static ref BLOCK_COMMIT_DURATION_MS: Histogram = OP_COUNTERS.histogram("block_commit_duration_ms"); + +/// Histogram for the number of txns per (committed) blocks. +pub static ref NUM_TXNS_PER_BLOCK: Histogram = OP_COUNTERS.histogram("num_txns_per_block"); + +/// Histogram of per-transaction execution time (ms) of non-empty blocks +/// (calculated as the overall execution time of a block divided by the number of transactions). +pub static ref TXN_EXECUTION_DURATION_MS: Histogram = OP_COUNTERS.histogram("txn_execution_duration_ms"); + +/// Histogram of execution time (ms) of empty blocks. +pub static ref EMPTY_BLOCK_EXECUTION_DURATION_MS: Histogram = OP_COUNTERS.histogram("empty_block_execution_duration_ms"); + +/// Histogram of the time it takes for a block to get committed. +/// Measured as the commit time minus block's timestamp. +pub static ref CREATION_TO_COMMIT_MS: Histogram = OP_COUNTERS.histogram("creation_to_commit_ms"); + +/// Duration between block generation time until the moment it gathers full QC +pub static ref CREATION_TO_QC_MS: Histogram = OP_COUNTERS.histogram("creation_to_qc_ms"); + +/// Duration between block generation time until the moment it is received and ready for execution. +pub static ref CREATION_TO_RECEIVAL_MS: Histogram = OP_COUNTERS.histogram("creation_to_receival_ms"); + +//////////////////////////////////// +// PROPSOSAL/VOTE TIMESTAMP COUNTERS +//////////////////////////////////// +/// Count of the proposals that passed the timestamp rules and did not have to wait +pub static ref PROPOSAL_NO_WAIT_REQUIRED_COUNT: IntCounter = OP_COUNTERS.counter("proposal_no_wait_required_count"); + +/// Count of the proposals where passing the timestamp rules required waiting +pub static ref PROPOSAL_WAIT_WAS_REQUIRED_COUNT: IntCounter = OP_COUNTERS.counter("proposal_wait_was_required_count"); + +/// Count of the proposals that were not made due to the waiting period exceeding the maximum allowed duration, breaking timestamp rules +pub static ref PROPOSAL_MAX_WAIT_EXCEEDED_COUNT: IntCounter = OP_COUNTERS.counter("proposal_max_wait_exceeded_count"); + +/// Count of the proposals that were not made due to waiting to ensure the current time exceeds min_duration_since_epoch failed, breaking timestamp rules +pub static ref PROPOSAL_WAIT_FAILED_COUNT: IntCounter = OP_COUNTERS.counter("proposal_wait_failed_count"); + +/// Histogram of time waited for successfully proposing a proposal (both those that waited and didn't wait) after following timestamp rules +pub static ref PROPOSAL_SUCCESS_WAIT_MS: Histogram = OP_COUNTERS.histogram("proposal_success_wait_ms"); + +/// Histogram of time waited for failing to propose a proposal (both those that waited and didn't wait) while trying to follow timestamp rules +pub static ref PROPOSAL_FAILURE_WAIT_MS: Histogram = OP_COUNTERS.histogram("proposal_failure_wait_ms"); + +/// Count of the votes that passed the timestamp rules and did not have to wait +pub static ref VOTE_NO_WAIT_REQUIRED_COUNT: IntCounter = OP_COUNTERS.counter("vote_no_wait_required_count"); + +/// Count of the votes where passing the timestamp rules required waiting +pub static ref VOTE_WAIT_WAS_REQUIRED_COUNT: IntCounter = OP_COUNTERS.counter("vote_wait_was_required_count"); + +/// Count of the votes that were not made due to the waiting period exceeding the maximum allowed duration, breaking timestamp rules +pub static ref VOTE_MAX_WAIT_EXCEEDED_COUNT: IntCounter = OP_COUNTERS.counter("vote_max_wait_exceeded_count"); + +/// Count of the votes that were not made due to waiting to ensure the current time exceeds min_duration_since_epoch failed, breaking timestamp rules +pub static ref VOTE_WAIT_FAILED_COUNT: IntCounter = OP_COUNTERS.counter("vote_wait_failed_count"); + +/// Histogram of time waited for successfully having the ability to vote (both those that waited and didn't wait) after following timestamp rules. +/// A success only means that a replica has an opportunity to vote. It may not vote if it doesn't pass the voting rules. +pub static ref VOTE_SUCCESS_WAIT_MS: Histogram = OP_COUNTERS.histogram("vote_success_wait_ms"); + +/// Histogram of time waited for failing to have the ability to vote (both those that waited and didn't wait) while trying to follow timestamp rules +pub static ref VOTE_FAILURE_WAIT_MS: Histogram = OP_COUNTERS.histogram("vote_failure_wait_ms"); + +/////////////////// +// CHANNEL COUNTERS +/////////////////// +/// Count of the pending messages sent to itself in the channel +pub static ref PENDING_SELF_MESSAGES: IntGauge = OP_COUNTERS.gauge("pending_self_messages"); + +/// Count of the pending inbound proposals +pub static ref PENDING_PROPOSAL: IntGauge = OP_COUNTERS.gauge("pending_proposal"); + +/// Count of the pending inbound votes +pub static ref PENDING_VOTES: IntGauge = OP_COUNTERS.gauge("pending_votes"); + +/// Count of the pending inbound block requests +pub static ref PENDING_BLOCK_REQUESTS: IntGauge = OP_COUNTERS.gauge("pending_block_requests"); + +/// Count of the pending inbound chunk requests +pub static ref PENDING_CHUNK_REQUESTS: IntGauge = OP_COUNTERS.gauge("pending_chunk_requests"); + +/// Count of the pending inbound new round messages +pub static ref PENDING_NEW_ROUND_MESSAGES: IntGauge = OP_COUNTERS.gauge("pending_new_round_messages"); + +/// Count of the pending outbound pacemaker timeouts +pub static ref PENDING_PACEMAKER_TIMEOUTS: IntGauge = OP_COUNTERS.gauge("pending_pacemaker_timeouts"); +} diff --git a/consensus/src/lib.rs b/consensus/src/lib.rs new file mode 100644 index 0000000000000..5ee237bd79bee --- /dev/null +++ b/consensus/src/lib.rs @@ -0,0 +1,39 @@ +// Copyright (c) The Libra Core Contributors +// SPDX-License-Identifier: Apache-2.0 + +//! Consensus for the Libra Core blockchain +//! +//! Encapsulates public consensus traits and any implementations of those traits. +//! Currently, the only consensus protocol supported is LibraBFT (based on +//! [HotStuff](https://arxiv.org/pdf/1803.05069.pdf)). + +#![deny(missing_docs)] +#![feature(async_await, slice_patterns)] +#![feature(drain_filter)] +#![feature(checked_duration_since)] +#![feature(crate_visibility_modifier)] +#![recursion_limit = "128"] +#[macro_use] +extern crate failure; + +mod chained_bft; + +/// Defines the public consensus provider traits to implement for +/// use in the Libra Core blockchain. +pub mod consensus_provider; + +mod counters; + +mod state_computer; +mod state_replication; +mod state_synchronizer; +mod stream_utils; +mod time_service; +mod txn_manager; + +#[cfg(test)] +mod mock_time_service; +#[cfg(test)] +mod stream_utils_test; +#[cfg(test)] +mod time_service_test; diff --git a/consensus/src/mock_time_service.rs b/consensus/src/mock_time_service.rs new file mode 100644 index 0000000000000..fccfb28fcd476 --- /dev/null +++ b/consensus/src/mock_time_service.rs @@ -0,0 +1,138 @@ +// Copyright (c) The Libra Core Contributors +// SPDX-License-Identifier: Apache-2.0 + +use crate::time_service::{ScheduledTask, TimeService}; +use futures::{Future, FutureExt}; +use logger::prelude::*; +use std::{ + pin::Pin, + sync::{Arc, Mutex}, + time::Duration, +}; + +/// SimulatedTimeService implements TimeService, however it does not depend on actual time +/// There are multiple ways to use it: +/// SimulatedTimeService::new will create time service that simply 'stuck' on time 0 +/// SimulatedTimeService::update_auto_advance_limit can then be used to allow time to advance up to +/// certain limit. SimulatedTimeService::auto_advance_until will create time service that will 'run' +/// until certain time limit Note that SimulatedTimeService does not actually wait for any timeouts, +/// notion of time in it is abstract. Tasks run asap as long as they are scheduled before configured +/// time limit +pub struct SimulatedTimeService { + inner: Arc>, +} + +struct SimulatedTimeServiceInner { + now: Duration, + pending: Vec<(Duration, Box)>, + time_limit: Duration, + /// Maximum duration self.now is allowed to advance to + max: Duration, +} + +impl TimeService for SimulatedTimeService { + fn run_after(&self, timeout: Duration, mut t: Box) { + let mut inner = self.inner.lock().unwrap(); + let now = inner.now; + let deadline = now + timeout; + if deadline > inner.time_limit { + debug!( + "sched for deadline: {}, now: {}, limit: {}", + deadline.as_millis(), + now.as_millis(), + inner.time_limit.as_millis() + ); + inner.pending.push((deadline, t)); + } else { + debug!( + "exec deadline: {}, now: {}", + deadline.as_millis(), + now.as_millis() + ); + inner.now = deadline; + if inner.now > inner.max { + inner.now = inner.max; + } + // Perhaps this could be done better, but I think its good enough for tests... + futures::executor::block_on(t.run()); + } + } + + fn get_current_timestamp(&self) -> Duration { + self.inner.lock().unwrap().now + } + + fn sleep(&self, t: Duration) -> Pin + Send>> { + let inner = self.inner.clone(); + let fut = async move { + let mut inner = inner.lock().unwrap(); + inner.now += t; + if inner.now > inner.max { + inner.now = inner.max; + } + }; + fut.boxed() + } +} + +impl SimulatedTimeService { + /// Creates new SimulatedTimeService in disabled state (time not running) + pub fn new() -> SimulatedTimeService { + SimulatedTimeService { + inner: Arc::new(Mutex::new(SimulatedTimeServiceInner { + now: Duration::from_secs(0), + pending: vec![], + time_limit: Duration::from_secs(0), + max: Duration::from_secs(std::u64::MAX), + })), + } + } + + /// Creates new SimulatedTimeService in disabled state (time not running) with a max duration + pub fn max(max: Duration) -> SimulatedTimeService { + SimulatedTimeService { + inner: Arc::new(Mutex::new(SimulatedTimeServiceInner { + now: Duration::from_secs(0), + pending: vec![], + time_limit: Duration::from_secs(0), + max, + })), + } + } + + /// Creates new SimulatedTimeService that automatically advance time up to time_limit + pub fn auto_advance_until(time_limit: Duration) -> SimulatedTimeService { + SimulatedTimeService { + inner: Arc::new(Mutex::new(SimulatedTimeServiceInner { + now: Duration::from_secs(0), + pending: vec![], + time_limit, + max: Duration::from_secs(std::u64::MAX), + })), + } + } + + /// Update time_limit of this SimulatedTimeService instance and run pending tasks that has + /// deadline lower then new time_limit + #[allow(dead_code)] + pub fn update_auto_advance_limit(&mut self, time: Duration) { + let mut inner = self.inner.lock().unwrap(); + inner.time_limit += time; + let time_limit = inner.time_limit; + let drain = inner + .pending + .drain_filter(move |(deadline, _)| *deadline <= time_limit); + for (_, mut t) in drain { + // probably could be done better then that, but for now I feel its good enough for tests + futures::executor::block_on(t.run()); + } + } +} + +impl Clone for SimulatedTimeService { + fn clone(&self) -> SimulatedTimeService { + SimulatedTimeService { + inner: self.inner.clone(), + } + } +} diff --git a/consensus/src/state_computer.rs b/consensus/src/state_computer.rs new file mode 100644 index 0000000000000..3029a10204571 --- /dev/null +++ b/consensus/src/state_computer.rs @@ -0,0 +1,175 @@ +// Copyright (c) The Libra Core Contributors +// SPDX-License-Identifier: Apache-2.0 + +use crate::{ + chained_bft::QuorumCert, + counters, + state_replication::{StateComputeResult, StateComputer}, + state_synchronizer::{StateSynchronizer, SyncStatus}, +}; +use crypto::HashValue; +use execution_proto::proto::{ + execution::{CommitBlockRequest, CommitBlockStatus, ExecuteBlockRequest, ExecuteBlockResponse}, + execution_grpc::ExecutionClient, +}; +use failure::Result; +use futures::{compat::Future01CompatExt, Future, FutureExt}; +use proto_conv::{FromProto, IntoProto}; +use std::{pin::Pin, sync::Arc, time::Instant}; +use types::{ + ledger_info::LedgerInfoWithSignatures, + transaction::{SignedTransaction, TransactionListWithProof, TransactionStatus}, +}; + +/// Basic communication with the Execution module; +/// implements StateComputer traits. +pub struct ExecutionProxy { + execution: Arc, + synchronizer: Arc, +} + +impl ExecutionProxy { + pub fn new(execution: Arc, synchronizer: Arc) -> Self { + Self { + execution: Arc::clone(&execution), + synchronizer, + } + } + + fn process_exec_response( + response: ExecuteBlockResponse, + pre_execution_instant: Instant, + ) -> StateComputeResult { + let execution_block_response = execution_proto::ExecuteBlockResponse::from_proto(response) + .expect("Couldn't decode ExcecutionBlockResponse from protobuf"); + let execution_duration_ms = pre_execution_instant.elapsed().as_millis(); + let num_txns = execution_block_response.status().len(); + if num_txns == 0 { + // no txns in that block + counters::EMPTY_BLOCK_EXECUTION_DURATION_MS.observe(execution_duration_ms as f64); + } else { + counters::BLOCK_EXECUTION_DURATION_MS.observe(execution_duration_ms as f64); + let per_txn_duration = (execution_duration_ms as f64) / (num_txns as f64); + counters::TXN_EXECUTION_DURATION_MS.observe(per_txn_duration); + } + let mut compute_status = vec![]; + let mut num_successful_txns = 0; + for vm_status in execution_block_response.status() { + let status = match vm_status { + TransactionStatus::Keep(_) => { + num_successful_txns += 1; + true + } + TransactionStatus::Discard(_) => false, + }; + compute_status.push(status); + } + + StateComputeResult { + new_state_id: execution_block_response.root_hash(), + compute_status, + num_successful_txns, + validators: execution_block_response.validators().clone(), + } + } +} + +impl StateComputer for ExecutionProxy { + type Payload = Vec; + + fn compute( + &self, + // The id of a parent block, on top of which the given transactions should be executed. + parent_block_id: HashValue, + // The id of a current block. + block_id: HashValue, + // Transactions to execute. + transactions: &Self::Payload, + ) -> Pin> + Send>> { + let mut exec_req = ExecuteBlockRequest::new(); + exec_req.set_parent_block_id(parent_block_id.to_vec()); + exec_req.set_block_id(block_id.to_vec()); + exec_req.set_transactions(::protobuf::RepeatedField::from_vec( + transactions + .clone() + .into_iter() + .map(IntoProto::into_proto) + .collect(), + )); + + let pre_execution_instant = Instant::now(); + match self.execution.execute_block_async(&exec_req) { + Ok(receiver) => { + // convert from grpcio enum to failure::Error + async move { + match receiver.compat().await { + Ok(response) => { + Ok(Self::process_exec_response(response, pre_execution_instant)) + } + Err(e) => Err(e.into()), + } + } + .boxed() + } + Err(e) => async move { Err(e.into()) }.boxed(), + } + } + + /// Send a successful commit. A future is fulfilled when the state is finalized. + fn commit( + &self, + commit: LedgerInfoWithSignatures, + ) -> Pin> + Send>> { + counters::LAST_COMMITTED_VERSION.set(commit.ledger_info().version() as i64); + let mut commit_req = CommitBlockRequest::new(); + commit_req.set_ledger_info_with_sigs(commit.into_proto()); + + let pre_commit_instant = Instant::now(); + match self.execution.commit_block_async(&commit_req) { + Ok(receiver) => { + // convert from grpcio enum to failure::Error + async move { + match receiver.compat().await { + Ok(response) => { + if response.get_status() == CommitBlockStatus::SUCCEEDED { + let commit_duration_ms = pre_commit_instant.elapsed().as_millis(); + counters::BLOCK_COMMIT_DURATION_MS + .observe(commit_duration_ms as f64); + Ok(()) + } else { + Err(grpcio::Error::RpcFailure(grpcio::RpcStatus::new( + grpcio::RpcStatusCode::Unknown, + Some("Commit failure!".to_string()), + )) + .into()) + } + } + Err(e) => Err(e.into()), + } + } + .boxed() + } + Err(e) => async move { Err(e.into()) }.boxed(), + } + } + + /// Synchronize to a commit that not present locally. + fn sync_to( + &self, + commit: QuorumCert, + ) -> Pin> + Send>> { + counters::STATE_SYNC_COUNT.inc(); + self.synchronizer.sync_to(commit).boxed() + } + + fn get_chunk( + &self, + start_version: u64, + target_version: u64, + batch_size: u64, + ) -> Pin> + Send>> { + self.synchronizer + .get_chunk(start_version, target_version, batch_size) + .boxed() + } +} diff --git a/consensus/src/state_replication.rs b/consensus/src/state_replication.rs new file mode 100644 index 0000000000000..7c7fcfc1f3146 --- /dev/null +++ b/consensus/src/state_replication.rs @@ -0,0 +1,141 @@ +// Copyright (c) The Libra Core Contributors +// SPDX-License-Identifier: Apache-2.0 + +use crate::{chained_bft::QuorumCert, state_synchronizer::SyncStatus}; +use canonical_serialization::{CanonicalSerialize, CanonicalSerializer}; +use crypto::{hash::ACCUMULATOR_PLACEHOLDER_HASH, HashValue}; +use failure::Result; +use futures::Future; +use serde::{Deserialize, Serialize}; +use std::{pin::Pin, sync::Arc}; +use types::{ + ledger_info::LedgerInfoWithSignatures, + transaction::{TransactionListWithProof, Version}, + validator_set::ValidatorSet, +}; + +/// A structure that specifies the result of the execution. +/// The execution is responsible for generating the ID of the new state, which is returned in the +/// result. +/// +/// Not every transaction in the payload succeeds: the returned vector keeps the boolean status +/// of success / failure of the transactions. +/// Note that the specific details of compute_status are opaque to StateMachineReplication, +/// which is going to simply pass the results between StateComputer and TxnManager. +pub struct StateComputeResult { + /// The new state generated after the execution. + pub new_state_id: HashValue, + /// The compute status (success/failure) of the given payload. The specific details are opaque + /// for StateMachineReplication, which is merely passing it between StateComputer and + /// TxnManager. + pub compute_status: Vec, + /// Counts the number of `true` values in the `compute_status` field. + pub num_successful_txns: u64, + /// If set, these are the validator public keys that will be used to start the next epoch + /// immediately after this state is committed + /// TODO [Reconfiguration] the validators are currently ignored, no reconfiguration yet. + pub validators: Option, +} + +/// Retrieves and updates the status of transactions on demand (e.g., via talking with Mempool) +pub trait TxnManager: Send + Sync { + type Payload; + + /// Brings new transactions to be applied. + /// The `exclude_txns` list includes the transactions that are already pending in the + /// branch of blocks consensus is trying to extend. + fn pull_txns( + &self, + max_size: u64, + exclude_txns: Vec<&Self::Payload>, + ) -> Pin> + Send>>; + + /// Notifies TxnManager about the payload of the committed block including the state compute + /// result, which includes the specifics of what transactions succeeded and failed. + fn commit_txns<'a>( + &'a self, + txns: &Self::Payload, + compute_result: &StateComputeResult, + // Monotonic timestamp_usecs of committed blocks is used to GC expired transactions. + timestamp_usecs: u64, + ) -> Pin> + Send + 'a>>; +} + +#[derive(Clone, Copy, Debug, PartialEq, Eq, Serialize, Deserialize)] +pub struct ExecutedState { + pub state_id: HashValue, + pub version: Version, +} + +impl ExecutedState { + pub fn state_for_genesis() -> Self { + ExecutedState { + state_id: *ACCUMULATOR_PLACEHOLDER_HASH, + version: 0, + } + } +} + +impl CanonicalSerialize for ExecutedState { + fn serialize(&self, serializer: &mut impl CanonicalSerializer) -> Result<()> { + serializer.encode_raw_bytes(self.state_id.as_ref())?; + serializer.encode_u64(self.version)?; + Ok(()) + } +} + +/// While Consensus is managing proposed blocks, `StateComputer` is managing the results of the +/// (speculative) execution of their payload. +/// StateComputer is using proposed block ids for identifying the transactions. +pub trait StateComputer: Send + Sync { + type Payload; + + /// How to execute a sequence of transactions and obtain the next state. While some of the + /// transactions succeed, some of them can fail. + /// In case all the transactions are failed, new_state_id is equal to the previous state id. + fn compute( + &self, + // The id of a parent block, on top of which the given transactions should be executed. + // We're going to use a special GENESIS_BLOCK_ID constant defined in crypto::hash module to + // refer to the block id of the Genesis block, which is executed in a special way. + parent_block_id: HashValue, + // The id of a current block. + block_id: HashValue, + // Transactions to execute. + transactions: &Self::Payload, + ) -> Pin> + Send>>; + + /// Send a successful commit. A future is fulfilled when the state is finalized. + fn commit( + &self, + commit: LedgerInfoWithSignatures, + ) -> Pin> + Send>>; + + /// Synchronize to a commit that not present locally. + fn sync_to( + &self, + commit: QuorumCert, + ) -> Pin> + Send>>; + + /// Get a chunk of transactions as a batch + fn get_chunk( + &self, + start_version: u64, + target_version: u64, + batch_size: u64, + ) -> Pin> + Send>>; +} + +pub trait StateMachineReplication { + type Payload; + /// The function is synchronous: it returns when the state is initialized / recovered from + /// persisted storage and all the threads have been started. + fn start( + &mut self, + txn_manager: Arc>, + state_computer: Arc>, + ) -> Result<()>; + + /// Stop is synchronous: returns when all the threads are shutdown and the state is persisted. + fn stop(&mut self); +} diff --git a/consensus/src/state_synchronizer/coordinator.rs b/consensus/src/state_synchronizer/coordinator.rs new file mode 100644 index 0000000000000..eb285a0b5d505 --- /dev/null +++ b/consensus/src/state_synchronizer/coordinator.rs @@ -0,0 +1,284 @@ +// Copyright (c) The Libra Core Contributors +// SPDX-License-Identifier: Apache-2.0 + +use crate::{chained_bft::QuorumCert, counters, state_synchronizer::downloader::FetchChunkMsg}; +use config::config::NodeConfig; +use execution_proto::proto::{ + execution::{ExecuteChunkRequest, ExecuteChunkResponse}, + execution_grpc::ExecutionClient, +}; +use failure::prelude::*; +use futures::{ + channel::{mpsc, oneshot}, + Future, FutureExt, SinkExt, StreamExt, +}; +use grpc_helpers::convert_grpc_response; +use grpcio::{ChannelBuilder, EnvBuilder}; +use logger::prelude::*; +use proto_conv::IntoProto; +use std::{collections::BTreeMap, pin::Pin, sync::Arc}; +use storage_client::{StorageRead, StorageReadServiceClient}; +use types::proto::transaction::TransactionListWithProof; + +/// unified message used for communication with Coordinator +pub enum CoordinatorMsg { + // is sent from Synchronizer to Coordinator to request a new sync + Requested(QuorumCert, oneshot::Sender), + // is sent from Downloader to Coordinator to indicate that new batch is ready + Fetched(Result, QuorumCert), +} + +#[derive(Clone, Debug, PartialEq)] +pub enum SyncStatus { + Finished, + ExecutionFailed, + StorageReadFailed, + DownloadFailed, + DownloaderNotAvailable, + ChunkIsEmpty, +} + +/// used to coordinate synchronization process +/// handles Consensus requests and drives sync with remote peers +pub struct SyncCoordinator { + // communication with SyncCoordinator is done via this channel + receiver: mpsc::UnboundedReceiver, + // connection to transaction fetcher + sender_to_downloader: mpsc::Sender, + + // last committed version that validator is aware of + known_version: u64, + // target state to sync to + target: Option, + // used to track progress of synchronization + sync_position: u64, + // subscribers of synchronization + // each of them will be notified once their target version is ready + subscribers: BTreeMap>>, + executor_proxy: T, +} + +impl SyncCoordinator { + pub fn new( + receiver: mpsc::UnboundedReceiver, + sender_to_downloader: mpsc::Sender, + executor_proxy: T, + ) -> Self { + Self { + receiver, + sender_to_downloader, + + known_version: 0, + target: None, + sync_position: 0, + subscribers: BTreeMap::new(), + executor_proxy, + } + } + + /// main routine. starts sync coordinator that listens for CoordinatorMsg + pub async fn start(mut self) { + while let Some(msg) = self.receiver.next().await { + match msg { + CoordinatorMsg::Requested(qc, subscriber) => { + self.handle_request(qc, subscriber).await; + } + CoordinatorMsg::Fetched(Ok(txn_list_with_proof), ledger_info_with_sigs) => { + self.process_transactions(txn_list_with_proof, ledger_info_with_sigs) + .await; + } + CoordinatorMsg::Fetched(Err(_), _) => { + self.notify_subscribers(SyncStatus::DownloadFailed); + } + } + } + } + + fn target_version(&self) -> u64 { + match &self.target { + Some(qc) => qc.ledger_info().ledger_info().version(), + None => 0, + } + } + + /// Consensus request handler + async fn handle_request(&mut self, qc: QuorumCert, subscriber: oneshot::Sender) { + let requested_version = qc.ledger_info().ledger_info().version(); + let committed_version = self.executor_proxy.get_latest_version().await; + + // if requested version equals to current committed, just pass ledger info to executor + // there might be still empty blocks between committed state and requested + if let Ok(version) = committed_version { + if version == requested_version { + let status = match self + .store_transactions(TransactionListWithProof::new(), qc) + .await + { + Ok(_) => SyncStatus::Finished, + Err(_) => SyncStatus::ExecutionFailed, + }; + if subscriber.send(status).is_err() { + log_collector_error!( + "[state synchronizer] coordinator failed to notify subscriber" + ); + } + return; + } + } + + if requested_version > self.target_version() { + self.target = Some(qc.clone()); + } + + self.subscribers + .entry(requested_version) + .or_insert_with(|| vec![]) + .push(subscriber); + + if self.sync_position == 0 { + // start new fetch + match committed_version { + Ok(version) => { + self.known_version = version; + self.sync_position = self.known_version + 1; + // send request to Downloader + let fetch_request = FetchChunkMsg { + start_version: self.sync_position, + target: qc, + }; + if self.sender_to_downloader.send(fetch_request).await.is_err() { + self.notify_subscribers(SyncStatus::DownloaderNotAvailable); + } + } + Err(_) => { + self.notify_subscribers(SyncStatus::StorageReadFailed); + } + } + } + } + + /// processes batch of transactions downloaded by fetcher + /// executes transactions, updates progress state, notifies subscribers if some sync is finished + async fn process_transactions( + &mut self, + txn_list_with_proof: TransactionListWithProof, + qc: QuorumCert, + ) { + let chunk_size = txn_list_with_proof.get_transactions().len() as u64; + if chunk_size == 0 { + self.notify_subscribers(SyncStatus::ChunkIsEmpty); + } + self.sync_position += chunk_size; + + if let Some(target) = self.target.clone() { + if self.sync_position <= self.target_version() { + let fetch_msg = FetchChunkMsg { + start_version: self.sync_position, + target, + }; + // start download of next batch + if self.sender_to_downloader.send(fetch_msg).await.is_err() { + self.notify_subscribers(SyncStatus::DownloaderNotAvailable); + return; + } + } + } + + let status = match self.store_transactions(txn_list_with_proof, qc).await { + Ok(_) => SyncStatus::Finished, + Err(_) => SyncStatus::ExecutionFailed, + }; + counters::STATE_SYNC_TXN_REPLAYED.inc_by(chunk_size as i64); + self.notify_subscribers(status); + } + + fn notify_subscribers(&mut self, result: SyncStatus) { + let mut active_subscribers = match result { + SyncStatus::Finished => self.subscribers.split_off(&self.sync_position), + _ => BTreeMap::new(), + }; + + // notify subscribers if some syncs are ready + for channels in self.subscribers.values_mut() { + channels.drain(..).for_each(|ch| { + if ch.send(result.clone()).is_err() { + log_collector_error!( + "[state synchronizer] coordinator failed to notify subscriber" + ); + } + }); + } + self.subscribers.clear(); + self.subscribers.append(&mut active_subscribers); + // reset sync state if done + if self.subscribers.is_empty() { + self.sync_position = 0; + } + } + + async fn store_transactions( + &self, + txn_list_with_proof: TransactionListWithProof, + qc: QuorumCert, + ) -> Result { + let mut req = ExecuteChunkRequest::new(); + req.set_txn_list_with_proof(txn_list_with_proof); + req.set_ledger_info_with_sigs(qc.ledger_info().clone().into_proto()); + self.executor_proxy.execute_chunk(req).await + } +} + +/// Proxy execution for state synchronization +pub trait ExecutorProxyTrait: Sync + Send { + /// Return the latest known version + fn get_latest_version(&self) -> Pin> + Send>>; + + /// Execute and commit a batch of transactions + fn execute_chunk( + &self, + request: ExecuteChunkRequest, + ) -> Pin> + Send>>; +} + +pub(crate) struct ExecutorProxy { + storage_client: Arc, + execution_client: Arc, +} + +impl ExecutorProxy { + pub fn new(config: &NodeConfig) -> Self { + let connection_str = format!("localhost:{}", config.execution.port); + let env = Arc::new(EnvBuilder::new().name_prefix("grpc-coord-").build()); + let execution_client = Arc::new(ExecutionClient::new( + ChannelBuilder::new(Arc::clone(&env)).connect(&connection_str), + )); + let storage_client = Arc::new(StorageReadServiceClient::new( + env, + &config.storage.address, + config.storage.port, + )); + Self { + storage_client, + execution_client, + } + } +} + +impl ExecutorProxyTrait for ExecutorProxy { + fn get_latest_version(&self) -> Pin> + Send>> { + let client = Arc::clone(&self.storage_client); + async move { + let resp = client.update_to_latest_ledger_async(0, vec![]).await?; + Ok(resp.1.ledger_info().version()) + } + .boxed() + } + + fn execute_chunk( + &self, + request: ExecuteChunkRequest, + ) -> Pin> + Send>> { + let client = Arc::clone(&self.execution_client); + convert_grpc_response(client.execute_chunk_async(&request)).boxed() + } +} diff --git a/consensus/src/state_synchronizer/downloader.rs b/consensus/src/state_synchronizer/downloader.rs new file mode 100644 index 0000000000000..f55775f0ce833 --- /dev/null +++ b/consensus/src/state_synchronizer/downloader.rs @@ -0,0 +1,109 @@ +// Copyright (c) The Libra Core Contributors +// SPDX-License-Identifier: Apache-2.0 + +use crate::{ + chained_bft::QuorumCert, + counters::OP_COUNTERS, + state_synchronizer::{coordinator::CoordinatorMsg, PeerId}, +}; +use failure::prelude::*; +use futures::{channel::mpsc, SinkExt, StreamExt}; +use logger::prelude::*; +use network::{proto::RequestChunk, validator_network::ConsensusNetworkSender}; +use proto_conv::IntoProto; +use rand::{thread_rng, Rng}; +use std::time::Duration; +use types::proto::transaction::TransactionListWithProof; + +/// Used for communication between coordinator and downloader +/// and represents a single fetch request +#[derive(Clone)] +pub struct FetchChunkMsg { + // target version that we want to fetch + pub target: QuorumCert, + // version from which to start fetching (the offset version) + pub start_version: u64, +} + +/// Used to download chunks of transactions from peers +pub struct Downloader { + receiver_from_coordinator: mpsc::Receiver, + sender_to_coordinator: mpsc::UnboundedSender, + network: ConsensusNetworkSender, + batch_size: u64, + retries: usize, +} + +impl Downloader { + pub fn new( + receiver_from_coordinator: mpsc::Receiver, + sender_to_coordinator: mpsc::UnboundedSender, + network: ConsensusNetworkSender, + batch_size: u64, + retries: usize, + ) -> Self { + Self { + receiver_from_coordinator, + sender_to_coordinator, + network, + batch_size, + retries, + } + } + + /// Starts chunk downloader that listens to FetchChunkMsgs + pub async fn start(mut self) { + while let Some(msg) = self.receiver_from_coordinator.next().await { + for attempt in 0..self.retries { + let peer_id = self.pick_peer_id(&msg); + let download_result = self.download_chunk(peer_id, msg.clone()).await; + if download_result.is_ok() || attempt == self.retries - 1 { + let send_result = self + .sender_to_coordinator + .send(CoordinatorMsg::Fetched(download_result, msg.target)) + .await; + if send_result.is_err() { + log_collector_error!("[state synchronizer] failed to send chunk from downloader to coordinator"); + } + break; + } + } + } + } + + /// Downloads a chunk from another validator or from a cloud provider. + /// It then verifies that the data in the chunk is valid and returns the validated data. + async fn download_chunk( + &mut self, + peer_id: PeerId, + msg: FetchChunkMsg, + ) -> Result { + // Construct the message and use rpc call via network stack + let mut req = RequestChunk::new(); + req.set_start_version(msg.start_version); + req.set_target(msg.target.clone().into_proto()); + req.set_batch_size(self.batch_size); + // Longer-term, we will read from a cloud provider. But for testnet, just read + // from the node which is proposing this block + let mut resp = self + .network + .request_chunk(peer_id, req, Duration::from_millis(1000)) + .await?; + + OP_COUNTERS.inc_by( + "download", + resp.get_txn_list_with_proof().get_transactions().len(), + ); + Ok(resp.take_txn_list_with_proof()) + } + + fn pick_peer_id(&self, msg: &FetchChunkMsg) -> PeerId { + let signatures = msg.target.ledger_info().signatures(); + let idx = thread_rng().gen_range(0, signatures.len()); + signatures + .keys() + .nth(idx) + .cloned() + .expect("[state synchronizer] failed to pick peer from qc") + } +} diff --git a/consensus/src/state_synchronizer/mocks.rs b/consensus/src/state_synchronizer/mocks.rs new file mode 100644 index 0000000000000..d679687360ff6 --- /dev/null +++ b/consensus/src/state_synchronizer/mocks.rs @@ -0,0 +1,66 @@ +// Copyright (c) The Libra Core Contributors +// SPDX-License-Identifier: Apache-2.0 + +use crate::state_synchronizer::coordinator::ExecutorProxyTrait; +use crypto::HashValue; +use execution_proto::proto::execution::{ExecuteChunkRequest, ExecuteChunkResponse}; +use failure::Result; +use futures::{Future, FutureExt}; +use proto_conv::FromProto; +use std::{ + pin::Pin, + sync::atomic::{AtomicU64, Ordering}, +}; +use types::{ + account_address::AccountAddress, + proof::AccumulatorProof, + test_helpers::transaction_test_helpers::get_test_signed_txn, + transaction::{SignedTransaction, TransactionInfo, TransactionListWithProof}, +}; +use vm_genesis::{encode_transfer_program, GENESIS_KEYPAIR}; + +#[derive(Default)] +pub struct MockExecutorProxy { + version: AtomicU64, +} + +impl ExecutorProxyTrait for MockExecutorProxy { + fn get_latest_version(&self) -> Pin> + Send>> { + let version = self.version.load(Ordering::Relaxed); + async move { Ok(version) }.boxed() + } + + fn execute_chunk( + &self, + _request: ExecuteChunkRequest, + ) -> Pin> + Send>> { + self.version.fetch_add(1, Ordering::Relaxed); + async move { Ok(ExecuteChunkResponse::new()) }.boxed() + } +} + +pub fn gen_txn_list(sequence_number: u64) -> TransactionListWithProof { + let sender = AccountAddress::from(GENESIS_KEYPAIR.1); + let receiver = AccountAddress::new([0xff; 32]); + let program = encode_transfer_program(&receiver, 1); + let transaction = get_test_signed_txn( + sender.into(), + sequence_number, + GENESIS_KEYPAIR.0.clone(), + GENESIS_KEYPAIR.1, + Some(program), + ); + + let txn_info = TransactionInfo::new(HashValue::zero(), HashValue::zero(), HashValue::zero(), 0); + let accumulator_proof = AccumulatorProof::new(vec![]); + TransactionListWithProof::new( + vec![( + SignedTransaction::from_proto(transaction).unwrap(), + txn_info, + )], + None, + Some(0), + Some(accumulator_proof), + None, + ) +} diff --git a/consensus/src/state_synchronizer/mod.rs b/consensus/src/state_synchronizer/mod.rs new file mode 100644 index 0000000000000..8ab7efb252aed --- /dev/null +++ b/consensus/src/state_synchronizer/mod.rs @@ -0,0 +1,39 @@ +// Copyright (c) The Libra Core Contributors +// SPDX-License-Identifier: Apache-2.0 + +//! This library is used to perform synchronization between validators for committed states. +//! This is used for restarts and catching up +//! +//! It consists of three components: `SyncCoordinator`, `Downloader` and `StateSynchronizer` +//! +//! `Downloader` is used to download chunks of transactions from peers +//! +//! `SyncCoordinator` drives synchronization process. It handles new requests from Consensus and +//! drives whole sync flow +//! +//! `StateSynchronizer` is an external interface for module. +//! It's used for convenient communication with `SyncCoordinator`. +//! To set it up do: **let synchronizer = StateSynchronizer::setup(network, executor, config)**. +//! +//! It will spawn coordinator and downloader routines and return handle for communication with +//! coordinator. +//! To request synchronization call: **synchronizer.sync_to(peer_id, version).await** +//! +//! Note that it's possible to issue multiple synchronization requests at the same time. +//! `SyncCoordinator` handles it and makes sure each chunk will be downloaded only once + +pub use self::coordinator::SyncStatus; + +mod coordinator; +mod downloader; +mod synchronizer; + +pub use self::synchronizer::{setup_state_synchronizer, StateSynchronizer}; +use types::account_address::AccountAddress; + +#[cfg(test)] +mod mocks; +#[cfg(test)] +mod sync_test; + +pub type PeerId = AccountAddress; diff --git a/consensus/src/state_synchronizer/sync_test.rs b/consensus/src/state_synchronizer/sync_test.rs new file mode 100644 index 0000000000000..74d1cb455c5c3 --- /dev/null +++ b/consensus/src/state_synchronizer/sync_test.rs @@ -0,0 +1,250 @@ +// Copyright (c) The Libra Core Contributors +// SPDX-License-Identifier: Apache-2.0 + +use crate::{ + chained_bft::{test_utils, QuorumCert}, + state_replication::ExecutedState, + state_synchronizer::{ + coordinator::SyncStatus, + mocks::{gen_txn_list, MockExecutorProxy}, + PeerId, StateSynchronizer, + }, +}; +use bytes::Bytes; +use config::config::NodeConfig; +use config_builder::util::get_test_config; +use crypto::{signing, x25519, HashValue}; +use failure::prelude::*; +use futures::{ + executor::block_on, + future::{join_all, TryFutureExt}, + stream::StreamExt, + FutureExt, +}; +use metrics::get_all_metrics; +use network::{ + proto::{ConsensusMsg, RespondChunk}, + validator_network::{ + network_builder::{NetworkBuilder, TransportType}, + Event, RpcError, CONSENSUS_RPC_PROTOCOL, + }, + NetworkPublicKeys, ProtocolId, +}; +use parity_multiaddr::Multiaddr; +use proto_conv::IntoProto; +use protobuf::Message; +use rusty_fork::{rusty_fork_id, rusty_fork_test, rusty_fork_test_name}; +use std::{ + collections::HashMap, + sync::atomic::{AtomicUsize, Ordering}, +}; +use tokio::runtime::Runtime; +use types::{ + ledger_info::{LedgerInfo, LedgerInfoWithSignatures}, + transaction::TransactionListWithProof, +}; + +struct SynchronizerEnv { + synchronizers: Vec, + peers: Vec, + _runtime: Runtime, +} + +impl SynchronizerEnv { + fn new() -> Self { + let handler = Box::new(|| -> Result { Ok(gen_txn_list(0)) }); + Self::new_with(handler, None) + } + + fn new_with( + handler: Box Result + Send + 'static>, + opt_config: Option, + ) -> Self { + let mut runtime = test_utils::consensus_runtime(); + let config = opt_config.unwrap_or_else(|| { + let (config_inner, _) = get_test_config(); + config_inner + }); + let peers = vec![PeerId::random(), PeerId::random()]; + + // setup network + let addr: Multiaddr = "/memory/0".parse().unwrap(); + let protocols = vec![ProtocolId::from_static(CONSENSUS_RPC_PROTOCOL)]; + + // Setup signing public keys. + let (a_signing_private_key, a_signing_public_key) = signing::generate_keypair(); + let (b_signing_private_key, b_signing_public_key) = signing::generate_keypair(); + // Setup identity public keys. + let (a_identity_private_key, a_identity_public_key) = x25519::generate_keypair(); + let (b_identity_private_key, b_identity_public_key) = x25519::generate_keypair(); + + let trusted_peers: HashMap<_, _> = vec![ + ( + peers[0], + NetworkPublicKeys { + signing_public_key: a_signing_public_key, + identity_public_key: a_identity_public_key, + }, + ), + ( + peers[1], + NetworkPublicKeys { + signing_public_key: b_signing_public_key, + identity_public_key: b_identity_public_key, + }, + ), + ] + .into_iter() + .collect(); + + let (_, (sender_b, mut events_b), listener_addr) = + NetworkBuilder::new(runtime.executor(), peers[1], addr.clone()) + .signing_keys((b_signing_private_key, b_signing_public_key)) + .identity_keys((b_identity_private_key, b_identity_public_key)) + .trusted_peers(trusted_peers.clone()) + .transport(TransportType::Memory) + .consensus_protocols(protocols.clone()) + .rpc_protocols(protocols.clone()) + .build(); + + let (sender_a, mut events_a) = + NetworkBuilder::new(runtime.executor(), peers[0], addr.clone()) + .transport(TransportType::Memory) + .signing_keys((a_signing_private_key, a_signing_public_key)) + .identity_keys((a_identity_private_key, a_identity_public_key)) + .trusted_peers(trusted_peers.clone()) + .seed_peers([(peers[1], vec![listener_addr])].iter().cloned().collect()) + .consensus_protocols(protocols.clone()) + .rpc_protocols(protocols) + .build() + .1; + + // await peer discovery + block_on(events_a.next()).unwrap().unwrap(); + block_on(events_b.next()).unwrap().unwrap(); + + // create synchronizers + let synchronizers = vec![ + StateSynchronizer::new( + sender_a, + runtime.executor(), + &config, + MockExecutorProxy::default(), + ), + StateSynchronizer::new( + sender_b, + runtime.executor(), + &config, + MockExecutorProxy::default(), + ), + ]; + + let rpc_handler = async move { + while let Some(event) = events_b.next().await { + if let Ok(Event::RpcRequest((_, _, callback))) = event { + match handler() { + Ok(txn_list) => { + let mut response_msg = ConsensusMsg::new(); + let mut response = RespondChunk::new(); + response.set_txn_list_with_proof(txn_list.into_proto()); + response_msg.set_respond_chunk(response); + let response_data = Bytes::from(response_msg.write_to_bytes().unwrap()); + callback.send(Ok(response_data)).unwrap(); + } + Err(err) => { + callback.send(Err(RpcError::ApplicationError(err))).unwrap(); + } + } + } + } + }; + runtime.spawn(rpc_handler.boxed().unit_error().compat()); + + Self { + synchronizers, + peers, + _runtime: runtime, + } + } + + fn gen_commit(&self, version: u64) -> QuorumCert { + let ledger_info = LedgerInfo::new( + version, + HashValue::zero(), + HashValue::zero(), + HashValue::zero(), + 0, + 0, + ); + let mut signatures = HashMap::new(); + let private_key = signing::generate_genesis_keypair().0; + let signature = signing::sign_message(HashValue::zero(), &private_key).unwrap(); + signatures.insert(self.peers[1], signature); + QuorumCert::new( + HashValue::zero(), + ExecutedState::state_for_genesis(), + 0, + LedgerInfoWithSignatures::new(ledger_info, signatures), + ) + } +} + +#[test] +fn test_basic_flow() { + let env = SynchronizerEnv::new(); + + // test small sequential syncs + for version in 1..5 { + let status = block_on(env.synchronizers[0].sync_to(env.gen_commit(version))); + assert_eq!(status.unwrap(), SyncStatus::Finished); + } + // test batch sync for multiple transactions + let status = block_on(env.synchronizers[0].sync_to(env.gen_commit(10))); + assert_eq!(status.unwrap(), SyncStatus::Finished); +} + +rusty_fork_test! { +#[test] +fn test_concurrent_requests() { + let env = SynchronizerEnv::new(); + + let requests = vec![ + env.synchronizers[0].sync_to(env.gen_commit(1)), + env.synchronizers[0].sync_to(env.gen_commit(2)), + env.synchronizers[0].sync_to(env.gen_commit(3)), + ]; + // ensure we can execute requests in parallel + block_on(join_all(requests)); + // ensure we downloaded each chunk exactly 1 time + let metrics = get_all_metrics(); + assert_eq!(metrics["consensus{op=download}"].parse::().unwrap(), 3); +} +} + +#[test] +fn test_download_failure() { + // create handler that causes errors + let handler = Box::new(|| -> Result { bail!("chunk fetch failed") }); + + let env = SynchronizerEnv::new_with(handler, None); + let status = block_on(env.synchronizers[0].sync_to(env.gen_commit(5))); + assert_eq!(status.unwrap(), SyncStatus::DownloadFailed); +} + +#[test] +fn test_download_retry() { + // create handler that causes error, but has successful retries + let attempt = AtomicUsize::new(0); + let handler = Box::new(move || -> Result { + let fail_request = attempt.load(Ordering::Relaxed) == 0; + attempt.fetch_add(1, Ordering::Relaxed); + if fail_request { + bail!("chunk fetch failed") + } else { + Ok(gen_txn_list(0)) + } + }); + let env = SynchronizerEnv::new_with(handler, None); + let status = block_on(env.synchronizers[0].sync_to(env.gen_commit(1))); + assert_eq!(status.unwrap(), SyncStatus::Finished); +} diff --git a/consensus/src/state_synchronizer/synchronizer.rs b/consensus/src/state_synchronizer/synchronizer.rs new file mode 100644 index 0000000000000..64453b33f014c --- /dev/null +++ b/consensus/src/state_synchronizer/synchronizer.rs @@ -0,0 +1,122 @@ +// Copyright (c) The Libra Core Contributors +// SPDX-License-Identifier: Apache-2.0 + +use crate::{ + chained_bft::QuorumCert, + state_synchronizer::{ + coordinator::{ + CoordinatorMsg, ExecutorProxy, ExecutorProxyTrait, SyncCoordinator, SyncStatus, + }, + downloader::Downloader, + }, +}; +use config::config::NodeConfig; +use failure::prelude::*; +use futures::{ + channel::{mpsc, oneshot}, + future::{Future, FutureExt, TryFutureExt}, + SinkExt, +}; +use grpcio::EnvBuilder; +use logger::prelude::*; +use network::validator_network::ConsensusNetworkSender; +use std::sync::Arc; +use storage_client::{StorageRead, StorageReadServiceClient}; +use tokio::runtime::TaskExecutor; +use types::transaction::TransactionListWithProof; + +/// Used for synchronization between validators for committed states +pub struct StateSynchronizer { + synchronizer_to_coordinator: mpsc::UnboundedSender, + storage_read_client: Arc, +} + +impl StateSynchronizer { + /// Setup state synchronizer. spawns coordinator and downloader routines on executor + pub fn new( + network: ConsensusNetworkSender, + executor: TaskExecutor, + config: &NodeConfig, + executor_proxy: E, + ) -> Self { + let (coordinator_sender, coordinator_receiver) = mpsc::unbounded(); + let (fetcher_sender, fetcher_receiver) = mpsc::channel(1); + + let coordinator = + SyncCoordinator::new(coordinator_receiver, fetcher_sender, executor_proxy); + let downloader = Downloader::new( + fetcher_receiver, + coordinator_sender.clone(), + network, + config.base.node_sync_batch_size, + config.base.node_sync_retries, + ); + + executor.spawn(coordinator.start().boxed().unit_error().compat()); + executor.spawn(downloader.start().boxed().unit_error().compat()); + + let env = Arc::new(EnvBuilder::new().name_prefix("grpc-sync-").build()); + let storage_read_client = Arc::new(StorageReadServiceClient::new( + env, + &config.storage.address, + config.storage.port, + )); + + Self { + synchronizer_to_coordinator: coordinator_sender, + storage_read_client, + } + } + + /// Sync validator's state up to given `version` + pub fn sync_to(&self, qc: QuorumCert) -> impl Future> { + let mut sender = self.synchronizer_to_coordinator.clone(); + let (cb_sender, cb_receiver) = oneshot::channel(); + async move { + sender + .send(CoordinatorMsg::Requested(qc, cb_sender)) + .await?; + let sync_status = cb_receiver.await?; + Ok(sync_status) + } + } + + /// Get a batch of transactions + pub fn get_chunk( + &self, + start_version: u64, + target_version: u64, + batch_size: u64, + ) -> impl Future> { + let client = Arc::clone(&self.storage_read_client); + async move { + let txn_list_with_proof = client + .get_transactions_async( + start_version, + batch_size, + target_version, + false, /* fetch_events */ + ) + .await?; + + if txn_list_with_proof.transaction_and_infos.is_empty() { + log_collector_warn!( + "Not able to get txn from version {} for {} items", + start_version, + batch_size + ); + } + Ok(txn_list_with_proof) + } + } +} + +/// Make the state synchronizer +pub fn setup_state_synchronizer( + network: ConsensusNetworkSender, + executor: TaskExecutor, + config: &NodeConfig, +) -> StateSynchronizer { + let executor_proxy = ExecutorProxy::new(config); + StateSynchronizer::new(network, executor, config, executor_proxy) +} diff --git a/consensus/src/stream_utils.rs b/consensus/src/stream_utils.rs new file mode 100644 index 0000000000000..3263b937f5351 --- /dev/null +++ b/consensus/src/stream_utils.rs @@ -0,0 +1,84 @@ +// Copyright (c) The Libra Core Contributors +// SPDX-License-Identifier: Apache-2.0 + +use futures::{channel::mpsc, Future, FutureExt, StreamExt, TryFutureExt}; +use std::{pin::Pin, sync::Arc}; +use tokio::runtime::TaskExecutor; + +/// EventBasedActor trait represents the actor style objects that are driven by some input +/// stream of events and that generate an output stream of events in response. +/// +/// Please see an example of the usage in stream_utils_test.rs +pub trait EventBasedActor { + type InputEvent; + type OutputEvent; + + /// Called synchronously when EventBasedActor startup + fn init( + &mut self, + input_stream_sender: mpsc::Sender, + output_stream_sender: mpsc::Sender, + ); + + /// Called before the main event loop starts. + /// The returned future is chained s.t. the first event processing starts after the completion + /// of the startup preparations. + fn on_startup(&self) -> Pin + Send>> { + async {}.boxed() + } + + /// Process a new event from the stream. + /// An implementation can generate new input / output events as a result. + /// The returned future is chained by the main event processing loop s.t. + /// the total order of the returned futures is preserved. + fn process_event(&self, event: Self::InputEvent) -> Pin + Send>>; +} + +/// Starts a loop of event processing for a given actor. +/// Events are received via the given input received, processed one by one by the actor, which is +/// kept in the state of the executor using the 'fold' function. +pub fn start_event_processing_loop( + actor: &mut Arc, + executor: TaskExecutor, +) -> (mpsc::Sender, mpsc::Receiver) +where + A: EventBasedActor + Send + Sync + 'static + ?Sized, + A::InputEvent: Send, +{ + let (input_tx, mut input_rx, output_tx, output_rx) = prep_channels::(); + Arc::get_mut(actor) + .expect("can not clone Arc before start") + .init(input_tx.clone(), output_tx); + + let actor = Arc::clone(actor); + + let processing_loop = async move { + actor.on_startup().await; + + while let Some(event) = input_rx.next().await { + actor.process_event(event).await; + } + }; + executor.spawn(processing_loop.boxed().unit_error().compat()); + (input_tx, output_rx) +} + +/// Generates the mpsc channels for input and output events. +pub fn prep_channels() -> ( + mpsc::Sender, + mpsc::Receiver, + mpsc::Sender, + mpsc::Receiver, +) +where + A: EventBasedActor + Send + 'static + ?Sized, +{ + let (input_events_tx, input_events_rx) = mpsc::channel(1_024); + let (output_events_tx, output_events_rx) = mpsc::channel(1_024); + ( + input_events_tx, + input_events_rx, + output_events_tx, + output_events_rx, + ) +} diff --git a/consensus/src/stream_utils_test.rs b/consensus/src/stream_utils_test.rs new file mode 100644 index 0000000000000..4dcdefa51d1ed --- /dev/null +++ b/consensus/src/stream_utils_test.rs @@ -0,0 +1,74 @@ +// Copyright (c) The Libra Core Contributors +// SPDX-License-Identifier: Apache-2.0 + +use crate::stream_utils::{start_event_processing_loop, EventBasedActor}; +use futures::{channel::mpsc, executor::block_on, Future, FutureExt, SinkExt, StreamExt}; +use logger::prelude::*; +use std::{ + pin::Pin, + sync::{Arc, RwLock}, +}; +use tokio::runtime; + +struct FibonacciActor { + a: u32, + b: u32, + output_stream: Option>, +} + +impl FibonacciActor { + pub fn new() -> Self { + Self { + a: 0, + b: 1, + output_stream: None, + } + } +} + +impl EventBasedActor for RwLock { + type InputEvent = (); + type OutputEvent = u32; + + fn init( + &mut self, + _: mpsc::Sender, + output_stream: mpsc::Sender, + ) { + self.write().unwrap().output_stream = Some(output_stream); + } + + fn process_event(&self, _: Self::InputEvent) -> Pin + Send>> { + let mut guard = self.write().unwrap(); + let next = guard.a + guard.b; + let mut sender = guard.output_stream.as_ref().unwrap().clone(); + let a = guard.a; + let send_fut = async move { + if let Err(e) = sender.send(a).await { + debug!("Error in sending output event {:?}", e); + } + }; + guard.a = guard.b; + guard.b = next; + send_fut.boxed() + } +} + +#[test] +fn test_event_loop() { + let runtime = runtime::Builder::new() + .build() + .expect("Failed to create Tokio runtime!"); + let mut fib_actor = Arc::new(RwLock::new(FibonacciActor::new())); + let (mut input_tx, output_rx) = start_event_processing_loop(&mut fib_actor, runtime.executor()); + + block_on(async move { + for _ in 0..5 { + input_tx.send(()).await.unwrap(); + } + assert_eq!( + output_rx.take(5).collect::>().await, + vec![0, 1, 1, 2, 3,] + ); + }); +} diff --git a/consensus/src/time_service.rs b/consensus/src/time_service.rs new file mode 100644 index 0000000000000..5905db2c6f644 --- /dev/null +++ b/consensus/src/time_service.rs @@ -0,0 +1,206 @@ +// Copyright (c) The Libra Core Contributors +// SPDX-License-Identifier: Apache-2.0 + +use futures::{ + channel::mpsc::Sender, compat::Future01CompatExt, Future, FutureExt, SinkExt, TryFutureExt, +}; +use logger::prelude::*; +use std::{ + pin::Pin, + time::{Duration, Instant, SystemTime, UNIX_EPOCH}, +}; +use tokio::{executor::Executor, runtime::TaskExecutor, timer::Delay}; + +/// Time service is an abstraction for operations that depend on time +/// It supports implementations that can simulated time or depend on actual time +/// We can use simulated time in tests so tests can run faster and be more stable. +/// see SimulatedTime for implementation that tests should use +/// Time service also supports opportunities for future optimizations +/// For example instead of scheduling O(N) tasks in TaskExecutor we could have more optimal code +/// that only keeps single task in TaskExecutor +pub trait TimeService: Send + Sync { + /// Sends message to given sender after timeout + fn run_after(&self, timeout: Duration, task: Box); + + /// Retrieve the current time stamp as a Duration (assuming it is on or after the UNIX_EPOCH) + fn get_current_timestamp(&self) -> Duration; + + /// Makes a future that will sleep for given Duration + /// This function guarantees that get_current_timestamp will increase at least by + /// given duration, e.g. + /// X = time_service::get_current_timestamp(); + /// time_service::sleep(Y).await; + /// Z = time_service::get_current_timestamp(); + /// assert(Z >= X + Y) + fn sleep(&self, t: Duration) -> Pin + Send>>; +} + +/// This trait represents abstract task that can be submitted to TimeService::run_after +pub trait ScheduledTask: Send { + /// TimeService::run_after will run this method when time expires + /// It is expected that this function is lightweight and does not take long time to complete + fn run(&mut self) -> Pin + Send>>; +} + +/// This tasks send message to given Sender +pub struct SendTask +where + T: Send + 'static, +{ + sender: Option>, + message: Option, +} + +impl SendTask +where + T: Send + 'static, +{ + /// Makes new SendTask for given sender and message and wraps it to Box + pub fn make(sender: Sender, message: T) -> Box { + Box::new(SendTask { + sender: Some(sender), + message: Some(message), + }) + } +} + +impl ScheduledTask for SendTask +where + T: Send + 'static, +{ + fn run(&mut self) -> Pin + Send>> { + let mut sender = self.sender.take().unwrap(); + let message = self.message.take().unwrap(); + let r = async move { + if let Err(e) = sender.send(message).await { + error!("Error on send: {:?}", e); + }; + }; + r.boxed() + } +} + +/// TimeService implementation that uses actual clock to schedule tasks +pub struct ClockTimeService { + executor: TaskExecutor, +} + +impl ClockTimeService { + /// Creates new TimeService that runs tasks based on actual clock + /// It needs executor to schedule internal tasks that facilitates it's work + pub fn new(executor: TaskExecutor) -> ClockTimeService { + ClockTimeService { executor } + } +} + +impl TimeService for ClockTimeService { + fn run_after(&self, timeout: Duration, mut t: Box) { + let task = async move { + let timeout_time = Instant::now() + timeout; + if let Err(e) = Delay::new(timeout_time).compat().await { + error!("Error on delay: {:?}", e); + }; + t.run().await; + }; + let task = task.boxed().unit_error().compat(); + let mut executor = self.executor.clone(); + if let Err(e) = Executor::spawn(&mut executor, Box::new(task)) { + warn!("Failed to submit task to runtime: {:?}", e) + } + } + + fn get_current_timestamp(&self) -> Duration { + duration_since_epoch() + } + + fn sleep(&self, t: Duration) -> Pin + Send>> { + async move { Delay::new(Instant::now() + t).compat().await.unwrap() }.boxed() + } +} + +/// Return the duration since the UNIX_EPOCH +pub fn duration_since_epoch() -> Duration { + SystemTime::now() + .duration_since(UNIX_EPOCH) + .expect("Timestamp generated is before the UNIX_EPOCH!") +} + +/// Success states for wait_if_possible +#[derive(Debug, PartialEq, Eq)] +pub enum WaitingSuccess { + /// No waiting to complete and includes the current duration since epoch and the difference + /// between the current duration since epoch and min_duration_since_epoch + NoWaitRequired { + current_duration_since_epoch: Duration, + early_duration: Duration, + }, + /// Waiting was required and includes the current duration since epoch and the duration + /// slept to finish waiting + WaitWasRequired { + current_duration_since_epoch: Duration, + wait_duration: Duration, + }, +} + +/// Error states for wait_if_possible +#[derive(Debug, PartialEq, Eq)] +pub enum WaitingError { + /// The waiting period exceeds the maximum allowed duration, returning immediately + MaxWaitExceeded, + /// Waiting to ensure the current time exceeds min_duration_since_epoch failed + WaitFailed { + current_duration_since_epoch: Duration, + wait_duration: Duration, + }, +} + +/// Attempt to wait until the current time exceeds the min_duration_since_epoch if possible +/// +/// If the waiting time exceeds max_instant then fail immediately. +/// There are 4 potential outcomes, 2 successful and 2 errors, each represented by +/// WaitingSuccess and WaitingError. +pub async fn wait_if_possible( + time_service: &TimeService, + min_duration_since_epoch: Duration, + max_instant: Instant, +) -> Result { + // Fail early if waiting for min_duration_since_epoch would exceed max_instant + // Ideally, comparing min_duration_since_epoch and max_instant would be straightforward, but + // min_duration_since_epoch is relative to UNIX_EPOCH and Instant is not comparable. Therefore, + // we use relative differences to do the comparison. + let current_instant = Instant::now(); + let current_duration_since_epoch = time_service.get_current_timestamp(); + if current_instant <= max_instant { + let duration_to_max_time = max_instant.duration_since(current_instant); + if current_duration_since_epoch <= min_duration_since_epoch { + let duration_to_min_time = min_duration_since_epoch - current_duration_since_epoch; + if duration_to_max_time < duration_to_min_time { + return Err(WaitingError::MaxWaitExceeded); + } + } + } + + if current_duration_since_epoch <= min_duration_since_epoch { + // Delay has millisecond granularity, add 1 millisecond to ensure a higher timestamp + let sleep_duration = + min_duration_since_epoch - current_duration_since_epoch + Duration::from_millis(1); + time_service.sleep(sleep_duration).await; + let waited_duration_since_epoch = time_service.get_current_timestamp(); + if waited_duration_since_epoch > min_duration_since_epoch { + Ok(WaitingSuccess::WaitWasRequired { + current_duration_since_epoch: waited_duration_since_epoch, + wait_duration: sleep_duration, + }) + } else { + Err(WaitingError::WaitFailed { + current_duration_since_epoch: waited_duration_since_epoch, + wait_duration: sleep_duration, + }) + } + } else { + Ok(WaitingSuccess::NoWaitRequired { + current_duration_since_epoch, + early_duration: current_duration_since_epoch - min_duration_since_epoch, + }) + } +} diff --git a/consensus/src/time_service_test.rs b/consensus/src/time_service_test.rs new file mode 100644 index 0000000000000..c90b0e414c045 --- /dev/null +++ b/consensus/src/time_service_test.rs @@ -0,0 +1,81 @@ +use crate::{ + mock_time_service::SimulatedTimeService, + time_service::{wait_if_possible, TimeService, WaitingError, WaitingSuccess}, +}; +use futures::executor::block_on; +use std::time::{Duration, Instant}; + +#[test] +fn wait_if_possible_test_waiting_required() { + let simulated_time = SimulatedTimeService::new(); + let min_duration_since_epoch = Duration::from_secs(1); + let max_instant = Instant::now() + Duration::from_secs(2); + let result = block_on(wait_if_possible( + &simulated_time, + min_duration_since_epoch, + max_instant, + )); + + assert_eq!( + result.ok().unwrap(), + WaitingSuccess::WaitWasRequired { + current_duration_since_epoch: min_duration_since_epoch + Duration::from_millis(1), + wait_duration: min_duration_since_epoch + Duration::from_millis(1), + } + ); +} + +#[test] +fn wait_if_possible_test_no_wait_required() { + let simulated_time = SimulatedTimeService::new(); + block_on(simulated_time.sleep(Duration::from_secs(3))); + let min_duration_since_epoch = Duration::from_secs(1); + let max_instant = Instant::now() + Duration::from_secs(5); + let result = block_on(wait_if_possible( + &simulated_time, + min_duration_since_epoch, + max_instant, + )); + + assert_eq!( + result.ok().unwrap(), + WaitingSuccess::NoWaitRequired { + current_duration_since_epoch: Duration::from_secs(3), + early_duration: Duration::from_secs(2) + } + ); +} + +#[test] +fn wait_if_possible_test_max_duration_exceeded() { + let simulated_time = SimulatedTimeService::new(); + let min_duration_since_epoch = Duration::from_secs(3); + let max_instant = Instant::now() + Duration::from_secs(2); + let result = block_on(wait_if_possible( + &simulated_time, + min_duration_since_epoch, + max_instant, + )); + + assert_eq!(result.err().unwrap(), WaitingError::MaxWaitExceeded); +} + +#[test] +fn wait_if_possible_test_sleep_failed() { + let simulated_time = SimulatedTimeService::max(Duration::from_secs(1)); + let min_duration_since_epoch = Duration::from_secs(2); + let max_instant = Instant::now() + Duration::from_secs(3); + let result = block_on(wait_if_possible( + &simulated_time, + min_duration_since_epoch, + max_instant, + )); + + assert_eq!( + result.err().unwrap(), + WaitingError::WaitFailed { + wait_duration: min_duration_since_epoch + Duration::from_millis(1), + current_duration_since_epoch: Duration::from_secs(1) + } + ); +} diff --git a/consensus/src/txn_manager.rs b/consensus/src/txn_manager.rs new file mode 100644 index 0000000000000..1225d90031337 --- /dev/null +++ b/consensus/src/txn_manager.rs @@ -0,0 +1,141 @@ +// Copyright (c) The Libra Core Contributors +// SPDX-License-Identifier: Apache-2.0 + +use crate::{ + counters, + state_replication::{StateComputeResult, TxnManager}, +}; +use failure::Result; +use futures::{compat::Future01CompatExt, Future, FutureExt}; +use logger::prelude::*; +use mempool::proto::{ + mempool::{ + CommitTransactionsRequest, CommittedTransaction, GetBlockRequest, TransactionExclusion, + }, + mempool_grpc::MempoolClient, +}; +use proto_conv::FromProto; +use std::{pin::Pin, sync::Arc}; +use types::transaction::SignedTransaction; + +/// Proxy interface to mempool +pub struct MempoolProxy { + mempool: Arc, +} + +impl MempoolProxy { + pub fn new(mempool: Arc) -> Self { + Self { + mempool: Arc::clone(&mempool), + } + } + + /// Generate mempool commit transactions request given the set of txns and their status + fn gen_commit_transactions_request( + txns: &[SignedTransaction], + compute_result: &StateComputeResult, + timestamp_usecs: u64, + ) -> CommitTransactionsRequest { + let mut all_updates = Vec::new(); + assert_eq!(txns.len(), compute_result.compute_status.len()); + for (txn, success) in txns.iter().zip(compute_result.compute_status.iter()) { + let mut transaction = CommittedTransaction::new(); + transaction.set_sender(txn.sender().as_ref().to_vec()); + transaction.set_sequence_number(txn.sequence_number()); + if *success { + counters::SUCCESS_TXNS_COUNT.inc(); + transaction.set_is_rejected(false); + } else { + counters::FAILED_TXNS_COUNT.inc(); + transaction.set_is_rejected(true); + } + all_updates.push(transaction); + } + let mut req = CommitTransactionsRequest::new(); + req.set_transactions(::protobuf::RepeatedField::from_vec(all_updates)); + req.set_block_timestamp_usecs(timestamp_usecs); + req + } + + /// Submit the request and return the future, which is fulfilled when the response is received. + fn submit_commit_transactions_request( + &self, + req: CommitTransactionsRequest, + ) -> Pin> + Send>> { + match self.mempool.commit_transactions_async(&req) { + Ok(receiver) => async move { + match receiver.compat().await { + Ok(_) => Ok(()), + Err(e) => Err(e.into()), + } + } + .boxed(), + Err(e) => async move { Err(e.into()) }.boxed(), + } + } +} + +impl TxnManager for MempoolProxy { + type Payload = Vec; + + /// The returned future is fulfilled with the vector of SignedTransactions + fn pull_txns( + &self, + max_size: u64, + exclude_payloads: Vec<&Self::Payload>, + ) -> Pin> + Send>> { + let mut exclude_txns = vec![]; + for payload in exclude_payloads { + for signed_txn in payload { + let mut txn_meta = TransactionExclusion::new(); + txn_meta.set_sender(signed_txn.sender().into()); + txn_meta.set_sequence_number(signed_txn.sequence_number()); + exclude_txns.push(txn_meta); + } + } + let mut get_block_request = GetBlockRequest::new(); + get_block_request.set_max_block_size(max_size); + get_block_request.set_transactions(::protobuf::RepeatedField::from_vec(exclude_txns)); + match self.mempool.get_block_async(&get_block_request) { + Ok(receiver) => async move { + match receiver.compat().await { + Ok(mut response) => Ok(response + .take_block() + .take_transactions() + .into_iter() + .filter_map(|proto_txn| { + match SignedTransaction::from_proto(proto_txn.clone()) { + Ok(t) => Some(t), + Err(e) => { + security_log(SecurityEvent::InvalidTransactionConsensus) + .error(&e) + .data(&proto_txn) + .log(); + None + } + } + }) + .collect()), + Err(e) => Err(e.into()), + } + } + .boxed(), + Err(e) => async move { Err(e.into()) }.boxed(), + } + } + + fn commit_txns<'a>( + &'a self, + txns: &Self::Payload, + compute_result: &StateComputeResult, + // Monotonic timestamp_usecs of committed blocks is used to GC expired transactions. + timestamp_usecs: u64, + ) -> Pin> + Send + 'a>> { + counters::COMMITTED_BLOCKS_COUNT.inc(); + counters::COMMITTED_TXNS_COUNT.inc_by(txns.len() as i64); + counters::NUM_TXNS_PER_BLOCK.observe(txns.len() as f64); + let req = + Self::gen_commit_transactions_request(txns.as_slice(), compute_result, timestamp_usecs); + self.submit_commit_transactions_request(req) + } +} diff --git a/contributing/corporate-cla.pdf b/contributing/corporate-cla.pdf new file mode 100644 index 0000000000000..010b03f000fbc Binary files /dev/null and b/contributing/corporate-cla.pdf differ diff --git a/contributing/individual-cla.pdf b/contributing/individual-cla.pdf new file mode 100644 index 0000000000000..99a33204c96ee Binary files /dev/null and b/contributing/individual-cla.pdf differ diff --git a/crypto/legacy_crypto/Cargo.toml b/crypto/legacy_crypto/Cargo.toml new file mode 100644 index 0000000000000..ef4acc82706c8 --- /dev/null +++ b/crypto/legacy_crypto/Cargo.toml @@ -0,0 +1,37 @@ +[package] +name = "crypto" +version = "0.1.0" +authors = ["Libra Association "] +license = "Apache-2.0" +publish = false +edition = "2018" + +[dependencies] +bincode = "1.1.1" +bytes = "0.4.12" +curve25519-dalek = "1.1.3" +derive_deref = "1.0.2" +ed25519-dalek = { version = "1.0.0-pre.1", features = ["serde"] } +hex = "0.3" +lazy_static = "1.3.0" +pairing = "0.14.2" +proptest = "0.9.1" +proptest-derive = "0.1.0" +rand = "0.6.5" +serde = { version = "1.0.89", features = ["derive"] } +threshold_crypto = "0.3" +tiny-keccak = "1.4.2" +x25519-dalek = "0.5.2" +digest = "0.8.0" +hmac = "0.7.0" +sha3 = "0.8.2" +sha2 = "0.8.0" + +failure = { path = "../../common/failure_ext", package = "failure_ext" } +crypto-derive = { path = "./src/macros" } +proto_conv = { path = "../../common/proto_conv" } + +[dev-dependencies] +bitvec = "0.10.1" +byteorder = "1.3.1" +ripemd160 = "0.8.0" diff --git a/crypto/legacy_crypto/README.md b/crypto/legacy_crypto/README.md new file mode 100644 index 0000000000000..19382b0c5dcb3 --- /dev/null +++ b/crypto/legacy_crypto/README.md @@ -0,0 +1,33 @@ +--- +id: crypto +title: Crypto +custom_edit_url: https://github.com/libra/libra/edit/master/crypto/legacy_crypto/README.md +--- +# Legacy Crypto + +The crypto component hosts all the implementations of cryptographic primitives we use in Libra: hashing, signing, and key derivation/generation. The NextGen directory contains implementations of cryptographic primitives that will be used in the upcoming versions: new crypto API Enforcing type safety, verifiable random functions, BLS signatures. + +## Overview + +Libra makes use of several cryptographic algorithms: + +* SHA-3 as the main hash function. It is standardized in [FIPS 202](https://nvlpubs.nist.gov/nistpubs/FIPS/NIST.FIPS.202.pdf). It is based on the [tiny_keccak](https://docs.rs/tiny-keccak/1.4.2/tiny_keccak/) library. +* X25519 to perform key exchanges. It is used to secure communications between validators via the [Noise Protocol Framework](http://www.noiseprotocol.org/noise.html). It is based on the x25519-dalek library. +* Ed25519 to perform signatures. It is used both for consensus signatures and for transaction signatures. EdDSA is planned to be added to the next revision of FIPS 186 as mentioned in [NIST SP 800-133 Rev. 1](https://nvlpubs.nist.gov/nistpubs/SpecialPublications/NIST.SP.800-133r1-draft.pdf). It is based on the [ed25519-dalek](https://docs.rs/ed25519-dalek/1.0.0-pre.1/ed25519_dalek/) library with additional security checks (e.g., for malleability). +* HKDF: HMAC-based Extract-and-Expand Key Derivation Function (HKDF) based on [RFC 5869](https://tools.ietf.org/html/rfc5869). It is used to generate keys from a salt (optional), seed, and application-info (optional). + +## How is this module organized? +``` + legacy_crypto/src + β”œβ”€β”€ signing.rs # Ed25519 signature scheme + β”œβ”€β”€ hash.rs # Hash function (SHA-3) + β”œβ”€β”€ hkdf.rs # HKDF implementation (HMAC-based Extract-and-Expand Key Derivation Function based on RFC 5869) + β”œβ”€β”€ x25519.rs # X25519 keys generation + β”œβ”€β”€ macros/ # Derivations for SilentDebug and SilentDisplay + β”œβ”€β”€ utils.rs # Serialization utility functions + β”œβ”€β”€ unit_tests # Tests + └── lib.rs +``` + +Currently `x25519.rs` only exposes the logic for managing keys. The relevant cryptographic primitives to the Noise Protocol Framework are under the [snow](https://docs.rs/snow/0.5.2/snow/) crate. + diff --git a/crypto/legacy_crypto/src/hash.rs b/crypto/legacy_crypto/src/hash.rs new file mode 100644 index 0000000000000..819e41487e816 --- /dev/null +++ b/crypto/legacy_crypto/src/hash.rs @@ -0,0 +1,615 @@ +// Copyright (c) The Libra Core Contributors +// SPDX-License-Identifier: Apache-2.0 + +//! This module defines traits and implementations of +//! [cryptographic hash functions](https://en.wikipedia.org/wiki/Cryptographic_hash_function) +//! for the Libra project. +//! +//! It is designed to help authors protect against two types of real world attacks: +//! +//! 1. **Domain Ambiguity**: imagine that Alice has a private key and is using +//! two different applications, X and Y. X asks Alice to sign a message saying +//! "I am Alice". naturally, Alice is willing to sign this message, since she +//! is in fact Alice. However, unbeknownst to Alice, in application Y, +//! messages beginning with the letter "I" represent transfers. " am " +//! represents a transfer of 500 coins and "Alice" can be interpreted as a +//! destination address. When Alice signed the message she needed to be +//! aware of how other applications might interpret that message. +//! +//! 2. **Format Ambiguity**: imagine a program that hashes a pair of strings. +//! To hash the strings `a` and `b` it hashes `a + "||" + b`. The pair of +//! strings `a="foo||", b = "bar"` and `a="foo", b = "||bar"` result in the +//! same input to the hash function and therefore the same hash. This +//! creates a collision. +//! +//! # Examples +//! +//! ``` +//! use crypto::hash::{CryptoHasher, TestOnlyHasher}; +//! +//! let mut hasher = TestOnlyHasher::default(); +//! hasher.write("Test message".as_bytes()); +//! let hash_value = hasher.finish(); +//! ``` +//! The output is of type [`HashValue`], which can be used as an input for signing. +//! +//! # Implementing new hashers +//! +//! For any new structure `MyNewStruct` that needs to be hashed, the developer should define a +//! new hasher with: +//! +//! ``` +//! # // To get around that there's no way to doc-test a non-exported macro: +//! # macro_rules! define_hasher { ($e:expr) => () } +//! define_hasher! { (MyNewStructHasher, MY_NEW_STRUCT_HASHER, b"MyNewStruct") } +//! ``` +//! +//! **Note**: The last argument for the `define_hasher` macro must be a unique string. +//! +//! Then, the `CryptoHash` trait should be implemented: +//! ``` +//! # use crypto::hash::*; +//! # #[derive(Default)] +//! # struct MyNewStructHasher; +//! # impl CryptoHasher for MyNewStructHasher { +//! # fn finish(self) -> HashValue { unimplemented!() } +//! # fn write(&mut self, bytes: &[u8]) -> &mut Self { unimplemented!() } +//! # } +//! struct MyNewStruct; +//! +//! impl CryptoHash for MyNewStruct { +//! type Hasher = MyNewStructHasher; // use the above defined hasher here +//! +//! fn hash(&self) -> HashValue { +//! let mut state = Self::Hasher::default(); +//! state.write(b"Struct serialized into bytes here"); +//! state.finish() +//! } +//! } +//! ``` + +#![allow(clippy::unit_arg)] + +use bytes::Bytes; +use failure::prelude::*; +use lazy_static::lazy_static; +use proptest_derive::Arbitrary; +use proto_conv::{FromProto, IntoProto}; +use rand::{rngs::EntropyRng, Rng}; +use serde::{Deserialize, Serialize}; +use std::{self, convert::AsRef, fmt}; +use tiny_keccak::Keccak; + +const LIBRA_HASH_SUFFIX: &[u8] = b"@@$$LIBRA$$@@"; + +#[cfg(test)] +#[path = "unit_tests/hash_test.rs"] +mod hash_test; + +/// Output value of our hash function. Intentionally opaque for safety and modularity. +#[derive(Clone, Copy, Eq, Hash, PartialEq, Serialize, Deserialize, PartialOrd, Ord, Arbitrary)] +pub struct HashValue { + hash: [u8; HashValue::LENGTH], +} + +impl HashValue { + /// The length of the hash in bytes. + pub const LENGTH: usize = 32; + /// The length of the hash in bits. + pub const LENGTH_IN_BITS: usize = HashValue::LENGTH * 8; + + /// Create a new [`HashValue`] from a byte array. + pub fn new(hash: [u8; HashValue::LENGTH]) -> Self { + HashValue { hash } + } + + /// Create from a slice (e.g. retrieved from storage). + pub fn from_slice(src: &[u8]) -> Result { + ensure!( + src.len() == HashValue::LENGTH, + "HashValue decoding failed due to length mismatch. HashValue \ + length: {}, src length: {}", + HashValue::LENGTH, + src.len() + ); + let mut value = Self::zero(); + value.hash.copy_from_slice(src); + Ok(value) + } + + /// Dumps into a vector. + pub fn to_vec(&self) -> Vec { + self.hash.to_vec() + } + + /// Creates a zero-initialized instance. + pub fn zero() -> Self { + HashValue { + hash: [0; HashValue::LENGTH], + } + } + + /// Check if the hash value is zero. + pub fn is_zero(&self) -> bool { + *self == HashValue::zero() + } + + /// Create a cryptographically random instance. + pub fn random() -> Self { + let mut rng = EntropyRng::new(); + let hash: [u8; HashValue::LENGTH] = rng.gen(); + HashValue { hash } + } + + /// Creates a random instance with given rng. Useful in unit tests. + pub fn random_with_rng(rng: &mut R) -> Self { + let hash: [u8; HashValue::LENGTH] = rng.gen(); + HashValue { hash } + } + + /// Get the size of the hash. + pub fn len() -> usize { + HashValue::LENGTH + } + + /// Get the last n bytes as a String. + pub fn last_n_bytes(&self, bytes: usize) -> String { + let mut string = String::from("HashValue(.."); + for byte in &self.hash[(HashValue::LENGTH - bytes)..] { + string.push_str(&format!("{:02x}", byte)); + } + string.push_str(")"); + string + } + + // Intentionally not public. + fn from_sha3(buffer: &[u8]) -> Self { + let mut sha3 = Keccak::new_sha3_256(); + sha3.update(buffer); + HashValue::from_keccak(sha3) + } + + #[cfg(test)] + pub fn from_iter_sha3<'a, I>(buffers: I) -> Self + where + I: IntoIterator, + { + let mut sha3 = Keccak::new_sha3_256(); + for buffer in buffers { + sha3.update(buffer); + } + HashValue::from_keccak(sha3) + } + + fn as_ref_mut(&mut self) -> &mut [u8] { + &mut self.hash[..] + } + + fn from_keccak(state: Keccak) -> Self { + let mut hash = Self::zero(); + state.finalize(hash.as_ref_mut()); + hash + } + + /// Returns a `HashValueBitIterator` over all the bits that represent this `HashValue`. + pub fn iter_bits(&self) -> HashValueBitIterator<'_> { + HashValueBitIterator::new(self) + } + + /// Returns the length of common prefix of `self` and `other` in bits. + pub fn common_prefix_bits_len(&self, other: HashValue) -> usize { + self.iter_bits() + .zip(other.iter_bits()) + .take_while(|(x, y)| x == y) + .count() + } +} + +impl Default for HashValue { + fn default() -> Self { + HashValue::zero() + } +} + +impl AsRef<[u8; HashValue::LENGTH]> for HashValue { + fn as_ref(&self) -> &[u8; HashValue::LENGTH] { + &self.hash + } +} + +impl std::ops::Index for HashValue { + type Output = u8; + + fn index(&self, s: usize) -> &u8 { + self.hash.index(s) + } +} + +impl fmt::Binary for HashValue { + fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { + for byte in &self.hash { + write!(f, "{:08b}", byte)?; + } + Ok(()) + } +} + +impl fmt::LowerHex for HashValue { + fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { + for byte in &self.hash { + write!(f, "{:02x}", byte)?; + } + Ok(()) + } +} + +impl fmt::Debug for HashValue { + fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { + write!(f, "HashValue(")?; + ::fmt(self, f)?; + write!(f, ")")?; + Ok(()) + } +} + +/// Will print shortened (4 bytes) hash +impl fmt::Display for HashValue { + fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result { + for byte in self.hash.iter().take(4) { + write!(f, "{:02x}", byte)?; + } + Ok(()) + } +} + +impl FromProto for HashValue { + type ProtoType = Vec; + + fn from_proto(bytes: Self::ProtoType) -> Result { + HashValue::from_slice(&bytes) + } +} + +impl IntoProto for HashValue { + type ProtoType = Vec; + + fn into_proto(self) -> Self::ProtoType { + self.to_vec() + } +} + +impl From for Bytes { + fn from(value: HashValue) -> Bytes { + value.hash.as_ref().into() + } +} + +/// An iterator over `HashValue` that generates one bit for each iteration. +pub struct HashValueBitIterator<'a> { + /// The reference to the bytes that represent the `HashValue`. + hash_bytes: &'a [u8], + pos: std::ops::Range, +} + +impl<'a> HashValueBitIterator<'a> { + /// Constructs a new `HashValueBitIterator` using given `HashValue`. + fn new(hash_value: &'a HashValue) -> Self { + HashValueBitIterator { + hash_bytes: hash_value.as_ref(), + pos: (0..HashValue::LENGTH_IN_BITS), + } + } + + /// Returns the `index`-th bit in the bytes. + fn get_bit(&self, index: usize) -> bool { + let pos = index / 8; + let bit = 7 - index % 8; + (self.hash_bytes[pos] >> bit) & 1 != 0 + } +} + +impl<'a> std::iter::Iterator for HashValueBitIterator<'a> { + type Item = bool; + + fn next(&mut self) -> Option { + self.pos.next().and_then(|x| Some(self.get_bit(x))) + } + + fn size_hint(&self) -> (usize, Option) { + self.pos.size_hint() + } +} + +impl<'a> std::iter::DoubleEndedIterator for HashValueBitIterator<'a> { + fn next_back(&mut self) -> Option { + self.pos.next_back().and_then(|x| Some(self.get_bit(x))) + } +} + +impl<'a> std::iter::ExactSizeIterator for HashValueBitIterator<'a> {} + +/// A type that implements `CryptoHash` can be hashed by a cryptographic hash function and produce +/// a `HashValue`. Each type needs to have its own `Hasher` type. +pub trait CryptoHash { + /// The associated `Hasher` type which comes with a unique salt for this type. + type Hasher: CryptoHasher; + + /// Hashes the object and produces a `HashValue`. + fn hash(&self) -> HashValue; +} + +/// A trait for generating hash from arbitrary stream of bytes. +/// +/// Instances of `CryptoHasher` usually represent state that is changed while hashing data. +/// Similar to `std::hash::Hasher` but not same. CryptoHasher cannot be reused after finish() has +/// been called. +pub trait CryptoHasher: Default { + /// Finish constructing the [`HashValue`]. + fn finish(self) -> HashValue; + /// Write bytes into the hasher. + fn write(&mut self, bytes: &[u8]) -> &mut Self; + /// Write a single byte into the hasher. + fn write_u8(&mut self, byte: u8) { + self.write(&[byte]); + } +} + +/// Our preferred hashing schema, outputing [`HashValue`]s. +/// * Hashing is parameterized by a `domain` to prevent domain +/// ambiguity attacks. +/// * The existence of serialization/deserialization function rules +/// out any formatting ambiguity. +/// * Assuming that the `domain` seed is used only once per Rust type, +/// or that the serialization carries enough type information to avoid +/// ambiguities within a same domain. +/// * Only used internally within this crate +#[derive(Clone)] +struct DefaultHasher { + state: Keccak, +} + +impl CryptoHasher for DefaultHasher { + fn finish(self) -> HashValue { + let mut hasher = HashValue::default(); + self.state.finalize(hasher.as_ref_mut()); + hasher + } + + fn write(&mut self, bytes: &[u8]) -> &mut Self { + self.state.update(bytes); + self + } +} + +impl Default for DefaultHasher { + fn default() -> Self { + DefaultHasher { + state: Keccak::new_sha3_256(), + } + } +} + +impl DefaultHasher { + fn new_with_salt(typename: &[u8]) -> Self { + let mut state = Keccak::new_sha3_256(); + if !typename.is_empty() { + let mut salt = typename.to_vec(); + salt.extend_from_slice(LIBRA_HASH_SUFFIX); + state.update(HashValue::from_sha3(&salt[..]).as_ref()); + } + DefaultHasher { state } + } +} + +macro_rules! define_hasher { + ( + $(#[$attr:meta])* + ($hasher_type: ident, $hasher_name: ident, $salt: expr) + ) => { + + #[derive(Clone)] + $(#[$attr])* + pub struct $hasher_type(DefaultHasher); + + impl $hasher_type { + fn new() -> Self { + $hasher_type(DefaultHasher::new_with_salt($salt)) + } + } + + impl Default for $hasher_type { + fn default() -> Self { + $hasher_name.clone() + } + } + + impl CryptoHasher for $hasher_type { + fn finish(self) -> HashValue { + self.0.finish() + } + + fn write(&mut self, bytes: &[u8]) -> &mut Self { + self.0.write(bytes); + self + } + } + + lazy_static! { + static ref $hasher_name: $hasher_type = { $hasher_type::new() }; + } + }; +} + +define_hasher! { + /// The hasher used to compute the hash of an AccessPath object. + (AccessPathHasher, ACCESS_PATH_HASHER, b"VM_ACCESS_PATH") +} + +define_hasher! { + /// The hasher used to compute the hash of an AccountAddress object. + ( + AccountAddressHasher, + ACCOUNT_ADDRESS_HASHER, + b"AccountAddress" + ) +} + +define_hasher! { + /// The hasher used to compute the hash of a LedgerInfo object. + (LedgerInfoHasher, LEDGER_INFO_HASHER, b"LedgerInfo") +} + +define_hasher! { + /// The hasher used to compute the hash of an internal node in the transaction accumulator. + ( + TransactionAccumulatorHasher, + TRANSACTION_ACCUMULATOR_HASHER, + b"TransactionAccumulator" + ) +} + +define_hasher! { + /// The hasher used to compute the hash of an internal node in the event accumulator. + ( + EventAccumulatorHasher, + EVENT_ACCUMULATOR_HASHER, + b"EventAccumulator" + ) +} + +define_hasher! { + /// The hasher used to compute the hash of an internal node in the Sparse Merkle Tree. + ( + SparseMerkleInternalHasher, + SPARSE_MERKLE_INTERNAL_HASHER, + b"SparseMerkleInternal" + ) +} + +define_hasher! { + /// The hasher used to compute the hash of a leaf node in the Sparse Merkle Tree. + ( + SparseMerkleLeafHasher, + SPARSE_MERKLE_LEAF_HASHER, + b"SparseMerkleLeaf" + ) +} + +define_hasher! { + /// The hasher used to compute the hash of the blob content of an account. + ( + AccountStateBlobHasher, + ACCOUNT_STATE_BLOB_HASHER, + b"AccountStateBlob" + ) +} + +define_hasher! { + /// The hasher used to compute the hash of a TransactionInfo object. + ( + TransactionInfoHasher, + TRANSACTION_INFO_HASHER, + b"TransactionInfo" + ) +} + +define_hasher! { + /// The hasher used to compute the hash of a RawTransaction object. + ( + RawTransactionHasher, + RAW_TRANSACTION_HASHER, + b"RawTransaction" + ) +} + +define_hasher! { + /// The hasher used to compute the hash of a SignedTransaction object. + ( + SignedTransactionHasher, + SIGNED_TRANSACTION_HASHER, + b"SignedTransaction" + ) +} + +define_hasher! { + /// The hasher used to compute the hash (block_id) of a Block object. + (BlockHasher, BLOCK_HASHER, b"BlockId") +} + +define_hasher! { + /// The hasher used to compute the hash of a PacemakerTimeout object. + (PacemakerTimeoutHasher, PACEMAKER_TIMEOUT_HASHER, b"PacemakerTimeout") +} + +define_hasher! { + /// The hasher used to compute the hash of a NewRoundMsgHasher object. + (NewRoundMsgHasher, NEW_ROUND_MSG_HASHER, b"NewRoundMsg") +} + +define_hasher! { + /// The hasher used to compute the hash of a VoteMsg object. + (VoteMsgHasher, VOTE_MSG_HASHER, b"VoteMsg") +} + +define_hasher! { + /// The hasher used to compute the hash of a ContractEvent object. + (ContractEventHasher, CONTRACT_EVENT_HASHER, b"ContractEvent") +} + +define_hasher! { + /// The hasher used only for testing. It doesn't have a salt. + (TestOnlyHasher, TEST_ONLY_HASHER, b"") +} + +define_hasher! { + /// The hasher used to compute the hash of a DiscoveryMsg object. + (DiscoveryMsgHasher, DISCOVERY_MSG_HASHER, b"DiscoveryMsg") +} + +fn create_literal_hash(word: &str) -> HashValue { + let mut s = word.as_bytes().to_vec(); + assert!(s.len() <= HashValue::LENGTH); + s.resize(HashValue::LENGTH, 0); + HashValue::from_slice(&s).expect("Cannot fail") +} + +lazy_static! { + /// Placeholder hash of `Accumulator`. + pub static ref ACCUMULATOR_PLACEHOLDER_HASH: HashValue = + create_literal_hash("ACCUMULATOR_PLACEHOLDER_HASH"); + + /// Placeholder hash of `SparseMerkleTree`. + pub static ref SPARSE_MERKLE_PLACEHOLDER_HASH: HashValue = + create_literal_hash("SPARSE_MERKLE_PLACEHOLDER_HASH"); + + /// Block id reserved as the id of parent block of the genesis block. + pub static ref PRE_GENESIS_BLOCK_ID: HashValue = + create_literal_hash("PRE_GENESIS_BLOCK_ID"); + + /// Genesis block id is used as a parent of the very first block executed by the executor. + pub static ref GENESIS_BLOCK_ID: HashValue = + create_literal_hash("GENESIS_BLOCK_ID"); +} + +/// Provides a test_only_hash() method that can be used in tests on types that implement +/// `serde::Serialize`. +/// +/// # Example +/// ``` +/// use crypto::hash::TestOnlyHash; +/// +/// b"hello world".test_only_hash(); +/// ``` +pub trait TestOnlyHash { + /// Generates a hash used only for tests. + fn test_only_hash(&self) -> HashValue; +} + +impl TestOnlyHash for T { + fn test_only_hash(&self) -> HashValue { + let bytes = ::bincode::serialize(self).expect("serialize failed during hash."); + let mut hasher = TestOnlyHasher::default(); + hasher.write(&bytes); + hasher.finish() + } +} diff --git a/crypto/legacy_crypto/src/hkdf.rs b/crypto/legacy_crypto/src/hkdf.rs new file mode 100644 index 0000000000000..9d903ea4c2eb2 --- /dev/null +++ b/crypto/legacy_crypto/src/hkdf.rs @@ -0,0 +1,195 @@ +// Copyright (c) The Libra Core Contributors +// SPDX-License-Identifier: Apache-2.0 + +//! An implementation of HKDF, the HMAC-based Extract-and-Expand Key Derivation Function for the +//! Libra project based on [RFC 5869](https://tools.ietf.org/html/rfc5869). +//! +//! The key derivation function (KDF) is intended to support a wide range of applications and +//! requirements, and is conservative in its use of cryptographic hash functions. In particular, +//! this implementation is compatible with hash functions that output 256 bits or more, such as +//! SHA256, SHA3-256 and SHA512. +//! +//! HKDF follows the "extract-then-expand" paradigm, where the KDF logically consists of two +//! modules: the first stage takes the input keying material (the seed) and "extracts" from it a +//! fixed-length pseudorandom key, and then the second stage "expands" this key into several +//! additional pseudorandom keys (the output of the KDF). For convenience, a function that runs both +//! steps in a single call is provided. Note that along with an initial high-entropy seed, a user +//! can optionally provide salt and app-info byte-arrays for extra security guarantees and domain +//! separation. +//! +//! # Applications +//! +//! HKDF is intended for use in a wide variety of KDF applications (see [Key derivation function](https://en.wikipedia.org/wiki/Key_derivation_function)), including: +//! a) derivation of keys from an origin high-entropy master seed. This is the recommended approach +//! for generating keys in Libra, especially when a True Random Generator is not available. +//! b) derivation of session keys from a shared Diffie-Hellman value in a key-agreement protocol. +//! c) combining entropy from multiple sources of randomness, such as entropy collected +//! from system events, user's keystrokes, /dev/urandom etc. The combined seed can then be used to +//! generate cryptographic keys for account, network and transaction signing keys among the others. +//! d) hierarchical private key derivation, similarly to Bitcoin's BIP32 protocol for easier key +//! management. +//! e) hybrid key generation that combines a master seed with a PRNG output for extra security +//! guarantees against a master seed leak or low PRNG entropy. +//! +//! # Recommendations +//! +//! **Salt** +//! HKDF can operate with and without random 'salt'. The use of salt adds to the strength of HKDF, +//! ensuring independence between different uses of the hash function, supporting +//! "source-independent" extraction, and strengthening the HKDF use. The salt value should be a +//! random string of the same length as the hash output. A shorter or less random salt value can +//! still make a contribution to the security of the output key material. Salt values should be +//! independent of the input keying material. In particular, an application needs to make sure that +//! salt values are not chosen or manipulated by an attacker. +//! +//! *Application info* +//! Key expansion accepts an optional 'info' value to which the application assigns some meaning. +//! Its objective is to bind the derived key material to application- and context-specific +//! information. For example, 'info' may contain a protocol number, algorithm identifier, +//! child key number (similarly to [BIP32](https://github.com/bitcoin/bips/blob/master/bip-0032.mediawiki)), etc. The only technical requirement for 'info' is that +//! it be independent of the seed. +//! +//! **Which function to use: extract, expand or both?** +//! Unless absolutely sure of what they are doing, applications should use both steps β€” if only for +//! the sake of compatibility with the general case. +//! +//! # Example +//! +//! Run HKDF extract-then-expand so as to return 64 bytes, using 'salt', 'seed' and 'info' as +//! inputs. +//! ``` +//! use crypto::hkdf::Hkdf; +//! use sha2::Sha256; +//! +//! // some bytes required for this example. +//! let raw_bytes = [2u8; 10]; +//! // define salt +//! let salt = Some(&raw_bytes[0..4]); +//! // define seed - in production this is recommended to be a 32 bytes or longer random seed. +//! let seed = [3u8; 32]; +//! // define application info +//! let info = Some(&raw_bytes[4..10]); +//! +//! // HKDF extract-then-expand 64-bytes output +//! let derived_bytes = Hkdf::::extract_then_expand(salt, &seed, info, 64); +//! assert_eq!(derived_bytes.unwrap().len(), 64) +//! ``` + +use digest::{ + generic_array::{self, ArrayLength, GenericArray}, + BlockInput, FixedOutput, Input, Reset, +}; +use generic_array::typenum::Unsigned; +use hmac::{Hmac, Mac}; +use std::marker::PhantomData; + +/// Structure representing the HKDF, capable of HKDF-Extract and HKDF-Expand operations, as defined +/// in RFC 5869. +#[derive(Clone, Debug)] +pub struct Hkdf +where + D: Input + BlockInput + FixedOutput + Reset + Default + Clone, + D::OutputSize: ArrayLength, +{ + _marker: PhantomData, +} + +impl Hkdf +where + D: Input + BlockInput + FixedOutput + Reset + Default + Clone, + D::BlockSize: ArrayLength + Clone, + D::OutputSize: ArrayLength, +{ + /// Minimum acceptable output length for the underlying hash function is 32 bytes. + const D_MINIMUM_SIZE: usize = 32; + + /// The RFC5869 HKDF-Extract operation. + pub fn extract(salt: Option<&[u8]>, ikm: &[u8]) -> Result, HkdfError> { + let d_output_size = D::OutputSize::to_usize(); + if d_output_size < Hkdf::::D_MINIMUM_SIZE { + return Err(HkdfError::NotSupportedHashFunctionError); + } + + let mut hmac = match salt { + Some(s) => Hmac::::new_varkey(s).map_err(|_| HkdfError::MACKeyError)?, + None => Hmac::::new(&Default::default()), + }; + + hmac.input(ikm); + + Ok(hmac.result().code().to_vec()) + } + + /// The RFC5869 HKDF-Expand operation. + pub fn expand(prk: &[u8], info: Option<&[u8]>, length: usize) -> Result, HkdfError> { + let hmac_output_bytes = D::OutputSize::to_usize(); + if prk.len() < hmac_output_bytes { + return Err(HkdfError::WrongPseudorandomKeyError); + } + // According to RFC5869, MAX_OUTPUT_LENGTH <= 255 * HashLen. + // We specifically exclude zero size as well. + if length == 0 || length > hmac_output_bytes * 255 { + return Err(HkdfError::InvalidOutputLengthError); + } + let mut okm = vec![0u8; length]; + let mut prev: Option::OutputSize>> = None; + let mut hmac = Hmac::::new_varkey(prk).map_err(|_| HkdfError::MACKeyError)?; + + for (blocknum, okm_block) in okm.chunks_mut(hmac_output_bytes).enumerate() { + if let Some(ref prev) = prev { + hmac.input(prev) + } + if let Some(_info) = info { + hmac.input(_info); + } + hmac.input(&[blocknum as u8 + 1]); + + let output = hmac.result_reset().code(); + okm_block.copy_from_slice(&output[..okm_block.len()]); + + prev = Some(output); + } + + Ok(okm) + } + + /// HKDF Extract then Expand operation as a single step. + pub fn extract_then_expand( + salt: Option<&[u8]>, + ikm: &[u8], + info: Option<&[u8]>, + length: usize, + ) -> Result, HkdfError> { + let prk = Hkdf::::extract(salt, ikm)?; + Hkdf::::expand(&prk, info, length) + } +} + +/// An error type for HKDF key derivation issues. +/// +/// This enum reflects there are various causes of HKDF failures, including: +/// a) requested HKDF output size exceeds the maximum allowed or is zero. +/// b) hash functions outputting less than 32 bits are not supported (i.e., SHA1 is not supported). +/// c) small PRK value in HKDF-Expand according to RFC 5869. +/// d) any other underlying HMAC error. +#[derive(Clone, Debug, PartialEq, Eq, failure::prelude::Fail)] +pub enum HkdfError { + /// HKDF expand output exceeds the maximum allowed or is zero. + #[fail( + display = "HKDF expand error - requested output size exceeds the maximum allowed or is zero" + )] + InvalidOutputLengthError, + /// Hash function is not supported because its output is less than 32 bits. + #[fail(display = "HKDF error - the hash function is not supported because \ + its output is less than 32 bits")] + NotSupportedHashFunctionError, + /// PRK on HKDF-Expand should not be less than the underlying hash output bits. + #[fail( + display = "HKDF expand error - the pseudorandom key input ('prk' in RFC 5869) \ + is less than the underlying hash output bits" + )] + WrongPseudorandomKeyError, + /// HMAC key related error; unlikely to happen because every key size is accepted in HMAC. + #[fail(display = "HMAC key error")] + MACKeyError, +} diff --git a/crypto/legacy_crypto/src/lib.rs b/crypto/legacy_crypto/src/lib.rs new file mode 100644 index 0000000000000..311e4167610e4 --- /dev/null +++ b/crypto/legacy_crypto/src/lib.rs @@ -0,0 +1,25 @@ +// Copyright (c) The Libra Core Contributors +// SPDX-License-Identifier: Apache-2.0 + +//! A library supplying various cryptographic primitives used in Libra. + +#![deny(missing_docs)] +#![feature(test)] +#![feature(trait_alias)] + +pub mod hash; +pub mod hkdf; +pub mod signing; +pub mod utils; +pub mod x25519; + +#[cfg(test)] +extern crate test; + +#[cfg(test)] +mod unit_tests; + +pub use crate::{ + hash::HashValue, + signing::{PrivateKey, PublicKey, Signature}, +}; diff --git a/crypto/legacy_crypto/src/macros/Cargo.toml b/crypto/legacy_crypto/src/macros/Cargo.toml new file mode 100644 index 0000000000000..1a99cd338f602 --- /dev/null +++ b/crypto/legacy_crypto/src/macros/Cargo.toml @@ -0,0 +1,12 @@ +[package] +name = "crypto-derive" +version = "0.1.0" +description = "Custom derives for `crypto`" +edition = "2018" + +[lib] +proc_macro = true + +[dependencies] +syn = { version = "0.15.26", features = ["derive"] } +quote = "0.6.11" diff --git a/crypto/legacy_crypto/src/macros/src/lib.rs b/crypto/legacy_crypto/src/macros/src/lib.rs new file mode 100644 index 0000000000000..3aefdcf75a436 --- /dev/null +++ b/crypto/legacy_crypto/src/macros/src/lib.rs @@ -0,0 +1,38 @@ +// Copyright (c) The Libra Core Contributors +// SPDX-License-Identifier: Apache-2.0 + +extern crate proc_macro; + +use proc_macro::TokenStream; +use quote::quote; +use syn::DeriveInput; + +#[proc_macro_derive(SilentDisplay)] +pub fn silent_display(source: TokenStream) -> TokenStream { + let ast: DeriveInput = syn::parse(source).expect("Infallible"); + let name = &ast.ident; + let gen = quote! { + // In order to ensure that secrets are never leaked, Display is elided + impl ::std::fmt::Display for #name { + fn fmt(&self, f: &mut ::std::fmt::Formatter<'_>) -> ::std::fmt::Result { + write!(f, "", stringify!(#name)) + } + } + }; + gen.into() +} + +#[proc_macro_derive(SilentDebug)] +pub fn silent_debug(source: TokenStream) -> TokenStream { + let ast: DeriveInput = syn::parse(source).expect("Infallible"); + let name = &ast.ident; + let gen = quote! { + // In order to ensure that secrets are never leaked, Debug is elided + impl ::std::fmt::Debug for #name { + fn fmt(&self, f: &mut ::std::fmt::Formatter<'_>) -> ::std::fmt::Result { + write!(f, "", stringify!(#name)) + } + } + }; + gen.into() +} diff --git a/crypto/legacy_crypto/src/signing.rs b/crypto/legacy_crypto/src/signing.rs new file mode 100644 index 0000000000000..45b1870b0d5d0 --- /dev/null +++ b/crypto/legacy_crypto/src/signing.rs @@ -0,0 +1,529 @@ +// Copyright (c) The Libra Core Contributors +// SPDX-License-Identifier: Apache-2.0 + +//! An implementation of the associated functions for handling +//! [cryptographic signatures](https://en.wikipedia.org/wiki/Digital_signature) +//! for the Libra project. +//! +//! This is an API for [Pure Ed25519 EdDSA signatures](https://tools.ietf.org/html/rfc8032). +//! +//! Warning: This API will soon be updated in the [`nextgen`] module. +//! +//! # Example +//! +//! Note that the signing and verifying functions take an input message of type [`HashValue`]. +//! ``` +//! use crypto::{hash::*, signing::*}; +//! +//! let mut hasher = TestOnlyHasher::default(); +//! hasher.write("Test message".as_bytes()); +//! let hashed_message = hasher.finish(); +//! +//! let (private_key, public_key) = generate_keypair(); +//! let signature = sign_message(hashed_message, &private_key).unwrap(); +//! assert!(verify_message(hashed_message, &signature, &public_key).is_ok()); +//! ``` + +use crate::{hash::HashValue, hkdf::Hkdf, utils::*}; +use bincode::{deserialize, serialize}; +use curve25519_dalek::scalar::Scalar; +use ed25519_dalek; +use failure::prelude::*; +use proptest::{ + arbitrary::any, + prelude::{Arbitrary, BoxedStrategy}, + strategy::*, +}; +use rand::{ + rngs::{EntropyRng, StdRng}, + CryptoRng, RngCore, SeedableRng, +}; +use serde::{de, export, ser, Deserialize, Serialize}; +use sha2::Sha256; +use std::{clone::Clone, fmt, hash::Hash}; + +/// An ed25519 private key. +pub struct PrivateKey { + value: ed25519_dalek::SecretKey, +} + +/// An ed25519 public key. +#[derive(Copy, Clone, Default)] +pub struct PublicKey { + value: ed25519_dalek::PublicKey, +} + +/// An ed25519 signature. +#[derive(Copy, Clone)] +pub struct Signature { + value: ed25519_dalek::Signature, +} + +/// A public and private key pair. +#[derive(Deserialize, Serialize, Clone, Debug, PartialEq, Eq)] +pub struct KeyPair { + public_key: PublicKey, + private_key: PrivateKey, +} + +impl KeyPair { + /// Produces a new keypair from a private key. + pub fn new(private_key: PrivateKey) -> Self { + let public: ed25519_dalek::PublicKey = (&private_key.value).into(); + let public_key = PublicKey { value: public }; + Self { + private_key, + public_key, + } + } + + /// Returns the public key. + pub fn public_key(&self) -> PublicKey { + self.public_key + } + + /// Returns the private key. + pub fn private_key(&self) -> &PrivateKey { + &self.private_key + } +} + +/// Constructs a signature for `message` using `private_key`. +pub fn sign_message(message: HashValue, private_key: &PrivateKey) -> Result { + let public_key: ed25519_dalek::PublicKey = (&private_key.value).into(); + let expanded_secret_key: ed25519_dalek::ExpandedSecretKey = + ed25519_dalek::ExpandedSecretKey::from(&private_key.value); + Ok(Signature { + value: expanded_secret_key.sign(message.as_ref(), &public_key), + }) +} + +/// Checks that `signature` is valid for `message` using `public_key`. +pub fn verify_message( + message: HashValue, + signature: &Signature, + public_key: &PublicKey, +) -> Result<()> { + signature.is_valid()?; + public_key + .value + .verify(message.as_ref(), &signature.value)?; + Ok(()) +} + +/// Generates a well-known keypair `(PrivateKey, PublicKey)` for special use +/// in the genesis block. +/// +/// **Warning**: This function will soon be updated to return a [`KeyPair`] struct +pub fn generate_genesis_keypair() -> (PrivateKey, PublicKey) { + let mut buf = [0u8; ed25519_dalek::SECRET_KEY_LENGTH]; + buf[ed25519_dalek::SECRET_KEY_LENGTH - 1] = 1; + let secret_key: ed25519_dalek::SecretKey = ed25519_dalek::SecretKey::from_bytes(&buf).unwrap(); + let public: ed25519_dalek::PublicKey = (&secret_key).into(); + ( + PrivateKey { value: secret_key }, + PublicKey { value: public }, + ) +} + +/// Generates a random keypair `(PrivateKey, PublicKey)`. +/// +/// **Warning**: This function will soon be updated to return a [`KeyPair`] struct. +pub fn generate_keypair() -> (PrivateKey, PublicKey) { + let mut rng = EntropyRng::new(); + generate_keypair_from_rng(&mut rng) +} + +/// Derives a keypair `(PrivateKey, PublicKey)` from +/// a) salt (optional) - denoted as 'salt' in RFC 5869 +/// b) seed - denoted as 'IKM' in RFC 5869 +/// c) application info (optional) - denoted as 'info' in RFC 5869 +/// +/// using the HKDF key derivation protocol, as defined in RFC 5869. +/// This implementation uses the full extract-then-expand HKDF steps +/// based on the SHA-256 hash function. +/// +/// **Warning**: This function will soon be updated to return a [`KeyPair`] struct. +pub fn derive_keypair_from_seed( + salt: Option<&[u8]>, + seed: &[u8], + app_info: Option<&[u8]>, +) -> (PrivateKey, PublicKey) { + let derived_bytes = + Hkdf::::extract_then_expand(salt, seed, app_info, ed25519_dalek::SECRET_KEY_LENGTH); + + let secret = ed25519_dalek::SecretKey::from_bytes(&derived_bytes.unwrap()).unwrap(); + let public: ed25519_dalek::PublicKey = (&secret).into(); + (PrivateKey { value: secret }, PublicKey { value: public }) +} + +/// Generates a random keypair `(PrivateKey, PublicKey)` and returns a tuple of string +/// representations: +/// 1. human readable public key +/// 2. hex encoded serialized public key +/// 3. hex encoded serialized private key +pub fn generate_and_encode_keypair() -> (String, String, String) { + let (private_key, public_key) = generate_keypair(); + let pub_key_human = hex::encode(public_key.value.to_bytes()); + let public_key_serialized_str = encode_to_string(&public_key); + let private_key_serialized_str = encode_to_string(&private_key); + ( + pub_key_human, + public_key_serialized_str, + private_key_serialized_str, + ) +} + +/// Generates consistent keypair `(PrivateKey, PublicKey)` for unit tests. +/// +/// **Warning**: This function will soon be updated to return a [`KeyPair`] struct. +pub fn generate_keypair_for_testing(rng: &mut R) -> (PrivateKey, PublicKey) +where + R: SeedableRng + RngCore + CryptoRng, +{ + generate_keypair_from_rng(rng) +} + +/// Generates a keypair `(PrivateKey, PublicKey)` based on an RNG. +pub fn generate_keypair_from_rng(rng: &mut R) -> (PrivateKey, PublicKey) +where + R: RngCore + CryptoRng, +{ + let keypair = ed25519_dalek::Keypair::generate(rng); + ( + PrivateKey { + value: keypair.secret, + }, + PublicKey { + value: keypair.public, + }, + ) +} + +/// Generates a random keypair `(PrivateKey, PublicKey)` by combining the output of `EntropyRng` +/// with a user-provided seed. This concatenated seed is used as the seed to HKDF (RFC 5869). +/// +/// Similarly to `derive_keypair_from_seed` the user provides the following inputs: +/// a) salt (optional) - denoted as 'salt' in RFC 5869 +/// b) seed - denoted as 'IKM' in RFC 5869 +/// c) application info (optional) - denoted as 'info' in RFC 5869 +/// +/// Note that this method is not deterministic, but the (random + static seed) key +/// generation makes it safer against low entropy pools and weak RNGs. +/// +/// **Warning**: This function will soon be updated to return a [`KeyPair`] struct. +pub fn generate_keypair_hybrid( + salt: Option<&[u8]>, + seed: &[u8], + app_info: Option<&[u8]>, +) -> (PrivateKey, PublicKey) { + let mut rng = EntropyRng::new(); + let mut seed_from_rng = [0u8; ed25519_dalek::SECRET_KEY_LENGTH]; + rng.fill_bytes(&mut seed_from_rng); + + let mut final_seed = seed.to_vec(); + final_seed.extend_from_slice(&seed_from_rng); + + derive_keypair_from_seed(salt, &final_seed, app_info) +} + +impl Clone for PrivateKey { + fn clone(&self) -> Self { + let encoded_privkey: Vec = serialize(&self.value).unwrap(); + let temp = deserialize::<::ed25519_dalek::SecretKey>(&encoded_privkey).unwrap(); + PrivateKey { value: temp } + } +} + +impl PartialEq for PrivateKey { + fn eq(&self, other: &PrivateKey) -> bool { + let encoded_privkey: Vec = serialize(&self.value).unwrap(); + let other_encoded_privkey: Vec = serialize(&other.value).unwrap(); + encoded_privkey == other_encoded_privkey + } +} + +impl Eq for PrivateKey {} + +impl fmt::Display for PrivateKey { + fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result { + write!(f, "") + } +} + +// ed25519_dalek doesn't implement Hash, which we need to put signatures into +// containers. For now, the derive_hash_xor_eq has no impact. +impl Hash for PublicKey { + fn hash(&self, state: &mut H) { + let encoded_pubkey: Vec = serialize(&self.value).unwrap(); + state.write(&encoded_pubkey); + } +} + +impl Hash for Signature { + fn hash(&self, state: &mut H) { + let encoded_signature: Vec = serialize(&self.value).unwrap(); + state.write(&encoded_signature); + } +} + +impl PartialEq for PublicKey { + fn eq(&self, other: &PublicKey) -> bool { + serialize(&self.value).unwrap() == serialize(&other.value).unwrap() + } +} + +impl Eq for PublicKey {} + +impl PartialEq for Signature { + fn eq(&self, other: &Signature) -> bool { + serialize(&self.value).unwrap() == serialize(&other.value).unwrap() + } +} + +impl Eq for Signature {} + +impl PublicKey { + /// The length of the public key in bytes. + pub const LENGTH: usize = ed25519_dalek::PUBLIC_KEY_LENGTH; + + /// Obtain a public key from a slice. + pub fn from_slice(data: &[u8]) -> Result { + match ed25519_dalek::PublicKey::from_bytes(data) { + Ok(key) => Ok(PublicKey { value: key }), + Err(err) => bail!("Public key decode error: {}", err), + } + } + + /// Convert the public key into a slice. + pub fn to_slice(&self) -> [u8; Self::LENGTH] { + let mut out = [0u8; Self::LENGTH]; + let temp = self.value.as_bytes(); + out.copy_from_slice(&temp[..]); + out + } +} + +fn public_key_strategy() -> impl Strategy { + any::<[u8; 32]>() + .prop_map(|seed| { + let mut rng: StdRng = SeedableRng::from_seed(seed); + let (_, public_key) = generate_keypair_for_testing(&mut rng); + public_key + }) + .no_shrink() +} + +impl Arbitrary for PublicKey { + type Parameters = (); + type Strategy = BoxedStrategy; + + fn arbitrary_with(_args: Self::Parameters) -> Self::Strategy { + public_key_strategy().boxed() + } +} + +impl From<&PrivateKey> for PublicKey { + fn from(private_key: &PrivateKey) -> Self { + let public_key = (&private_key.value).into(); + Self { value: public_key } + } +} + +impl Signature { + /// Obtains a signature from a byte representation + pub fn from_compact(data: &[u8]) -> Result { + match ed25519_dalek::Signature::from_bytes(data) { + Ok(sig) => Ok(Signature { value: sig }), + Err(_error) => bail!("error"), + } + } + + /// Converts the signature to its byte representation + pub fn to_compact(&self) -> [u8; ed25519_dalek::SIGNATURE_LENGTH] { + let mut out = [0u8; ed25519_dalek::SIGNATURE_LENGTH]; + let temp = self.value.to_bytes(); + out.copy_from_slice(&temp); + out + } + + // Check for malleable signatures. This method ensures that the S part is of canonical form + // and R does not lie on a small group (S and R as defined in RFC 8032). + fn is_valid(&self) -> Result<()> { + let bytes = self.to_compact(); + + let mut s_bits: [u8; 32] = [0u8; 32]; + s_bits.copy_from_slice(&bytes[32..]); + + // Check if S is of canonical form. + // We actually test if S < order_of_the_curve to capture malleable signatures. + let s = Scalar::from_canonical_bytes(s_bits); + if s == None { + bail!( + "Non canonical signature detected: As mentioned in RFC 8032, the 'S' part of the \ + signature should be smaller than the curve order. Consider reducing 'S' by mod 'L', \ + where 'L' is the order of the ed25519 curve."); + } + + // Check if the R lies on a small subgroup. + // Even though the security implications of a small order R are unclear, + // points of order <= 8 are rejected. + let mut r_bits: [u8; 32] = [0u8; 32]; + r_bits.copy_from_slice(&bytes[..32]); + + let compressed = curve25519_dalek::edwards::CompressedEdwardsY(r_bits); + let point = compressed.decompress(); + + match point { + Some(p) => { + if p.is_small_order() { + bail!( + "Non canonical signature detected: the 'R' part of the signature, \ + as defined in RFC 8032, lies on a small subgroup." + ) + } else { + Ok(()) + } + } + None => bail!("Malformed signature detected, the 'R' part of the signature is invalid"), + } + } +} + +impl ser::Serialize for PrivateKey { + fn serialize(&self, serializer: S) -> export::Result + where + S: ser::Serializer, + { + ed25519_dalek::SecretKey::serialize(&self.value, serializer) + } +} + +impl ser::Serialize for PublicKey { + fn serialize(&self, serializer: S) -> export::Result + where + S: ser::Serializer, + { + ed25519_dalek::PublicKey::serialize(&self.value, serializer) + } +} + +impl ser::Serialize for Signature { + fn serialize(&self, serializer: S) -> export::Result + where + S: ser::Serializer, + { + ed25519_dalek::Signature::serialize(&self.value, serializer) + } +} + +struct PrivateKeyVisitor; + +struct PublicKeyVisitor; + +struct SignatureVisitor; + +impl<'de> de::Visitor<'de> for PrivateKeyVisitor { + type Value = PrivateKey; + + fn expecting(&self, formatter: &mut fmt::Formatter<'_>) -> fmt::Result { + formatter.write_str("ed25519_dalek private key in bytes") + } + + fn visit_bytes(self, value: &[u8]) -> export::Result + where + E: de::Error, + { + match ed25519_dalek::SecretKey::from_bytes(value) { + Ok(key) => Ok(PrivateKey { value: key }), + Err(error) => Err(E::custom(error)), + } + } +} + +impl<'de> de::Visitor<'de> for PublicKeyVisitor { + type Value = PublicKey; + + fn expecting(&self, formatter: &mut fmt::Formatter<'_>) -> fmt::Result { + formatter.write_str("public key in bytes") + } + + fn visit_bytes(self, value: &[u8]) -> export::Result + where + E: de::Error, + { + match ed25519_dalek::PublicKey::from_bytes(value) { + Ok(key) => Ok(PublicKey { value: key }), + Err(error) => Err(E::custom(error)), + } + } +} + +impl<'de> de::Visitor<'de> for SignatureVisitor { + type Value = Signature; + + fn expecting(&self, formatter: &mut fmt::Formatter<'_>) -> fmt::Result { + formatter.write_str("ed25519_dalek signature in compact encoding") + } + + fn visit_bytes(self, value: &[u8]) -> export::Result + where + E: de::Error, + { + match ::ed25519_dalek::Signature::from_bytes(value) { + Ok(key) => Ok(Signature { value: key }), + Err(error) => Err(E::custom(error)), + } + } +} + +impl<'de> de::Deserialize<'de> for PrivateKey { + fn deserialize(deserializer: D) -> export::Result + where + D: de::Deserializer<'de>, + { + deserializer.deserialize_bytes(PrivateKeyVisitor {}) + } +} + +impl<'de> de::Deserialize<'de> for PublicKey { + fn deserialize(deserializer: D) -> export::Result + where + D: de::Deserializer<'de>, + { + deserializer.deserialize_bytes(PublicKeyVisitor {}) + } +} + +impl<'de> de::Deserialize<'de> for Signature { + fn deserialize(deserializer: D) -> export::Result + where + D: de::Deserializer<'de>, + { + deserializer.deserialize_bytes(SignatureVisitor {}) + } +} + +impl fmt::Debug for PrivateKey { + fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { + self.value.fmt(f) + } +} + +impl fmt::Debug for PublicKey { + fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { + write!(f, "{}", hex::encode(&self.value.to_bytes()[..])) + } +} + +impl fmt::Display for PublicKey { + fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { + write!(f, "{}", hex::encode(&self.value.to_bytes()[..])) + } +} + +impl fmt::Debug for Signature { + fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { + self.value.fmt(f) + } +} diff --git a/crypto/legacy_crypto/src/unit_tests/hash_test.rs b/crypto/legacy_crypto/src/unit_tests/hash_test.rs new file mode 100644 index 0000000000000..6033a04223040 --- /dev/null +++ b/crypto/legacy_crypto/src/unit_tests/hash_test.rs @@ -0,0 +1,220 @@ +// Copyright (c) The Libra Core Contributors +// SPDX-License-Identifier: Apache-2.0 + +use crate::hash::*; +use bitvec::BitVec; +use byteorder::{LittleEndian, WriteBytesExt}; +use proptest::{collection::vec, prelude::*}; +use proto_conv::{FromProto, IntoProto}; +use rand::{rngs::StdRng, SeedableRng}; + +#[derive(Serialize)] +struct Foo(u32); + +#[test] +fn test_default_hasher() { + assert_eq!( + Foo(3).test_only_hash(), + HashValue::from_iter_sha3(vec![::bincode::serialize(&Foo(3)).unwrap().as_slice()]), + ); + assert_eq!( + format!("{:x}", b"hello".test_only_hash()), + "3338be694f50c5f338814986cdf0686453a888b84f424d792af4b9202398f392", + ); + assert_eq!( + format!("{:x}", b"world".test_only_hash()), + "420baf620e3fcd9b3715b42b92506e9304d56e02d3a103499a3a292560cb66b2", + ); +} + +#[test] +fn test_primitive_type() { + let x = 0xf312_u16; + let mut wtr: Vec = vec![]; + wtr.write_u16::(x).unwrap(); + assert_eq!(x.test_only_hash(), HashValue::from_sha3(&wtr[..])); + + let x = 0x_ff001234_u32; + let mut wtr: Vec = vec![]; + wtr.write_u32::(x).unwrap(); + assert_eq!(x.test_only_hash(), HashValue::from_sha3(&wtr[..])); + + let x = 0x_89abcdef_01234567_u64; + let mut wtr: Vec = vec![]; + wtr.write_u64::(x).unwrap(); + assert_eq!(x.test_only_hash(), HashValue::from_sha3(&wtr[..])); +} + +#[test] +fn test_from_slice() { + { + let zero_byte_vec = vec![0; 32]; + assert_eq!( + HashValue::from_slice(&zero_byte_vec).unwrap(), + HashValue::zero() + ); + } + { + // The length is mismatched. + let zero_byte_vec = vec![0; 31]; + assert!(HashValue::from_slice(&zero_byte_vec).is_err()); + } +} + +#[test] +fn test_random_with_rng() { + let mut seed = [0u8; 32]; + seed[..4].copy_from_slice(&[1, 2, 3, 4]); + let hash1; + let hash2; + { + let mut rng: StdRng = SeedableRng::from_seed(seed); + hash1 = HashValue::random_with_rng(&mut rng); + } + { + let mut rng: StdRng = SeedableRng::from_seed(seed); + hash2 = HashValue::random_with_rng(&mut rng); + } + assert_eq!(hash1, hash2); +} + +#[test] +fn test_hash_value_iter_bits() { + let hash = b"hello".test_only_hash(); + let bits = hash.iter_bits().collect::>(); + + assert_eq!(bits.len(), HashValue::LENGTH_IN_BITS); + assert_eq!(bits[0], false); + assert_eq!(bits[1], false); + assert_eq!(bits[2], true); + assert_eq!(bits[3], true); + assert_eq!(bits[4], false); + assert_eq!(bits[5], false); + assert_eq!(bits[6], true); + assert_eq!(bits[7], true); + assert_eq!(bits[248], true); + assert_eq!(bits[249], false); + assert_eq!(bits[250], false); + assert_eq!(bits[251], true); + assert_eq!(bits[252], false); + assert_eq!(bits[253], false); + assert_eq!(bits[254], true); + assert_eq!(bits[255], false); + + let mut bits_rev = hash.iter_bits().rev().collect::>(); + bits_rev.reverse(); + assert_eq!(bits, bits_rev); +} + +#[test] +fn test_hash_value_iterator_exact_size() { + let hash = b"hello".test_only_hash(); + + let mut iter = hash.iter_bits(); + assert_eq!(iter.len(), HashValue::LENGTH_IN_BITS); + iter.next(); + assert_eq!(iter.len(), HashValue::LENGTH_IN_BITS - 1); + iter.next_back(); + assert_eq!(iter.len(), HashValue::LENGTH_IN_BITS - 2); + + let iter_rev = hash.iter_bits().rev(); + assert_eq!(iter_rev.len(), HashValue::LENGTH_IN_BITS); + + let iter_skip = hash.iter_bits().skip(100); + assert_eq!(iter_skip.len(), HashValue::LENGTH_IN_BITS - 100); +} + +#[test] +fn test_fmt_binary() { + let hash = b"hello".test_only_hash(); + let hash_str = format!("{:b}", hash); + assert_eq!(hash_str.len(), HashValue::LENGTH_IN_BITS); + for (bit, chr) in hash.iter_bits().zip(hash_str.chars()) { + assert_eq!(chr, if bit { '1' } else { '0' }); + } +} + +#[test] +fn test_common_prefix_bits_len() { + { + let hash1 = b"hello".test_only_hash(); + let hash2 = b"HELLO".test_only_hash(); + assert_eq!(hash1[0], 0b0011_0011); + assert_eq!(hash2[0], 0b1011_1000); + assert_eq!(hash1.common_prefix_bits_len(hash2), 0); + } + { + let hash1 = b"hello".test_only_hash(); + let hash2 = b"world".test_only_hash(); + assert_eq!(hash1[0], 0b0011_0011); + assert_eq!(hash2[0], 0b0100_0010); + assert_eq!(hash1.common_prefix_bits_len(hash2), 1); + } + { + let hash1 = b"hello".test_only_hash(); + let hash2 = b"100011001000".test_only_hash(); + assert_eq!(hash1[0], 0b0011_0011); + assert_eq!(hash2[0], 0b0011_0011); + assert_eq!(hash1[1], 0b0011_1000); + assert_eq!(hash2[1], 0b0010_0010); + assert_eq!(hash1.common_prefix_bits_len(hash2), 11); + } + { + let hash1 = b"hello".test_only_hash(); + let hash2 = b"hello".test_only_hash(); + assert_eq!( + hash1.common_prefix_bits_len(hash2), + HashValue::LENGTH_IN_BITS + ); + } +} + +#[test] +fn test_from_proto_invalid_length() { + let bytes = vec![1; 123]; + assert!(HashValue::from_proto(bytes).is_err()); +} + +proptest! { + #[test] + fn test_hashvalue_to_bits_roundtrip(hash in any::()) { + let bitvec: BitVec = hash.iter_bits().collect(); + let bytes: Vec = bitvec.into(); + let hash2 = HashValue::from_slice(&bytes).unwrap(); + prop_assert_eq!(hash, hash2); + } + + #[test] + fn test_hashvalue_to_bits_inverse_roundtrip(bits in vec(any::(), HashValue::LENGTH_IN_BITS)) { + let bitvec: BitVec = bits.iter().cloned().collect(); + let bytes: Vec = bitvec.into(); + let hash = HashValue::from_slice(&bytes).unwrap(); + let bits2: Vec = hash.iter_bits().collect(); + prop_assert_eq!(bits, bits2); + } + + #[test] + fn test_hashvalue_iter_bits_rev(hash in any::()) { + let bits1: Vec = hash.iter_bits().collect(); + let mut bits2: Vec = hash.iter_bits().rev().collect(); + bits2.reverse(); + prop_assert_eq!(bits1, bits2); + } + + #[test] + fn test_hashvalue_to_rev_bits_roundtrip(hash in any::()) { + let bitvec: BitVec = hash.iter_bits().rev().collect(); + let mut bytes: Vec = bitvec.into(); + bytes.reverse(); + let hash2 = HashValue::from_slice(&bytes).unwrap(); + prop_assert_eq!(hash, hash2); + } + + #[test] + fn test_hashvalue_proto_conversion_roundtrip(hash in any::()) { + let bytes = hash.into_proto(); + prop_assert_eq!(bytes.clone(), hash.as_ref()); + let hash2 = HashValue::from_proto(bytes).unwrap(); + prop_assert_eq!(hash, hash2); + } +} diff --git a/crypto/legacy_crypto/src/unit_tests/hkdf_test.rs b/crypto/legacy_crypto/src/unit_tests/hkdf_test.rs new file mode 100644 index 0000000000000..e3664df91e5f2 --- /dev/null +++ b/crypto/legacy_crypto/src/unit_tests/hkdf_test.rs @@ -0,0 +1,216 @@ +// Copyright (c) The Libra Core Contributors +// SPDX-License-Identifier: Apache-2.0 + +use crate::hkdf::*; +use sha2::{Sha256, Sha512}; +use sha3::Sha3_256; + +// Testing against sha256 test vectors. Unfortunately the rfc does not provide test vectors for +// sha3 and sha512. +#[test] +fn test_sha256_test_vectors() { + let tests = test_vectors_sha256(); + for t in tests.iter() { + let ikm = hex::decode(&t.ikm).unwrap(); + let salt = hex::decode(&t.salt).unwrap(); + let info = hex::decode(&t.info).unwrap(); + + let hkdf_extract = Hkdf::::extract(Option::from(&salt[..]), &ikm[..]).unwrap(); + let hkdf_expand = Hkdf::::expand(&hkdf_extract, Some(&info[..]), t.length); + + assert!(hkdf_expand.is_ok()); + assert_eq!(t.prk, hex::encode(hkdf_extract)); + assert_eq!(t.okm, hex::encode(hkdf_expand.unwrap())); + } +} + +// Testing against sha256 test vectors for the extract_then_expand function. +#[test] +fn test_extract_then_expand() { + let tests = test_vectors_sha256(); + for t in tests.iter() { + let ikm = hex::decode(&t.ikm).unwrap(); + let salt = hex::decode(&t.salt).unwrap(); + let info = hex::decode(&t.info).unwrap(); + + let hkdf_full = Hkdf::::extract_then_expand( + Option::from(&salt[..]), + &ikm[..], + Option::from(&info[..]), + t.length, + ); + + assert!(hkdf_full.is_ok()); + assert_eq!(t.okm, hex::encode(hkdf_full.unwrap())); + } +} + +#[test] +fn test_sha256_output_length() { + // According to rfc, max_sha256_length <= 255 * HashLen bytes + let max_hash_length: usize = 255 * 32; // = 8160 + + // We extract once, then we reuse it. + let hkdf_extract = Hkdf::::extract(None, &[]).unwrap(); + + // Test for max allowed (expected to pass) + let hkdf_expand = Hkdf::::expand(&hkdf_extract, None, max_hash_length); + assert!(hkdf_expand.is_ok()); + assert_eq!(hkdf_expand.unwrap().len(), max_hash_length); + + // Test for max + 1 (expected to fail) + let hkdf_expand = Hkdf::::expand(&hkdf_extract, None, max_hash_length + 1); + assert_eq!( + hkdf_expand.unwrap_err(), + HkdfError::InvalidOutputLengthError + ); + + // Test for 10_000 > max (expected to fail) + let hkdf_expand = Hkdf::::expand(&hkdf_extract, None, 10_000); + assert_eq!( + hkdf_expand.unwrap_err(), + HkdfError::InvalidOutputLengthError + ); + + // Test for zero size output (expected to fail) + let hkdf_expand = Hkdf::::expand(&hkdf_extract, None, 0); + assert_eq!( + hkdf_expand.unwrap_err(), + HkdfError::InvalidOutputLengthError + ); +} + +// FIPS 202 approves HMAC-SHA3 and specifies the block sizes (see top of page 22). +// SP 800-56C approves of HKDF-HMAC-hash as a randomness extractor with any approved hash function. +// But in contrast, I can't find any NIST statement that explicitly approves the use of KMAC +// as a randomness extractor. +// But, it's debatable if this is a pointless construct, as HMAC only exists to cover up weaknesses +// in Merkle-Damgard hashes, but sha3 (and Keccak) are sponge constructions, immune to length +// extension attacks. +#[test] +fn test_sha3_256_output_length() { + let max_hash_length: usize = 255 * 32; // = 8160 + + let hkdf_extract = Hkdf::::extract(None, &[]).unwrap(); + + // Test for max allowed (expected to pass) + let hkdf_expand = Hkdf::::expand(&hkdf_extract, None, max_hash_length); + assert!(hkdf_expand.is_ok()); + assert_eq!(hkdf_expand.unwrap().len(), max_hash_length); + + // Test for max + 1 (expected to fail) + let hkdf_expand = Hkdf::::expand(&hkdf_extract, None, max_hash_length + 1); + assert_eq!( + hkdf_expand.unwrap_err(), + HkdfError::InvalidOutputLengthError + ); + + // Test for 10_000 > max (expected to fail) + let hkdf_expand = Hkdf::::expand(&hkdf_extract, None, 10_000); + assert_eq!( + hkdf_expand.unwrap_err(), + HkdfError::InvalidOutputLengthError + ); + + // Test for zero size output (expected to fail) + let hkdf_expand = Hkdf::::expand(&hkdf_extract, None, 0); + assert_eq!( + hkdf_expand.unwrap_err(), + HkdfError::InvalidOutputLengthError + ); +} + +#[test] +fn test_sha512_output_length() { + let max_hash_length: usize = 255 * 64; // = 16320 + + let hkdf_extract = Hkdf::::extract(None, &[]).unwrap(); + + // Test for max allowed (expected to pass) + let hkdf_expand = Hkdf::::expand(&hkdf_extract, None, max_hash_length); + assert!(hkdf_expand.is_ok()); + assert_eq!(hkdf_expand.unwrap().len(), max_hash_length); + + // Test for max + 1 (expected to fail) + let hkdf_expand = Hkdf::::expand(&hkdf_extract, None, max_hash_length + 1); + assert_eq!( + hkdf_expand.unwrap_err(), + HkdfError::InvalidOutputLengthError + ); + + // Test for 10_000 > max (expected to fail) + let hkdf_expand = Hkdf::::expand(&hkdf_extract, None, 20_000); + assert_eq!( + hkdf_expand.unwrap_err(), + HkdfError::InvalidOutputLengthError + ); + + // Test for zero size output (expected to fail) + let hkdf_expand = Hkdf::::expand(&hkdf_extract, None, 0); + assert_eq!( + hkdf_expand.unwrap_err(), + HkdfError::InvalidOutputLengthError + ); +} + +#[test] +fn test_unsupported_hash_functions() { + // Test for ripemd160, output_length < 256 + let ripemd160_hkdf = Hkdf::::extract(None, &[]); + assert_eq!( + ripemd160_hkdf.unwrap_err(), + HkdfError::NotSupportedHashFunctionError + ); +} + +// Test Vectors for sha256 from https://tools.ietf.org/html/rfc5869. +struct Test<'a> { + ikm: &'a str, + salt: &'a str, + info: &'a str, + length: usize, + prk: &'a str, + okm: &'a str, +} + +fn test_vectors_sha256<'a>() -> Vec> { + vec![ + Test { + // Test Case 1 + ikm: "0b0b0b0b0b0b0b0b0b0b0b0b0b0b0b0b0b0b0b0b0b0b", + salt: "000102030405060708090a0b0c", + info: "f0f1f2f3f4f5f6f7f8f9", + length: 42, + prk: "077709362c2e32df0ddc3f0dc47bba6390b6c73bb50f9c3122ec844ad7c2b3e5", + okm: "3cb25f25faacd57a90434f64d0362f2a2d2d0a90cf1a5a4c5db02d56ecc4c5bf34007208d5b8\ + 87185865", + }, + Test { + // Test Case 2 + ikm: "000102030405060708090a0b0c0d0e0f101112131415161718191a1b1c1d1e1f202122232425\ + 262728292a2b2c2d2e2f303132333435363738393a3b3c3d3e3f404142434445464748494a4b\ + 4c4d4e4f", + salt: "606162636465666768696a6b6c6d6e6f707172737475767778797a7b7c7d7e7f80818283848\ + 5868788898a8b8c8d8e8f909192939495969798999a9b9c9d9e9fa0a1a2a3a4a5a6a7a8a9aa\ + abacadaeaf", + info: "b0b1b2b3b4b5b6b7b8b9babbbcbdbebfc0c1c2c3c4c5c6c7c8c9cacbcccdcecfd0d1d2d3d4d\ + 5d6d7d8d9dadbdcdddedfe0e1e2e3e4e5e6e7e8e9eaebecedeeeff0f1f2f3f4f5f6f7f8f9fa\ + fbfcfdfeff", + length: 82, + prk: "06a6b88c5853361a06104c9ceb35b45cef760014904671014a193f40c15fc244", + okm: "b11e398dc80327a1c8e7f78c596a49344f012eda2d4efad8a050cc4c19afa97c59045a99cac7\ + 827271cb41c65e590e09da3275600c2f09b8367793a9aca3db71cc30c58179ec3e87c14c01d5\ + c1f3434f1d87", + }, + Test { + // Test Case 3 + ikm: "0b0b0b0b0b0b0b0b0b0b0b0b0b0b0b0b0b0b0b0b0b0b", + salt: "", + info: "", + length: 42, + prk: "19ef24a32c717b167f33a91d6f648bdf96596776afdb6377ac434c1c293ccb04", + okm: "8da4e775a563c18f715f802a063c5a31b8a11f5c5ee1879ec3454e5f3c738d2d9d201395faa4\ + b61a96c8", + }, + ] +} diff --git a/crypto/legacy_crypto/src/unit_tests/mod.rs b/crypto/legacy_crypto/src/unit_tests/mod.rs new file mode 100644 index 0000000000000..9ba7ec34996d7 --- /dev/null +++ b/crypto/legacy_crypto/src/unit_tests/mod.rs @@ -0,0 +1,6 @@ +// Copyright (c) The Libra Core Contributors +// SPDX-License-Identifier: Apache-2.0 + +mod hkdf_test; +mod signing_test; +mod x25519_test; diff --git a/crypto/legacy_crypto/src/unit_tests/signing_test.rs b/crypto/legacy_crypto/src/unit_tests/signing_test.rs new file mode 100644 index 0000000000000..0fd2a6099d147 --- /dev/null +++ b/crypto/legacy_crypto/src/unit_tests/signing_test.rs @@ -0,0 +1,310 @@ +// Copyright (c) The Libra Core Contributors +// SPDX-License-Identifier: Apache-2.0 + +use crate::{hash::HashValue, signing::*, utils::*}; +use bincode::{deserialize, serialize}; +use core::ops::{Index, IndexMut}; +use proptest::prelude::*; +use rand::{rngs::StdRng, SeedableRng}; +use test::Bencher; + +#[test] +fn test_generate_and_encode_keypair() { + let (pub_key_human, pub_key_serialized, priv_key_serialized) = generate_and_encode_keypair(); + assert!(!pub_key_human.is_empty()); + assert!(!pub_key_serialized.is_empty()); + assert!(!priv_key_serialized.is_empty()); + + let public_key_out = from_encoded_string::(pub_key_serialized); + let _private_key_out = from_encoded_string::(priv_key_serialized); + + let public_key_human_out = format!("{:?}", public_key_out); + assert_eq!(pub_key_human, public_key_human_out) +} + +#[test] +fn test_default_key_pair() { + let mut seed: [u8; 32] = [0u8; 32]; + seed[..4].copy_from_slice(&[1, 2, 3, 4]); + let keypair1; + let keypair2; + { + let mut rng: StdRng = SeedableRng::from_seed(seed); + keypair1 = generate_keypair_for_testing(&mut rng); + } + { + let mut rng: StdRng = SeedableRng::from_seed(seed); + keypair2 = generate_keypair_for_testing(&mut rng); + } + assert_eq!(keypair1.1, keypair2.1); +} + +#[test] +fn test_hkdf_key_pair() { + // HKDF without salt and info. + let salt = None; + let seed = [0u8; 32]; + let info = None; + let (private_key, public_key) = derive_keypair_from_seed(salt, &seed, info); + let hash = HashValue::zero(); + let signature = sign_message(hash, &private_key).unwrap(); + assert!(verify_message(hash, &signature, &public_key).is_ok()); + + // HKDF with salt and info. + let raw_bytes = [2u8; 10]; + let salt = Some(&raw_bytes[0..4]); + let seed = [3u8; 32]; + let info = Some(&raw_bytes[4..10]); + + let (private_key1, public_key1) = derive_keypair_from_seed(salt, &seed, info); + let (_, public_key2) = derive_keypair_from_seed(salt, &seed, info); + assert_eq!(public_key1, public_key2); // Ensure determinism. + + let hash = HashValue::zero(); + let signature = sign_message(hash, &private_key1).unwrap(); + assert!(verify_message(hash, &signature, &public_key1).is_ok()); +} + +#[test] +fn test_generate_key_pair_with_seed() { + let salt = &b"some salt"[..]; + // In production, ensure seed has at least 256 bits of entropy. + let seed = [5u8; 32]; // seed is denoted as IKM in HKDF RFC 5869. + let info = &b"some app info"[..]; + let (_, public_key1) = generate_keypair_hybrid(Some(salt), &seed, Some(info)); + let (_, public_key2) = generate_keypair_hybrid(Some(salt), &seed, Some(info)); + assert_ne!(public_key1, public_key2); + + // Ensure that the deterministic generate_keypair_from_seed returns a completely different key. + let (_, public_key3) = derive_keypair_from_seed(Some(salt), &seed, Some(info)); + assert_ne!(public_key3, public_key1); + assert_ne!(public_key3, public_key2); +} + +#[bench] +pub fn bench_sign(bh: &mut Bencher) { + let (private_key, _) = generate_keypair(); + let hash = HashValue::zero(); + + bh.iter(|| { + let _ = sign_message(hash, &private_key); + }); +} + +#[bench] +pub fn bench_verify(bh: &mut Bencher) { + let (private_key, public_key) = generate_keypair(); + let hash = HashValue::zero(); + let signature = sign_message(hash, &private_key).unwrap(); + + bh.iter(|| { + verify_message(hash, &signature, &public_key).unwrap(); + }); +} + +proptest! { + #[test] + fn test_keys_encode((private_key, public_key) in keypair_strategy()) { + { + let serialized = serialize(&private_key).unwrap(); + let encoded = ::hex::encode(&serialized); + let decoded = from_encoded_string::(encoded); + prop_assert_eq!(private_key, decoded); + } + { + let serialized = serialize(&public_key).unwrap(); + let encoded = ::hex::encode(&serialized); + let decoded = from_encoded_string::(encoded); + prop_assert_eq!(public_key, decoded); + } + } + + #[test] + fn test_keys_serde((private_key, public_key) in keypair_strategy()) { + { + let serialized = serialize(&private_key).unwrap(); + let deserialized = deserialize::(&serialized).unwrap(); + prop_assert_eq!(private_key, deserialized); + } + { + let serialized = serialize(&public_key).unwrap(); + let deserialized = deserialize::(&serialized).unwrap(); + prop_assert_eq!(public_key, deserialized); + } + } + + #[test] + fn test_signature_serde( + hash in any::(), + (private_key, _public_key) in keypair_strategy() + ) { + let signature = sign_message(hash, &private_key).unwrap(); + let serialized = serialize(&signature).unwrap(); + let deserialized = deserialize::(&serialized).unwrap(); + assert_eq!(signature, deserialized); + } + + #[test] + fn test_sign_and_verify( + hash in any::(), + (private_key, public_key) in keypair_strategy() + ) { + let signature = sign_message(hash, &private_key).unwrap(); + prop_assert!(verify_message(hash, &signature, &public_key).is_ok()); + } + + // Check for canonical s and malleable signatures. + #[test] + fn test_signature_malleability( + hash in any::(), + (private_key, public_key) in keypair_strategy() + ) { + // ed25519-dalek signing ensures a canonical s value. + let signature = sign_message(hash, &private_key).unwrap(); + // Canonical signatures can be verified as expected. + prop_assert!(verify_message(hash, &signature, &public_key).is_ok()); + + let mut serialized = signature.to_compact(); + + let mut r_bits: [u8; 32] = [0u8; 32]; + r_bits.copy_from_slice(&serialized[..32]); + + let mut s_bits: [u8; 32] = [0u8; 32]; + s_bits.copy_from_slice(&serialized[32..]); + + // ed25519-dalek signing ensures a canonical s value. + let s = Scalar52::from_bytes(&s_bits); + + // adding L (order of the base point) so that s + L > L + let malleable_s = Scalar52::add(&s, &L); + let malleable_s_bits = malleable_s.to_bytes(); + // Update the signature (the s part); the signature gets not canonical. + serialized[32..].copy_from_slice(&malleable_s_bits); + + let non_canonical_sig = Signature::from_compact(&serialized).unwrap(); + + // Check that malleable signatures will pass verification and deserialization in dalek. + // Construct the corresponding dalek public key. + let dalek_public_key = ed25519_dalek::PublicKey::from_bytes(&public_key.to_slice()).unwrap(); + + // Construct the corresponding dalek Signature. This signature is not canonical. + let dalek_sig = ed25519_dalek::Signature::from_bytes(&non_canonical_sig.to_compact()); + + // ed25519_dalek will verify malleable signatures as valid. + prop_assert!(dalek_public_key.verify(hash.as_ref(), &dalek_sig.unwrap()).is_ok()); + + // Malleable signatures will fail to verify in our implementation, even if for some reason + // we received one. We detect non canonical signatures. + prop_assert!(verify_message(hash, &non_canonical_sig, &public_key).is_err()); + } +} + +/// The `Scalar52` struct represents an element in +/// β„€/β„“β„€ as 5 52-bit limbs. +struct Scalar52(pub [u64; 5]); + +/// `L` is the order of base point, i.e. 2^252 + 27742317777372353535851937790883648493 +const L: Scalar52 = Scalar52([ + 0x0002_631a_5cf5_d3ed, + 0x000d_ea2f_79cd_6581, + 0x0000_0000_0014_def9, + 0x0000_0000_0000_0000, + 0x0000_1000_0000_0000, +]); + +impl Scalar52 { + /// Return the zero scalar + fn zero() -> Scalar52 { + Scalar52([0, 0, 0, 0, 0]) + } + + /// Unpack a 256 bit scalar into 5 52-bit limbs. + fn from_bytes(bytes: &[u8; 32]) -> Scalar52 { + let mut words = [0u64; 4]; + for i in 0..4 { + for j in 0..8 { + words[i] |= u64::from(bytes[(i * 8) + j]) << (j * 8); + } + } + + let mask = (1u64 << 52) - 1; + let top_mask = (1u64 << 48) - 1; + let mut s = Scalar52::zero(); + + s[0] = words[0] & mask; + s[1] = ((words[0] >> 52) | (words[1] << 12)) & mask; + s[2] = ((words[1] >> 40) | (words[2] << 24)) & mask; + s[3] = ((words[2] >> 28) | (words[3] << 36)) & mask; + s[4] = (words[3] >> 16) & top_mask; + + s + } + + /// Pack the limbs of this `Scalar52` into 32 bytes + fn to_bytes(&self) -> [u8; 32] { + let mut s = [0u8; 32]; + + s[0] = self.0[0] as u8; + s[1] = (self.0[0] >> 8) as u8; + s[2] = (self.0[0] >> 16) as u8; + s[3] = (self.0[0] >> 24) as u8; + s[4] = (self.0[0] >> 32) as u8; + s[5] = (self.0[0] >> 40) as u8; + s[6] = ((self.0[0] >> 48) | (self.0[1] << 4)) as u8; + s[7] = (self.0[1] >> 4) as u8; + s[8] = (self.0[1] >> 12) as u8; + s[9] = (self.0[1] >> 20) as u8; + s[10] = (self.0[1] >> 28) as u8; + s[11] = (self.0[1] >> 36) as u8; + s[12] = (self.0[1] >> 44) as u8; + s[13] = self.0[2] as u8; + s[14] = (self.0[2] >> 8) as u8; + s[15] = (self.0[2] >> 16) as u8; + s[16] = (self.0[2] >> 24) as u8; + s[17] = (self.0[2] >> 32) as u8; + s[18] = (self.0[2] >> 40) as u8; + s[19] = ((self.0[2] >> 48) | (self.0[3] << 4)) as u8; + s[20] = (self.0[3] >> 4) as u8; + s[21] = (self.0[3] >> 12) as u8; + s[22] = (self.0[3] >> 20) as u8; + s[23] = (self.0[3] >> 28) as u8; + s[24] = (self.0[3] >> 36) as u8; + s[25] = (self.0[3] >> 44) as u8; + s[26] = self.0[4] as u8; + s[27] = (self.0[4] >> 8) as u8; + s[28] = (self.0[4] >> 16) as u8; + s[29] = (self.0[4] >> 24) as u8; + s[30] = (self.0[4] >> 32) as u8; + s[31] = (self.0[4] >> 40) as u8; + + s + } + + /// Compute `a + b` (without mod β„“) + fn add(a: &Scalar52, b: &Scalar52) -> Scalar52 { + let mut sum = Scalar52::zero(); + let mask = (1u64 << 52) - 1; + + // a + b + let mut carry: u64 = 0; + for i in 0..5 { + carry = a[i] + b[i] + (carry >> 52); + sum[i] = carry & mask; + } + + sum + } +} + +impl Index for Scalar52 { + type Output = u64; + fn index(&self, _index: usize) -> &u64 { + &(self.0[_index]) + } +} + +impl IndexMut for Scalar52 { + fn index_mut(&mut self, _index: usize) -> &mut u64 { + &mut (self.0[_index]) + } +} diff --git a/crypto/legacy_crypto/src/unit_tests/x25519_test.rs b/crypto/legacy_crypto/src/unit_tests/x25519_test.rs new file mode 100644 index 0000000000000..2e9b4cc4e61ea --- /dev/null +++ b/crypto/legacy_crypto/src/unit_tests/x25519_test.rs @@ -0,0 +1,70 @@ +// Copyright (c) The Libra Core Contributors +// SPDX-License-Identifier: Apache-2.0 + +use crate::{utils::from_encoded_string, x25519::*}; +use rand::{rngs::StdRng, SeedableRng}; + +#[test] +fn test_generate_and_encode_keypair() { + let (pub_key_human, pub_key_serialized, priv_key_serialized) = generate_and_encode_keypair(); + assert!(!pub_key_human.is_empty()); + assert!(!pub_key_serialized.is_empty()); + assert!(!priv_key_serialized.is_empty()); + + let public_key_out = from_encoded_string::(pub_key_serialized); + let public_key_human_out = format!("{:?}", public_key_out); + assert_eq!(pub_key_human, public_key_human_out) +} + +#[test] +fn test_default_key_pair() { + let mut seed: [u8; 32] = [0u8; 32]; + seed[..4].copy_from_slice(&[1, 2, 3, 4]); + let keypair1; + let keypair2; + { + let mut rng: StdRng = SeedableRng::from_seed(seed); + keypair1 = generate_keypair_for_testing(&mut rng); + } + { + let mut rng: StdRng = SeedableRng::from_seed(seed); + keypair2 = generate_keypair_for_testing(&mut rng); + } + assert_eq!(keypair1.1, keypair2.1); +} + +#[test] +fn test_hkdf_key_pair() { + // HKDF without salt and info. + let salt = None; + let seed = [0u8; 32]; + let info = None; + let (_, public_key1) = derive_keypair_from_seed(salt, &seed, info); + let (_, public_key2) = derive_keypair_from_seed(salt, &seed, info); + assert_eq!(public_key1, public_key2); + + // HKDF with salt and info. + let raw_bytes = [2u8; 10]; + let salt = Some(&raw_bytes[0..4]); + let seed = [3u8; 32]; + let info = Some(&raw_bytes[4..10]); + let (_, public_key1) = derive_keypair_from_seed(salt, &seed, info); + let (_, public_key2) = derive_keypair_from_seed(salt, &seed, info); + assert_eq!(public_key1, public_key2); +} + +#[test] +fn test_generate_key_pair_with_seed() { + let salt = &b"some salt"[..]; + // In production, ensure seed has at least 256 bits of entropy. + let seed = [5u8; 32]; // seed is denoted as IKM in HKDF RFC 5869. + let info = &b"some app info"[..]; + let (_, public_key1) = generate_keypair_hybrid(Some(salt), &seed, Some(info)); + let (_, public_key2) = generate_keypair_hybrid(Some(salt), &seed, Some(info)); + assert_ne!(public_key1, public_key2); + + // Ensure that the deterministic generate_keypair_from_seed returns a completely different key. + let (_, public_key3) = derive_keypair_from_seed(Some(salt), &seed, Some(info)); + assert_ne!(public_key3, public_key1); + assert_ne!(public_key3, public_key2); +} diff --git a/crypto/legacy_crypto/src/utils.rs b/crypto/legacy_crypto/src/utils.rs new file mode 100644 index 0000000000000..a287c5be8ba39 --- /dev/null +++ b/crypto/legacy_crypto/src/utils.rs @@ -0,0 +1,42 @@ +// Copyright (c) The Libra Core Contributors +// SPDX-License-Identifier: Apache-2.0 + +//! This module contains various utility functions for testing and debugging purposes + +use crate::signing::{generate_keypair_for_testing, PrivateKey, PublicKey}; +use bincode::{deserialize, serialize}; +use proptest::prelude::*; +use rand::{rngs::StdRng, SeedableRng}; +use serde::{de::DeserializeOwned, Serialize}; + +/// Used to produce keypairs from a seed for testing purposes +pub fn keypair_strategy() -> impl Strategy { + // The no_shrink is because keypairs should be fixed -- shrinking would cause a different + // keypair to be generated, which appears to not be very useful. + any::<[u8; 32]>() + .prop_map(|seed| { + let mut rng: StdRng = SeedableRng::from_seed(seed); + let (private_key, public_key) = generate_keypair_for_testing(&mut rng); + (private_key, public_key) + }) + .no_shrink() +} + +/// Generically deserializes a string into a struct `T`, used for producing test cases +pub fn from_encoded_string(encoded_str: String) -> T +where + T: DeserializeOwned, +{ + assert!(!encoded_str.is_empty()); + let bytes_out = ::hex::decode(encoded_str).unwrap(); + deserialize::(&bytes_out).unwrap() +} + +/// Generically serializes a string from a struct `T`, used for producing test cases +pub fn encode_to_string(to_encode: &T) -> String +where + T: Serialize, +{ + let bytes = serialize(to_encode).unwrap(); + ::hex::encode(&bytes) +} diff --git a/crypto/legacy_crypto/src/x25519.rs b/crypto/legacy_crypto/src/x25519.rs new file mode 100644 index 0000000000000..d94c3c89c1bd0 --- /dev/null +++ b/crypto/legacy_crypto/src/x25519.rs @@ -0,0 +1,344 @@ +// Copyright (c) The Libra Core Contributors +// SPDX-License-Identifier: Apache-2.0 + +//! An implementation of x25519 elliptic curve key pairs required for [Diffie-Hellman key exchange](https://en.wikipedia.org/wiki/Diffie%E2%80%93Hellman_key_exchange) +//! in the Libra project. +//! +//! This is an API for [Elliptic Curves for Security - RFC 7748](https://tools.ietf.org/html/rfc7748) +//! and it only deals with long-term key generation and handling. +//! +//! Warning: This API will soon be updated in the [`nextgen`] module. +//! +//! # Examples +//! +//! ``` +//! use crypto::x25519::{ +//! derive_keypair_from_seed, generate_keypair, generate_keypair_from_rng, +//! generate_keypair_hybrid, +//! }; +//! use rand::{rngs::StdRng, SeedableRng}; +//! +//! // Derive an X25519 from seed using the extract-then-expand HKDF method from RFC 5869. +//! let salt = &b"some salt"[..]; +//! // In production, ensure seed has at least 256 bits of entropy. +//! let seed = [5u8; 32]; // seed is denoted as IKM in HKDF RFC 5869. +//! let info = &b"some app info"[..]; +//! +//! let (private_key1, public_key1) = derive_keypair_from_seed(Some(salt), &seed, Some(info)); +//! let (private_key2, public_key2) = derive_keypair_from_seed(Some(salt), &seed, Some(info)); +//! assert_eq!(public_key1, public_key2); +//! +//! // Generate a random X25519 key pair. +//! let (private_key, public_key) = generate_keypair(); +//! +//! // Generate an X25519 key pair from an RNG (in this example a SeedableRng). +//! let seed = [1u8; 32]; +//! let mut rng: StdRng = SeedableRng::from_seed(seed); +//! let (private_key, public_key) = generate_keypair_from_rng(&mut rng); +//! +//! // Generate an X25519 key pair from an RNG and a user-provided seed. +//! let salt = &b"some salt"[..]; +//! // In production, ensure seed has at least 256 bits of entropy. +//! let seed = [5u8; 32]; // seed is denoted as IKM in HKDF RFC 5869. +//! let info = &b"some app info"[..]; +//! let (private_key1, public_key1) = generate_keypair_hybrid(Some(salt), &seed, Some(info)); +//! let (private_key2, public_key2) = generate_keypair_hybrid(Some(salt), &seed, Some(info)); +//! assert_ne!(public_key1, public_key2); +//! ``` +use crate::{hkdf::Hkdf, utils::*}; +use core::fmt; +use crypto_derive::{SilentDebug, SilentDisplay}; +use derive_deref::Deref; +use failure::prelude::*; +use proptest::{ + arbitrary::any, + prelude::{Arbitrary, BoxedStrategy}, + strategy::*, +}; +use rand::{ + rngs::{EntropyRng, StdRng}, + CryptoRng, RngCore, SeedableRng, +}; +use serde::{de, export, ser}; +use sha2::Sha256; +use std::fmt::{Debug, Display}; +use x25519_dalek; + +/// An x25519 private key. +#[derive(Deref, SilentDisplay, SilentDebug)] +pub struct X25519PrivateKey { + value: x25519_dalek::StaticSecret, +} + +/// An x25519 public key. +#[derive(Copy, Clone, Deref)] +pub struct X25519PublicKey { + value: x25519_dalek::PublicKey, +} + +#[deprecated( + since = "1.0.0", + note = "This will be superseded by the new cryptography API" +)] +impl Clone for X25519PrivateKey { + fn clone(&self) -> Self { + let bytes = self.to_bytes(); + X25519PrivateKey { + value: x25519_dalek::StaticSecret::from(bytes), + } + } +} + +impl X25519PrivateKey { + /// Length of the private key in bytes. + pub const LENGTH: usize = 32; +} + +impl X25519PublicKey { + /// Length of the public key in bytes. + pub const LENGTH: usize = 32; + + /// Obtain a public key from a slice. + pub fn from_slice(data: &[u8]) -> Result { + assert_eq!( + data.len(), + X25519PublicKey::LENGTH, + "X25519 Public key wrong length error; expected {} but received {}", + X25519PublicKey::LENGTH, + data.len() + ); + let mut fixed_size_data: [u8; X25519PublicKey::LENGTH] = Default::default(); + fixed_size_data.copy_from_slice(data); + let key = x25519_dalek::PublicKey::from(fixed_size_data); + Ok(X25519PublicKey { value: key }) + } + + /// Convert a public key into a slice. + pub fn to_slice(&self) -> [u8; Self::LENGTH] { + *self.value.as_bytes() + } + + fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + write!(f, "{}", hex::encode(&self.to_slice())) + } +} + +impl PartialEq for X25519PublicKey { + fn eq(&self, other: &X25519PublicKey) -> bool { + self.as_bytes() == other.as_bytes() + } +} + +impl Eq for X25519PublicKey {} + +impl Display for X25519PublicKey { + fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + X25519PublicKey::fmt(self, f) + } +} + +impl Debug for X25519PublicKey { + fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + X25519PublicKey::fmt(self, f) + } +} + +impl ser::Serialize for X25519PrivateKey { + fn serialize(&self, serializer: S) -> export::Result + where + S: ser::Serializer, + { + serializer.serialize_bytes(&self.to_bytes()) + } +} + +impl ser::Serialize for X25519PublicKey { + fn serialize(&self, serializer: S) -> export::Result + where + S: ser::Serializer, + { + serializer.serialize_bytes(self.as_bytes()) + } +} + +impl From<&X25519PrivateKey> for X25519PublicKey { + fn from(private_key: &X25519PrivateKey) -> Self { + let public_key = (&private_key.value).into(); + Self { value: public_key } + } +} + +struct X25519PrivateKeyVisitor; +struct X25519PublicKeyVisitor; + +impl<'de> de::Visitor<'de> for X25519PrivateKeyVisitor { + type Value = X25519PrivateKey; + + fn expecting(&self, formatter: &mut fmt::Formatter<'_>) -> fmt::Result { + formatter.write_str("X25519_dalek private key in bytes") + } + + fn visit_bytes(self, value: &[u8]) -> export::Result + where + E: de::Error, + { + let mut fixed_size_data: [u8; X25519PrivateKey::LENGTH] = Default::default(); + fixed_size_data.copy_from_slice(value); + let key = x25519_dalek::StaticSecret::from(fixed_size_data); + Ok(X25519PrivateKey { value: key }) + } +} + +impl<'de> de::Visitor<'de> for X25519PublicKeyVisitor { + type Value = X25519PublicKey; + + fn expecting(&self, formatter: &mut fmt::Formatter<'_>) -> fmt::Result { + formatter.write_str("X25519_dalek public key in bytes") + } + + fn visit_bytes(self, value: &[u8]) -> export::Result + where + E: de::Error, + { + let mut fixed_size_data: [u8; X25519PublicKey::LENGTH] = Default::default(); + fixed_size_data.copy_from_slice(value); + let key = x25519_dalek::PublicKey::from(fixed_size_data); + Ok(X25519PublicKey { value: key }) + } +} + +impl<'de> de::Deserialize<'de> for X25519PrivateKey { + fn deserialize(deserializer: D) -> export::Result + where + D: de::Deserializer<'de>, + { + deserializer.deserialize_bytes(X25519PrivateKeyVisitor {}) + } +} + +impl<'de> de::Deserialize<'de> for X25519PublicKey { + fn deserialize(deserializer: D) -> export::Result + where + D: de::Deserializer<'de>, + { + deserializer.deserialize_bytes(X25519PublicKeyVisitor {}) + } +} + +impl Arbitrary for X25519PublicKey { + type Parameters = (); + fn arbitrary_with(_args: Self::Parameters) -> Self::Strategy { + public_key_strategy().boxed() + } + + type Strategy = BoxedStrategy; +} + +fn public_key_strategy() -> impl Strategy { + any::<[u8; X25519PublicKey::LENGTH]>() + .prop_map(|seed| { + let mut rng: StdRng = SeedableRng::from_seed(seed); + let (_, public_key) = generate_keypair_from_rng(&mut rng); + public_key + }) + .no_shrink() +} + +/// Generates a consistent keypair `(X25519PrivateKey, X25519PublicKey)` for unit tests. +pub fn generate_keypair_for_testing(rng: &mut R) -> (X25519PrivateKey, X25519PublicKey) +where + R: ::rand::SeedableRng + ::rand::RngCore + ::rand::CryptoRng, +{ + generate_keypair_from_rng(rng) +} + +/// Generates a random key-pair `(X25519PrivateKey, X25519PublicKey)`. +pub fn generate_keypair() -> (X25519PrivateKey, X25519PublicKey) { + let mut rng = EntropyRng::new(); + generate_keypair_from_rng(&mut rng) +} + +/// Derives a keypair `(X25519PrivateKey, X25519PublicKey)` from +/// a) salt (optional) - denoted as 'salt' in RFC 5869 +/// b) seed - denoted as 'IKM' in RFC 5869 +/// c) application info (optional) - denoted as 'info' in RFC 5869 +/// +/// using the HKDF key derivation protocol, as defined in RFC 5869. +/// This implementation uses the full extract-then-expand HKDF steps +/// based on the SHA-256 hash function. +/// +/// **Warning**: This function will soon be updated to return a KeyPair struct. +pub fn derive_keypair_from_seed( + salt: Option<&[u8]>, + seed: &[u8], + app_info: Option<&[u8]>, +) -> (X25519PrivateKey, X25519PublicKey) { + let derived_bytes = + Hkdf::::extract_then_expand(salt, seed, app_info, X25519PrivateKey::LENGTH); + let mut key_bytes = [0u8; X25519PrivateKey::LENGTH]; + key_bytes.copy_from_slice(derived_bytes.unwrap().as_slice()); + + let secret: x25519_dalek::StaticSecret = x25519_dalek::StaticSecret::from(key_bytes); + let public: x25519_dalek::PublicKey = (&secret).into(); + ( + X25519PrivateKey { value: secret }, + X25519PublicKey { value: public }, + ) +} + +/// Generates a keypair `(X25519PrivateKey, X25519PublicKey)` based on an RNG. +pub fn generate_keypair_from_rng(rng: &mut R) -> (X25519PrivateKey, X25519PublicKey) +where + R: RngCore + CryptoRng, +{ + let secret: x25519_dalek::StaticSecret = x25519_dalek::StaticSecret::new(rng); + let public: x25519_dalek::PublicKey = (&secret).into(); + ( + X25519PrivateKey { value: secret }, + X25519PublicKey { value: public }, + ) +} + +/// Generates a random keypair `(X25519PrivateKey, X25519PublicKey)` and returns string +/// representations tuple: +/// 1. human readable public key. +/// 2. hex encoded serialized public key. +/// 3. hex encoded serialized private key. +pub fn generate_and_encode_keypair() -> (String, String, String) { + let (private_key, public_key) = generate_keypair(); + let pub_key_human = hex::encode(public_key.to_slice()); + let public_key_serialized_str = encode_to_string(&public_key); + let private_key_serialized_str = encode_to_string(&private_key); + ( + pub_key_human, + public_key_serialized_str, + private_key_serialized_str, + ) +} + +/// Generates a random keypair `(PrivateKey, PublicKey)` by combining the output of `EntropyRng` +/// with a user-provided seed. This concatenated seed is used as the seed to HKDF (RFC 5869). +/// +/// Similarly to `derive_keypair_from_seed` the user provides the following inputs: +/// a) salt (optional) - denoted as 'salt' in RFC 5869 +/// b) seed - denoted as 'IKM' in RFC 5869 +/// c) application info (optional) - denoted as 'info' in RFC 5869 +/// +/// Note that this method is not deterministic, but the (random + static seed) key +/// generation makes it safer against low entropy pools and weak RNGs. +/// +/// **Warning**: This function will soon be updated to return a [`KeyPair`] struct. +pub fn generate_keypair_hybrid( + salt: Option<&[u8]>, + seed: &[u8], + app_info: Option<&[u8]>, +) -> (X25519PrivateKey, X25519PublicKey) { + let mut rng = EntropyRng::new(); + let mut seed_from_rng = [0u8; ed25519_dalek::SECRET_KEY_LENGTH]; + rng.fill_bytes(&mut seed_from_rng); + + let mut final_seed = seed.to_vec(); + final_seed.extend_from_slice(&seed_from_rng); + + derive_keypair_from_seed(salt, &final_seed, app_info) +} diff --git a/crypto/nextgen_crypto/Cargo.toml b/crypto/nextgen_crypto/Cargo.toml new file mode 100644 index 0000000000000..e5ae924d6c4d1 --- /dev/null +++ b/crypto/nextgen_crypto/Cargo.toml @@ -0,0 +1,40 @@ +[package] +name = "nextgen_crypto" +version = "0.1.0" +authors = ["Libra Association "] +license = "Apache-2.0" +publish = false +edition = "2018" + +[dependencies] +bincode = "1.1.1" +byteorder = "1.3.1" +bytes = "0.4.12" +curve25519-dalek = "1.1.3" +derive_deref = "1.0.2" +ed25519-dalek = { version = "1.0.0-pre.1", features = ["serde"] } +hex = "0.3" +lazy_static = "1.3.0" +pairing = "0.14.2" +proptest = "0.9.1" +proptest-derive = "0.1.0" +rand = "0.6.5" +serde = { version = "1.0.89", features = ["derive"] } +threshold_crypto = "0.3" +tiny-keccak = "1.4.2" +x25519-dalek = "0.5.2" +digest = "0.8.0" +hmac = "0.7.0" +sha3 = "0.8.2" +sha2 = "0.8.0" +regex = "1.1.6" + +failure = { path = "../../common/failure_ext", package = "failure_ext" } +crypto-derive = { path = "../legacy_crypto/src/macros" } +proto_conv = { path = "../../common/proto_conv" } +crypto = { path = "../legacy_crypto" } + +[dev-dependencies] +bitvec = "0.10.1" +byteorder = "1.3.1" +ripemd160 = "0.8.0" diff --git a/crypto/nextgen_crypto/README.md b/crypto/nextgen_crypto/README.md new file mode 100644 index 0000000000000..7c4e47e996218 --- /dev/null +++ b/crypto/nextgen_crypto/README.md @@ -0,0 +1,33 @@ +--- +id: crypto +title: NextGen Crypto +custom_edit_url: https://github.com/libra/libra/edit/master/crypto/nextgen_crypto/README.md +--- +# NextGen + +The nextgen folder hosts the future version of the Libra crypto API and several algorithms that will be used in the upcoming versions. + +## Overview + +Nextgen contains the following implementations: + +* traits.rs introduces new abstractions for the crypto API. +* Ed25519 performs signatures using the new API design based on [ed25519-dalek](https://docs.rs/ed25519-dalek/1.0.0-pre.1/ed25519_dalek/) library with additional security checks (e.g. for malleability). +* BLS12381 performs signatures using the new API design based on [threshold_crypto](https://github.com/poanetwork/threshold_crypto) library. BLS signatures currently undergo a [standartization process](https://tools.ietf.org/html/draft-boneh-bls-signature-00). +* ECVRF implements a verifiable random function (VRF) according to [draft-irtf-cfrg-vrf-04](https://tools.ietf.org/html/draft-irtf-cfrg-vrf-04) over curve25519. +* SLIP-0010 implements universal hierarchical key derivation for Ed25519 according to [SLIP-0010](https://github.com/satoshilabs/slips/blob/master/slip-0010.md). + +## How is this module organized? + nextgen_crypto/src + β”œβ”€β”€ bls12381.rs # Bls12-381 implementation of the signing/verification API in traits.rs + β”œβ”€β”€ ed25519.rs # Ed25519 implementation of the signing/verification API in traits.rs + β”œβ”€β”€ lib.rs + β”œβ”€β”€ slip0010.rs # SLIP-0010 universal hierarchical key derivation for Ed25519 + β”œβ”€β”€ test_utils.rs + β”œβ”€β”€ traits.rs # New API design and the necessary abstractions + β”œβ”€β”€ unit_tests/ # Tests + └── vrf/ + β”œβ”€β”€ ecvrf.rs # ECVRF implementation using curve25519 and SHA512 + β”œβ”€β”€ mod.rs + └── unit_tests # Tests + diff --git a/crypto/nextgen_crypto/src/bls12381.rs b/crypto/nextgen_crypto/src/bls12381.rs new file mode 100644 index 0000000000000..215725d385087 --- /dev/null +++ b/crypto/nextgen_crypto/src/bls12381.rs @@ -0,0 +1,254 @@ +// Copyright (c) The Libra Core Contributors +// SPDX-License-Identifier: Apache-2.0 + +//! This module implements the Verifying/Signing API for signatures on the +//! [BLS12-381 curve](https://tools.ietf.org/id/draft-yonezawa-pairing-friendly-curves-00.html). +//! +//! # Example +//! +//! ``` +//! use crypto::hash::{CryptoHasher, TestOnlyHasher}; +//! use nextgen_crypto::{ +//! bls12381::*, +//! traits::{Signature, SigningKey, Uniform}, +//! }; +//! use rand::{rngs::StdRng, SeedableRng}; +//! +//! let mut hasher = TestOnlyHasher::default(); +//! hasher.write("Test message".as_bytes()); +//! let hashed_message = hasher.finish(); +//! +//! let mut rng: StdRng = SeedableRng::from_seed([0; 32]); +//! let private_key = BLS12381PrivateKey::generate_for_testing(&mut rng); +//! let public_key: BLS12381PublicKey = (&private_key).into(); +//! let signature = private_key.sign_message(&hashed_message); +//! assert!(signature.verify(&hashed_message, &public_key).is_ok()); +//! ``` +//! **Note**: The above example generates a private key using a private function intended only for +//! testing purposes. Production code should find an alternate means for secure key generation. +//! +//! This module is not currently used, but could be included in the future for improved +//! performance in consensus. + +use crate::traits::*; +use bincode::{deserialize, serialize}; +use core::convert::TryFrom; +use crypto::hash::HashValue; +use crypto_derive::{SilentDebug, SilentDisplay}; +use derive_deref::Deref; +use failure::prelude::*; +use pairing::{ + bls12_381::{Fr, FrRepr}, + PrimeField, +}; +use rand::Rng; +use serde::{Deserialize, Serialize}; +use threshold_crypto; + +// type alias for this unwieldy type +type ThresholdBLSPrivateKey = + threshold_crypto::serde_impl::SerdeSecret; + +/// A BLS12-381 private key +#[derive(Serialize, Deserialize, Deref, SilentDisplay, SilentDebug)] +pub struct BLS12381PrivateKey(ThresholdBLSPrivateKey); + +/// A BLS12-381 public key +#[derive(Clone, Hash, Serialize, Deserialize, Deref, Debug, PartialEq, Eq)] +pub struct BLS12381PublicKey(threshold_crypto::PublicKey); + +/// A BLS12-381 signature +#[derive(Clone, Hash, Serialize, Deserialize, Deref, Debug, PartialEq, Eq)] +pub struct BLS12381Signature(threshold_crypto::Signature); + +impl BLS12381PublicKey { + /// Serializes a BLS12381PublicKey + pub fn to_bytes(&self) -> [u8; threshold_crypto::PK_SIZE] { + self.0.to_bytes() + } +} + +impl BLS12381Signature { + /// Serializes a BLS12381Signature + pub fn to_bytes(&self) -> [u8; threshold_crypto::SIG_SIZE] { + self.0.to_bytes() + } +} + +impl BLS12381PrivateKey { + #[allow(dead_code)] + /// Deserialize a [`BLS12381PrivateKey`]. This method DOES NOT check for key validity. + fn from_bytes_unchecked( + mut fr_repr: [u64; 4usize], + ) -> std::result::Result { + // let mut fr_repr: [u64; 4usize] = rng.gen(); + // Since field modulus is 381-bit prime, drop the 3 highest-order bits + fr_repr[3] &= 0x1FFF_FFFF_FFFF_FFFF; + let mut fr = Fr::from_repr(FrRepr(fr_repr)).unwrap(); + Ok(BLS12381PrivateKey( + threshold_crypto::serde_impl::SerdeSecret(threshold_crypto::SecretKey::from_mut( + &mut fr, + )), + )) + } +} + +/////////////////////// +// PrivateKey Traits // +/////////////////////// + +impl PrivateKey for BLS12381PrivateKey { + type PublicKeyMaterial = BLS12381PublicKey; +} + +impl SigningKey for BLS12381PrivateKey { + type VerifyingKeyMaterial = BLS12381PublicKey; + type SignatureMaterial = BLS12381Signature; + + fn sign_message(&self, message: &HashValue) -> BLS12381Signature { + let secret_key: &ThresholdBLSPrivateKey = self; + let sig = secret_key.sign(message.as_ref()); + BLS12381Signature(sig) + } +} + +impl Uniform for BLS12381PrivateKey { + fn generate_for_testing(rng: &mut R) -> Self + where + R: SeedableCryptoRng, + { + let mut fr_repr: [u64; 4usize] = rng.gen(); + // Since field modulus is 381-bit prime, drop the 3 highest-order bits + fr_repr[3] &= 0x1FFF_FFFF_FFFF_FFFF; + let mut fr = Fr::from_repr(FrRepr(fr_repr)).unwrap(); + BLS12381PrivateKey(threshold_crypto::serde_impl::SerdeSecret( + threshold_crypto::SecretKey::from_mut(&mut fr), + )) + } +} + +impl std::cmp::PartialEq for BLS12381PrivateKey { + fn eq(&self, other: &Self) -> bool { + serialize(self).unwrap() == serialize(other).unwrap() + } +} + +impl std::cmp::Eq for BLS12381PrivateKey {} + +impl TryFrom<&[u8]> for BLS12381PrivateKey { + type Error = CryptoMaterialError; + + fn try_from(bytes: &[u8]) -> std::result::Result { + // first we deserialize raw bytes, which may or may not work + let key_res = deserialize::(bytes); + // Note that the underlying implementation of SerdeSecret checks for validation during + // deserialisation. Also, we don't need to check for validity of the derived PublicKey, as a + // correct SerdeSecret (checked during desrialisation that is in field) will always produce + // a valid PublicKey. + key_res.or(Err(CryptoMaterialError::DeserializationError)) + } +} +impl ValidKey for BLS12381PrivateKey { + // TODO(ladi): implement! + fn to_bytes(&self) -> Vec { + unimplemented!("ask ladi!") + } +} + +impl Genesis for BLS12381PrivateKey { + fn genesis() -> Self { + let mut buf = [0u8; threshold_crypto::PK_SIZE]; + buf[threshold_crypto::PK_SIZE - 1] = 1; + Self::try_from(buf.as_ref()).unwrap() + } +} + +////////////////////// +// PublicKey Traits // +////////////////////// +impl From<&BLS12381PrivateKey> for BLS12381PublicKey { + fn from(secret_key: &BLS12381PrivateKey) -> Self { + let secret: &ThresholdBLSPrivateKey = secret_key; + let public: threshold_crypto::PublicKey = secret.public_key(); + BLS12381PublicKey(public) + } +} + +// We deduce PublicKey from this +// we get the ability to do `pubkey.validate(msg, signature)` +impl PublicKey for BLS12381PublicKey { + type PrivateKeyMaterial = BLS12381PrivateKey; + fn length() -> usize { + threshold_crypto::PK_SIZE + } +} + +impl VerifyingKey for BLS12381PublicKey { + type SigningKeyMaterial = BLS12381PrivateKey; + type SignatureMaterial = BLS12381Signature; +} +impl std::fmt::Display for BLS12381PublicKey { + fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + write!(f, "{}", hex::encode(&self.to_bytes()[..])) + } +} + +impl TryFrom<&[u8]> for BLS12381PublicKey { + type Error = CryptoMaterialError; + + fn try_from(bytes: &[u8]) -> std::result::Result { + // first we deserialize raw bytes, which may or may not work + let key_res = deserialize::(bytes); + // TODO: call some validation! For now we just put in a + // very conspicuous useless lambda + key_res.or(Err(CryptoMaterialError::DeserializationError)) + } +} + +impl ValidKey for BLS12381PublicKey { + fn to_bytes(&self) -> Vec { + self.0.to_bytes().to_vec() + } +} + +////////////////////// +// Signature Traits // +////////////////////// + +impl Signature for BLS12381Signature { + type VerifyingKeyMaterial = BLS12381PublicKey; + + /// Checks that `signature` is valid for `message` using `public_key`. + fn verify(&self, message: &HashValue, public_key: &BLS12381PublicKey) -> Result<()> { + self.verify_arbitrary_msg(message.as_ref(), public_key) + } + + /// Checks that `signature` is valid for an arbitrary &[u8] `message` using `public_key`. + fn verify_arbitrary_msg(&self, message: &[u8], public_key: &BLS12381PublicKey) -> Result<()> { + if public_key.verify(self, message.as_ref()) { + Ok(()) + } else { + bail!("The provided signature is not valid on this PublicKey and Message") + } + } + + fn to_bytes(&self) -> Vec { + self.0.to_bytes().to_vec() + } +} + +impl TryFrom<&[u8]> for BLS12381Signature { + type Error = CryptoMaterialError; + + fn try_from(bytes: &[u8]) -> std::result::Result { + let l = bytes.len(); + if l > threshold_crypto::SIG_SIZE { + return Err(CryptoMaterialError::WrongLengthError); + } + let mut tmp = [0u8; threshold_crypto::SIG_SIZE]; + tmp[..l].copy_from_slice(&bytes[..l]); + let sig = threshold_crypto::Signature::from_bytes(&tmp) + .map_err(|_err| CryptoMaterialError::ValidationError)?; + Ok(BLS12381Signature(sig)) + } +} diff --git a/crypto/nextgen_crypto/src/ed25519.rs b/crypto/nextgen_crypto/src/ed25519.rs new file mode 100644 index 0000000000000..2dfeeb93ab065 --- /dev/null +++ b/crypto/nextgen_crypto/src/ed25519.rs @@ -0,0 +1,353 @@ +// Copyright (c) The Libra Core Contributors +// SPDX-License-Identifier: Apache-2.0 + +//! This module provides an API for the PureEdDSA signature scheme over the ed25519 twisted +//! Edwards curve as defined in [RFC8032](https://tools.ietf.org/html/rfc8032). +//! +//! Signature verification also checks and rejects non-canonical signatures. +//! +//! # Examples +//! +//! ``` +//! use crypto::hash::{CryptoHasher, TestOnlyHasher}; +//! use nextgen_crypto::{ +//! ed25519::*, +//! traits::{Signature, SigningKey, Uniform}, +//! }; +//! use rand::{rngs::StdRng, SeedableRng}; +//! +//! let mut hasher = TestOnlyHasher::default(); +//! hasher.write("Test message".as_bytes()); +//! let hashed_message = hasher.finish(); +//! +//! let mut rng: StdRng = SeedableRng::from_seed([0; 32]); +//! let private_key = Ed25519PrivateKey::generate_for_testing(&mut rng); +//! let public_key: Ed25519PublicKey = (&private_key).into(); +//! let signature = private_key.sign_message(&hashed_message); +//! assert!(signature.verify(&hashed_message, &public_key).is_ok()); +//! ``` +//! **Note**: The above example generates a private key using a private function intended only for +//! testing purposes. Production code should find an alternate means for secure key generation. + +use crate::traits::*; +use core::convert::TryFrom; +use crypto::hash::HashValue; +use crypto_derive::{SilentDebug, SilentDisplay}; +use curve25519_dalek::scalar::Scalar; +use ed25519_dalek; +use failure::prelude::*; +use serde::{Deserialize, Serialize}; + +/// An Ed25519 private key +#[derive(Serialize, Deserialize, SilentDisplay, SilentDebug)] +pub struct Ed25519PrivateKey(ed25519_dalek::SecretKey); + +/// An Ed25519 public key +#[derive(Clone, Serialize, Deserialize, Debug)] +pub struct Ed25519PublicKey(ed25519_dalek::PublicKey); + +/// An Ed25519 signature +#[derive(Clone, Serialize, Deserialize, Debug)] +pub struct Ed25519Signature(ed25519_dalek::Signature); + +impl Ed25519PrivateKey { + /// Serialize an Ed25519PrivateKey. + pub fn to_bytes(&self) -> [u8; ed25519_dalek::SECRET_KEY_LENGTH] { + self.0.to_bytes() + } + + /// Deserialize an Ed25519PrivateKey without any validation checks apart from expected key size. + fn from_bytes_unchecked( + bytes: &[u8], + ) -> std::result::Result { + match ed25519_dalek::SecretKey::from_bytes(bytes) { + Ok(dalek_secret_key) => Ok(Ed25519PrivateKey(dalek_secret_key)), + Err(_) => Err(CryptoMaterialError::DeserializationError), + } + } +} + +impl Ed25519PublicKey { + /// Serialize an Ed25519PublicKey. + pub fn to_bytes(&self) -> [u8; ed25519_dalek::PUBLIC_KEY_LENGTH] { + self.0.to_bytes() + } + + /// Deserialize an Ed25519PublicKey without any validation checks apart from expected key size. + pub(crate) fn from_bytes_unchecked( + bytes: &[u8], + ) -> std::result::Result { + match ed25519_dalek::PublicKey::from_bytes(bytes) { + Ok(dalek_public_key) => Ok(Ed25519PublicKey(dalek_public_key)), + Err(_) => Err(CryptoMaterialError::DeserializationError), + } + } +} + +impl Ed25519Signature { + /// Serialize an Ed25519Signature. + pub fn to_bytes(&self) -> [u8; ed25519_dalek::SIGNATURE_LENGTH] { + self.0.to_bytes() + } + + /// Deserialize an Ed25519Signature without any validation checks (malleability) + /// apart from expected key size. + pub(crate) fn from_bytes_unchecked( + bytes: &[u8], + ) -> std::result::Result { + match ed25519_dalek::Signature::from_bytes(bytes) { + Ok(dalek_signature) => Ok(Ed25519Signature(dalek_signature)), + Err(_) => Err(CryptoMaterialError::DeserializationError), + } + } + + /// Check for correct size and malleability issues. + /// This method ensures s is of canonical form and R does not lie on a small group. + fn is_valid(bytes: &[u8]) -> std::result::Result<(), CryptoMaterialError> { + if bytes.len() != ed25519_dalek::SIGNATURE_LENGTH { + return Err(CryptoMaterialError::WrongLengthError); + } + + let mut s_bits: [u8; 32] = [0u8; 32]; + s_bits.copy_from_slice(&bytes[32..]); + + // Check if s is of canonical form. + // We actually test if s < order_of_the_curve to capture malleable signatures. + let s = Scalar::from_canonical_bytes(s_bits); + if s == None { + return Err(CryptoMaterialError::CanonicalRepresentationError); + } + + // Check if the R lies on a small subgroup. + // Even though the security implications of a small order R are unclear, + // points of order <= 8 are rejected. + let mut r_bits: [u8; 32] = [0u8; 32]; + r_bits.copy_from_slice(&bytes[..32]); + + let compressed = curve25519_dalek::edwards::CompressedEdwardsY(r_bits); + let point = compressed.decompress(); + + match point { + Some(p) => { + if p.is_small_order() { + Err(CryptoMaterialError::SmallSubgroupError) + } else { + Ok(()) + } + } + None => Err(CryptoMaterialError::DeserializationError), + } + } +} + +/////////////////////// +// PrivateKey Traits // +/////////////////////// + +impl PrivateKey for Ed25519PrivateKey { + type PublicKeyMaterial = Ed25519PublicKey; +} + +impl SigningKey for Ed25519PrivateKey { + type VerifyingKeyMaterial = Ed25519PublicKey; + type SignatureMaterial = Ed25519Signature; + + fn sign_message(&self, message: &HashValue) -> Ed25519Signature { + let secret_key: &ed25519_dalek::SecretKey = &self.0; + let public_key: Ed25519PublicKey = self.into(); + let expanded_secret_key: ed25519_dalek::ExpandedSecretKey = + ed25519_dalek::ExpandedSecretKey::from(secret_key); + let sig = expanded_secret_key.sign(message.as_ref(), &public_key.0); + Ed25519Signature(sig) + } +} + +impl Uniform for Ed25519PrivateKey { + fn generate_for_testing(rng: &mut R) -> Self + where + R: SeedableCryptoRng, + { + Ed25519PrivateKey(ed25519_dalek::SecretKey::generate(rng)) + } +} + +impl PartialEq for Ed25519PrivateKey { + fn eq(&self, other: &Self) -> bool { + self.to_bytes() == other.to_bytes() + } +} + +impl Eq for Ed25519PrivateKey {} + +// We could have a distinct kind of validation for the PrivateKey, for +// ex. checking the derived PublicKey is valid? +impl TryFrom<&[u8]> for Ed25519PrivateKey { + type Error = CryptoMaterialError; + + /// Deserialize an Ed25519PrivateKey. This method will also check for key validity. + fn try_from(bytes: &[u8]) -> std::result::Result { + // Note that the only requirement is that the size of the key is 32 bytes, something that + // is already checked during deserialization of ed25519_dalek::SecretKey + // Also, the underlying ed25519_dalek implementation ensures that the derived public key + // is safe and it will not lie in a small-order group, thus no extra check for PublicKey + // validation is required. + Ed25519PrivateKey::from_bytes_unchecked(bytes) + } +} +impl ValidKey for Ed25519PrivateKey { + fn to_bytes(&self) -> Vec { + self.to_bytes().to_vec() + } +} + +impl Genesis for Ed25519PrivateKey { + fn genesis() -> Self { + let mut buf = [0u8; ed25519_dalek::SECRET_KEY_LENGTH]; + buf[ed25519_dalek::SECRET_KEY_LENGTH - 1] = 1; + Self::try_from(buf.as_ref()).unwrap() + } +} + +////////////////////// +// PublicKey Traits // +////////////////////// + +// Implementing From<&PrivateKey<...>> allows to derive a public key in a more elegant fashion +impl From<&Ed25519PrivateKey> for Ed25519PublicKey { + fn from(secret_key: &Ed25519PrivateKey) -> Self { + let secret: &ed25519_dalek::SecretKey = &secret_key.0; + let public: ed25519_dalek::PublicKey = secret.into(); + Ed25519PublicKey(public) + } +} + +// We deduce PublicKey from this +impl PublicKey for Ed25519PublicKey { + type PrivateKeyMaterial = Ed25519PrivateKey; + fn length() -> usize { + ed25519_dalek::PUBLIC_KEY_LENGTH + } +} + +impl std::hash::Hash for Ed25519PublicKey { + fn hash(&self, state: &mut H) { + let encoded_pubkey = self.to_bytes(); + state.write(&encoded_pubkey); + } +} + +// Those are required by the implementation of hash above +impl PartialEq for Ed25519PublicKey { + fn eq(&self, other: &Ed25519PublicKey) -> bool { + self.to_bytes() == other.to_bytes() + } +} + +impl Eq for Ed25519PublicKey {} + +// We deduce VerifyingKey from pointing to the signature material +// we get the ability to do `pubkey.validate(msg, signature)` +impl VerifyingKey for Ed25519PublicKey { + type SigningKeyMaterial = Ed25519PrivateKey; + type SignatureMaterial = Ed25519Signature; +} + +impl std::fmt::Display for Ed25519PublicKey { + fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + write!(f, "{}", hex::encode(&self.0.to_bytes()[..])) + } +} + +impl TryFrom<&[u8]> for Ed25519PublicKey { + type Error = CryptoMaterialError; + + /// Deserialize an Ed25519PublicKey. This method will also check for key validity, for instance + /// it will only deserialize keys that are safe against small subgroup attacks. + fn try_from(bytes: &[u8]) -> std::result::Result { + // We need to access the Edwards point which is not directly accessible from + // ed25519_dalek::PublicKey, so we need to do some custom deserialization. + if bytes.len() != ed25519_dalek::PUBLIC_KEY_LENGTH { + return Err(CryptoMaterialError::WrongLengthError); + } + + let mut bits = [0u8; ed25519_dalek::PUBLIC_KEY_LENGTH]; + bits.copy_from_slice(&bytes[..ed25519_dalek::PUBLIC_KEY_LENGTH]); + + let compressed = curve25519_dalek::edwards::CompressedEdwardsY(bits); + let point = compressed + .decompress() + .ok_or(CryptoMaterialError::DeserializationError)?; + + // Check if the point lies on a small subgroup. This is required + // when using curves with a small cofactor (in ed25519, cofactor = 8). + if point.is_small_order() { + return Err(CryptoMaterialError::SmallSubgroupError); + } + + // Unfortunately, tuple struct `PublicKey` is private so we cannot + // Ok(Ed25519PublicKey(ed25519_dalek::PublicKey(compressed, point))) + // and we have to again invoke deserialization. + Ed25519PublicKey::from_bytes_unchecked(bytes) + } +} + +impl ValidKey for Ed25519PublicKey { + fn to_bytes(&self) -> Vec { + self.0.to_bytes().to_vec() + } +} + +////////////////////// +// Signature Traits // +////////////////////// + +impl Signature for Ed25519Signature { + type VerifyingKeyMaterial = Ed25519PublicKey; + + /// Checks that `self` is valid for `message` using `public_key`. + fn verify(&self, message: &HashValue, public_key: &Ed25519PublicKey) -> Result<()> { + self.verify_arbitrary_msg(message.as_ref(), public_key) + } + + /// Checks that `self` is valid for an arbitary &[u8] `message` using `public_key`. + /// Outside of this crate, this particular function should only be used for native signature + /// verification in move + fn verify_arbitrary_msg(&self, message: &[u8], public_key: &Ed25519PublicKey) -> Result<()> { + Ed25519Signature::is_valid(&self.to_bytes())?; + + public_key + .0 + .verify(message, &self.0) + .map_err(std::convert::Into::into) + .and(Ok(())) + } + + fn to_bytes(&self) -> Vec { + self.0.to_bytes().to_vec() + } +} + +impl std::hash::Hash for Ed25519Signature { + fn hash(&self, state: &mut H) { + let encoded_pubkey = self.to_bytes(); + state.write(&encoded_pubkey); + } +} + +impl TryFrom<&[u8]> for Ed25519Signature { + type Error = CryptoMaterialError; + + fn try_from(bytes: &[u8]) -> std::result::Result { + Ed25519Signature::is_valid(bytes)?; + Ed25519Signature::from_bytes_unchecked(bytes) + } +} + +// Those are required by the implementation of hash above +impl PartialEq for Ed25519Signature { + fn eq(&self, other: &Ed25519Signature) -> bool { + self.to_bytes().as_ref() == other.to_bytes().as_ref() + } +} + +impl Eq for Ed25519Signature {} diff --git a/crypto/nextgen_crypto/src/lib.rs b/crypto/nextgen_crypto/src/lib.rs new file mode 100644 index 0000000000000..f1d742766607c --- /dev/null +++ b/crypto/nextgen_crypto/src/lib.rs @@ -0,0 +1,24 @@ +// Copyright (c) The Libra Core Contributors +// SPDX-License-Identifier: Apache-2.0 + +//! A library supplying various cryptographic primitives that will be used in the next version. + +#![deny(missing_docs)] +#![feature(test)] +#![feature(trait_alias)] + +pub mod bls12381; +pub mod ed25519; +pub mod slip0010; +pub mod traits; +pub mod vrf; + +#[cfg(test)] +mod unit_tests; + +mod test_utils; + +pub use crypto::{ + hash::HashValue, + signing::{PrivateKey, PublicKey, Signature}, +}; diff --git a/crypto/nextgen_crypto/src/slip0010.rs b/crypto/nextgen_crypto/src/slip0010.rs new file mode 100644 index 0000000000000..1aff2b66867f6 --- /dev/null +++ b/crypto/nextgen_crypto/src/slip0010.rs @@ -0,0 +1,188 @@ +// Copyright (c) The Libra Core Contributors +// SPDX-License-Identifier: Apache-2.0 + +//! This module provides an API for SLIP-0010 and for the Ed25519 curve based on +//! [SLIP-0010 : Universal private key derivation from master private key](https://github.com/satoshilabs/slips/blob/master/slip-0010.md). +//! +//! SLIP-0010 describes how to derive a master private/public key for various curves and how a +//! [BIP-0032](https://github.com/bitcoin/bips/blob/master/bip-0032.mediawiki) like derivation is used to hierarchically generate keys from an original seed. +//! +//! Note that SLIP-0010 only supports private parent key β†’ private child key hardened key generation +//! for Ed25519. This key derivation protocol should be preferred to HKDF if compatibility with +//! hardware wallets and the BIP32 protocol is required. +use byteorder::{BigEndian, ByteOrder}; +use ed25519_dalek::{PublicKey, SecretKey}; +use hmac::{Hmac, Mac}; +use regex::Regex; +use sha2::Sha512; + +/// Extended private key that includes additional child_path and chain-code. +pub struct ExtendedPrivKey { + /// Full key path. + pub key_path: String, + /// Private key. + pub private_key: SecretKey, + /// Chain code. + pub chain_code: [u8; 32], +} + +impl ExtendedPrivKey { + /// Construct a new ExtendedPrivKey. + pub fn new( + key_path: String, + private_key: SecretKey, + chain_code: [u8; 32], + ) -> Result { + if Slip0010::is_valid_path(&key_path) { + Ok(Self { + key_path, + private_key, + chain_code, + }) + } else { + Err(Slip0010Error::KeyPathError) + } + } + + /// Get public key. + pub fn get_public(&self) -> PublicKey { + (&self.private_key).into() + } + + /// Derive a child key from this key and a child number. + fn child_key(&self, child_number: u32) -> Result { + let hardened_child_number = if child_number < Slip0010::HARDENED_START { + child_number + Slip0010::HARDENED_START + } else { + child_number + }; + + let mut hmac = + Hmac::::new_varkey(&self.chain_code).map_err(|_| Slip0010Error::MACKeyError)?; + + let mut be_n = [0u8; 4]; + BigEndian::write_u32(&mut be_n, hardened_child_number); + + hmac.input(&[0u8]); + hmac.input(self.private_key.as_bytes()); + hmac.input(&be_n); + + let hmac_output = hmac.result_reset().code(); + + let mut chain_code_bits: [u8; 32] = [0u8; 32]; + chain_code_bits.copy_from_slice(&hmac_output[32..]); + + let mut new_child_path = self.key_path.to_owned(); + new_child_path.push_str("/"); + new_child_path.push_str(&*child_number.to_string()); + + let secret_key = + SecretKey::from_bytes(&hmac_output[..32]).map_err(|_| Slip0010Error::SecretKeyError)?; + + ExtendedPrivKey::new(new_child_path, secret_key, chain_code_bits) + } +} + +/// SLIP-0010 structure. +pub struct Slip0010 {} + +impl Slip0010 { + /// Curve = "ed25519 seed" for the ed25519 curve. + const ED25519_CURVE: &'static [u8] = b"ed25519 seed"; + /// Hardened keys start from 2^31 = 2147483648. + const HARDENED_START: u32 = 2_147_483_648; + + /// Generate master key from seed. + pub fn generate_master(seed: &[u8]) -> Result { + let mut hmac = Hmac::::new_varkey(&Slip0010::ED25519_CURVE) + .map_err(|_| Slip0010Error::MACKeyError)?; + hmac.input(seed); + let hmac_output = hmac.result_reset().code(); + + let mut chain_code_bits: [u8; 32] = [0u8; 32]; + chain_code_bits.copy_from_slice(&hmac_output[32..]); + + let secret_key = + SecretKey::from_bytes(&hmac_output[..32]).map_err(|_| Slip0010Error::SecretKeyError)?; + + ExtendedPrivKey::new("m".to_string(), secret_key, chain_code_bits) + } + + /// Generate a child private key. + pub fn derive_child_key( + parent_key: ExtendedPrivKey, + child_number: u32, + ) -> Result { + parent_key.child_key(child_number) + } + /// Match a valid path of the form "m/A/B.."; each sub-path after m is smaller than 2147483648. + pub fn is_valid_path(path: &str) -> bool { + // match path where each node is [0..2147483647] + if !Regex::new( + r"^m(/([0-9]|[1-8][0-9]|9[0-9]|[1-8][0-9]{2}|9[0-8][0-9]|99[0-9]|[1-8] +[0-9]{3}|9[0-8][0-9]{2}|99[0-8][0-9]|999[0-9]|[1-8][0-9]{4}|9[0-8][0-9]{3}|99[0-8][0-9]{2}|999[0-8] +[0-9]|9999[0-9]|[1-8][0-9]{5}|9[0-8][0-9]{4}|99[0-8][0-9]{3}|999[0-8][0-9]{2}|9999[0-8][0-9]|99999 +[0-9]|[1-8][0-9]{6}|9[0-8][0-9]{5}|99[0-8][0-9]{4}|999[0-8][0-9]{3}|9999[0-8][0-9]{2}|99999[0-8] +[0-9]|999999[0-9]|[1-8][0-9]{7}|9[0-8][0-9]{6}|99[0-8][0-9]{5}|999[0-8][0-9]{4}|9999[0-8][0-9]{3} +|99999[0-8][0-9]{2}|999999[0-8][0-9]|9999999[0-9]|[1-8][0-9]{8}|9[0-8][0-9]{7}|99[0-8][0-9]{6}|999 +[0-8][0-9]{5}|9999[0-8][0-9]{4}|99999[0-8][0-9]{3}|999999[0-8][0-9]{2}|9999999[0-8][0-9]|99999999 +[0-9]|1[0-9]{9}|20[0-9]{8}|21[0-3][0-9]{7}|214[0-6][0-9]{6}|2147[0-3][0-9]{5}|21474[0-7][0-9]{4}| +214748[0-2][0-9]{3}|2147483[0-5][0-9]{2}|21474836[0-3][0-9]|214748364[0-7]))*$", + ) + .unwrap() // The expression is valid, so this will never fail. + .is_match(path) + { + return false; + } + + let segments: Vec<&str> = path.split('/').collect(); + segments + .iter() + .skip(1) + .map(|s| s.replace("'", "")) + .all(|s| s.parse::().is_ok()) + } + + /// Derive a key from a path and a seed. + pub fn derive_from_path(path: &str, seed: &[u8]) -> Result { + if !Slip0010::is_valid_path(path) { + return Err(Slip0010Error::KeyPathError); + } + + let mut key = Slip0010::generate_master(seed)?; + + let segments: Vec<&str> = path.split('/').collect(); + let segments = segments + .iter() + .skip(1) + .map(|s| s.replace("'", "")) + // We first check if the path is valid, so this will never fail. + .map(|s| s.parse::().unwrap()) + .collect::>(); + + for segment in segments { + key = Slip0010::derive_child_key(key, segment)?; + } + + Ok(key) + } +} + +/// An error type for SLIP-0010 key derivation issues. +/// +/// This enum reflects there are various causes of SLIP-0010 failures, including: +/// a) invalid key-path. +/// b) secret_key generation errors. +/// c) hmac related errors. +#[derive(Clone, Debug, PartialEq, Eq, failure::prelude::Fail)] +pub enum Slip0010Error { + /// Invalid key path. + #[fail(display = "SLIP-0010 invalid key path")] + KeyPathError, + /// Any error related to key derivation. + #[fail(display = "SLIP-0010 - cannot generate key")] + SecretKeyError, + /// HMAC key related error; unlikely to happen because every key size is accepted in HMAC. + #[fail(display = "SLIP-0010 - HMAC key error")] + MACKeyError, +} diff --git a/crypto/nextgen_crypto/src/test_utils.rs b/crypto/nextgen_crypto/src/test_utils.rs new file mode 100644 index 0000000000000..2f3629b6fe0ea --- /dev/null +++ b/crypto/nextgen_crypto/src/test_utils.rs @@ -0,0 +1,56 @@ +// Copyright (c) The Libra Core Contributors +// SPDX-License-Identifier: Apache-2.0 + +//! Internal module containing convenience utility functions mainly for testing + +use crate::traits::{SeedableCryptoRng, Uniform}; +use bincode::serialize; +use serde::Serialize; + +/// A keypair consisting of a private and public key +#[derive(Clone)] +pub struct KeyPair +where + for<'a> P: From<&'a S>, +{ + pub private_key: S, + pub public_key: P, +} + +impl From for KeyPair +where + for<'a> P: From<&'a S>, +{ + fn from(private_key: S) -> Self { + KeyPair { + public_key: (&private_key).into(), + private_key, + } + } +} + +impl Uniform for KeyPair +where + S: Uniform, + for<'a> P: From<&'a S>, +{ + fn generate_for_testing(rng: &mut R) -> Self + where + R: SeedableCryptoRng, + { + let private_key = S::generate_for_testing(rng); + private_key.into() + } +} + +impl std::fmt::Debug for KeyPair +where + Priv: Serialize, + Pub: Serialize + for<'a> From<&'a Priv>, +{ + fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + let mut v = serialize(&self.private_key).unwrap(); + v.extend(&serialize(&self.public_key).unwrap()); + write!(f, "{}", hex::encode(&v[..])) + } +} diff --git a/crypto/nextgen_crypto/src/traits.rs b/crypto/nextgen_crypto/src/traits.rs new file mode 100644 index 0000000000000..159c654a5c62e --- /dev/null +++ b/crypto/nextgen_crypto/src/traits.rs @@ -0,0 +1,212 @@ +// Copyright (c) The Libra Core Contributors +// SPDX-License-Identifier: Apache-2.0 + +//! This module provides a generic set of traits for dealing with cryptographic primitives. +//! +//! For examples on how to use these traits, see the implementations of the [`ed25519`] or +//! [`bls12381`] modules. + +use core::convert::{From, TryFrom}; +use crypto::hash::HashValue; +use failure::prelude::*; +use std::{fmt::Debug, hash::Hash}; + +/// An error type for key and signature validation issues, see [`ValidKey`][ValidKey]. +/// +/// This enum reflects there are two interesting causes of validation +/// failure for the ingestion of key or signature material: deserialization errors +/// (often, due to mangled material or curve equation failure for ECC) and +/// validation errors (material recognizable but unacceptable for use, +/// e.g. unsafe). +#[derive(Clone, Debug, PartialEq, Eq, failure::prelude::Fail)] +pub enum CryptoMaterialError { + /// Key or signature material does not deserialize correctly. + #[fail(display = "DeserializationError")] + DeserializationError, + /// Key or signature material deserializes, but is otherwise not valid. + #[fail(display = "ValidationError")] + ValidationError, + /// Key or signature material does not have the expected size. + #[fail(display = "WrongLengthError")] + WrongLengthError, + /// Part of the signature or key is not canonical resulting to malleability issues. + #[fail(display = "CanonicalRepresentationError")] + CanonicalRepresentationError, + /// A curve point (i.e., a public key) lies on a small group. + #[fail(display = "SmallSubgroupError")] + SmallSubgroupError, + /// A curve point (i.e., a public key) does not satisfy the curve equation. + #[fail(display = "PointNotOnCurveError")] + PointNotOnCurveError, +} + +/// Key material with a notion of byte validation. +/// +/// A type family for material that knows how to serialize and +/// deserialize, as well as validate byte-encoded material. The +/// validation must be implemented as a [`TryFrom`][TryFrom] which +/// classifies its failures against the above +/// [`CryptoMaterialError`][CryptoMaterialError]. +/// +/// This provides an implementation for a validation that relies on a +/// round-trip to bytes and corresponding [`TryFrom`][TryFrom]. +pub trait ValidKey: + // The for<'a> exactly matches the assumption "deserializable from any lifetime". + for<'a> TryFrom<&'a [u8], Error = CryptoMaterialError> + Debug +{ + /// TryFrom is the source of truth on whether we can build a valid key. + /// => we can use it once we've built, to validate! + fn validate(&self) -> std::result::Result<(), CryptoMaterialError> { + Self::try_from(self.to_bytes().as_slice())?; + Ok(()) + } + + /// Convert the valid key to bytes. + fn to_bytes(&self) -> Vec; +} + +/// An extension to to/from Strings for [`ValidKey`][ValidKey]. +/// +/// Relies on [`hex`][::hex] for string encoding / decoding. +/// No required fields, provides a default implementation. +pub trait ValidKeyStringExt: ValidKey { + /// When trying to convert from bytes, we simply decode the string into + /// bytes before checking if we can convert. + fn from_encoded_string(encoded_str: &str) -> std::result::Result { + let bytes_out = ::hex::decode(encoded_str); + // We defer to `try_from` to make sure we only produce valid keys. + bytes_out + // We reinterpret a failure to serialize: key is mangled someway. + .or(Err(CryptoMaterialError::DeserializationError)) + .and_then(|ref bytes| Self::try_from(bytes)) + } + /// A function to encode into hex-string after serializing. + fn to_encoded_string(&self) -> Result { + Ok(::hex::encode(&self.to_bytes())) + } +} + +// There's nothing required in this extension, so let's just derive it +// for anybody that has a ValidKey. +impl ValidKeyStringExt for T {} + +/// A type family for key material that should remain secret and has an +/// associated type of the [`PublicKey`][PublicKey] family. +pub trait PrivateKey: ValidKey { + /// We require public / private types to be coupled, i.e. their + /// associated type is each other. + type PublicKeyMaterial: PublicKey; +} + +/// A type family of valid keys that know how to sign. +/// +/// A trait for a [`ValidKey`][ValidKey] which knows how to sign a +/// message, and return an associated `Signature` type. +pub trait SigningKey: + PrivateKey::VerifyingKeyMaterial> +{ + /// The associated verifying key for this signing key. + type VerifyingKeyMaterial: VerifyingKey; + /// The associated signature for this signing key. + type SignatureMaterial: Signature; + + /// Signs an input message. + fn sign_message(&self, message: &HashValue) -> Self::SignatureMaterial; +} + +/// A type for key material that can be publicly shared, and in asymmetric +/// fashion, can be obtained from a [`PrivateKey`][PrivateKey] +/// reference. +/// This convertibility requirement ensures the existence of a +/// deterministic, canonical public key construction from a private key. +pub trait PublicKey: ValidKey + Clone + Eq + Hash + + // This unsightly turbofish type parameter is the precise constraint + // needed to require that there exists an + // + // ``` + // impl From<&MyPrivateKeyMaterial> for MyPublicKeyMaterial + // ``` + // + // declaration, for any `MyPrivateKeyMaterial`, `MyPublicKeyMaterial` + // on which we register (respectively) `PublicKey` and `PrivateKey` + // implementations. + for<'a> From<&'a ::PrivateKeyMaterial> { + /// We require public / private types to be coupled, i.e. their + /// associated type is each other. + type PrivateKeyMaterial: PrivateKey; + /// The length of the [`PublicKey`] + fn length() -> usize; + +} + +/// A type family of public keys that are used for signing. +/// +/// It is linked to a type of the Signature family, which carries the +/// verification implementation. +pub trait VerifyingKey: + PublicKey::SigningKeyMaterial> +{ + /// The associated signing key for this verifying key. + type SigningKeyMaterial: SigningKey; + /// The associated signature for this verifying key. + type SignatureMaterial: Signature; + + /// We provide the logical implementation which dispatches to the signature. + fn verify_signature( + &self, + message: &HashValue, + signature: &Self::SignatureMaterial, + ) -> Result<()> { + signature.verify(message, self) + } +} + +/// A type family for signature material that knows which public key type +/// is needed to verify it, and given such a public key, knows how to +/// verify. +/// +/// This trait simply requires an association to some type of the +/// [`PublicKey`][PublicKey] family of which we are the `SignatureMaterial`. +/// +/// It should be possible to write a generic signature function that +/// checks signature material passed as `&[u8]` and only returns Ok when +/// that material de-serializes to a signature of the expected concrete +/// scheme. This would be done as an extension trait of +/// [`Signature`][Signature]. +pub trait Signature: + for<'a> TryFrom<&'a [u8], Error = CryptoMaterialError> + Sized + Debug + Clone + Eq + Hash +{ + /// The associated verifying key for this signature. + type VerifyingKeyMaterial: VerifyingKey; + + /// The verification function. + fn verify(&self, message: &HashValue, public_key: &Self::VerifyingKeyMaterial) -> Result<()>; + + /// Native verification function. + fn verify_arbitrary_msg( + &self, + message: &[u8], + public_key: &Self::VerifyingKeyMaterial, + ) -> Result<()>; + + /// Convert the signature into a byte representation. + fn to_bytes(&self) -> Vec; +} + +/// An alias for the RNG used in the [`Uniform`] trait. +pub trait SeedableCryptoRng = ::rand::SeedableRng + ::rand::RngCore + ::rand::CryptoRng; + +/// A type family for schemes which know how to generate key material from +/// a cryptographically-secure [`CryptoRng`][::rand::CryptoRng]. +pub trait Uniform { + /// Generate key material from an RNG for testing purposes. + fn generate_for_testing(rng: &mut R) -> Self + where + R: SeedableCryptoRng; +} + +/// A type family with a by-convention notion of genesis private key. +pub trait Genesis: PrivateKey { + /// Produces the genesis private key. + fn genesis() -> Self; +} diff --git a/crypto/nextgen_crypto/src/unit_tests/bls12381_test.rs b/crypto/nextgen_crypto/src/unit_tests/bls12381_test.rs new file mode 100644 index 0000000000000..a05130b81ac6e --- /dev/null +++ b/crypto/nextgen_crypto/src/unit_tests/bls12381_test.rs @@ -0,0 +1,70 @@ +// Copyright (c) The Libra Core Contributors +// SPDX-License-Identifier: Apache-2.0 + +use crate::{ + bls12381::{BLS12381PrivateKey, BLS12381PublicKey, BLS12381Signature}, + traits::*, + unit_tests::uniform_keypair_strategy, +}; +use bincode::{deserialize, serialize}; +use crypto::hash::HashValue; +use proptest::prelude::*; +use std::convert::TryFrom; + +proptest! { + #[test] + fn test_keys_encode(keypair in uniform_keypair_strategy::()) { + { + let serialized = serialize(&keypair.private_key).unwrap(); + let encoded = ::hex::encode(&serialized); + let decoded = BLS12381PrivateKey::from_encoded_string(&encoded); + prop_assert_eq!(Some(keypair.private_key), decoded.ok()); + } + { + let serialized = serialize(&keypair.public_key).unwrap(); + let encoded = ::hex::encode(&serialized); + let decoded = BLS12381PublicKey::from_encoded_string(&encoded); + prop_assert_eq!(Some(keypair.public_key), decoded.ok()); + } + } + + #[test] + fn test_keys_serde(keypair in uniform_keypair_strategy::()) { + { + let serialized: &[u8] = &serialize(&keypair.private_key).unwrap(); + let deserialized = BLS12381PrivateKey::try_from(serialized); + prop_assert_eq!(Some(keypair.private_key), deserialized.ok()); + } + { + let serialized: &[u8] = &serialize(&keypair.public_key).unwrap(); + let deserialized = BLS12381PublicKey::try_from(serialized); + prop_assert_eq!(Some(keypair.public_key), deserialized.ok()); + } + } +} + +proptest! { + #![proptest_config(ProptestConfig::with_cases(20))] + + #[test] + fn test_signature_serde( + hash in any::(), + keypair in uniform_keypair_strategy::() + ) { + let signature = keypair.private_key.sign_message(&hash); + let serialized = serialize(&signature).unwrap(); + let deserialized = deserialize::(&serialized).unwrap(); + assert!(keypair.public_key.verify_signature(&hash, &deserialized).is_ok()); + } + + #[test] + fn test_sign_and_verify( + hash in any::(), + keypair in uniform_keypair_strategy::() + ) { + let signature = keypair.private_key.sign_message(&hash); + let serialized = serialize(&signature).unwrap(); + let deserialized = deserialize::(&serialized).unwrap(); + prop_assert!(keypair.public_key.verify_signature(&hash, &deserialized).is_ok()); + } +} diff --git a/crypto/nextgen_crypto/src/unit_tests/cross_test.rs b/crypto/nextgen_crypto/src/unit_tests/cross_test.rs new file mode 100644 index 0000000000000..98f9e0939ecbe --- /dev/null +++ b/crypto/nextgen_crypto/src/unit_tests/cross_test.rs @@ -0,0 +1,229 @@ +// Copyright (c) The Libra Core Contributors +// SPDX-License-Identifier: Apache-2.0 + +use crate::{ + bls12381::{BLS12381PrivateKey, BLS12381PublicKey, BLS12381Signature}, + ed25519::{Ed25519PrivateKey, Ed25519PublicKey, Ed25519Signature}, + traits::*, + unit_tests::uniform_keypair_strategy, +}; + +use crypto::hash::HashValue; + +use core::convert::TryFrom; +use crypto_derive::SilentDebug; +use failure::prelude::*; +use proptest::prelude::*; +use serde::{Deserialize, Serialize}; + +// Here we aim to make a point about how we can build an enum generically +// on top of a few choice signing scheme types. This enum implements the +// VerifyingKey, SigningKey for precisely the types selected for that enum +// (here, Ed25519(PublicKey|PrivateKey|Signature) and BLS(...)). +// +// Note that we do not break type safety (see towards the end), and that +// this enum can safely be put into the usual collection (Vec, HashMap). +// +// Note also that *nothing* in this definition requires knowing the details +// of the enum, so all of the below declarations canΒΉ be generated by a +// macro for any enum type making a coherent choice of concrete alternatives. +// +// ΒΉ: and in the near future will << TODO(fga) + +#[derive(Serialize, Deserialize, Debug, Clone, PartialEq, Eq, Hash)] +enum PublicK { + Ed(Ed25519PublicKey), + BLS(BLS12381PublicKey), +} + +#[derive(Serialize, Deserialize, SilentDebug)] +enum PrivateK { + Ed(Ed25519PrivateKey), + BLS(BLS12381PrivateKey), +} + +#[allow(clippy::large_enum_variant)] +#[derive(Clone, Debug, PartialEq, Eq, Hash)] +enum Sig { + Ed(Ed25519Signature), + BLS(BLS12381Signature), +} + +impl From<&PrivateK> for PublicK { + fn from(secret_key: &PrivateK) -> Self { + match secret_key { + PrivateK::Ed(pk) => PublicK::Ed(pk.into()), + PrivateK::BLS(pk) => PublicK::BLS(pk.into()), + } + } +} + +impl TryFrom<&[u8]> for PrivateK { + type Error = CryptoMaterialError; + fn try_from(bytes: &[u8]) -> std::result::Result { + Ed25519PrivateKey::try_from(bytes) + .and_then(|ed_priv_key| Ok(PrivateK::Ed(ed_priv_key))) + .or_else(|_err| { + BLS12381PrivateKey::try_from(bytes) + .and_then(|bls_priv_key| Ok(PrivateK::BLS(bls_priv_key))) + }) + } +} + +impl ValidKey for PrivateK { + fn to_bytes(&self) -> Vec { + match self { + PrivateK::BLS(privkey) => privkey.to_bytes().to_vec(), + PrivateK::Ed(privkey) => privkey.to_bytes().to_vec(), + } + } +} + +impl PublicKey for PublicK { + type PrivateKeyMaterial = PrivateK; + // TODO(fga): fix this! + fn length() -> usize { + std::cmp::max(BLS12381PublicKey::length(), Ed25519PublicKey::length()) + } +} + +impl TryFrom<&[u8]> for PublicK { + type Error = CryptoMaterialError; + fn try_from(bytes: &[u8]) -> std::result::Result { + Ed25519PublicKey::try_from(bytes) + .and_then(|ed_priv_key| Ok(PublicK::Ed(ed_priv_key))) + .or_else(|_err| { + BLS12381PublicKey::try_from(bytes) + .and_then(|bls_priv_key| Ok(PublicK::BLS(bls_priv_key))) + }) + } +} + +impl ValidKey for PublicK { + fn to_bytes(&self) -> Vec { + match self { + PublicK::BLS(pubkey) => pubkey.to_bytes().to_vec(), + PublicK::Ed(pubkey) => pubkey.to_bytes().to_vec(), + } + } +} + +impl PrivateKey for PrivateK { + type PublicKeyMaterial = PublicK; +} + +impl SigningKey for PrivateK { + type VerifyingKeyMaterial = PublicK; + type SignatureMaterial = Sig; + + fn sign_message(&self, message: &HashValue) -> Sig { + match self { + PrivateK::Ed(ed_priv) => Sig::Ed(ed_priv.sign_message(message)), + PrivateK::BLS(bls_priv) => Sig::BLS(bls_priv.sign_message(message)), + } + } +} + +impl Signature for Sig { + type VerifyingKeyMaterial = PublicK; + + fn verify(&self, message: &HashValue, public_key: &PublicK) -> Result<()> { + self.verify_arbitrary_msg(message.as_ref(), public_key) + } + + fn verify_arbitrary_msg(&self, message: &[u8], public_key: &PublicK) -> Result<()> { + match (self, public_key) { + (Sig::Ed(ed_sig), PublicK::Ed(ed_pub)) => ed_sig.verify_arbitrary_msg(message, ed_pub), + (Sig::BLS(bls_sig), PublicK::BLS(bls_pub)) => { + bls_sig.verify_arbitrary_msg(message, bls_pub) + } + _ => bail!( + "provided the wrong alternative in {:?}!", + (self, public_key) + ), + } + } + + fn to_bytes(&self) -> Vec { + match self { + Sig::Ed(sig) => sig.to_bytes().to_vec(), + Sig::BLS(sig) => sig.to_bytes().to_vec(), + } + } +} + +impl TryFrom<&[u8]> for Sig { + type Error = CryptoMaterialError; + fn try_from(bytes: &[u8]) -> std::result::Result { + Ed25519Signature::try_from(bytes) + .and_then(|ed_sig| Ok(Sig::Ed(ed_sig))) + .or_else(|_err| { + BLS12381Signature::try_from(bytes).and_then(|bls_sig| Ok(Sig::BLS(bls_sig))) + }) + } +} + +impl VerifyingKey for PublicK { + type SigningKeyMaterial = PrivateK; + type SignatureMaterial = Sig; +} + +/////////////////////////////////////////////////////// +// End of declarations β€” let's now prove type safety // +/////////////////////////////////////////////////////// +proptest! { + #![proptest_config(ProptestConfig::with_cases(20))] + + #[test] + fn test_keys_mix( + hash in any::(), + ed_keypair1 in uniform_keypair_strategy::(), + ed_keypair2 in uniform_keypair_strategy::(), + bls_keypair in uniform_keypair_strategy::() + ) { + // this is impossible to write statically, due to the trait not being + // object-safe (voluntarily) + // let mut l: Vec> = vec![]; + let mut l: Vec = vec![]; + l.push(ed_keypair1.private_key); + let ed_key = l.pop().unwrap(); + let signature = ed_key.sign_message(&hash); + + // This is business as usual + prop_assert!(signature.verify(&hash, &ed_keypair1.public_key).is_ok()); + + // This is impossible to write, and generates: + // expected struct `ed25519::Ed25519PublicKey`, found struct `bls12381::BLS12381PublicKey` + // prop_assert!(signature.verify(&hash, &bls_keypair.public_key).is_ok()); + + let mut l2: Vec = vec![]; + l2.push(PrivateK::BLS(bls_keypair.private_key)); + l2.push(PrivateK::Ed(ed_keypair2.private_key)); + + let ed_key = l2.pop().unwrap(); + let ed_signature = ed_key.sign_message(&hash); + + // This is still business as usual + let ed_pubkey2 = PublicK::Ed(ed_keypair2.public_key); + let good_sigver = ed_signature.verify(&hash, &ed_pubkey2); + prop_assert!(good_sigver.is_ok(), "{:?}", good_sigver); + + // but this still fails, as expected + let bls_pubkey = PublicK::BLS(bls_keypair.public_key); + let bad_sigver = ed_signature.verify(&hash, &bls_pubkey); + prop_assert!(bad_sigver.is_err(), "{:?}", bad_sigver); + + // And now just in case we're confused again, we pop in the + // reverse direction + let bls_key = l2.pop().unwrap(); + let bls_signature = bls_key.sign_message(&hash); + + // This is still business as usual + let good_sigver = bls_signature.verify(&hash, &bls_pubkey); + prop_assert!(good_sigver.is_ok(), "{:?}", good_sigver); + + // but this still fails, as expected + let bad_sigver = bls_signature.verify(&hash, &ed_pubkey2); + prop_assert!(bad_sigver.is_err(), "{:?}", bad_sigver); + } +} diff --git a/crypto/nextgen_crypto/src/unit_tests/ed25519_test.rs b/crypto/nextgen_crypto/src/unit_tests/ed25519_test.rs new file mode 100644 index 0000000000000..9a2ef98e7ffb1 --- /dev/null +++ b/crypto/nextgen_crypto/src/unit_tests/ed25519_test.rs @@ -0,0 +1,295 @@ +// Copyright (c) The Libra Core Contributors +// SPDX-License-Identifier: Apache-2.0 + +use crate::{ + ed25519::{Ed25519PrivateKey, Ed25519PublicKey, Ed25519Signature}, + traits::*, + unit_tests::uniform_keypair_strategy, +}; + +use core::{ + convert::TryFrom, + ops::{Index, IndexMut}, +}; + +use crypto::hash::HashValue; +use ed25519_dalek; +use proptest::prelude::*; + +proptest! { + #[test] + fn test_keys_encode(keypair in uniform_keypair_strategy::()) { + { + let encoded = keypair.private_key.to_encoded_string().unwrap(); + // Hex encoding of a 32-bytes key is 64 (2 x 32) characters. + prop_assert_eq!(2 * ed25519_dalek::SECRET_KEY_LENGTH, encoded.len()); + let decoded = Ed25519PrivateKey::from_encoded_string(&encoded); + prop_assert_eq!(Some(keypair.private_key), decoded.ok()); + } + { + let encoded = keypair.public_key.to_encoded_string().unwrap(); + // Hex encoding of a 32-bytes key is 64 (2 x 32) characters. + prop_assert_eq!(2 * ed25519_dalek::PUBLIC_KEY_LENGTH, encoded.len()); + let decoded = Ed25519PublicKey::from_encoded_string(&encoded); + prop_assert_eq!(Some(keypair.public_key), decoded.ok()); + } + } + + #[test] + fn test_keys_custom_serialisation( + keypair in uniform_keypair_strategy::() + ) { + { + let serialized: &[u8] = &(keypair.private_key.to_bytes()); + prop_assert_eq!(ed25519_dalek::SECRET_KEY_LENGTH, serialized.len()); + let deserialized = Ed25519PrivateKey::try_from(serialized); + prop_assert_eq!(Some(keypair.private_key), deserialized.ok()); + } + { + let serialized: &[u8] = &(keypair.public_key.to_bytes()); + prop_assert_eq!(ed25519_dalek::PUBLIC_KEY_LENGTH, serialized.len()); + let deserialized = Ed25519PublicKey::try_from(serialized); + prop_assert_eq!(Some(keypair.public_key), deserialized.ok()); + } + } + + #[test] + fn test_signature_verification_custom_serialisation( + hash in any::(), + keypair in uniform_keypair_strategy::() + ) { + let signature = keypair.private_key.sign_message(&hash); + let serialized: &[u8] = &(signature.to_bytes()); + prop_assert_eq!(ed25519_dalek::SIGNATURE_LENGTH, serialized.len()); + let deserialized = Ed25519Signature::try_from(serialized).unwrap(); + assert!(keypair.public_key.verify_signature(&hash, &deserialized).is_ok()); + } + + // Check for canonical s. + #[test] + fn test_signature_malleability( + hash in any::(), + keypair in uniform_keypair_strategy::() + ) { + let signature = keypair.private_key.sign_message(&hash); + let mut serialized = signature.to_bytes(); + + let mut r_bits: [u8; 32] = [0u8; 32]; + r_bits.copy_from_slice(&serialized[..32]); + + let mut s_bits: [u8; 32] = [0u8; 32]; + s_bits.copy_from_slice(&serialized[32..]); + + // ed25519-dalek signing ensures a canonical s value. + let s = Scalar52::from_bytes(&s_bits); + + // adding L (order of the base point) so that s + L > L + let malleable_s = Scalar52::add(&s, &L); + let malleable_s_bits = malleable_s.to_bytes(); + // Update the signature (the s part). + serialized[32..].copy_from_slice(&malleable_s_bits); + + // Check that malleable signatures will pass verification and deserialization in dalek. + // Construct the corresponding dalek public key. + let dalek_public_key = ed25519_dalek::PublicKey::from_bytes( + &keypair.public_key.to_bytes() + ).unwrap(); + + // Construct the corresponding dalek Signature. This signature is malleable. + let dalek_sig = ed25519_dalek::Signature::from_bytes(&serialized); + + // ed25519_dalek will deserialize the malleable signature. It does NOT detect it. + prop_assert!(dalek_sig.is_ok()); + // ed25519_dalek will verify malleable signatures as valid. + prop_assert!(dalek_public_key.verify(hash.as_ref(), &dalek_sig.unwrap()).is_ok()); + + let serialized_malleable: &[u8] = &serialized; + // from_bytes will fail on malleable signatures. We detect malleable signatures + // during deserialization. + prop_assert_eq!( + Ed25519Signature::try_from(serialized_malleable), + Err(CryptoMaterialError::CanonicalRepresentationError) + ); + + // We expect from_bytes_unchecked deserialization not to fail, as we don't check + // for signature malleability. This method is pub(crate) and only used for test purposes. + let sig_unchecked = Ed25519Signature::from_bytes_unchecked(&serialized).unwrap(); + + // Malleable signatures will fail to verify in our implementation, even if for some reason + // we receive one. Note that this is a second step of validation as we typically check + // malleable signatures during deserialization. + prop_assert!(keypair.public_key.verify_signature(&hash, &sig_unchecked).is_err()); + } +} + +// Test against known small subgroup public keys. +#[test] +fn test_publickey_smallorder() { + for torsion_point in &EIGHT_TORSION { + let serialized: &[u8] = torsion_point; + // We expect from_bytes_unchecked to pass, as it does not validate the key. + assert!(Ed25519PublicKey::from_bytes_unchecked(serialized).is_ok()); + // from_bytes will fail on invalid key. + assert_eq!( + Ed25519PublicKey::try_from(serialized), + Err(CryptoMaterialError::SmallSubgroupError) + ); + } +} + +// The 8-torsion subgroup E[8]. +// +// In the case of Curve25519, it is cyclic; the i-th element of +// the array is [i]P, where P is a point of order 8 +// generating E[8]. +// +// Thus E[8] is the points indexed by `0,2,4,6`, and +// E[2] is the points indexed by `0,4`. +// +// The following byte arrays have been ported from curve25519-dalek /backend/serial/u64/constants.rs +// and they represent the serialised version of the CompressedEdwardsY points. + +const EIGHT_TORSION: [[u8; 32]; 8] = [ + [ + 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, + 0, 0, + ], + [ + 199, 23, 106, 112, 61, 77, 216, 79, 186, 60, 11, 118, 13, 16, 103, 15, 42, 32, 83, 250, 44, + 57, 204, 198, 78, 199, 253, 119, 146, 172, 3, 122, + ], + [ + 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, + 0, 128, + ], + [ + 38, 232, 149, 143, 194, 178, 39, 176, 69, 195, 244, 137, 242, 239, 152, 240, 213, 223, 172, + 5, 211, 198, 51, 57, 177, 56, 2, 136, 109, 83, 252, 5, + ], + [ + 236, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, + 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 127, + ], + [ + 38, 232, 149, 143, 194, 178, 39, 176, 69, 195, 244, 137, 242, 239, 152, 240, 213, 223, 172, + 5, 211, 198, 51, 57, 177, 56, 2, 136, 109, 83, 252, 133, + ], + [ + 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, + 0, 0, + ], + [ + 199, 23, 106, 112, 61, 77, 216, 79, 186, 60, 11, 118, 13, 16, 103, 15, 42, 32, 83, 250, 44, + 57, 204, 198, 78, 199, 253, 119, 146, 172, 3, 250, + ], +]; + +/// The `Scalar52` struct represents an element in +/// β„€/β„“β„€ as 5 52-bit limbs. +struct Scalar52(pub [u64; 5]); + +/// `L` is the order of base point, i.e. 2^252 + 27742317777372353535851937790883648493 +const L: Scalar52 = Scalar52([ + 0x0002_631a_5cf5_d3ed, + 0x000d_ea2f_79cd_6581, + 0x0000_0000_0014_def9, + 0x0000_0000_0000_0000, + 0x0000_1000_0000_0000, +]); + +impl Scalar52 { + /// Return the zero scalar + fn zero() -> Scalar52 { + Scalar52([0, 0, 0, 0, 0]) + } + + /// Unpack a 32 byte / 256 bit scalar into 5 52-bit limbs. + fn from_bytes(bytes: &[u8; 32]) -> Scalar52 { + let mut words = [0u64; 4]; + for i in 0..4 { + for j in 0..8 { + words[i] |= u64::from(bytes[(i * 8) + j]) << (j * 8); + } + } + + let mask = (1u64 << 52) - 1; + let top_mask = (1u64 << 48) - 1; + let mut s = Scalar52::zero(); + + s[0] = words[0] & mask; + s[1] = ((words[0] >> 52) | (words[1] << 12)) & mask; + s[2] = ((words[1] >> 40) | (words[2] << 24)) & mask; + s[3] = ((words[2] >> 28) | (words[3] << 36)) & mask; + s[4] = (words[3] >> 16) & top_mask; + + s + } + + /// Pack the limbs of this `Scalar52` into 32 bytes + fn to_bytes(&self) -> [u8; 32] { + let mut s = [0u8; 32]; + + s[0] = self.0[0] as u8; + s[1] = (self.0[0] >> 8) as u8; + s[2] = (self.0[0] >> 16) as u8; + s[3] = (self.0[0] >> 24) as u8; + s[4] = (self.0[0] >> 32) as u8; + s[5] = (self.0[0] >> 40) as u8; + s[6] = ((self.0[0] >> 48) | (self.0[1] << 4)) as u8; + s[7] = (self.0[1] >> 4) as u8; + s[8] = (self.0[1] >> 12) as u8; + s[9] = (self.0[1] >> 20) as u8; + s[10] = (self.0[1] >> 28) as u8; + s[11] = (self.0[1] >> 36) as u8; + s[12] = (self.0[1] >> 44) as u8; + s[13] = self.0[2] as u8; + s[14] = (self.0[2] >> 8) as u8; + s[15] = (self.0[2] >> 16) as u8; + s[16] = (self.0[2] >> 24) as u8; + s[17] = (self.0[2] >> 32) as u8; + s[18] = (self.0[2] >> 40) as u8; + s[19] = ((self.0[2] >> 48) | (self.0[3] << 4)) as u8; + s[20] = (self.0[3] >> 4) as u8; + s[21] = (self.0[3] >> 12) as u8; + s[22] = (self.0[3] >> 20) as u8; + s[23] = (self.0[3] >> 28) as u8; + s[24] = (self.0[3] >> 36) as u8; + s[25] = (self.0[3] >> 44) as u8; + s[26] = self.0[4] as u8; + s[27] = (self.0[4] >> 8) as u8; + s[28] = (self.0[4] >> 16) as u8; + s[29] = (self.0[4] >> 24) as u8; + s[30] = (self.0[4] >> 32) as u8; + s[31] = (self.0[4] >> 40) as u8; + + s + } + + /// Compute `a + b` (without mod β„“) + fn add(a: &Scalar52, b: &Scalar52) -> Scalar52 { + let mut sum = Scalar52::zero(); + let mask = (1u64 << 52) - 1; + + // a + b + let mut carry: u64 = 0; + for i in 0..5 { + carry = a[i] + b[i] + (carry >> 52); + sum[i] = carry & mask; + } + + sum + } +} + +impl Index for Scalar52 { + type Output = u64; + fn index(&self, _index: usize) -> &u64 { + &(self.0[_index]) + } +} + +impl IndexMut for Scalar52 { + fn index_mut(&mut self, _index: usize) -> &mut u64 { + &mut (self.0[_index]) + } +} diff --git a/crypto/nextgen_crypto/src/unit_tests/mod.rs b/crypto/nextgen_crypto/src/unit_tests/mod.rs new file mode 100644 index 0000000000000..40b4b9352c36c --- /dev/null +++ b/crypto/nextgen_crypto/src/unit_tests/mod.rs @@ -0,0 +1,31 @@ +// Copyright (c) The Libra Core Contributors +// SPDX-License-Identifier: Apache-2.0 + +mod bls12381_test; +mod cross_test; +mod ed25519_test; +mod slip0010_test; + +use crate::{ + test_utils::KeyPair, + traits::{SeedableCryptoRng, Uniform}, +}; +use proptest::prelude::*; +use rand::rngs::StdRng; +use serde::Serialize; + +/// Produces a uniformly random keypair from a seed +pub(super) fn uniform_keypair_strategy() -> impl Strategy> +where + Pub: Serialize + for<'a> From<&'a Priv>, + Priv: Serialize + Uniform, +{ + // The no_shrink is because keypairs should be fixed -- shrinking would cause a different + // keypair to be generated, which appears to not be very useful. + any::<[u8; 32]>() + .prop_map(|seed| { + let mut rng = StdRng::from_seed(seed); + KeyPair::::generate_for_testing(&mut rng) + }) + .no_shrink() +} diff --git a/crypto/nextgen_crypto/src/unit_tests/slip0010_test.rs b/crypto/nextgen_crypto/src/unit_tests/slip0010_test.rs new file mode 100644 index 0000000000000..b5fab26c73300 --- /dev/null +++ b/crypto/nextgen_crypto/src/unit_tests/slip0010_test.rs @@ -0,0 +1,148 @@ +// Copyright (c) The Libra Core Contributors +// SPDX-License-Identifier: Apache-2.0 + +use crate::slip0010::Slip0010; + +// Testing against SLIP-0010 test vectors. +#[test] +fn test_slip0010_vectors() { + let tests = test_vectors_slip0010_ed25519(); + for t in tests.iter() { + let seed = hex::decode(t.seed).unwrap(); + let path = t.path; + let key = Slip0010::derive_from_path(path, &seed).unwrap(); + assert_eq!(t.c_code, hex::encode(key.chain_code)); + assert_eq!(t.pr_key, hex::encode(key.private_key.to_bytes())); + assert_eq!(t.pb_key, hex::encode(key.get_public().to_bytes())); + } +} + +// Test Vectors for SLIP 0010 (ed25519) from https://github.com/satoshilabs/slips/blob/master/slip-0010.md +#[allow(dead_code)] +struct Test<'a> { + seed: &'a str, + path: &'a str, + fingerprint: &'a str, + c_code: &'a str, // chain code + pr_key: &'a str, // private key + pb_key: &'a str, // public key +} + +fn test_vectors_slip0010_ed25519<'a>() -> Vec> { + vec![ + Test { + // Test Vector 1 - chain 1 + seed: "000102030405060708090a0b0c0d0e0f", + path: "m", + fingerprint: "00000000", + c_code: "90046a93de5380a72b5e45010748567d5ea02bbf6522f979e05c0d8d8ca9fffb", + pr_key: "2b4be7f19ee27bbf30c667b642d5f4aa69fd169872f8fc3059c08ebae2eb19e7", + pb_key: "a4b2856bfec510abab89753fac1ac0e1112364e7d250545963f135f2a33188ed", + }, + Test { + // Test Vector 1 - chain 2 + seed: "000102030405060708090a0b0c0d0e0f", + path: "m/0", + fingerprint: "ddebc675", + c_code: "8b59aa11380b624e81507a27fedda59fea6d0b779a778918a2fd3590e16e9c69", + pr_key: "68e0fe46dfb67e368c75379acec591dad19df3cde26e63b93a8e704f1dade7a3", + pb_key: "8c8a13df77a28f3445213a0f432fde644acaa215fc72dcdf300d5efaa85d350c", + }, + Test { + // Test Vector 1 - chain 3 + seed: "000102030405060708090a0b0c0d0e0f", + path: "m/0/1", + fingerprint: "13dab143", + c_code: "a320425f77d1b5c2505a6b1b27382b37368ee640e3557c315416801243552f14", + pr_key: "b1d0bad404bf35da785a64ca1ac54b2617211d2777696fbffaf208f746ae84f2", + pb_key: "1932a5270f335bed617d5b935c80aedb1a35bd9fc1e31acafd5372c30f5c1187", + }, + Test { + // Test Vector 1 - chain 4 + seed: "000102030405060708090a0b0c0d0e0f", + path: "m/0/1/2", + fingerprint: "ebe4cb29", + c_code: "2e69929e00b5ab250f49c3fb1c12f252de4fed2c1db88387094a0f8c4c9ccd6c", + pr_key: "92a5b23c0b8a99e37d07df3fb9966917f5d06e02ddbd909c7e184371463e9fc9", + pb_key: "ae98736566d30ed0e9d2f4486a64bc95740d89c7db33f52121f8ea8f76ff0fc1", + }, + Test { + // Test Vector 1 - chain 5 + seed: "000102030405060708090a0b0c0d0e0f", + path: "m/0/1/2/2", + fingerprint: "316ec1c6", + c_code: "8f6d87f93d750e0efccda017d662a1b31a266e4a6f5993b15f5c1f07f74dd5cc", + pr_key: "30d1dc7e5fc04c31219ab25a27ae00b50f6fd66622f6e9c913253d6511d1e662", + pb_key: "8abae2d66361c879b900d204ad2cc4984fa2aa344dd7ddc46007329ac76c429c", + }, + Test { + // Test Vector 1 - chain 6 + seed: "000102030405060708090a0b0c0d0e0f", + path: "m/0/1/2/2/1000000000", + fingerprint: "d6322ccd", + c_code: "68789923a0cac2cd5a29172a475fe9e0fb14cd6adb5ad98a3fa70333e7afa230", + pr_key: "8f94d394a8e8fd6b1bc2f3f49f5c47e385281d5c17e65324b0f62483e37e8793", + pb_key: "3c24da049451555d51a7014a37337aa4e12d41e485abccfa46b47dfb2af54b7a", + }, + Test { + // Test Vector 2 - chain 1 + seed: "fffcf9f6f3f0edeae7e4e1dedbd8d5d2cfccc9c6c3c0bdbab7b4b1aeaba8a5a2\ + 9f9c999693908d8a8784817e7b7875726f6c696663605d5a5754514e4b484542", + path: "m", + fingerprint: "00000000", + c_code: "ef70a74db9c3a5af931b5fe73ed8e1a53464133654fd55e7a66f8570b8e33c3b", + pr_key: "171cb88b1b3c1db25add599712e36245d75bc65a1a5c9e18d76f9f2b1eab4012", + pb_key: "8fe9693f8fa62a4305a140b9764c5ee01e455963744fe18204b4fb948249308a", + }, + Test { + // Test Vector 2 - chain 2 + seed: "fffcf9f6f3f0edeae7e4e1dedbd8d5d2cfccc9c6c3c0bdbab7b4b1aeaba8a5a2\ + 9f9c999693908d8a8784817e7b7875726f6c696663605d5a5754514e4b484542", + path: "m/0", + fingerprint: "31981b50", + c_code: "0b78a3226f915c082bf118f83618a618ab6dec793752624cbeb622acb562862d", + pr_key: "1559eb2bbec5790b0c65d8693e4d0875b1747f4970ae8b650486ed7470845635", + pb_key: "86fab68dcb57aa196c77c5f264f215a112c22a912c10d123b0d03c3c28ef1037", + }, + Test { + // Test Vector 2 - chain 3 + seed: "fffcf9f6f3f0edeae7e4e1dedbd8d5d2cfccc9c6c3c0bdbab7b4b1aeaba8a5a2\ + 9f9c999693908d8a8784817e7b7875726f6c696663605d5a5754514e4b484542", + path: "m/0/2147483647", + fingerprint: "1e9411b1", + c_code: "138f0b2551bcafeca6ff2aa88ba8ed0ed8de070841f0c4ef0165df8181eaad7f", + pr_key: "ea4f5bfe8694d8bb74b7b59404632fd5968b774ed545e810de9c32a4fb4192f4", + pb_key: "5ba3b9ac6e90e83effcd25ac4e58a1365a9e35a3d3ae5eb07b9e4d90bcf7506d", + }, + Test { + // Test Vector 2 - chain 4 + seed: "fffcf9f6f3f0edeae7e4e1dedbd8d5d2cfccc9c6c3c0bdbab7b4b1aeaba8a5a2\ + 9f9c999693908d8a8784817e7b7875726f6c696663605d5a5754514e4b484542", + path: "m/0/2147483647/1", + fingerprint: "fcadf38c", + c_code: "73bd9fff1cfbde33a1b846c27085f711c0fe2d66fd32e139d3ebc28e5a4a6b90", + pr_key: "3757c7577170179c7868353ada796c839135b3d30554bbb74a4b1e4a5a58505c", + pb_key: "2e66aa57069c86cc18249aecf5cb5a9cebbfd6fadeab056254763874a9352b45", + }, + Test { + // Test Vector 2 - chain 5 + seed: "fffcf9f6f3f0edeae7e4e1dedbd8d5d2cfccc9c6c3c0bdbab7b4b1aeaba8a5a2\ + 9f9c999693908d8a8784817e7b7875726f6c696663605d5a5754514e4b484542", + path: "m/0/2147483647/1/2147483646", + fingerprint: "aca70953", + c_code: "0902fe8a29f9140480a00ef244bd183e8a13288e4412d8389d140aac1794825a", + pr_key: "5837736c89570de861ebc173b1086da4f505d4adb387c6a1b1342d5e4ac9ec72", + pb_key: "e33c0f7d81d843c572275f287498e8d408654fdf0d1e065b84e2e6f157aab09b", + }, + Test { + // Test Vector 2 - chain 6 + seed: "fffcf9f6f3f0edeae7e4e1dedbd8d5d2cfccc9c6c3c0bdbab7b4b1aeaba8a5a2\ + 9f9c999693908d8a8784817e7b7875726f6c696663605d5a5754514e4b484542", + path: "m/0/2147483647/1/2147483646/2", + fingerprint: "422c654b", + c_code: "5d70af781f3a37b829f0d060924d5e960bdc02e85423494afc0b1a41bbe196d4", + pr_key: "551d333177df541ad876a60ea71f00447931c0a9da16f227c11ea080d7391b8d", + pb_key: "47150c75db263559a70d5778bf36abbab30fb061ad69f69ece61a72b0cfa4fc0", + }, + ] +} diff --git a/crypto/nextgen_crypto/src/vrf/ecvrf.rs b/crypto/nextgen_crypto/src/vrf/ecvrf.rs new file mode 100644 index 0000000000000..662bbee42c34c --- /dev/null +++ b/crypto/nextgen_crypto/src/vrf/ecvrf.rs @@ -0,0 +1,323 @@ +// Copyright (c) The Libra Core Contributors +// SPDX-License-Identifier: Apache-2.0 + +//! This module implements an instantiation of a verifiable random function known as +//! [ECVRF-ED25519-SHA512-TAI](https://tools.ietf.org/html/draft-irtf-cfrg-vrf-04). +//! +//! # Examples +//! +//! ``` +//! use nextgen_crypto::{traits::Uniform, vrf::ecvrf::*}; +//! use rand::{rngs::StdRng, SeedableRng}; +//! +//! let message = b"Test message"; +//! let mut rng: StdRng = SeedableRng::from_seed([0; 32]); +//! let private_key = VRFPrivateKey::generate_for_testing(&mut rng); +//! let public_key: VRFPublicKey = (&private_key).into(); +//! ``` +//! **Note**: The above example generates a private key using a private function intended only for +//! testing purposes. Production code should find an alternate means for secure key generation. +//! +//! Produce a proof for a message from a `VRFPrivateKey`, and verify the proof and message +//! using a `VRFPublicKey`: +//! +//! ``` +//! # use nextgen_crypto::{traits::Uniform, vrf::ecvrf::*}; +//! # use rand::{rngs::StdRng, SeedableRng}; +//! # let message = b"Test message"; +//! # let mut rng: StdRng = SeedableRng::from_seed([0; 32]); +//! # let private_key = VRFPrivateKey::generate_for_testing(&mut rng); +//! # let public_key: VRFPublicKey = (&private_key).into(); +//! let proof = private_key.prove(message); +//! assert!(public_key.verify(&proof, message).is_ok()); +//! ``` +//! +//! Produce a pseudorandom output from a `Proof`: +//! +//! ``` +//! # use nextgen_crypto::{traits::Uniform, vrf::ecvrf::*}; +//! # use rand::{rngs::StdRng, SeedableRng}; +//! # let message = b"Test message"; +//! # let mut rng: StdRng = SeedableRng::from_seed([0; 32]); +//! # let private_key = VRFPrivateKey::generate_for_testing(&mut rng); +//! # let public_key: VRFPublicKey = (&private_key).into(); +//! # let proof = private_key.prove(message); +//! let output: Output = (&proof).into(); +//! ``` + +use crate::traits::*; +use core::convert::TryFrom; +use curve25519_dalek::{ + constants::ED25519_BASEPOINT_POINT, + edwards::{CompressedEdwardsY, EdwardsPoint}, + scalar::Scalar as ed25519_Scalar, +}; +use derive_deref::Deref; +use ed25519_dalek::{ + self, Digest, PublicKey as ed25519_PublicKey, SecretKey as ed25519_PrivateKey, Sha512, +}; +use failure::prelude::*; +use serde::{Deserialize, Serialize}; + +const SUITE: u8 = 0x03; +const ONE: u8 = 0x01; +const TWO: u8 = 0x02; +const THREE: u8 = 0x03; + +/// The number of bytes of [`Output`] +pub const OUTPUT_LENGTH: usize = 64; +/// The number of bytes of [`Proof`] +pub const PROOF_LENGTH: usize = 80; + +/// An ECVRF private key +#[derive(Serialize, Deserialize, Deref, Debug)] +pub struct VRFPrivateKey(ed25519_PrivateKey); + +/// An ECVRF public key +#[derive(Serialize, Deserialize, Deref, Debug, PartialEq, Eq)] +pub struct VRFPublicKey(ed25519_PublicKey); + +/// A longer private key which is slightly optimized for proof generation. +/// +/// This is similar in structure to ed25519_dalek::ExpandedSecretKey. It can be produced from +/// a VRFPrivateKey. +pub struct VRFExpandedPrivateKey { + pub(super) key: ed25519_Scalar, + pub(super) nonce: [u8; 32], +} + +impl VRFPrivateKey { + /// Produces a proof for an input (using the private key) + pub fn prove(&self, alpha: &[u8]) -> Proof { + VRFExpandedPrivateKey::from(self).prove(&VRFPublicKey((&self.0).into()), alpha) + } +} + +impl VRFExpandedPrivateKey { + /// Produces a proof for an input (using the expanded private key) + pub fn prove(&self, pk: &VRFPublicKey, alpha: &[u8]) -> Proof { + let h_point = pk.hash_to_curve(alpha); + let k_scalar = + ed25519_Scalar::from_bytes_mod_order_wide(&nonce_generation_bytes(self.nonce, h_point)); + let gamma = h_point * self.key; + let c_scalar = hash_points(&[ + h_point, + gamma, + ED25519_BASEPOINT_POINT * k_scalar, + h_point * k_scalar, + ]); + + Proof { + gamma, + c: c_scalar, + s: k_scalar + c_scalar * self.key, + } + } +} + +impl Uniform for VRFPrivateKey { + fn generate_for_testing(rng: &mut R) -> Self + where + R: SeedableCryptoRng, + { + VRFPrivateKey(ed25519_PrivateKey::generate(rng)) + } +} + +impl TryFrom<&[u8]> for VRFPrivateKey { + type Error = CryptoMaterialError; + + fn try_from(bytes: &[u8]) -> std::result::Result { + Ok(VRFPrivateKey( + ed25519_PrivateKey::from_bytes(bytes).unwrap(), + )) + } +} + +impl TryFrom<&[u8]> for VRFPublicKey { + type Error = CryptoMaterialError; + + fn try_from(bytes: &[u8]) -> std::result::Result { + if bytes.len() != ed25519_dalek::PUBLIC_KEY_LENGTH { + return Err(CryptoMaterialError::WrongLengthError); + } + + let mut bits: [u8; 32] = [0u8; 32]; + bits.copy_from_slice(&bytes[..32]); + + let compressed = curve25519_dalek::edwards::CompressedEdwardsY(bits); + let point = compressed + .decompress() + .ok_or(CryptoMaterialError::DeserializationError)?; + + // Check if the point lies on a small subgroup. This is required + // when using curves with a small cofactor (in ed25519, cofactor = 8). + if point.is_small_order() { + return Err(CryptoMaterialError::SmallSubgroupError); + } + + Ok(VRFPublicKey(ed25519_PublicKey::from_bytes(bytes).unwrap())) + } +} + +impl VRFPublicKey { + /// Given a [`Proof`] and an input, returns whether or not the proof is valid for the input + /// and public key + pub fn verify(&self, proof: &Proof, alpha: &[u8]) -> Result<()> { + let h_point = self.hash_to_curve(alpha); + let pk_point = CompressedEdwardsY::from_slice(self.as_bytes()) + .decompress() + .unwrap(); + let cprime = hash_points(&[ + h_point, + proof.gamma, + ED25519_BASEPOINT_POINT * proof.s - pk_point * proof.c, + h_point * proof.s - proof.gamma * proof.c, + ]); + + if proof.c == cprime { + Ok(()) + } else { + bail!("The proof failed to verify for this public key") + } + } + + pub(super) fn hash_to_curve(&self, alpha: &[u8]) -> EdwardsPoint { + let mut result = [0u8; 32]; + let mut counter = 0; + let mut wrapped_point: Option = None; + + while wrapped_point.is_none() { + result.copy_from_slice( + &Sha512::new() + .chain(&[SUITE, ONE]) + .chain(self.as_bytes()) + .chain(&alpha) + .chain(&[counter]) + .result()[..32], + ); + wrapped_point = CompressedEdwardsY::from_slice(&result).decompress(); + counter += 1; + } + + wrapped_point.unwrap().mul_by_cofactor() + } +} + +impl<'a> From<&'a VRFPrivateKey> for VRFPublicKey { + fn from(private_key: &'a VRFPrivateKey) -> Self { + let secret: &ed25519_PrivateKey = private_key; + let public: ed25519_PublicKey = secret.into(); + VRFPublicKey(public) + } +} + +impl<'a> From<&'a VRFPrivateKey> for VRFExpandedPrivateKey { + fn from(private_key: &'a VRFPrivateKey) -> Self { + let mut h: Sha512 = Sha512::default(); + let mut hash: [u8; 64] = [0u8; 64]; + let mut lower: [u8; 32] = [0u8; 32]; + let mut upper: [u8; 32] = [0u8; 32]; + + h.input(private_key.to_bytes()); + hash.copy_from_slice(h.result().as_slice()); + + lower.copy_from_slice(&hash[00..32]); + upper.copy_from_slice(&hash[32..64]); + + lower[0] &= 248; + lower[31] &= 63; + lower[31] |= 64; + + VRFExpandedPrivateKey { + key: ed25519_Scalar::from_bits(lower), + nonce: upper, + } + } +} + +/// A VRF proof that can be used to validate an input with a public key +pub struct Proof { + gamma: EdwardsPoint, + c: ed25519_Scalar, + s: ed25519_Scalar, +} + +impl Proof { + /// Produces a new Proof struct from its fields + pub fn new(gamma: EdwardsPoint, c: ed25519_Scalar, s: ed25519_Scalar) -> Proof { + Proof { gamma, c, s } + } + + /// Converts a Proof into bytes + pub fn to_bytes(&self) -> [u8; PROOF_LENGTH] { + let mut ret = [0u8; PROOF_LENGTH]; + ret[..32].copy_from_slice(&self.gamma.compress().to_bytes()[..]); + ret[32..48].copy_from_slice(&self.c.to_bytes()[..16]); + ret[48..].copy_from_slice(&self.s.to_bytes()[..]); + ret + } +} + +impl TryFrom<&[u8]> for Proof { + type Error = CryptoMaterialError; + + fn try_from(bytes: &[u8]) -> std::result::Result { + let mut c_buf = [0u8; 32]; + c_buf[..16].copy_from_slice(&bytes[32..48]); + let mut s_buf = [0u8; 32]; + s_buf.copy_from_slice(&bytes[48..]); + Ok(Proof { + gamma: CompressedEdwardsY::from_slice(&bytes[..32]) + .decompress() + .unwrap(), + c: ed25519_Scalar::from_bits(c_buf), + s: ed25519_Scalar::from_bits(s_buf), + }) + } +} + +/// The ECVRF output produced from the proof +pub struct Output([u8; OUTPUT_LENGTH]); + +impl Output { + /// Converts an Output into bytes + #[inline] + pub fn to_bytes(&self) -> [u8; OUTPUT_LENGTH] { + self.0 + } +} + +impl<'a> From<&'a Proof> for Output { + fn from(proof: &'a Proof) -> Output { + let mut output = [0u8; OUTPUT_LENGTH]; + output.copy_from_slice( + &Sha512::new() + .chain(&[SUITE, THREE]) + .chain(&proof.gamma.mul_by_cofactor().compress().to_bytes()[..]) + .result()[..], + ); + Output(output) + } +} + +pub(super) fn nonce_generation_bytes(nonce: [u8; 32], h_point: EdwardsPoint) -> [u8; 64] { + let mut k_buf = [0u8; 64]; + k_buf.copy_from_slice( + &Sha512::new() + .chain(nonce) + .chain(h_point.compress().as_bytes()) + .result()[..], + ); + k_buf +} + +pub(super) fn hash_points(points: &[EdwardsPoint]) -> ed25519_Scalar { + let mut result = [0u8; 32]; + let mut hash = Sha512::new().chain(&[SUITE, TWO]); + for point in points.iter() { + hash = hash.chain(point.compress().to_bytes()); + } + result[..16].copy_from_slice(&hash.result()[..16]); + ed25519_Scalar::from_bits(result) +} diff --git a/crypto/nextgen_crypto/src/vrf/mod.rs b/crypto/nextgen_crypto/src/vrf/mod.rs new file mode 100644 index 0000000000000..cc97f95599f0b --- /dev/null +++ b/crypto/nextgen_crypto/src/vrf/mod.rs @@ -0,0 +1,11 @@ +// Copyright (c) The Libra Core Contributors +// SPDX-License-Identifier: Apache-2.0 + +//! This module contains implementations of a +//! [verifiable random function](https://en.wikipedia.org/wiki/Verifiable_random_function) +//! (currently only ECVRF). VRFs can be used in the consensus protocol for leader election. + +pub mod ecvrf; + +#[cfg(test)] +mod unit_tests; diff --git a/crypto/nextgen_crypto/src/vrf/unit_tests/mod.rs b/crypto/nextgen_crypto/src/vrf/unit_tests/mod.rs new file mode 100644 index 0000000000000..24f78c049a0ab --- /dev/null +++ b/crypto/nextgen_crypto/src/vrf/unit_tests/mod.rs @@ -0,0 +1,4 @@ +// Copyright (c) The Libra Core Contributors +// SPDX-License-Identifier: Apache-2.0 + +mod vrf_test; diff --git a/crypto/nextgen_crypto/src/vrf/unit_tests/vrf_test.rs b/crypto/nextgen_crypto/src/vrf/unit_tests/vrf_test.rs new file mode 100644 index 0000000000000..731701d0aa43f --- /dev/null +++ b/crypto/nextgen_crypto/src/vrf/unit_tests/vrf_test.rs @@ -0,0 +1,196 @@ +// Copyright (c) The Libra Core Contributors +// SPDX-License-Identifier: Apache-2.0 + +use crate::{unit_tests::uniform_keypair_strategy, vrf::ecvrf::*}; +use core::convert::TryFrom; +use crypto::hash::HashValue; +use curve25519_dalek::{ + constants::ED25519_BASEPOINT_POINT, edwards::CompressedEdwardsY, + scalar::Scalar as ed25519_Scalar, +}; +use proptest::prelude::*; + +macro_rules! to_string { + ($e:expr) => { + format!("{}", ::hex::encode($e.to_bytes().as_ref())) + }; +} + +macro_rules! from_string { + (CompressedEdwardsY, $e:expr) => { + CompressedEdwardsY::from_slice(&::hex::decode($e).unwrap()) + .decompress() + .unwrap() + }; + (VRFPublicKey, $e:expr) => {{ + let v: &[u8] = &::hex::decode($e).unwrap(); + VRFPublicKey::try_from(v).unwrap() + }}; + ($t:ty, $e:expr) => { + <$t>::try_from(::hex::decode($e).unwrap().as_ref()).unwrap() + }; +} + +#[allow(dead_code, non_snake_case)] +struct VRFTestVector { + SK: &'static str, + PK: &'static str, + alpha: &'static [u8], + x: &'static str, + H: &'static str, + k: &'static str, + U: &'static str, + V: &'static str, + pi: &'static str, + beta: &'static str, +} + +const TESTVECTORS : [VRFTestVector; 3] = [ + VRFTestVector { + SK : "9d61b19deffd5a60ba844af492ec2cc44449c5697b326919703bac031cae7f60", + PK : "d75a980182b10ab7d54bfed3c964073a0ee172f3daa62325af021a68f707511a", + alpha : b"", + x : "307c83864f2833cb427a2ef1c00a013cfdff2768d980c0a3a520f006904de94f", + // try_and_increment succeded on ctr = 0 + H : "5b2c80db3ce2d79cc85b1bfb269f02f915c5f0e222036dc82123f640205d0d24", + k : "647ac2b3ca3f6a77e4c4f4f79c6c4c8ce1f421a9baaa294b0adf0244915130f7067640acb6fd9e7e84f8bc30d4e03a95e410b82f96a5ada97080e0f187758d38", + U : "a21c342b8704853ad10928e3db3e58ede289c798e3cdfd485fbbb8c1b620604f", + V : "426fe41752f0b27439eb3d0c342cb645174a720cae2d4e9bb37de034eefe27ad", + pi : "9275df67a68c8745c0ff97b48201ee6db447f7c93b23ae24cdc2400f52fdb08a1a6ac7ec71bf9c9c76e96ee4675ebff60625af28718501047bfd87b810c2d2139b73c23bd69de66360953a642c2a330a", + beta : "a64c292ec45f6b252828aff9a02a0fe88d2fcc7f5fc61bb328f03f4c6c0657a9d26efb23b87647ff54f71cd51a6fa4c4e31661d8f72b41ff00ac4d2eec2ea7b3", + }, + VRFTestVector { + SK : "4ccd089b28ff96da9db6c346ec114e0f5b8a319f35aba624da8cf6ed4fb8a6fb", + PK : "3d4017c3e843895a92b70aa74d1b7ebc9c982ccf2ec4968cc0cd55f12af4660c", + alpha : b"\x72", + x : "68bd9ed75882d52815a97585caf4790a7f6c6b3b7f821c5e259a24b02e502e51", + // try_and_increment succeded on ctr = 4 + H : "08e18a34f3923db32e80834fb8ced4e878037cd0459c63ddd66e5004258cf76c", + k : "627237308294a8b344a09ad893997c630153ee514cd292eddd577a9068e2a6f24cbee0038beb0b1ee5df8be08215e9fc74608e6f9358b0e8d6383b1742a70628", + U : "18b5e500cb34690ced061a0d6995e2722623c105221eb91b08d90bf0491cf979", + V : "87e1f47346c86dbbd2c03eafc7271caa1f5307000a36d1f71e26400955f1f627", + pi : "84a63e74eca8fdd64e9972dcda1c6f33d03ce3cd4d333fd6cc789db12b5a7b9d03f1cb6b2bf7cd81a2a20bacf6e1c04e59f2fa16d9119c73a45a97194b504fb9a5c8cf37f6da85e03368d6882e511008", + beta : "cddaa399bb9c56d3be15792e43a6742fb72b1d248a7f24fd5cc585b232c26c934711393b4d97284b2bcca588775b72dc0b0f4b5a195bc41f8d2b80b6981c784e", + }, + VRFTestVector { + SK : "c5aa8df43f9f837bedb7442f31dcb7b166d38535076f094b85ce3a2e0b4458f7", + PK : "fc51cd8e6218a1a38da47ed00230f0580816ed13ba3303ac5deb911548908025", + alpha : b"\xaf\x82", + x : "909a8b755ed902849023a55b15c23d11ba4d7f4ec5c2f51b1325a181991ea95c", + // try_and_increment succeded on ctr = 0 + H : "e4581824b70badf0e57af789dd8cf85513d4b9814566de0e3f738439becfba33", + k : "a950f736af2e3ae2dbcb76795f9cbd57c671eee64ab17069f945509cd6c4a74852fe1bbc331e1bd573038ec703ca28601d861ad1e9684ec89d57bc22986acb0e", + U : "5114dc4e741b7c4a28844bc585350240a51348a05f337b5fd75046d2c2423f7a", + V : "a6d5780c472dea1ace78795208aaa05473e501ed4f53da57e1fb13b7e80d7f59", + pi : "aca8ade9b7f03e2b149637629f95654c94fc9053c225ec21e5838f193af2b727b84ad849b0039ad38b41513fe5a66cdd2367737a84b488d62486bd2fb110b4801a46bfca770af98e059158ac563b690f", + beta : "d938b2012f2551b0e13a49568612effcbdca2aed5d1d3a13f47e180e01218916e049837bd246f66d5058e56d3413dbbbad964f5e9f160a81c9a1355dcd99b453", + }, +]; + +#[test] +fn test_expand_secret_key() { + for tv in TESTVECTORS.iter() { + let sk = from_string!(VRFPrivateKey, tv.SK); + println!("{:?}", sk); + let esk = VRFExpandedPrivateKey::from(&sk); + let pk = VRFPublicKey::try_from(&sk).unwrap(); + assert_eq!(tv.PK, to_string!(pk)); + assert_eq!(tv.x, to_string!(esk.key)); + } +} + +#[test] +fn test_hash_to_curve() { + for tv in TESTVECTORS.iter() { + let pk = from_string!(VRFPublicKey, tv.PK); + let h_point = pk.hash_to_curve(&tv.alpha); + assert_eq!(tv.H, to_string!(h_point.compress())); + } +} + +#[test] +fn test_nonce_generation() { + for tv in TESTVECTORS.iter() { + let sk = VRFExpandedPrivateKey::from(&from_string!(VRFPrivateKey, tv.SK)); + let h_point = from_string!(CompressedEdwardsY, tv.H); + let k = nonce_generation_bytes(sk.nonce, h_point); + assert_eq!(tv.k, ::hex::encode(&k[..])); + } +} + +#[test] +fn test_hash_points() { + for tv in TESTVECTORS.iter() { + let sk = VRFExpandedPrivateKey::from(&from_string!(VRFPrivateKey, tv.SK)); + let h_point = from_string!(CompressedEdwardsY, tv.H); + let k_bytes = nonce_generation_bytes(sk.nonce, h_point); + let k_scalar = ed25519_Scalar::from_bytes_mod_order_wide(&k_bytes); + + let gamma = h_point * sk.key; + let u = ED25519_BASEPOINT_POINT * k_scalar; + let v = h_point * k_scalar; + + assert_eq!(tv.U, to_string!(u.compress())); + assert_eq!(tv.V, to_string!(v.compress())); + + let c_scalar = hash_points(&[h_point, gamma, u, v]); + + let s_scalar = k_scalar + c_scalar * sk.key; + s_scalar.reduce(); + + let mut c_bytes = [0u8; 16]; + c_bytes.copy_from_slice(&c_scalar.to_bytes()[..16]); + + let pi = Proof::new(gamma, c_scalar, s_scalar); + + assert_eq!(tv.pi, to_string!(pi)); + } +} + +#[test] +fn test_prove() { + for tv in TESTVECTORS.iter() { + let sk = from_string!(VRFPrivateKey, tv.SK); + let pi = sk.prove(tv.alpha); + + assert_eq!(tv.pi, to_string!(pi)); + } +} + +#[test] +fn test_verify() { + for tv in TESTVECTORS.iter() { + assert!(from_string!(VRFPublicKey, tv.PK) + .verify(&from_string!(Proof, tv.pi), tv.alpha) + .is_ok()); + } +} + +#[test] +fn test_output_from_proof() { + for tv in TESTVECTORS.iter() { + assert_eq!( + tv.beta, + to_string!(Output::from( + &from_string!(VRFPrivateKey, tv.SK).prove(tv.alpha) + )) + ); + } +} + +proptest! { + #[test] + fn test_prove_and_verify( + hash1 in any::(), + hash2 in any::(), + keypair in uniform_keypair_strategy::() + ) { + let (pk, sk) = (&keypair.public_key, &keypair.private_key); + let pk_test = VRFPublicKey::try_from(sk).unwrap(); + prop_assert_eq!(pk, &pk_test); + let (input1, input2) = (hash1.as_ref(), hash2.as_ref()); + let proof1 = sk.prove(input1); + prop_assert!(pk.verify(&proof1, input1).is_ok()); + prop_assert!(pk.verify(&proof1, input2).is_err()); + } +} diff --git a/crypto/secret_service/Cargo.toml b/crypto/secret_service/Cargo.toml new file mode 100644 index 0000000000000..ae5e3c70c7246 --- /dev/null +++ b/crypto/secret_service/Cargo.toml @@ -0,0 +1,33 @@ +[package] +name = "secret_service" +version = "0.1.0" +authors = ["Libra Association "] +license = "Apache-2.0" +publish = false +edition = "2018" + +[dependencies] +futures = "0.1.25" +grpcio = "0.4.3" +protobuf = "2.6" + +config = { path = "../../config" } +grpc_helpers = { path = "../../common/grpc_helpers"} +debug_interface = { path = "../../common/debug_interface" } +failure = { package = "failure_ext", path = "../../common/failure_ext" } +executable_helpers = { path = "../../common/executable_helpers"} +logger = { path = "../../common/logger" } + +nextgen_crypto = { path = "../nextgen_crypto" } +crypto = { path = "../legacy_crypto" } +# ed25519-dalek = { version = "1.0.0-pre.1", features = ["serde"] } +serde = { version = "1.0.89", features = ["derive"] } +rand = "0.6.5" +rand_chacha = "0.1.1" + +derive_deref = "1.0.2" + +crypto-derive = { path = "../legacy_crypto/src/macros" } + +[build-dependencies] +build_helpers = { path = "../../common/build_helpers" } diff --git a/crypto/secret_service/README.md b/crypto/secret_service/README.md new file mode 100644 index 0000000000000..f3bf67be15579 --- /dev/null +++ b/crypto/secret_service/README.md @@ -0,0 +1,46 @@ +--- +id: secret-service +title: Secret Service +custom_edit_url: https://github.com/libra/libra/edit/master/crypto/secret_service/README.md +--- +# Secret Service + +The secret service is a separate process that will manage cryptographic secret keys. + +## Overview + +**Note**: The secret service is under development, the rest of the code does not use the secret service yet, but will use it in the upcoming version. + +The secret service will hold the following secret keys for a validator node: +* account key giving the validator control over the three keys below, +* consensus signing key, +* network discovery signing key, +* network handshake Diffie-Hellman static key. + +All of the signing operations will be happening on the side of the secret service, and no signing key will ever leave the secret service process boundary. +We also plan in the future to equip the secret service with the ability to do the network handshake, so that the Diffie-Hellman static key also stays within the boundaries of the process. + +Right now the secret service exposes the following APIs: +* generate key: takes in a specification for key generation and returns the keyid which is handle to a newly generated key, +* get public key: returns the public key given the key id, +* sign: given a prehashed message and a keyid returns a signature. +These APIs will evolve possibly allowing for key-rotations, key-backup, key-provisioning, key-drop, etc. + +Right now the keys are generated randomly: the seed is driven from OS randomness (EntropyRng), the seedable Rng (ChaChaRng) is instantiated with the seed and the keys are generated using this seedable rng. The procedure for key derivation will be changed to facilitate: +* forward security, +* post-compromise security, +* easy backup, +* strong entropy. + +## How is this module organized? + secret_service/src + β”œβ”€β”€ secret_service_server.rs # Struct SecretServiceServer that holds the map of the generated secret keys and implements API answering the requests + β”œβ”€β”€ secret_service_client.rs # ConsensusKeyManager that represents a client for the secret service, it submits the requests and wraps the responses + β”œβ”€β”€ secret_service_node.rs # Runnable SecretServiceNode that opens connections on the ports specified in the node_config + β”œβ”€β”€ crypto_wrappers.rs # Helper methods for new crypto API located in the nextgen directory + β”œβ”€β”€ main.rs # Runs the secret service in its own process + β”œβ”€β”€ unit_tests # Tests + β”œβ”€β”€ lib.rs + └── proto/ + └── secret_service.proto # Rpc definitions of callable functions, the format for request and response messages as well as the error codes + diff --git a/crypto/secret_service/build.rs b/crypto/secret_service/build.rs new file mode 100644 index 0000000000000..307e0d8b454a7 --- /dev/null +++ b/crypto/secret_service/build.rs @@ -0,0 +1,17 @@ +// Copyright (c) The Libra Core Contributors +// SPDX-License-Identifier: Apache-2.0 + +//! This compiles all the `.proto` files under `src/` directory. +//! +//! For example, if there is a file `src/a/b/c.proto`, it will generate `src/a/b/c.rs` and +//! `src/a/b/c_grpc.rs`. + +fn main() { + let proto_root = "src/proto"; + + build_helpers::build_helpers::compile_proto( + proto_root, + vec![], + true, /* generate_client_stub */ + ); +} diff --git a/crypto/secret_service/src/crypto_wrappers.rs b/crypto/secret_service/src/crypto_wrappers.rs new file mode 100644 index 0000000000000..8402cb3fc4802 --- /dev/null +++ b/crypto/secret_service/src/crypto_wrappers.rs @@ -0,0 +1,186 @@ +// Copyright (c) The Libra Core Contributors +// SPDX-License-Identifier: Apache-2.0 + +//! This module contains wrappers around nextgen crypto API to support crypto agility +//! and to make the secret service agnostic to the details of the particular signing algorithms. + +use core::convert::TryFrom; +use crypto::hash::HashValue; +use crypto_derive::SilentDebug; +use derive_deref::Deref; +use failure::prelude::*; +use nextgen_crypto::{ + bls12381::{BLS12381PrivateKey, BLS12381PublicKey, BLS12381Signature}, + ed25519::{Ed25519PrivateKey, Ed25519PublicKey, Ed25519Signature}, + traits::*, +}; +use serde::{Deserialize, Serialize}; + +/// KeyID value is a handler to the secret key and a simple wrapper around the hash value. +#[derive(Clone, PartialEq, Eq, Hash, Deref)] +pub struct KeyID(pub HashValue); + +/////////////////////////////////////////////////////////////////// +// Declarations pulled from crypto/src/unit_tests/cross_test.rs // +/////////////////////////////////////////////////////////////////// +/// Generic public key enum +#[derive(Serialize, Deserialize, Debug, Clone, PartialEq, Eq, Hash)] +pub enum GenericPublicKey { + /// Ed25519 public key + Ed(Ed25519PublicKey), + /// BLS12-381 public key + BLS(BLS12381PublicKey), +} + +/// Generic private key enum +#[derive(Serialize, Deserialize, SilentDebug)] +pub enum GenericPrivateKey { + /// Ed25519 private key + Ed(Ed25519PrivateKey), + /// BLS12-381 private key + BLS(BLS12381PrivateKey), +} + +/// Generic signature enum +#[allow(clippy::large_enum_variant)] +#[derive(Clone, Debug, PartialEq, Eq, Hash)] +pub enum GenericSignature { + /// Ed25519 signature + Ed(Ed25519Signature), + /// BLS12-381 signature + BLS(BLS12381Signature), +} + +impl GenericSignature { + /// Convert the signature to bytes for serialization + pub fn to_bytes(&self) -> Vec { + match self { + GenericSignature::Ed(sig) => sig.to_bytes().to_vec(), + GenericSignature::BLS(sig) => sig.to_bytes().to_vec(), + } + } +} + +impl From<&GenericPrivateKey> for GenericPublicKey { + fn from(secret_key: &GenericPrivateKey) -> Self { + match secret_key { + GenericPrivateKey::Ed(pk) => GenericPublicKey::Ed(pk.into()), + GenericPrivateKey::BLS(pk) => GenericPublicKey::BLS(pk.into()), + } + } +} + +impl TryFrom<&[u8]> for GenericPrivateKey { + type Error = CryptoMaterialError; + fn try_from(bytes: &[u8]) -> std::result::Result { + Ed25519PrivateKey::try_from(bytes) + .and_then(|ed_priv_key| Ok(GenericPrivateKey::Ed(ed_priv_key))) + .or_else(|_err| { + BLS12381PrivateKey::try_from(bytes) + .and_then(|bls_priv_key| Ok(GenericPrivateKey::BLS(bls_priv_key))) + }) + } +} + +impl ValidKey for GenericPrivateKey { + fn to_bytes(&self) -> Vec { + match self { + GenericPrivateKey::BLS(privkey) => privkey.to_bytes().to_vec(), + GenericPrivateKey::Ed(privkey) => privkey.to_bytes().to_vec(), + } + } +} + +impl PublicKey for GenericPublicKey { + type PrivateKeyMaterial = GenericPrivateKey; + + fn length() -> usize { + std::cmp::max(BLS12381PublicKey::length(), Ed25519PublicKey::length()) + } +} + +impl TryFrom<&[u8]> for GenericPublicKey { + type Error = CryptoMaterialError; + fn try_from(bytes: &[u8]) -> std::result::Result { + Ed25519PublicKey::try_from(bytes) + .and_then(|ed_priv_key| Ok(GenericPublicKey::Ed(ed_priv_key))) + .or_else(|_err| { + BLS12381PublicKey::try_from(bytes) + .and_then(|bls_priv_key| Ok(GenericPublicKey::BLS(bls_priv_key))) + }) + } +} + +impl ValidKey for GenericPublicKey { + fn to_bytes(&self) -> Vec { + match self { + GenericPublicKey::BLS(pubkey) => pubkey.to_bytes().to_vec(), + GenericPublicKey::Ed(pubkey) => pubkey.to_bytes().to_vec(), + } + } +} + +impl PrivateKey for GenericPrivateKey { + type PublicKeyMaterial = GenericPublicKey; +} + +impl SigningKey for GenericPrivateKey { + type VerifyingKeyMaterial = GenericPublicKey; + type SignatureMaterial = GenericSignature; + + fn sign_message(&self, message: &HashValue) -> GenericSignature { + match self { + GenericPrivateKey::Ed(ed_priv) => GenericSignature::Ed(ed_priv.sign_message(message)), + GenericPrivateKey::BLS(bls_priv) => { + GenericSignature::BLS(bls_priv.sign_message(message)) + } + } + } +} + +impl Signature for GenericSignature { + type VerifyingKeyMaterial = GenericPublicKey; + + fn verify(&self, message: &HashValue, public_key: &GenericPublicKey) -> Result<()> { + self.verify_arbitrary_msg(message.as_ref(), public_key) + } + + fn verify_arbitrary_msg(&self, message: &[u8], public_key: &GenericPublicKey) -> Result<()> { + match (self, public_key) { + (GenericSignature::Ed(ed_sig), GenericPublicKey::Ed(ed_pub)) => { + ed_sig.verify_arbitrary_msg(message, ed_pub) + } + (GenericSignature::BLS(bls_sig), GenericPublicKey::BLS(bls_pub)) => { + bls_sig.verify_arbitrary_msg(message, bls_pub) + } + _ => bail!( + "provided the wrong alternative in {:?}!", + (self, public_key) + ), + } + } + + fn to_bytes(&self) -> Vec { + match self { + GenericSignature::Ed(sig) => sig.to_bytes().to_vec(), + GenericSignature::BLS(sig) => sig.to_bytes().to_vec(), + } + } +} + +impl TryFrom<&[u8]> for GenericSignature { + type Error = CryptoMaterialError; + fn try_from(bytes: &[u8]) -> std::result::Result { + Ed25519Signature::try_from(bytes) + .and_then(|ed_sig| Ok(GenericSignature::Ed(ed_sig))) + .or_else(|_err| { + BLS12381Signature::try_from(bytes) + .and_then(|bls_sig| Ok(GenericSignature::BLS(bls_sig))) + }) + } +} + +impl VerifyingKey for GenericPublicKey { + type SigningKeyMaterial = GenericPrivateKey; + type SignatureMaterial = GenericSignature; +} diff --git a/crypto/secret_service/src/lib.rs b/crypto/secret_service/src/lib.rs new file mode 100644 index 0000000000000..ec2ce47b1687f --- /dev/null +++ b/crypto/secret_service/src/lib.rs @@ -0,0 +1,12 @@ +// Copyright (c) The Libra Core Contributors +// SPDX-License-Identifier: Apache-2.0 + +#![deny(missing_docs)] +//! A secret service providing cryptographic operations on secret keys, will be used in future +//! releases. +pub mod crypto_wrappers; +pub mod proto; +pub mod secret_service_client; + +pub mod secret_service_node; +pub mod secret_service_server; diff --git a/crypto/secret_service/src/main.rs b/crypto/secret_service/src/main.rs new file mode 100644 index 0000000000000..43505becdab77 --- /dev/null +++ b/crypto/secret_service/src/main.rs @@ -0,0 +1,21 @@ +// Copyright (c) The Libra Core Contributors +// SPDX-License-Identifier: Apache-2.0 + +use executable_helpers::helpers::{ + setup_executable, ARG_CONFIG_PATH, ARG_DISABLE_LOGGING, ARG_PEER_ID, +}; +use secret_service::secret_service_node; + +/// Run a SecretService in its own process. +fn main() { + let (config, _logger, _args) = setup_executable( + "Libra Secret Service".to_string(), + vec![ARG_PEER_ID, ARG_CONFIG_PATH, ARG_DISABLE_LOGGING], + ); + + let secret_service_node = secret_service_node::SecretServiceNode::new(config); + + secret_service_node + .run() + .expect("Unable to run SecretService"); +} diff --git a/crypto/secret_service/src/proto/mod.rs b/crypto/secret_service/src/proto/mod.rs new file mode 100644 index 0000000000000..f136c82b4ab6c --- /dev/null +++ b/crypto/secret_service/src/proto/mod.rs @@ -0,0 +1,7 @@ +// Copyright (c) The Libra Core Contributors +// SPDX-License-Identifier: Apache-2.0 + +#![allow(missing_docs)] +// use types::proto::*; +pub mod secret_service; +pub mod secret_service_grpc; diff --git a/crypto/secret_service/src/proto/secret_service.proto b/crypto/secret_service/src/proto/secret_service.proto new file mode 100644 index 0000000000000..2aa326ae592d2 --- /dev/null +++ b/crypto/secret_service/src/proto/secret_service.proto @@ -0,0 +1,64 @@ +// Copyright (c) The Libra Core Contributors +// SPDX-License-Identifier: Apache-2.0 + +syntax = "proto3"; + +package secret_service; + +// ----------------------------------------------------------------------------- +// ---------------- Service definition +// ----------------------------------------------------------------------------- +service SecretService { + // API to request key generation + rpc GenerateKey (GenerateKeyRequest) returns (GenerateKeyResponse) {} + // API to request a public key + rpc GetPublicKey (PublicKeyRequest) returns (PublicKeyResponse) {} + // API to request a signature + rpc Sign (SignRequest) returns (SignResponse) {} +} + +message GenerateKeyRequest { + // Spec gives a way to generate the key (potentially BIP32 private derivation path here) + KeyType spec = 1; +} + +message GenerateKeyResponse { + bytes key_id = 1; + ErrorCode code = 2; +} + +message PublicKeyRequest { + bytes key_id = 1; +} + +message PublicKeyResponse { + bytes public_key = 1; + ErrorCode code = 2; +} + +message SignRequest { + bytes key_id = 1; + // message_hash should be a prehashed message of length crypto::HashValue::LENGTH = 32 bytes + bytes message_hash = 2; +} + +message SignResponse { + bytes signature = 1; + ErrorCode code = 2; +} + +enum ErrorCode { + Success = 0; + KeyIdNotFound = 1; + WrongLength = 2; + InvalidParameters = 3; + AuthenticationFailed = 4; + Unspecified = 5; + + // Good examples of more error codes: https://developers.yubico.com/YubiHSM2/Component_Reference/KSP/Status_codes.html +} + +enum KeyType { + Ed25519 = 0; + BLS12381 = 1; +} diff --git a/crypto/secret_service/src/secret_service_client.rs b/crypto/secret_service/src/secret_service_client.rs new file mode 100644 index 0000000000000..263a98b565911 --- /dev/null +++ b/crypto/secret_service/src/secret_service_client.rs @@ -0,0 +1,67 @@ +// Copyright (c) The Libra Core Contributors +// SPDX-License-Identifier: Apache-2.0 + +//! ConsensusKeyManager gives a simple interface for consensus to interact with the secret service. +//! This simple key manager will become more complicated in future versions, +//! now it asks the secret service to generate an ed25519 key on creation, +//! it can then transfer to the secret service the requests to get consensus public key and to sign +//! a consensus message. + +use crate::{ + crypto_wrappers::{GenericPublicKey, GenericSignature, KeyID}, + proto::{ + secret_service::{GenerateKeyRequest, KeyType, PublicKeyRequest, SignRequest}, + secret_service_grpc::SecretServiceClient, + }, +}; +use crypto::hash::HashValue; +use failure::prelude::*; +use nextgen_crypto::ed25519::{Ed25519PublicKey, Ed25519Signature}; +use std::{convert::TryFrom, sync::Arc}; + +/// A consensus key manager - interface between consensus and the secret service. +pub struct ConsensusKeyManager { + secret_service: Arc, + signing_keyid: KeyID, +} + +impl ConsensusKeyManager { + /// Saves a reference to the secret service and asks it to generate a new signing key. + pub fn new(secret_service: Arc) -> Result { + Ok(Self { + secret_service: Arc::clone(&secret_service), + signing_keyid: { + // generating consensus key: for simplicity it's assumed we only have one key and it + // is generated here we will have to modify this later + let mut gen_req: GenerateKeyRequest = GenerateKeyRequest::new(); + gen_req.set_spec(KeyType::Ed25519); + + let response = secret_service.generate_key(&gen_req)?; + KeyID(HashValue::from_slice(response.get_key_id())?) + }, + }) + } + + /// Asks the secret service for the public key and returns it. + pub fn get_consensus_public_key(&self) -> Result { + let mut pk_req: PublicKeyRequest = PublicKeyRequest::new(); + pk_req.set_key_id(self.signing_keyid.to_vec()); + let response = self.secret_service.get_public_key(&pk_req)?; + let public_key: &[u8] = response.get_public_key(); + + Ok(GenericPublicKey::Ed(Ed25519PublicKey::try_from( + public_key, + )?)) + } + + /// Asks the secret service to sign a hash of the consensus message. + pub fn sign_consensus_message(&self, message: &HashValue) -> Result<(GenericSignature)> { + let mut sig_req: SignRequest = SignRequest::new(); + sig_req.set_key_id(self.signing_keyid.to_vec()); + sig_req.set_message_hash(message.to_vec()); + let response = self.secret_service.sign(&sig_req)?; + let signature = response.get_signature(); + + Ok(GenericSignature::Ed(Ed25519Signature::try_from(signature)?)) + } +} diff --git a/crypto/secret_service/src/secret_service_node.rs b/crypto/secret_service/src/secret_service_node.rs new file mode 100644 index 0000000000000..dee33ac20762d --- /dev/null +++ b/crypto/secret_service/src/secret_service_node.rs @@ -0,0 +1,67 @@ +// Copyright (c) The Libra Core Contributors +// SPDX-License-Identifier: Apache-2.0 + +//! A secret service node upon run creates its own process. +//! It is a remote process running on address read from node_config.secret_service.address and +//! accepts connections on port node_config.secret_service.secret_service_port. +//! The proto/secret_service.proto file shows the requests that the service accepts and the +//! responses that it gives back. For an example on how to run the secret service see main.rs. + +use crate::{proto::secret_service_grpc, secret_service_server::SecretServiceServer}; +use config::config::NodeConfig; +use debug_interface::{node_debug_service::NodeDebugService, proto::node_debug_interface_grpc}; +use failure::prelude::*; +use grpc_helpers::spawn_service_thread; +use logger::prelude::*; +use std::thread; + +#[cfg(test)] +#[path = "unit_tests/secret_service_node_test.rs"] +mod secret_service_test; + +/// Secret service node is run a separate process and handles the secret keys. +pub struct SecretServiceNode { + node_config: NodeConfig, +} + +impl SecretServiceNode { + /// Instantiates the service with a config file. + pub fn new(node_config: NodeConfig) -> Self { + SecretServiceNode { node_config } + } + + /// Starts the secret service + pub fn run(&self) -> Result<()> { + info!("Starting secret service node"); + + let handle = SecretServiceServer::new(); + let service = secret_service_grpc::create_secret_service(handle); + let _ss_service_handle = spawn_service_thread( + service, + self.node_config.secret_service.address.clone(), + self.node_config.secret_service.secret_service_port, + "secret_service", + ); + + // Start Debug interface + let debug_service = + node_debug_interface_grpc::create_node_debug_interface(NodeDebugService::new()); + let _debug_handle = spawn_service_thread( + debug_service, + self.node_config.secret_service.address.clone(), + self.node_config + .debug_interface + .secret_service_node_debug_port, + "debug_secret_service", + ); + + info!( + "Started AdmissionControl node on port {}", + self.node_config.secret_service.secret_service_port + ); + + loop { + thread::park(); + } + } +} diff --git a/crypto/secret_service/src/secret_service_server.rs b/crypto/secret_service/src/secret_service_server.rs new file mode 100644 index 0000000000000..1cc000e07a65f --- /dev/null +++ b/crypto/secret_service/src/secret_service_server.rs @@ -0,0 +1,194 @@ +// Copyright (c) The Libra Core Contributors +// SPDX-License-Identifier: Apache-2.0 + +//! The Secret service server stores the secret key and performs operations on these keys. +//! Right now the service supports requests to generate the secret key (of Ed25519 or BLS12-381 +//! type), return the corresponding public key and sign. + +use crate::{ + crypto_wrappers::{GenericPrivateKey, GenericPublicKey, GenericSignature, KeyID}, + proto::{ + secret_service::{ + ErrorCode, GenerateKeyRequest, GenerateKeyResponse, KeyType, PublicKeyRequest, + PublicKeyResponse, SignRequest, SignResponse, + }, + secret_service_grpc, + }, +}; +use crypto::hash::HashValue; +use failure::prelude::*; +use grpc_helpers::provide_grpc_response; +use nextgen_crypto::{bls12381::BLS12381PrivateKey, ed25519::Ed25519PrivateKey, traits::*}; +use rand::{rngs::EntropyRng, Rng}; +use rand_chacha::ChaChaRng; +use std::{ + collections::HashMap, + sync::{Arc, RwLock}, +}; + +#[cfg(test)] +#[path = "unit_tests/secret_service_test.rs"] +mod secret_service_test; + +/// Secret service server that holds the secret keys and implements the necessary operations on +/// them. +#[derive(Clone, Default)] +pub struct SecretServiceServer { + // RwLock is chosen over Mutex because the RwLock won't get poisoned if a panic occurs during + // read + keys: Arc>>, /* Arc for shared ownership by clones; + * RwLock for being write-accessible + * by one + * thread at a time */ +} + +/// SecretServiceServer matches the API of proto/secret_service.proto but operates on our own Crypto +/// API structures +impl SecretServiceServer { + /// A fresh secret service creates an empty HashMap for the keys. + pub fn new() -> Self { + SecretServiceServer { + keys: Arc::new(RwLock::new(HashMap::new())), + } + } + + /// Generates a new secret key (for now this is the code for testing). + pub fn generate_key_inner(&mut self, spec: KeyType) -> Result { + let seed: [u8; 32] = EntropyRng::new().gen(); + + let private_key: GenericPrivateKey = { + let mut rng = ChaChaRng::from_seed(seed); + + match spec { + KeyType::Ed25519 => { + GenericPrivateKey::Ed(Ed25519PrivateKey::generate_for_testing(&mut rng)) + } + KeyType::BLS12381 => { + GenericPrivateKey::BLS(BLS12381PrivateKey::generate_for_testing(&mut rng)) + } + } + }; + + // For better security the keyid is a random string independent of the keys + let keyid = KeyID(HashValue::random()); + + // Alternatively (but less secure): keyid can be the hash of the public key. + // The problem here is that someone who knows the public key (e.g. from the chain) will know + // the how to make a signature request to the secret service + /* + let ed25519_public_key: ed25519_dalek::PublicKey = (&new_ed25519_key).into(); + let keyid = KeyID(HashValue::from_slice(ed25519_public_key.as_bytes()).unwrap()); + */ + + let result = keyid.clone(); + let mut keys = self + .keys + .write() + .expect("[generating new key] acquire keys lock"); + keys.insert(keyid, private_key); + Ok(result) + } + + /// Computes and returns the public key of the corresponding secret key. + pub fn get_public_key_inner(&self, keyid: &KeyID) -> Option { + let keys = self + .keys + .read() + .expect("[getting public key] acquire keys lock"); + keys.get(keyid).map(GenericPublicKey::from) + } + + /// Signs a hash value and returns the signature. + pub fn sign_inner(&self, keyid: &KeyID, message: &HashValue) -> Option { + let keys = self + .keys + .read() + .expect("[obtaining signature] acquire keys lock"); + keys.get(keyid) + .map(|secret_key| secret_key.sign_message(message)) + } +} + +/// SecretServiceServer implements the proto trait secret_service_grpc::SecretService. +/// The methods below wrap around inner methods of SecretServiceServer and operate on grpc's +/// requests/responses. +impl secret_service_grpc::SecretService for SecretServiceServer { + /// Generates a new key answering a GenerateKeyRequest with a GenerateKeyResponse. + fn generate_key( + &mut self, + ctx: ::grpcio::RpcContext, + req: GenerateKeyRequest, + sink: ::grpcio::UnarySink, + ) { + let mut response = GenerateKeyResponse::new(); + let spec = req.get_spec(); + let keyid = self.generate_key_inner(spec); + if let Ok(key_identity) = keyid { + response.set_code(ErrorCode::Success); + response.set_key_id(key_identity.to_vec()); + } else { + response.set_code(ErrorCode::Unspecified); + } + provide_grpc_response(Ok(response), ctx, sink); + } + + /// Returns a corresponding public key answering a PublicKeyRequest with a PublicKeyResponse. + fn get_public_key( + &mut self, + ctx: ::grpcio::RpcContext, + req: PublicKeyRequest, + sink: ::grpcio::UnarySink, + ) { + let mut response = PublicKeyResponse::new(); + let keyid_raw_bytes = req.get_key_id(); + if keyid_raw_bytes.len() != HashValue::LENGTH { + response.set_code(ErrorCode::WrongLength); + } else if let Ok(keyid) = HashValue::from_slice(keyid_raw_bytes) { + let keyid = KeyID(keyid); + let public_key = self.get_public_key_inner(&keyid); + if let Some(pkey) = public_key { + response.set_code(ErrorCode::Success); + response.set_public_key(pkey.to_bytes().to_vec()); + } else { + response.set_code(ErrorCode::KeyIdNotFound); + } + } else { + response.set_code(ErrorCode::Unspecified); + } + provide_grpc_response(Ok(response), ctx, sink); + } + + /// Returns a signature on a given hash-value with a correponding signing key answering a + /// SignRequest with a SignResponse. + fn sign( + &mut self, + ctx: ::grpcio::RpcContext, + req: SignRequest, + sink: ::grpcio::UnarySink, + ) { + let mut response = SignResponse::new(); + let keyid_raw_bytes = req.get_key_id(); + let message_raw_bytes = req.get_message_hash(); + if keyid_raw_bytes.len() != HashValue::LENGTH + || message_raw_bytes.len() != HashValue::LENGTH + { + response.set_code(ErrorCode::WrongLength); + } else if let Ok(keyid) = HashValue::from_slice(keyid_raw_bytes) { + let keyid = KeyID(keyid); + if let Ok(message) = HashValue::from_slice(message_raw_bytes) { + let signature = self.sign_inner(&keyid, &message); + if let Some(sig) = signature { + response.set_code(ErrorCode::Success); + response.set_signature(sig.to_bytes().to_vec()); + } else { + response.set_code(ErrorCode::KeyIdNotFound); + } + } else { + response.set_code(ErrorCode::Unspecified); + } + } else { + response.set_code(ErrorCode::Unspecified); + } + provide_grpc_response(Ok(response), ctx, sink); + } +} diff --git a/crypto/secret_service/src/unit_tests/secret_service_node_test.rs b/crypto/secret_service/src/unit_tests/secret_service_node_test.rs new file mode 100644 index 0000000000000..927751e9295ac --- /dev/null +++ b/crypto/secret_service/src/unit_tests/secret_service_node_test.rs @@ -0,0 +1,151 @@ +// Copyright (c) The Libra Core Contributors +// SPDX-License-Identifier: Apache-2.0 + +use crate::proto::{ + secret_service::{ErrorCode, GenerateKeyRequest, PublicKeyRequest}, + secret_service_grpc::SecretServiceClient, +}; +use config::config::{NodeConfig, NodeConfigHelpers}; +use debug_interface::node_debug_helpers::{check_node_up, create_debug_client}; +use grpcio::{ChannelBuilder, EnvBuilder}; +use std::{sync::Arc, thread}; + +use crate::{ + proto::secret_service::KeyType, secret_service_client::ConsensusKeyManager, + secret_service_node::SecretServiceNode, +}; +use crypto::hash::HashValue; +use logger::prelude::*; +// use crate::crypto_wrappers::GenericSignature; +use nextgen_crypto::traits::Signature; + +///////////////////////////////////////////////////////////////////////////////////// +// These tests check interoperability of key_generation, // +// key_retrieval and signing for crate::secret_service_server::SecretServiceServer // +///////////////////////////////////////////////////////////////////////////////////// + +#[test] +fn create_secret_service_node() { + let node_config = NodeConfigHelpers::get_single_node_test_config(true); + let _secret_service_node = SecretServiceNode::new(node_config); +} + +fn create_client(public_port: u16) -> SecretServiceClient { + let node_connection_str = format!("localhost:{}", public_port); + let env = Arc::new(EnvBuilder::new().build()); + let ch = ChannelBuilder::new(env).connect(&node_connection_str); + SecretServiceClient::new(ch) +} + +fn create_secret_service_node_and_client(node_config: NodeConfig) -> SecretServiceClient { + let public_port = node_config.secret_service.secret_service_port; + let debug_port = node_config.debug_interface.secret_service_node_debug_port; + + thread::spawn(move || { + let secret_service_node = SecretServiceNode::new(node_config); + secret_service_node.run().unwrap(); + info!("SecretService node stopped"); + }); + + let debug_client = create_debug_client(debug_port); + check_node_up(&debug_client); + + create_client(public_port) +} + +// Testing higher level interface // +#[test] +fn test_generate_consensus_key() { + let node_config = NodeConfigHelpers::get_single_node_test_config(true); + let client = create_secret_service_node_and_client(node_config.clone()); + + let key_manager = ConsensusKeyManager::new(Arc::new(client)).unwrap(); + let public_key = key_manager.get_consensus_public_key().unwrap(); + // let message_hash = b"consensus test message".digest(STANDARD_DIGESTER.get()); + let message_hash = HashValue::random(); + let signature = key_manager.sign_consensus_message(&message_hash).unwrap(); + // check that the signature verifies + assert!( + signature.verify(&message_hash, &public_key).is_ok(), + "Correct signature does not verify" + ); +} + +#[test] +fn test_generate_key() { + let node_config = NodeConfigHelpers::get_single_node_test_config(true); + let client = create_secret_service_node_and_client(node_config.clone()); + + // create request to generate new key + let mut gen_req: GenerateKeyRequest = GenerateKeyRequest::new(); + gen_req.set_spec(KeyType::Ed25519); + + let result = client.generate_key(&gen_req); + if result.is_ok() { + let response = result.ok().unwrap(); + let is_successful = response.get_code() == ErrorCode::Success; + assert!(is_successful); + } else { + panic!("key generation failed: {}", result.err().unwrap()); + } +} + +#[test] +fn test_generate_and_retrieve_key() { + let node_config = NodeConfigHelpers::get_single_node_test_config(true); + let client = create_secret_service_node_and_client(node_config.clone()); + + // create request to generate new key + let mut gen_req: GenerateKeyRequest = GenerateKeyRequest::new(); + gen_req.set_spec(KeyType::Ed25519); + + let result = client.generate_key(&gen_req); + if result.is_err() { + panic!("key generation failed: {}", result.err().unwrap()); + } + + let response = result.ok().unwrap(); + let is_successful = response.get_code() == ErrorCode::Success; + assert!(is_successful); + let keyid = response.get_key_id(); + + // assert_eq!(keyid, [230, 70, 121, 164, 224, 224, 49, 236, 222, 129, 71, 209, 108, 208, 39, + // 161, 6, 166, 100, 236, 85, 0, 83, 224, 28, 229, 132, 230, 86, 31, 198, 235]); + + // existing key can be obtained + let mut pk_req: PublicKeyRequest = PublicKeyRequest::new(); + pk_req.set_key_id(keyid.to_vec()); + + let result = client.get_public_key(&pk_req); + if result.is_err() { + panic!("public key retrieval failed: {}", result.err().unwrap()); + } + let response = result.ok().unwrap(); + let is_successful = response.get_code() == ErrorCode::Success; + assert!(is_successful); + let _result = response.get_public_key(); + + // invalid length keyid argument returns an error + let mut pk_req: PublicKeyRequest = PublicKeyRequest::new(); + pk_req.set_key_id([0, 1, 2, 4].to_vec()); + + let result = client.get_public_key(&pk_req); + if result.is_err() { + panic!("public key retrieval failed: {}", result.err().unwrap()); + } + let response = result.ok().unwrap(); + let is_successful = response.get_code() == ErrorCode::WrongLength; + assert!(is_successful); + + // obtaining non-existing key returns an error + let mut pk_req: PublicKeyRequest = PublicKeyRequest::new(); + pk_req.set_key_id([255; HashValue::LENGTH].to_vec()); + + let result = client.get_public_key(&pk_req); + if result.is_err() { + panic!("public key retrieval failed: {}", result.err().unwrap()); + } + let response = result.ok().unwrap(); + let is_successful = response.get_code() == ErrorCode::KeyIdNotFound; + assert!(is_successful); +} diff --git a/crypto/secret_service/src/unit_tests/secret_service_test.rs b/crypto/secret_service/src/unit_tests/secret_service_test.rs new file mode 100644 index 0000000000000..4d0b5224b822c --- /dev/null +++ b/crypto/secret_service/src/unit_tests/secret_service_test.rs @@ -0,0 +1,162 @@ +// Copyright (c) The Libra Core Contributors +// SPDX-License-Identifier: Apache-2.0 + +use crate::{ + proto::secret_service::KeyType, + secret_service_server::{KeyID, SecretServiceServer}, +}; +use crypto::hash::HashValue; +use nextgen_crypto::traits::{Signature, ValidKey}; + +///////////////////////////////////////////////////////////////////////////////////// +// These tests check interoperability of key_generation, // +// key_retrieval and signing for crate::secret_service_server::SecretServiceServer // +///////////////////////////////////////////////////////////////////////////////////// + +#[test] +fn test_generate_and_retrieve_keys() { + let mut ss_service = SecretServiceServer::new(); + let keys = [KeyType::Ed25519, KeyType::BLS12381]; + + for key_type in keys.iter() { + /* no placeholder key id exists */ + assert!( + ss_service + .get_public_key_inner(&KeyID(HashValue::zero())) + .is_none(), + "Empty SecretService returns a key" + ); + + let keyid1 = ss_service.generate_key_inner(*key_type).unwrap(); + + /* no placeholder key id exists */ + assert!( + ss_service + .get_public_key_inner(&KeyID(HashValue::zero())) + .is_none(), + "SecretService returns a key on incorrect keyid" + ); + + /* key id that was put exists */ + let public_key1 = ss_service.get_public_key_inner(&keyid1); + assert!( + ss_service.get_public_key_inner(&keyid1).is_some(), + "SecretService does not return a key" + ); + let public_key1 = public_key1.unwrap(); + + /* serialized and deserialized key id that was put exists */ + let key_id_from_wire = KeyID(HashValue::from_slice(&keyid1.0.to_vec()).unwrap()); + assert!( + ss_service.get_public_key_inner(&key_id_from_wire).is_some(), + "KeyId serialization into byte array is broken" + ); + + let keyid2 = ss_service.generate_key_inner(*key_type).unwrap(); + let public_key2 = ss_service.get_public_key_inner(&keyid2); + assert!( + ss_service.get_public_key_inner(&keyid2).is_some(), + "SecretService does not return a key" + ); + let public_key2 = public_key2.unwrap(); + + assert_ne!( + public_key1.to_bytes(), + public_key2.to_bytes(), + "SecretService returns same keys on different key ids" + ); + + /* check same keys received when invoked again */ + let public_key10 = ss_service.get_public_key_inner(&keyid1); + assert!( + ss_service.get_public_key_inner(&keyid1).is_some(), + "SecretService does not return a key" + ); + let public_key10 = public_key10.unwrap(); + let public_key20 = ss_service.get_public_key_inner(&keyid2); + assert!( + ss_service.get_public_key_inner(&keyid2).is_some(), + "SecretService does not return a key" + ); + let public_key20 = public_key20.unwrap(); + assert_eq!( + public_key1.to_bytes(), + public_key10.to_bytes(), + "SecretService keys don't match" + ); + assert_eq!( + public_key2.to_bytes(), + public_key20.to_bytes(), + "SecretService keys don't match" + ); + } +} + +#[test] +fn test_ed25519_sign() { + let mut ss_service = SecretServiceServer::new(); + + let keys = [KeyType::Ed25519, KeyType::BLS12381]; + + for key_type in keys.iter() { + let keyid1 = ss_service.generate_key_inner(*key_type).unwrap(); + let public_key1 = ss_service.get_public_key_inner(&keyid1); + assert!( + ss_service.get_public_key_inner(&keyid1).is_some(), + "SecretService does not return a key" + ); + let public_key1 = public_key1.unwrap(); + + let keyid2 = ss_service.generate_key_inner(*key_type).unwrap(); + let public_key2 = ss_service.get_public_key_inner(&keyid2); + assert!( + ss_service.get_public_key_inner(&keyid2).is_some(), + "SecretService does not return a key" + ); + let public_key2 = public_key2.unwrap(); + + /* signature obtained verifies */ + // let message_hash1 = b"hello".digest(STANDARD_DIGESTER.get()); + let message_hash1 = HashValue::random(); + let signature11 = ss_service.sign_inner(&keyid1, &message_hash1); + assert!( + signature11.is_some(), + "SecretService does not return a signature" + ); + let signature11 = signature11.unwrap(); + assert!( + signature11.verify(&message_hash1, &public_key1).is_ok(), + "Correct signature does not verify" + ); + + // let message_hash2 = b"world".digest(STANDARD_DIGESTER.get()); + let message_hash2 = HashValue::random(); + assert!( + signature11.verify(&message_hash2, &public_key1).is_err(), + "Incorrect signature verifies" + ); + + let signature22 = ss_service.sign_inner(&keyid2, &message_hash2); + assert!( + signature22.is_some(), + "SecretService does not return a signature" + ); + let signature22 = signature22.unwrap(); + assert!( + signature22.verify(&message_hash2, &public_key2).is_ok(), + "Correct signature does not verify" + ); + assert!( + signature22.verify(&message_hash1, &public_key2).is_err(), + "Incorrect signature verifies" + ); + assert!( + signature22.verify(&message_hash1, &public_key1).is_err(), + "Incorrect signature verifies" + ); + assert!( + signature11.verify(&message_hash2, &public_key2).is_err(), + "Incorrect signature verifies" + ); + } +} diff --git a/docker/mint/build.sh b/docker/mint/build.sh new file mode 100755 index 0000000000000..a86567c10932d --- /dev/null +++ b/docker/mint/build.sh @@ -0,0 +1,9 @@ +#!/bin/sh +set -e + +PROXY="" +if [ "$https_proxy" ]; then + PROXY=" --build-arg https_proxy=$https_proxy --build-arg http_proxy=$http_proxy" +fi + +docker build -f docker/mint/mint.Dockerfile . --tag libra_mint --build-arg GIT_REV=$(git rev-parse HEAD) --build-arg GIT_UPSTREAM=$(git merge-base --fork-point origin/master) --build-arg BUILD_DATE="$(date -u +'%Y-%m-%dT%H:%M:%SZ')" $PROXY diff --git a/docker/mint/mint.Dockerfile b/docker/mint/mint.Dockerfile new file mode 100644 index 0000000000000..2ac3bbdefc0b1 --- /dev/null +++ b/docker/mint/mint.Dockerfile @@ -0,0 +1,56 @@ +FROM debian:stretch as builder + +# To use http/https proxy while building, use: +# docker build --build-arg https_proxy=http://fwdproxy:8080 --build-arg http_proxy=http://fwdproxy:8080 + +RUN echo "deb http://deb.debian.org/debian stretch-backports main" > /etc/apt/sources.list.d/backports.list \ + && apt-get update && apt-get install -y protobuf-compiler/stretch-backports cmake golang curl \ + && apt-get clean && rm -r /var/lib/apt/lists/* + +RUN curl --proto '=https' --tlsv1.2 -sSf https://sh.rustup.rs | sh -s -- -y --default-toolchain none +ENV PATH "$PATH:/root/.cargo/bin" + +WORKDIR /libra +COPY rust-toolchain /libra/rust-toolchain +RUN rustup install $(cat rust-toolchain) + +COPY . /libra +RUN cargo build -p client --release + +### Production Image ### +FROM debian:stretch + +# TODO: Unsure which of these are needed exactly for client +RUN apt-get update && apt-get install -y python3-pip nano net-tools tcpdump iproute2 netcat \ + && apt-get clean && rm -r /var/lib/apt/lists/* + +# RUN apt-get install python3 +# TODO: Move to requirements.txt +RUN pip3 install flask flask_limiter gunicorn pexpect + +RUN mkdir -p /opt/libra/bin /opt/libra/etc /libra/client/data/wallet/ + +#TODO: Remove this once wallet location is set properly +RUN mkdir -p /libra/client/data/wallet/ + +COPY --from=builder /libra/target/release/client /opt/libra/bin +COPY docker/mint/server.py /opt/libra/bin + +# Mint proxy listening address +EXPOSE 8000 + +# Define TRUSTED_PEERS, MINT_KEY, AC_HOST and AC_PORT environment variables when running +CMD cd /opt/libra/etc && echo "$TRUSTED_PEERS" > trusted_peers.config.toml && echo "$MINT_KEY" | \ + base64 -d > mint.key && \ + cd /opt/libra/bin && \ + gunicorn --bind 0.0.0.0:8000 --access-logfile - --error-logfile - --log-level $LOG_LEVEL server + + +ARG BUILD_DATE +ARG GIT_REV +ARG GIT_UPSTREAM + +LABEL org.label-schema.schema-version="1.0" +LABEL org.label-schema.build-date=$BUILD_DATE +LABEL org.label-schema.vcs-ref=$GIT_REV +LABEL vcs-upstream=$GIT_UPSTREAM diff --git a/docker/mint/run.sh b/docker/mint/run.sh new file mode 100755 index 0000000000000..ab8d014a57e24 --- /dev/null +++ b/docker/mint/run.sh @@ -0,0 +1,17 @@ +#!/bin/sh + +# Example Usage: +# ./run.sh +# ./run.sh libra_mint:latest ac.dev.aws.hlw3truzy4ls.com 80 info + +set -ex + +IMAGE="$1" +CONFIGDIR="$(dirname "$0")/../../terraform/validator-sets/dev" + +docker network create --subnet 172.18.0.0/24 testnet || true + +docker run -p 8000:8000 -e AC_HOST=$2 -e AC_PORT=$3 -e LOG_LEVEL=$4 \ + -e TRUSTED_PEERS="$(cat $CONFIGDIR/trusted_peers.config.toml)" \ + -e MINT_KEY="$(base64 $CONFIGDIR/mint.key)" \ + --network testnet $IMAGE diff --git a/docker/mint/server.py b/docker/mint/server.py new file mode 100644 index 0000000000000..c5258f11643b3 --- /dev/null +++ b/docker/mint/server.py @@ -0,0 +1,70 @@ +""" +Dummy web server that proxies incoming requests to local client that owns association keys + +Installation: + virtualenv -p python3 ~/env + source ~/env/bin/activate + pip install flask gunicorn pexpect + +To run: + gunicorn --bind 0.0.0.0:8000 server + +""" +import flask + +import decimal +import os +import pexpect +import platform +import random +import re +import sys + +print(sys.version) +print(platform.python_version()) + + +def setup_app(): + application = flask.Flask(__name__) + ac_host = os.environ['AC_HOST'] + ac_port = os.environ['AC_PORT'] + + # If we have comma separated list take a random one + ac_hosts = ac_host.split(',') + ac_host = random.choice(ac_hosts) + + print("Connecting to ac on: {}:{}".format(ac_host, ac_port)) + + cmd = "/opt/libra/bin/client --host {} --port {} -m /opt/libra/etc/mint.key --validator_set_file /opt/libra/etc/trusted_peers.config.toml".format( + ac_host, ac_port) + application.client = pexpect.spawn(cmd) + application.client.expect("Please, input commands") + return application + + +application = setup_app() + +MAX_MINT = 10 ** 12 + +@application.route("/") +def send_transaction(): + address = flask.request.args['address'] + + # Return immediately if address is invalid + if re.match('^[a-f0-9]{64}$', address) is None: + return 'Malformed address', 400 + + try: + amount = decimal.Decimal(flask.request.args['amount']) + except: + return 'Bad amount', 400 + + if amount > MAX_MINT: + return 'Exceeded max mint amount of {}'.format(MAX_MINT), 400 + + application.client.sendline("a m {} {}".format(address, amount / (10 ** 6))) + application.client.expect("Mint request submitted", timeout=2) + + application.client.sendline("a la") + application.client.expect(r"sequence_number: ([0-9]+)", timeout=1) + return application.client.match.groups()[0] diff --git a/docker/validator/build.sh b/docker/validator/build.sh new file mode 100755 index 0000000000000..48a7ea1bc9016 --- /dev/null +++ b/docker/validator/build.sh @@ -0,0 +1,9 @@ +#!/bin/sh +set -e + +PROXY="" +if [ "$https_proxy" ]; then + PROXY=" --build-arg https_proxy=$https_proxy --build-arg http_proxy=$http_proxy" +fi + +docker build -f docker/validator/validator.Dockerfile . --tag libra_e2e --build-arg GIT_REV="$(git rev-parse HEAD)" --build-arg GIT_UPSTREAM="$(git merge-base HEAD origin/master)" --build-arg BUILD_DATE="$(date -u +'%Y-%m-%dT%H:%M:%SZ')" $PROXY diff --git a/docker/validator/install-tools.sh b/docker/validator/install-tools.sh new file mode 100755 index 0000000000000..f61e5ce0ff54e --- /dev/null +++ b/docker/validator/install-tools.sh @@ -0,0 +1,4 @@ +#!/bin/sh + +apt-get update +apt-get install --no-install-recommends -y nano net-tools tcpdump iproute2 netcat ngrep atop gdb strace curl diff --git a/docker/validator/run.sh b/docker/validator/run.sh new file mode 100755 index 0000000000000..d7a159096aaf2 --- /dev/null +++ b/docker/validator/run.sh @@ -0,0 +1,16 @@ +#!/bin/sh +set -ex + +IMAGE="libra_e2e:latest" +CONFIGDIR="$(dirname "$0")/../../terraform/validator-sets/dev/" + +SEED_PEERS="$(sed 's,SEED_IP,172.18.0.10,' $CONFIGDIR/seed_peers.config.toml)" +GENESIS_BLOB="$(base64 $CONFIGDIR/genesis.blob)" +TRUSTED_PEERS="$(cat $CONFIGDIR/trusted_peers.config.toml)" + +docker network create --subnet 172.18.0.0/24 testnet || true + +docker run -e PEER_ID=8deeeaed65f0cd7484a9e4e5ac51fbac548f2f71299a05e000156031ca78fb9f -e SELF_IP=172.18.0.10 -e SEED_PEERS="$SEED_PEERS" -e TRUSTED_PEERS="$TRUSTED_PEERS" -e GENESIS_BLOB="$GENESIS_BLOB" -e PEER_KEYPAIRS="$(cat $CONFIGDIR/8deeeaed65f0cd7484a9e4e5ac51fbac548f2f71299a05e000156031ca78fb9f.node.keys.toml)" --ip 172.18.0.10 --network testnet --detach "$IMAGE" +docker run -e PEER_ID=1e5d5a74b0fd09f601ac0fca2fe7d213704e02e51943d18cf25a546b8416e9e1 -e SELF_IP=172.18.0.11 -e SEED_PEERS="$SEED_PEERS" -e TRUSTED_PEERS="$TRUSTED_PEERS" -e GENESIS_BLOB="$GENESIS_BLOB" -e PEER_KEYPAIRS="$(cat $CONFIGDIR/1e5d5a74b0fd09f601ac0fca2fe7d213704e02e51943d18cf25a546b8416e9e1.node.keys.toml)" --ip 172.18.0.11 --network testnet --detach "$IMAGE" +docker run -e PEER_ID=ab0d6a54ce9d7fc79c061f95883a308f9bdfc987262b6a34a360fdd788fcd9cd -e SELF_IP=172.18.0.12 -e SEED_PEERS="$SEED_PEERS" -e TRUSTED_PEERS="$TRUSTED_PEERS" -e GENESIS_BLOB="$GENESIS_BLOB" -e PEER_KEYPAIRS="$(cat $CONFIGDIR/ab0d6a54ce9d7fc79c061f95883a308f9bdfc987262b6a34a360fdd788fcd9cd.node.keys.toml)" --ip 172.18.0.12 --network testnet --detach "$IMAGE" +docker run -e PEER_ID=57ff83747054695f2228042c26eb6a243ac73de1b9038aea103999480b076d45 -e SELF_IP=172.18.0.13 -e SEED_PEERS="$SEED_PEERS" -e TRUSTED_PEERS="$TRUSTED_PEERS" -e GENESIS_BLOB="$GENESIS_BLOB" -e PEER_KEYPAIRS="$(cat $CONFIGDIR/57ff83747054695f2228042c26eb6a243ac73de1b9038aea103999480b076d45.node.keys.toml)" --ip 172.18.0.13 --network testnet --publish 30307:30307 "$IMAGE" diff --git a/docker/validator/validator.Dockerfile b/docker/validator/validator.Dockerfile new file mode 100644 index 0000000000000..33911b9019b29 --- /dev/null +++ b/docker/validator/validator.Dockerfile @@ -0,0 +1,43 @@ +FROM debian:stretch as builder + +# To use http/https proxy while building, use: +# docker build --build-arg https_proxy=http://fwdproxy:8080 --build-arg http_proxy=http://fwdproxy:8080 + +RUN echo "deb http://deb.debian.org/debian stretch-backports main" > /etc/apt/sources.list.d/backports.list && apt-get update && apt-get install -y protobuf-compiler/stretch-backports cmake golang curl && apt-get clean && rm -r /var/lib/apt/lists/* + +RUN curl --proto '=https' --tlsv1.2 -sSf https://sh.rustup.rs | sh -s -- -y --default-toolchain none +ENV PATH "$PATH:/root/.cargo/bin" + +WORKDIR /libra +COPY rust-toolchain /libra/rust-toolchain +RUN rustup install $(cat rust-toolchain) + +COPY . /libra +RUN cargo build --release -p libra_node + +### Production Image ### +FROM debian:stretch + +RUN mkdir -p /opt/libra/bin /opt/libra/etc +COPY terraform/validator-sets/dev/node.config.toml /opt/libra/etc +COPY docker/validator/install-tools.sh /root +COPY --from=builder /libra/target/release/libra_node /opt/libra/bin + +# Admission control +EXPOSE 30307 +# Validator network +EXPOSE 30303 +# Metrics +EXPOSE 14297 + +# Define SEED_PEERS, SELF_IP, PEER_KEYPAIRS, GENESIS_BLOB and PEER_ID environment variables when running +CMD cd /opt/libra/etc && sed -i "s,SELF_IP,$SELF_IP," node.config.toml && echo "$SEED_PEERS" > seed_peers.config.toml && echo "$TRUSTED_PEERS" > trusted_peers.config.toml && echo "$PEER_KEYPAIRS" > peer_keypairs.config.toml && echo "$GENESIS_BLOB" | base64 -d > genesis.blob && exec /opt/libra/bin/libra_node -f node.config.toml --peer_id "$PEER_ID" + +ARG BUILD_DATE +ARG GIT_REV +ARG GIT_UPSTREAM + +LABEL org.label-schema.schema-version="1.0" +LABEL org.label-schema.build-date=$BUILD_DATE +LABEL org.label-schema.vcs-ref=$GIT_REV +LABEL vcs-upstream=$GIT_UPSTREAM diff --git a/documentation/coding_guidelines.md b/documentation/coding_guidelines.md new file mode 100644 index 0000000000000..d9fd9fc6b403d --- /dev/null +++ b/documentation/coding_guidelines.md @@ -0,0 +1,264 @@ +--- +id: coding-guidelines +title: Coding Guidelines +--- + +# Libra Core coding guidelines + +This document describes the coding guidelines for the Libra Core Rust codebase. + +## Code formatting + +All code formatting is enforced with [rustfmt](https://github.com/rust-lang/rustfmt) with a project specific configuration. Below is an example command to adhere to the project conventions. + +``` +libra/libra$ cargo fmt +``` + +## Code analysis + +[Clippy](https://github.com/rust-lang/rust-clippy) is used to catch common mistakes and is run as a part of continuous integration. Before submitting your code for review, you can run clippy with our configuration: + +``` +libra$ ../setup_scripts/clippy.sh +``` + +In general, we follow the recommendations from [rust-lang-nursery](https://rust-lang-nursery.github.io/api-guidelines/about.html). The remainder of the guide is to provide more detailed guidance on specific areas to provide as much uniformity as possible. + + +## Code documentation + +Any public fields, functions, and methods should be documented with [Rustdoc](https://doc.rust-lang.org/book/ch14-02-publishing-to-crates-io.html#making-useful-documentation-comments). + + Please follow the conventions as detailed below for modules, structs, enums, and functions. The *single line* is used as a preview when navigating Rustdoc. As an example, see the 'Structs' and 'Enums' sections in the [collections](https://doc.rust-lang.org/std/collections/index.html) Rustdoc. + + ``` + /// [Single line] One line summary description + /// + /// [Longer description] Multiple lines, inline code + /// examples, invariants, purpose, usage, etc. + [Attributes] If attributes exist, add after Rustdoc + ``` + +Example below: + +``` +/// Represents (x, y) of a 2-dimensional grid +/// +/// A line is defined by 2 instances. +/// A plane is defined by 3 instances. +#[repr(C)] +struct Point { + x: i32, + y: i32, +} +``` + +### Constants and fields + +Describe the purpose and definition of this data. + +### Functions and methods + +Document the following for each function: + +* The action the method performs - β€œThis method *adds* a new transaction to the mempool.” Use *active voice* and *present tense* (i.e. adds/creates/checks/updates/deletes). +* Describe how and why to use this method. +* Any condition that must be met _before_ calling the method. +* State conditions under which the function will `panic!()` or returns an `Error` +* Brief description of return values. +* Any special behavior that is not obvious + +### README.md for top-level directories and other major components + +Each major component of the system needs to have a `README.md` file. Major components are: +* top-level directories (e.g. `libra/network`, `libra/language`) +* the most important crates in the system (e.g. `vm_runtime`) + +This file should contain: + + * The *conceptual* *documentation* of the component. + * A link to external API documentation for the component. + * A link to the master license of the project. + * A link to the master contributing guide for the project. + +A template for readmes: + +``` +# Component Name + +[Summary line: Start with one sentence about this component.] + +## Overview + +* Describe its purpose and how the code in this directory works. +* Describe the interaction of code in this directory with other components. +* Describe the security model and assumptions about the crates in this directory. Examples of how to describe the security assumptions will be added in the future. + +## Implementation Details + +* Describe how the component is modeled. For example, why is the code organized the way it is? +* Other relevant implementation details. + +## API Documentation + +For the external API of this crate refer to [Link to rustdoc API]. + +[For a top-level directory, link to the most important APIs within.] + +## Contributing + +Refer to the Libra Project contributing guide [LINK]. + +## License + +Refer to the Libra Project License [LINK]. +``` + +A good example of README.md is `libra/network/README.md` that describes the networking crate. + +## Code suggestions + +Below are suggested best practices for a uniform codebase that will evolve and improve over time. We will be investigating what can be clippy enforced and will update this section in the future. + +### Attributes + +Make sure to use the appropriate attributes for handling dead code: + +``` +// For code that is intended for production usage in the future +#[allow(dead_code)] +// For code that is only intended for testing and +// has no intended production use +#[cfg(test)] +``` + +### Avoid Deref polymorphism + +Don't abuse the Deref trait to emulate inheritance between structs, and thus reuse methods. For more information, read [here](https://github.com/rust-unofficial/patterns/blob/master/anti_patterns/deref.md). + +### Comments + +Prefer `//` and `///` comments rather than block comments `/* ... */` for uniformity and simpler grepping. + +### Cloning + +If `x` is reference counted, prefer [`Arc::clone(x)`](https://doc.rust-lang.org/std/sync/struct.Arc.html) rather than `x.clone()`, as it is more explicit that we are cloning `x` in the former case. This avoids confusion about whether we are performing an expensive clone of a `struct`, `enum` or other type or a cheap reference copy. + +Also, if you are passing around [`Arc`](https://doc.rust-lang.org/std/sync/struct.Arc.html) types, consider using a newtype wrapper: + +``` +#[derive(Clone, Debug)] +pub struct Foo(Arc); +``` + +### Concurrent types + +Concurrent types such as [`CHashMap`](https://docs.rs/crate/chashmap), [`AtomicUsize`](https://doc.rust-lang.org/std/sync/atomic/struct.AtomicUsize.html), etc. have an immutable borrow on self i.e. `fn foo_mut(&self,...)` in order to support concurrent access on interior mutating methods. Good practices (such as those in the examples mentioned) avoid exposing synchronization primitives externally (e.g. `Mutex`, `RwLock`) and document the method semantics and invariants clearly. + +*When to use channels vs concurrent types?* + +Below are high level suggestions for the distinction based on experience. + +* Channels are for ownership transfer, decoupling of types, and coarse grained messages. They fit well for transferring ownership of data, distributing units of work, and communicating async results. Furthermore, they help break circular dependencies (e.g. `struct Foo` contains an `Arc` and `struct Bar` contains an `Arc` that leads to complex initialization). + +* Concurrent types (e.g. such as [`CHashMap`](https://docs.rs/crate/chashmap) or structs that have interior mutability building on [`Mutex`](https://doc.rust-lang.org/std/sync/struct.Mutex.html), [`RwLock`](https://doc.rust-lang.org/std/sync/struct.RwLock.html), etc.) are better suited for caches and states. + +### Error handling + +Error handling suggestions follow the [Rust book guidance](https://doc.rust-lang.org/book/ch09-00-error-handling.html). Rust groups errors into two major categories: recoverable and unrecoverable errors, where recoverable errors should be handled with [Result](https://doc.rust-lang.org/std/result/). For our suggestions on unrecoverable errors, see below: + +*Panic* + +* `panic!()` - Runtime panic! should only be used when the resulting state cannot be processed going forward. It should not be used for any recoverable errors. +* `unwrap()` - Unwrap should only be used for mutexes (e.g. `lock().unwrap()`) and test code. For all other use cases, prefer `expect()`. The only exception is if the error message is custom-generated, in which case use `.unwrap_or_else(|| panic!("error: {}", foo))` +* `expect()` - Expect should be invoked when a system invariant is expected to be preserved. It is preferred over unwrap() and should have a detailed error message on failure in most cases. +* `assert!()` - This macro is kept in both debug/release and should be used to protect invariants of the system as necessary +* `unreachable!()` - This macro is will panic on code that should not be reached (violating an invariant) and can be used as appropriate. + +### Generics + +Generics allow dynamic behavior (similar to [`trait`](https://doc.rust-lang.org/book/ch10-02-traits.html) methods) with static dispatch. Consider that as the number of generic type parameters increase, the difficulty of using the type/method also increases (e.g. what combination of trait bounds is required for this type, duplicate trait bounds on related types, etc.). In order to avoid this complexity, we generally try to avoid using a large number of generic type parameters. We have found that converting code with a large number of generic objects to trait objects with dynamic dispatch often simplifies our code. + +### Getters/setters + +Excluding test code, set field visibility to private as much as possible. Private fields allow constructors to enforce internal invariants. Implement getters for data that consumers may need, but avoid setters unless mutable state is necessary. + +Public fields are most appropriate for [`struct`](https://doc.rust-lang.org/book/ch05-00-structs.html) types in the C spirit: compound, passive data structures without internal invariants. Naming suggestions follow the guidance [here](https://rust-lang-nursery.github.io/api-guidelines/naming.html#getter-names-follow-rust-convention-c-getter) as shown below. + +``` +struct Foo { + size: usize, + key_to_value: HashMap +} + +impl Foo { + /// Return a copy when inexpensive + fn size(&self) -> usize { + self.size + } + + /// Borrow for expensive copies + fn key_to_value(&self) -> &HashMap { + &self.key_to_value + } + + /// Setter follows set_xxx pattern + fn set_foo(&mut self, size: usize){ + self.size = size; + } + + /// For a more complex getter, using get_XXX is acceptable + /// (similar to HashMap) with well-defined and + /// commented semantics + fn get_value(&self, key: u32) -> Option<&u32> { + self.key_to_value.get(&key) + } +} +``` + +### Logging + +We currently use [slog](https://docs.rs/slog/) for logging. + +* [error!](https://docs.rs/slog/2.4.1/slog/macro.error.html) - Error-level messages have the highest urgency in [slog](https://docs.rs/slog/). An unexpected error has occurred (e.g. exceeded the maximum number of retries to complete an RPC or inability to store data to local storage). +* [warn!](https://docs.rs/slog/2.4.1/slog/macro.warn.html) - Warn-level messages help notify admins about automatically handled issues (e.g. retrying a failed network connection or receiving the same message multiple times, etc.). +* [info!](https://docs.rs/slog/2.4.1/slog/macro.info.html) - Info-level messages are well suited for "one time" events (such as logging state on one-time startup and shutdown) or periodic events that are not frequently occurring - e.g. changing the validator set every day. +* [debug!](https://docs.rs/slog/2.4.1/slog/macro.debug.html) - Debug-level messages are frequently occurring (i.e. potentially > 1 message per second) and are not typically expected to be enabled in production. +* [trace!](https://docs.rs/slog/2.4.1/slog/macro.trace.html) - Trace-level logging is typically only used for function entry/exit. + +### Testing + +*Unit tests* + +Ideally, all code will be unit tested. Unit test files should exist in the same directory as `mod.rs` and their file names should end in `_test.rs`. A module to be tested should have the test modules annotated with `#[cfg(test)]`. For example, if in a crate there is a db module, the expected directory structure is as follows: + +``` +src/db -> directory of db module +src/db/mod.rs -> code of db module +src/db/read_test.rs -> db test 1 +src/db/write_test.rs -> db test 2 +src/db/access/mod.rs -> directory of access submodule +src/db/access/access_test.rs -> test of access submodule +``` + +*Property-based tests* + +Libra contains [property-based tests](https://blog.jessitron.com/2013/04/25/property-based-testing-what-is-it/) written in Rust using the [`proptest` framework](https://github.com/AltSysrq/proptest). Property-based tests generate random test cases and assert that invariants, also called *properties*, hold about the code under test. + +Some examples of properties tested in Libra: + +* Every serializer and deserializer pair is tested for correctness with random inputs to the serializer. Any pair of functions that are inverses of each other can be tested this way. +* The results of executing common transactions through the VM are tested using randomly generated scenarios, and a simplified model as an *oracle*. + +A tutorial for `proptest` can be found in the [`proptest` book](https://altsysrq.github.io/proptest-book/proptest/getting-started.html). + +References: + +* [What is Property Based Testing?](https://hypothesis.works/articles/what-is-property-based-testing/) (includes a comparison with fuzzing) +* [An introduction to property-based testing](https://fsharpforfunandprofit.com/posts/property-based-testing/) +* [Choosing properties for property-based testing](https://fsharpforfunandprofit.com/posts/property-based-testing-2/) + +*Fuzzing* + +Libra contains harnesses for fuzzing crash-prone code like deserializers, using [`libFuzzer`](https://llvm.org/docs/LibFuzzer.html) through [`cargo fuzz`](https://rust-fuzz.github.io/book/cargo-fuzz.html). For more, see the `testsuite/libra_fuzzer` directory. diff --git a/execution/README.md b/execution/README.md new file mode 100644 index 0000000000000..bebbbc4f88c95 --- /dev/null +++ b/execution/README.md @@ -0,0 +1,113 @@ +--- +id: execution +title: Execution +custom_edit_url: https://github.com/libra/libra/edit/master/execution/README.md +--- + +# Execution + +## Overview + +The Libra Blockchain is a replicated state machine. Each validator is a replica +of the system. Starting from genesis state S0, each transaction +Ti updates previous state Si-1 to Si. Each +Si is a mapping from accounts (represented by 32-byte addresses) to +some data associated with each account. + +The execution component takes the totally ordered transactions, computes the +output for each transaction via the Move virtual machine, applies the output on +the previous state, and generates the new state. The execution system cooperates +with the consensus algorithm — HotStuff, a leader-based algorithm β€” to +help it agree on a proposed set of transactions and their execution. Such a +group of transactions is a block. Unlike in other blockchain systems, blocks +have no significance other than being a batch of transactions β€” every +transaction is identified by its position within the ledger, which is also +referred to as its "version". Each consensus participant builds a tree of blocks +like the following: + +``` + β”Œ-- C + β”Œ-- B <--─ + | β””-- D +<--- A <--─ (A is the last committed block) + | β”Œ-- F <--- G + β””-- E <--─ + β””-- H + + ↓ After committing block E + + β”Œ-- F <--- G +<--- A <--- E <--─ (E is the last committed block) + β””-- H +``` + +A block is a list of transactions that should be applied in the given order once +the block is committed. Each path from the last committed block to an +uncommitted block forms a valid chain. Regardless of the commit rule of the +consensus algorithm, there are two possible operations on this tree: + +1. Adding a block to the tree using a given parent and extending a specific + chain (for example, extending block `F` with the block `G`). When we extend a + chain with a new block, the block should include the correct execution + results of the transactions in the block as if all its ancestors have been + committed in the same order. However, all the uncommitted blocks and their + execution results are held in some temporary location and not visible to + external clients. +2. Committing a block. As consensus collects more and more votes on blocks, it + decides to commit a block and all its ancestors according to some specific + rules. Then we save all these blocks to permanent storage and also discard + all the conflicting blocks at the same time. + +Therefore, the execution component provides two primary APIs - `execute_block` +and `commit_block` - to support the above operations. + +## Implementation Details + +The state at each version is represented as a sparse Merkle tree in storage. +When a transaction modifies an account, the account and the siblings from tree +root to the account is loaded into memory. For example, if we execute a +transaction Ti on top of committed state and it modified account `A`, +we will end up having the following tree: + +``` + S_i + / \ + o y + / \ + x A +``` + +where `A` has the new state of the account, and `y` and `x` are the siblings on +the path from the root to `A` in the tree. If the next transaction Ti+1 +modified another account `B` that lives in the subtree at `y`, a new tree will +be constructed, and the structure will look like the following: + +``` + S_i S_{i+1} + / \ / \ + / y / \ + / _______/ \ + // \ + o y' + / \ / \ + x A z B +``` + +Using this structure, we are able to query the global state, taking into account +the output of uncommitted transactions. For example, if we want to execute +another transaction Ti+1', we can use the tree +Si. If we look for account A, we can find its new value in the tree. +Otherwise, we know the account does not exist in the tree, and we can fall back to +storage. As another example, if we want to execute transaction Ti+2, +we can use the tree Si+1 that has updated values for both account `A` +and `B`. + +## How is this component organized? +``` + execution + └── execution_client # A Rust wrapper on top of GRPC clients. + └── execution_proto # All interfaces provided by the execution component. + └── execution_service # Execution component as a GRPC service. + └── executor # The main implementation of execution component. +``` + diff --git a/execution/execution_client/Cargo.toml b/execution/execution_client/Cargo.toml new file mode 100644 index 0000000000000..889969c9b1660 --- /dev/null +++ b/execution/execution_client/Cargo.toml @@ -0,0 +1,15 @@ +[package] +name = "execution_client" +version = "0.1.0" +authors = ["Libra Association "] +license = "Apache-2.0" +publish = false +edition = "2018" + +[dependencies] +grpcio = "0.4.4" + +execution_proto = { path = "../execution_proto" } +failure = { path = "../../common/failure_ext", package = "failure_ext" } +types = { path = "../../types" } +proto_conv = { path = "../../common/proto_conv" } diff --git a/execution/execution_client/src/lib.rs b/execution/execution_client/src/lib.rs new file mode 100644 index 0000000000000..08f2ab833b6f5 --- /dev/null +++ b/execution/execution_client/src/lib.rs @@ -0,0 +1,42 @@ +// Copyright (c) The Libra Core Contributors +// SPDX-License-Identifier: Apache-2.0 + +use execution_proto::{ + proto::{execution::CommitBlockRequest, execution_grpc}, + ExecuteBlockRequest, ExecuteBlockResponse, +}; +use failure::{bail, Result}; +use grpcio::{ChannelBuilder, Environment}; +use proto_conv::{FromProto, IntoProto}; +use std::sync::Arc; +use types::ledger_info::LedgerInfoWithSignatures; + +pub struct ExecutionClient { + client: execution_grpc::ExecutionClient, +} + +impl ExecutionClient { + pub fn new(env: Arc, host: &str, port: u16) -> Self { + let channel = ChannelBuilder::new(env).connect(&format!("{}:{}", host, port)); + let client = execution_grpc::ExecutionClient::new(channel); + ExecutionClient { client } + } + + pub fn execute_block(&self, request: ExecuteBlockRequest) -> Result { + let proto_request = request.into_proto(); + match self.client.execute_block(&proto_request) { + Ok(proto_response) => Ok(ExecuteBlockResponse::from_proto(proto_response)?), + Err(err) => bail!("GRPC error: {}", err), + } + } + + pub fn commit_block(&self, ledger_info_with_sigs: LedgerInfoWithSignatures) -> Result<()> { + let proto_ledger_info = ledger_info_with_sigs.into_proto(); + let mut request = CommitBlockRequest::new(); + request.set_ledger_info_with_sigs(proto_ledger_info); + match self.client.commit_block(&request) { + Ok(_proto_response) => Ok(()), + Err(err) => bail!("GRPC error: {}", err), + } + } +} diff --git a/execution/execution_proto/Cargo.toml b/execution/execution_proto/Cargo.toml new file mode 100644 index 0000000000000..f4ea5dec470ce --- /dev/null +++ b/execution/execution_proto/Cargo.toml @@ -0,0 +1,22 @@ +[package] +name = "execution_proto" +version = "0.1.0" +authors = ["Libra Association "] +license = "Apache-2.0" +publish = false +edition = "2018" + +[dependencies] +futures = "0.1.25" +grpcio = "0.4.4" +proptest = "0.9.2" +proptest-derive = "0.1.0" +protobuf = "2.6" + +crypto = { path = "../../crypto/legacy_crypto" } +failure = { path = "../../common/failure_ext", package = "failure_ext" } +proto_conv = { path = "../../common/proto_conv", features = ["derive"] } +types = { path = "../../types" } + +[build-dependencies] +build_helpers = { path = "../../common/build_helpers" } diff --git a/execution/execution_proto/build.rs b/execution/execution_proto/build.rs new file mode 100644 index 0000000000000..764103e0f065b --- /dev/null +++ b/execution/execution_proto/build.rs @@ -0,0 +1,18 @@ +// Copyright (c) The Libra Core Contributors +// SPDX-License-Identifier: Apache-2.0 + +//! This compiles all the `.proto` files under `src/` directory. +//! +//! For example, if there is a file `src/a/b/c.proto`, it will generate `src/a/b/c.rs` and +//! `src/a/b/c_grpc.rs`. + +fn main() { + let proto_root = "src/proto"; + let dependent_root = "../../types/src/proto"; + + build_helpers::build_helpers::compile_proto( + proto_root, + vec![dependent_root], + false, /* generate_client_stub */ + ); +} diff --git a/execution/execution_proto/src/lib.rs b/execution/execution_proto/src/lib.rs new file mode 100644 index 0000000000000..999c7d11f0b01 --- /dev/null +++ b/execution/execution_proto/src/lib.rs @@ -0,0 +1,183 @@ +// Copyright (c) The Libra Core Contributors +// SPDX-License-Identifier: Apache-2.0 + +#![allow(clippy::unit_arg)] + +pub mod proto; + +#[cfg(test)] +mod protobuf_conversion_test; + +use crypto::HashValue; +use failure::prelude::*; +use proptest_derive::Arbitrary; +use proto_conv::{FromProto, IntoProto}; +use types::{ + ledger_info::LedgerInfoWithSignatures, + transaction::{SignedTransaction, TransactionListWithProof, TransactionStatus}, + validator_set::ValidatorSet, + vm_error::VMStatus, +}; + +#[derive(Arbitrary, Clone, Debug, Eq, PartialEq, FromProto, IntoProto)] +#[ProtoType(crate::proto::execution::ExecuteBlockRequest)] +pub struct ExecuteBlockRequest { + /// The list of transactions from consensus. + pub transactions: Vec, + + /// Id of parent block. + pub parent_block_id: HashValue, + + /// Id of current block. + pub block_id: HashValue, +} + +impl ExecuteBlockRequest { + pub fn new( + transactions: Vec, + parent_block_id: HashValue, + block_id: HashValue, + ) -> Self { + ExecuteBlockRequest { + transactions, + parent_block_id, + block_id, + } + } +} + +#[derive(Arbitrary, Clone, Debug, Eq, PartialEq)] +pub struct ExecuteBlockResponse { + /// Root hash of the transaction accumulator as if all transactions in this block are applied. + root_hash: HashValue, + + /// Status code for each individual transaction in this block. + status: Vec, + + /// If set, these are the set of validators that will be used to start the next epoch + /// immediately after this state is committed. + validators: Option, +} + +impl ExecuteBlockResponse { + pub fn new( + root_hash: HashValue, + status: Vec, + validators: Option, + ) -> Self { + ExecuteBlockResponse { + root_hash, + status, + validators, + } + } + + pub fn root_hash(&self) -> HashValue { + self.root_hash + } + + pub fn status(&self) -> &[TransactionStatus] { + &self.status + } + + pub fn validators(&self) -> &Option { + &self.validators + } +} + +impl FromProto for ExecuteBlockResponse { + type ProtoType = crate::proto::execution::ExecuteBlockResponse; + + fn from_proto(mut object: Self::ProtoType) -> Result { + Ok(ExecuteBlockResponse { + root_hash: HashValue::from_slice(object.get_root_hash())?, + status: object + .take_status() + .into_iter() + .map(|proto_vm_status| { + let vm_status = VMStatus::from_proto(proto_vm_status)?; + Ok(vm_status.into()) + }) + .collect::>>()?, + validators: object + .validators + .take() + .map(ValidatorSet::from_proto) + .transpose()?, + }) + } +} + +impl IntoProto for ExecuteBlockResponse { + type ProtoType = crate::proto::execution::ExecuteBlockResponse; + + fn into_proto(self) -> Self::ProtoType { + let mut out = Self::ProtoType::new(); + out.set_root_hash(self.root_hash.to_vec()); + out.set_status( + self.status + .into_iter() + .map(|transaction_status| { + let vm_status = match transaction_status { + TransactionStatus::Keep(status) => status, + TransactionStatus::Discard(status) => status, + }; + vm_status.into_proto() + }) + .collect(), + ); + if let Some(validators) = self.validators { + out.set_validators(validators.into_proto()); + } + out + } +} + +#[derive(Arbitrary, Clone, Debug, Eq, PartialEq, FromProto, IntoProto)] +#[ProtoType(crate::proto::execution::CommitBlockRequest)] +pub struct CommitBlockRequest { + pub ledger_info_with_sigs: LedgerInfoWithSignatures, +} + +#[derive(Arbitrary, Clone, Copy, Debug, Eq, PartialEq)] +pub enum CommitBlockResponse { + Succeeded, + Failed, +} + +impl FromProto for CommitBlockResponse { + type ProtoType = crate::proto::execution::CommitBlockResponse; + + fn from_proto(object: Self::ProtoType) -> Result { + use crate::proto::execution::CommitBlockStatus; + Ok(match object.get_status() { + CommitBlockStatus::SUCCEEDED => CommitBlockResponse::Succeeded, + CommitBlockStatus::FAILED => CommitBlockResponse::Failed, + }) + } +} + +impl IntoProto for CommitBlockResponse { + type ProtoType = crate::proto::execution::CommitBlockResponse; + + fn into_proto(self) -> Self::ProtoType { + use crate::proto::execution::CommitBlockStatus; + let mut out = Self::ProtoType::new(); + out.set_status(match self { + CommitBlockResponse::Succeeded => CommitBlockStatus::SUCCEEDED, + CommitBlockResponse::Failed => CommitBlockStatus::FAILED, + }); + out + } +} + +#[derive(Arbitrary, Clone, Debug, Eq, PartialEq, FromProto, IntoProto)] +#[ProtoType(crate::proto::execution::ExecuteChunkRequest)] +pub struct ExecuteChunkRequest { + pub txn_list_with_proof: TransactionListWithProof, + pub ledger_info_with_sigs: LedgerInfoWithSignatures, +} + +#[derive(Arbitrary, Clone, Debug, Eq, PartialEq, FromProto, IntoProto)] +#[ProtoType(crate::proto::execution::ExecuteChunkResponse)] +pub struct ExecuteChunkResponse {} diff --git a/execution/execution_proto/src/proto/execution.proto b/execution/execution_proto/src/proto/execution.proto new file mode 100644 index 0000000000000..d080986fd6078 --- /dev/null +++ b/execution/execution_proto/src/proto/execution.proto @@ -0,0 +1,89 @@ +// Copyright (c) The Libra Core Contributors +// SPDX-License-Identifier: Apache-2.0 + +syntax = "proto3"; + +package execution; + +import "get_with_proof.proto"; +import "ledger_info.proto"; +import "transaction.proto"; +import "validator_set.proto"; +import "vm_errors.proto"; + +// ----------------------------------------------------------------------------- +// ---------------- Execution Service Definition +// ----------------------------------------------------------------------------- +service Execution { + // Execute a list of signed transactions given by consensus. Return the id + // of the block and the root hash of the ledger after applying transactions + // in this block. + rpc ExecuteBlock(ExecuteBlockRequest) returns (ExecuteBlockResponse) {} + + // Commit a previously executed block that has been agreed by consensus. + rpc CommitBlock(CommitBlockRequest) returns (CommitBlockResponse) {} + + // Execute and commit a list of signed transactions received from peer + // during synchronization. Return the id of the block + rpc ExecuteChunk(ExecuteChunkRequest) returns (ExecuteChunkResponse) {} +} + +message ExecuteBlockRequest { + // The list of transactions from consensus. + repeated types.SignedTransaction transactions = 1; + + // Id of the parent block. + // We're going to use a special GENESIS_BLOCK_ID constant defined in + // crypto::hash module to refer to the block id of the Genesis block, which is + // executed in a special way. + bytes parent_block_id = 2; + + // Id of the current block. + bytes block_id = 3; +} + +// Result of transaction execution. +message ExecuteBlockResponse { + // Root hash of the ledger after applying all the transactions in this + // block. + bytes root_hash = 1; + + // The execution result of the transactions. Each transaction has a status + // field that indicates whether it should be included in the ledger once the + // block is committed. + repeated types.VMStatus status = 2; + + // If set, this field designates that if this block is committed, then the + // next epoch will start immediately with the included set of validators. + types.ValidatorSet validators = 3; +} + +message CommitBlockRequest { + // The ledger info with signatures from 2f+1 validators. It contains the id + // of the block consensus wants to commit. This will cause the given block + // and all the uncommitted ancestors to be committed to storage. + types.LedgerInfoWithSignatures ledger_info_with_sigs = 1; +} + +message CommitBlockResponse { CommitBlockStatus status = 1; } + +enum CommitBlockStatus { + // The block is persisted. + SUCCEEDED = 0; + + // Something went wrong. + FAILED = 1; +} + +// Ask Execution service to execute and commit a chunk of contiguous +// transactions. All the transactions in this chunk should belong to the same +// epoch E. If the caller has a list of transactions that span two epochs, it +// should split the transactions. +message ExecuteChunkRequest { + types.TransactionListWithProof txn_list_with_proof = 1; + types.LedgerInfoWithSignatures ledger_info_with_sigs = 2; +} + +// Either all transactions are successfully executed and persisted, or nothing +// happens. +message ExecuteChunkResponse {} diff --git a/execution/execution_proto/src/proto/mod.rs b/execution/execution_proto/src/proto/mod.rs new file mode 100644 index 0000000000000..94df4fcbc9693 --- /dev/null +++ b/execution/execution_proto/src/proto/mod.rs @@ -0,0 +1,7 @@ +// Copyright (c) The Libra Core Contributors +// SPDX-License-Identifier: Apache-2.0 + +use types::proto::*; + +pub mod execution; +pub mod execution_grpc; diff --git a/execution/execution_proto/src/protobuf_conversion_test.rs b/execution/execution_proto/src/protobuf_conversion_test.rs new file mode 100644 index 0000000000000..6680ecfbc32f6 --- /dev/null +++ b/execution/execution_proto/src/protobuf_conversion_test.rs @@ -0,0 +1,45 @@ +// Copyright (c) The Libra Core Contributors +// SPDX-License-Identifier: Apache-2.0 + +use crate::{ + CommitBlockRequest, CommitBlockResponse, ExecuteBlockRequest, ExecuteBlockResponse, + ExecuteChunkRequest, ExecuteChunkResponse, +}; +use proptest::prelude::*; +use proto_conv::test_helper::assert_protobuf_encode_decode; + +proptest! { + #![proptest_config(ProptestConfig::with_cases(10))] + + #[test] + fn test_execute_block_request_roundtrip(execute_block_request in any::()) { + assert_protobuf_encode_decode(&execute_block_request); + } + + #[test] + fn test_execute_block_response_roundtrip(execute_block_response in any::()) { + assert_protobuf_encode_decode(&execute_block_response); + } + + #[test] + fn test_commit_block_request_roundtrip(commit_block_request in any::()) { + assert_protobuf_encode_decode(&commit_block_request); + } + + #[test] + fn test_execute_chunk_request_roundtrip(execute_chunk_request in any::()) { + assert_protobuf_encode_decode(&execute_chunk_request); + } +} + +proptest! { + #[test] + fn test_commit_block_response_roundtrip(commit_block_response in any::()) { + assert_protobuf_encode_decode(&commit_block_response); + } + + #[test] + fn test_execute_chunk_response_roundtrip(execute_chunk_response in any::()) { + assert_protobuf_encode_decode(&execute_chunk_response); + } +} diff --git a/execution/execution_service/Cargo.toml b/execution/execution_service/Cargo.toml new file mode 100644 index 0000000000000..1c874819eaa93 --- /dev/null +++ b/execution/execution_service/Cargo.toml @@ -0,0 +1,33 @@ +[package] +name = "execution_service" +version = "0.1.0" +authors = ["Libra Association "] +license = "Apache-2.0" +publish = false +edition = "2018" + +[dependencies] +futures01 = { package = "futures", version = "0.1.26" } +futures03 = { package = "futures-preview", version = "=0.3.0-alpha.16", features = ["compat"] } +grpcio = "0.4.4" + +config = { path = "../../config" } +execution_proto = { path = "../execution_proto" } +executor = { path = "../executor" } +failure = { path = "../../common/failure_ext", package = "failure_ext" } +grpc_helpers = { path = "../../common/grpc_helpers" } +proto_conv = { path = "../../common/proto_conv" } +storage_client = { path = "../../storage/storage_client" } +types = { path = "../../types" } +vm_runtime = { path = "../../language/vm/vm_runtime" } + +[dev-dependencies] +tempfile = "3.0.7" + +config = { path = "../../config" } +config_builder = { path = "../../config/config_builder" } +crypto = { path = "../../crypto/legacy_crypto" } +execution_client = { path = "../execution_client" } +storage_proto = { path = "../../storage/storage_proto" } +storage_service = { path = "../../storage/storage_service" } +vm_genesis = { path = "../../language/vm/vm_genesis" } diff --git a/execution/execution_service/src/lib.rs b/execution/execution_service/src/lib.rs new file mode 100644 index 0000000000000..93a9fb063742a --- /dev/null +++ b/execution/execution_service/src/lib.rs @@ -0,0 +1,168 @@ +// Copyright (c) The Libra Core Contributors +// SPDX-License-Identifier: Apache-2.0 + +#![feature(async_await)] + +use config::config::NodeConfig; +use execution_proto::{CommitBlockRequest, ExecuteBlockRequest, ExecuteChunkRequest}; +use executor::Executor; +use failure::Result; +use futures01::future::Future; +use futures03::{ + channel::oneshot, + future::{FutureExt, TryFutureExt}, +}; +use grpc_helpers::default_reply_error_logger; +use grpcio::{RpcStatus, RpcStatusCode}; +use proto_conv::{FromProto, IntoProto}; +use std::sync::Arc; +use storage_client::{StorageRead, StorageWrite}; +use vm_runtime::MoveVM; + +#[derive(Clone)] +pub struct ExecutionService { + /// `ExecutionService` simply contains an `Executor` and uses it to process requests. We wrap + /// it in `Arc` because `ExecutionService` has to implement `Clone`. + executor: Arc>, +} + +impl ExecutionService { + /// Constructs an `ExecutionService`. + pub fn new( + storage_read_client: Arc, + storage_write_client: Arc, + config: &NodeConfig, + ) -> Self { + let executor = Arc::new(Executor::new( + storage_read_client, + storage_write_client, + config, + )); + ExecutionService { executor } + } +} + +impl execution_proto::proto::execution_grpc::Execution for ExecutionService { + fn execute_block( + &mut self, + ctx: grpcio::RpcContext, + request: execution_proto::proto::execution::ExecuteBlockRequest, + sink: grpcio::UnarySink, + ) { + match ExecuteBlockRequest::from_proto(request) { + Ok(req) => { + let fut = process_response( + self.executor.execute_block( + req.transactions, + req.parent_block_id, + req.block_id, + ), + sink, + ) + .boxed() + .unit_error() + .compat(); + ctx.spawn(fut); + } + Err(err) => { + let fut = process_conversion_error(err, sink); + ctx.spawn(fut); + } + } + } + + fn commit_block( + &mut self, + ctx: grpcio::RpcContext, + request: execution_proto::proto::execution::CommitBlockRequest, + sink: grpcio::UnarySink, + ) { + match CommitBlockRequest::from_proto(request) { + Ok(req) => { + let fut = + process_response(self.executor.commit_block(req.ledger_info_with_sigs), sink) + .boxed() + .unit_error() + .compat(); + ctx.spawn(fut); + } + Err(err) => { + let fut = process_conversion_error(err, sink); + ctx.spawn(fut); + } + } + } + + fn execute_chunk( + &mut self, + ctx: grpcio::RpcContext, + request: execution_proto::proto::execution::ExecuteChunkRequest, + sink: grpcio::UnarySink, + ) { + match ExecuteChunkRequest::from_proto(request) { + Ok(req) => { + let fut = process_response( + self.executor + .execute_chunk(req.txn_list_with_proof, req.ledger_info_with_sigs), + sink, + ) + .boxed() + .unit_error() + .compat(); + ctx.spawn(fut); + } + Err(err) => { + let fut = process_conversion_error(err, sink); + ctx.spawn(fut); + } + } + } +} + +async fn process_response( + resp: oneshot::Receiver>, + sink: grpcio::UnarySink<::ProtoType>, +) where + T: IntoProto, +{ + match resp.await { + Ok(Ok(response)) => { + sink.success(response.into_proto()); + } + Ok(Err(err)) => { + set_failure_message( + RpcStatusCode::Unknown, + format!("Failed to process request: {}", err), + sink, + ); + } + Err(oneshot::Canceled) => { + set_failure_message( + RpcStatusCode::Internal, + "Executor Internal error: sender is dropped.".to_string(), + sink, + ); + } + } +} + +fn process_conversion_error( + err: failure::Error, + sink: grpcio::UnarySink, +) -> impl Future { + set_failure_message( + RpcStatusCode::InvalidArgument, + format!("Failed to convert request from Protobuf: {}", err), + sink, + ) + .map_err(default_reply_error_logger) +} + +fn set_failure_message( + status_code: RpcStatusCode, + details: String, + sink: grpcio::UnarySink, +) -> grpcio::UnarySinkResult { + let status = RpcStatus::new(status_code, Some(details)); + sink.fail(status) +} diff --git a/execution/execution_service/tests/execution_service_test.rs b/execution/execution_service/tests/execution_service_test.rs new file mode 100644 index 0000000000000..044dad65bbe87 --- /dev/null +++ b/execution/execution_service/tests/execution_service_test.rs @@ -0,0 +1,75 @@ +// Copyright (c) The Libra Core Contributors +// SPDX-License-Identifier: Apache-2.0 + +mod test_helper; + +use crate::test_helper::{create_and_start_server, gen_block_id, gen_ledger_info_with_sigs}; +use config_builder::util::get_test_config; +use crypto::{ + hash::GENESIS_BLOCK_ID, + signing::{generate_keypair, KeyPair}, +}; +use execution_client::ExecutionClient; +use execution_proto::ExecuteBlockRequest; +use futures01::future::Future; +use grpcio::EnvBuilder; +use std::sync::Arc; +use types::{ + account_address::AccountAddress, + account_config, + transaction::{RawTransaction, SignedTransaction}, +}; +use vm_genesis::encode_mint_program; + +fn encode_mint_transaction(seqnum: u64, sender_keypair: &KeyPair) -> SignedTransaction { + let (_privkey, pubkey) = generate_keypair(); + let sender = account_config::association_address(); + let receiver = AccountAddress::from(pubkey); + let program = encode_mint_program(&receiver, 100); + let raw_txn = RawTransaction::new( + sender, + seqnum, + program, + /* max_gas_amount = */ 10_000, + /* gas_unit_price = */ 1, + std::time::Duration::from_secs(u64::max_value()), + ); + raw_txn + .sign(&sender_keypair.private_key(), sender_keypair.public_key()) + .expect("Signing should work.") +} + +#[test] +fn test_execution_service_basic() { + let (config, faucet_keypair) = get_test_config(); + let (_storage_server_handle, mut execution_server) = create_and_start_server(&config); + + let execution_client = ExecutionClient::new( + Arc::new(EnvBuilder::new().build()), + &config.execution.address, + config.execution.port, + ); + + let parent_block_id = *GENESIS_BLOCK_ID; + let block_id = gen_block_id(1); + let version = 100; + + let txns = (0..version) + .map(|i| encode_mint_transaction(i, &faucet_keypair)) + .collect(); + let execute_block_request = ExecuteBlockRequest::new(txns, parent_block_id, block_id); + let execute_block_response = execution_client + .execute_block(execute_block_request) + .unwrap(); + + let ledger_info_with_sigs = gen_ledger_info_with_sigs( + u64::from(version), + execute_block_response.root_hash(), + block_id, + ); + execution_client + .commit_block(ledger_info_with_sigs) + .unwrap(); + + execution_server.shutdown().wait().unwrap(); +} diff --git a/execution/execution_service/tests/storage_integration_test.rs b/execution/execution_service/tests/storage_integration_test.rs new file mode 100644 index 0000000000000..44e715aa5521e --- /dev/null +++ b/execution/execution_service/tests/storage_integration_test.rs @@ -0,0 +1,596 @@ +// Copyright (c) The Libra Core Contributors +// SPDX-License-Identifier: Apache-2.0 + +mod test_helper; + +use crate::test_helper::{create_and_start_server, gen_block_id, gen_ledger_info_with_sigs}; +use config_builder::util::get_test_config; +use crypto::{ + hash::GENESIS_BLOCK_ID, + signing::{generate_keypair, PrivateKey, PublicKey}, +}; +use execution_client::ExecutionClient; +use execution_proto::ExecuteBlockRequest; +use failure::prelude::*; +use grpcio::EnvBuilder; +use proto_conv::FromProto; +use std::sync::Arc; +use storage_client::{StorageRead, StorageReadServiceClient}; +use types::{ + access_path::AccessPath, + account_address::AccountAddress, + account_config::{association_address, get_account_resource_or_default}, + account_state_blob::AccountStateWithProof, + get_with_proof::{verify_update_to_latest_ledger_response, RequestItem}, + test_helpers::transaction_test_helpers::get_test_signed_txn as get_test_signed_txn_proto, + transaction::{ + Program, SignedTransaction, SignedTransactionWithProof, TransactionListWithProof, + }, + validator_verifier::ValidatorVerifier, +}; +use vm_genesis::{encode_create_account_program, encode_transfer_program}; + +fn get_test_signed_transaction( + sender: AccountAddress, + sequence_number: u64, + private_key: PrivateKey, + public_key: PublicKey, + program: Option, +) -> SignedTransaction { + SignedTransaction::from_proto(get_test_signed_txn_proto( + sender, + sequence_number, + private_key, + public_key, + program, + )) + .unwrap() +} + +#[test] +fn test_execution_with_storage() { + let (config, genesis_keypair) = get_test_config(); + let (_storage_server_handle, _execution_server) = create_and_start_server(&config); + + let storage_read_client = Arc::new(StorageReadServiceClient::new( + Arc::new(EnvBuilder::new().build()), + &config.storage.address, + config.storage.port, + )); + + let execution_client = ExecutionClient::new( + Arc::new(EnvBuilder::new().build()), + &config.execution.address, + config.execution.port, + ); + + let (privkey1, pubkey1) = generate_keypair(); + let account1 = AccountAddress::from(pubkey1); + let (privkey2, pubkey2) = generate_keypair(); + let account2 = AccountAddress::from(pubkey2); + let (_privkey3, pubkey3) = generate_keypair(); + let account3 = AccountAddress::from(pubkey3); + let genesis_account = association_address(); + + // Create account1 with 10k coins. + let txn1 = get_test_signed_transaction( + genesis_account, + /* sequence_number = */ 0, + genesis_keypair.private_key().clone(), + genesis_keypair.public_key(), + Some(encode_create_account_program(&account1, 200_000)), + ); + + // Create account2 with 20k coins. + let txn2 = get_test_signed_transaction( + genesis_account, + /* sequence_number = */ 1, + genesis_keypair.private_key().clone(), + genesis_keypair.public_key(), + Some(encode_create_account_program(&account2, 20_000)), + ); + + // Create account3 with 30k coins. + let txn3 = get_test_signed_transaction( + genesis_account, + /* sequence_number = */ 2, + genesis_keypair.private_key().clone(), + genesis_keypair.public_key(), + Some(encode_create_account_program(&account3, 10_000)), + ); + + // Transfer 2k coins from account1 to account2. + // balance: <198k, 22k, 10k + let txn4 = get_test_signed_transaction( + account1, + /* sequence_number = */ 0, + privkey1.clone(), + pubkey1, + Some(encode_transfer_program(&account2, 2_000)), + ); + + // Transfer 1k coins from account2 to account3. + // balance: <198k, <21k, 11k + let txn5 = get_test_signed_transaction( + account2, + /* sequence_number = */ 0, + privkey2.clone(), + pubkey2, + Some(encode_transfer_program(&account3, 1_000)), + ); + + // Transfer 7k coins from account1 to account3. + // balance: <191k, <21k, 18k + let txn6 = get_test_signed_transaction( + account1, + /* sequence_number = */ 1, + privkey1.clone(), + pubkey1, + Some(encode_transfer_program(&account3, 7_000)), + ); + + let block1 = vec![txn1, txn2, txn3, txn4, txn5, txn6]; + let block1_id = gen_block_id(1); + + let mut block2 = vec![]; + let block2_id = gen_block_id(2); + + // Create 14 txns transfering 1k from account1 to account3 each. + for i in 2..=15 { + block2.push(get_test_signed_transaction( + account1, + /* sequence_number = */ i, + privkey1.clone(), + pubkey1, + Some(encode_transfer_program(&account3, 1_000)), + )); + } + + let execute_block_request = + ExecuteBlockRequest::new(block1.clone(), *GENESIS_BLOCK_ID, block1_id); + let execute_block_response = execution_client + .execute_block(execute_block_request) + .unwrap(); + let ledger_info_with_sigs = + gen_ledger_info_with_sigs(6, execute_block_response.root_hash(), block1_id); + execution_client + .commit_block(ledger_info_with_sigs) + .unwrap(); + + let request_items = vec![ + RequestItem::GetAccountTransactionBySequenceNumber { + account: genesis_account, + sequence_number: 0, + fetch_events: false, + }, + RequestItem::GetAccountTransactionBySequenceNumber { + account: genesis_account, + sequence_number: 1, + fetch_events: false, + }, + RequestItem::GetAccountTransactionBySequenceNumber { + account: genesis_account, + sequence_number: 2, + fetch_events: false, + }, + RequestItem::GetAccountTransactionBySequenceNumber { + account: genesis_account, + sequence_number: 3, + fetch_events: false, + }, + RequestItem::GetAccountTransactionBySequenceNumber { + account: account1, + sequence_number: 0, + fetch_events: true, + }, + RequestItem::GetAccountTransactionBySequenceNumber { + account: account2, + sequence_number: 0, + fetch_events: false, + }, + RequestItem::GetAccountTransactionBySequenceNumber { + account: account1, + sequence_number: 1, + fetch_events: false, + }, + RequestItem::GetAccountState { address: account1 }, + RequestItem::GetAccountState { address: account2 }, + RequestItem::GetAccountState { address: account3 }, + RequestItem::GetTransactions { + start_version: 3, + limit: 10, + fetch_events: false, + }, + RequestItem::GetEventsByEventAccessPath { + access_path: AccessPath::new_for_sent_event(account1), + start_event_seq_num: 0, + ascending: true, + limit: 10, + }, + RequestItem::GetEventsByEventAccessPath { + access_path: AccessPath::new_for_sent_event(account2), + start_event_seq_num: 0, + ascending: true, + limit: 10, + }, + RequestItem::GetEventsByEventAccessPath { + access_path: AccessPath::new_for_sent_event(account3), + start_event_seq_num: 0, + ascending: true, + limit: 10, + }, + RequestItem::GetEventsByEventAccessPath { + access_path: AccessPath::new_for_received_event(account1), + start_event_seq_num: u64::max_value(), + ascending: false, + limit: 10, + }, + RequestItem::GetEventsByEventAccessPath { + access_path: AccessPath::new_for_received_event(account2), + start_event_seq_num: u64::max_value(), + ascending: false, + limit: 10, + }, + RequestItem::GetEventsByEventAccessPath { + access_path: AccessPath::new_for_received_event(account3), + start_event_seq_num: u64::max_value(), + ascending: false, + limit: 10, + }, + ]; + + let (mut response_items, ledger_info_with_sigs, _validator_change_events) = storage_read_client + .update_to_latest_ledger(/* client_known_version = */ 0, request_items.clone()) + .unwrap(); + verify_update_to_latest_ledger_response( + Arc::new(ValidatorVerifier::new_empty()), + 0, + &request_items, + &response_items, + &ledger_info_with_sigs, + ) + .unwrap(); + response_items.reverse(); + + let (t1, _) = response_items + .pop() + .unwrap() + .into_get_account_txn_by_seq_num_response() + .unwrap(); + verify_committed_txn_status(t1.as_ref(), &block1[0]).unwrap(); + + let (t2, _) = response_items + .pop() + .unwrap() + .into_get_account_txn_by_seq_num_response() + .unwrap(); + verify_committed_txn_status(t2.as_ref(), &block1[1]).unwrap(); + + let (t3, _) = response_items + .pop() + .unwrap() + .into_get_account_txn_by_seq_num_response() + .unwrap(); + verify_committed_txn_status(t3.as_ref(), &block1[2]).unwrap(); + + let (tn, pn) = response_items + .pop() + .unwrap() + .into_get_account_txn_by_seq_num_response() + .unwrap(); + verify_uncommitted_txn_status( + tn.as_ref(), + pn.as_ref(), + /* next_seq_num_of_this_account = */ 3, + ) + .unwrap(); + + let (t4, _) = response_items + .pop() + .unwrap() + .into_get_account_txn_by_seq_num_response() + .unwrap(); + verify_committed_txn_status(t4.as_ref(), &block1[3]).unwrap(); + // We requested the events to come back from this one, so verify that they did + assert_eq!(t4.unwrap().events.unwrap().len(), 2); + + let (t5, _) = response_items + .pop() + .unwrap() + .into_get_account_txn_by_seq_num_response() + .unwrap(); + verify_committed_txn_status(t5.as_ref(), &block1[4]).unwrap(); + + let (t6, _) = response_items + .pop() + .unwrap() + .into_get_account_txn_by_seq_num_response() + .unwrap(); + verify_committed_txn_status(t6.as_ref(), &block1[5]).unwrap(); + + let account1_state_with_proof = response_items + .pop() + .unwrap() + .into_get_account_state_response() + .unwrap(); + verify_account_balance(&account1_state_with_proof, |x| x < 191_000).unwrap(); + + let account2_state_with_proof = response_items + .pop() + .unwrap() + .into_get_account_state_response() + .unwrap(); + verify_account_balance(&account2_state_with_proof, |x| x < 21_000).unwrap(); + + let account3_state_with_proof = response_items + .pop() + .unwrap() + .into_get_account_state_response() + .unwrap(); + verify_account_balance(&account3_state_with_proof, |x| x == 18_000).unwrap(); + + let transaction_list_with_proof = response_items + .pop() + .unwrap() + .into_get_transactions_response() + .unwrap(); + verify_transactions(&transaction_list_with_proof, &block1[2..]).unwrap(); + + let (account1_sent_events, _) = response_items + .pop() + .unwrap() + .into_get_events_by_access_path_response() + .unwrap(); + assert_eq!(account1_sent_events.len(), 2); + + let (account2_sent_events, _) = response_items + .pop() + .unwrap() + .into_get_events_by_access_path_response() + .unwrap(); + assert_eq!(account2_sent_events.len(), 1); + + let (account3_sent_events, _) = response_items + .pop() + .unwrap() + .into_get_events_by_access_path_response() + .unwrap(); + assert_eq!(account3_sent_events.len(), 0); + + let (account1_received_events, _) = response_items + .pop() + .unwrap() + .into_get_events_by_access_path_response() + .unwrap(); + assert_eq!(account1_received_events.len(), 1); + + let (account2_received_events, _) = response_items + .pop() + .unwrap() + .into_get_events_by_access_path_response() + .unwrap(); + assert_eq!(account2_received_events.len(), 2); + + let (account3_received_events, _) = response_items + .pop() + .unwrap() + .into_get_events_by_access_path_response() + .unwrap(); + assert_eq!(account3_received_events.len(), 3); + + // Execution the 2nd block. + let execute_block_request = ExecuteBlockRequest::new(block2.clone(), block1_id, block2_id); + let execute_block_response = execution_client + .execute_block(execute_block_request) + .unwrap(); + let ledger_info_with_sigs = + gen_ledger_info_with_sigs(20, execute_block_response.root_hash(), block2_id); + execution_client + .commit_block(ledger_info_with_sigs) + .unwrap(); + + let request_items = vec![ + RequestItem::GetAccountTransactionBySequenceNumber { + account: account1, + sequence_number: 2, + fetch_events: false, + }, + RequestItem::GetAccountTransactionBySequenceNumber { + account: account1, + sequence_number: 15, + fetch_events: false, + }, + RequestItem::GetAccountState { address: account1 }, + RequestItem::GetAccountState { address: account3 }, + RequestItem::GetTransactions { + start_version: 7, + limit: 14, + fetch_events: false, + }, + RequestItem::GetEventsByEventAccessPath { + access_path: AccessPath::new_for_sent_event(account1), + start_event_seq_num: 0, + ascending: true, + limit: 10, + }, + RequestItem::GetEventsByEventAccessPath { + access_path: AccessPath::new_for_sent_event(account1), + start_event_seq_num: 10, + ascending: true, + limit: 10, + }, + RequestItem::GetEventsByEventAccessPath { + access_path: AccessPath::new_for_received_event(account3), + start_event_seq_num: u64::max_value(), + ascending: false, + limit: 10, + }, + RequestItem::GetEventsByEventAccessPath { + access_path: AccessPath::new_for_received_event(account3), + start_event_seq_num: 6, + ascending: false, + limit: 10, + }, + ]; + let (mut response_items, ledger_info_with_sigs, _validator_change_events) = storage_read_client + .update_to_latest_ledger(/* client_known_version = */ 0, request_items.clone()) + .unwrap(); + verify_update_to_latest_ledger_response( + Arc::new(ValidatorVerifier::new_empty()), + 0, + &request_items, + &response_items, + &ledger_info_with_sigs, + ) + .unwrap(); + response_items.reverse(); + + let (t7, _) = response_items + .pop() + .unwrap() + .into_get_account_txn_by_seq_num_response() + .unwrap(); + verify_committed_txn_status(t7.as_ref(), &block2[0]).unwrap(); + + let (t20, _) = response_items + .pop() + .unwrap() + .into_get_account_txn_by_seq_num_response() + .unwrap(); + verify_committed_txn_status(t20.as_ref(), &block2[13]).unwrap(); + + let account1_state_with_proof = response_items + .pop() + .unwrap() + .into_get_account_state_response() + .unwrap(); + verify_account_balance(&account1_state_with_proof, |x| x < 17_7000).unwrap(); + + let account3_state_with_proof = response_items + .pop() + .unwrap() + .into_get_account_state_response() + .unwrap(); + verify_account_balance(&account3_state_with_proof, |x| x == 32_000).unwrap(); + + let transaction_list_with_proof = response_items + .pop() + .unwrap() + .into_get_transactions_response() + .unwrap(); + verify_transactions(&transaction_list_with_proof, &block2[..]).unwrap(); + + let (account1_sent_events_batch1, _) = response_items + .pop() + .unwrap() + .into_get_events_by_access_path_response() + .unwrap(); + assert_eq!(account1_sent_events_batch1.len(), 10); + + let (account1_sent_events_batch2, _) = response_items + .pop() + .unwrap() + .into_get_events_by_access_path_response() + .unwrap(); + assert_eq!(account1_sent_events_batch2.len(), 6); + + let (account3_received_events_batch1, _) = response_items + .pop() + .unwrap() + .into_get_events_by_access_path_response() + .unwrap(); + assert_eq!(account3_received_events_batch1.len(), 10); + assert_eq!( + account3_received_events_batch1[0].event.sequence_number(), + 16 + ); + + let (account3_received_events_batch2, _) = response_items + .pop() + .unwrap() + .into_get_events_by_access_path_response() + .unwrap(); + assert_eq!(account3_received_events_batch2.len(), 7); + assert_eq!( + account3_received_events_batch2[0].event.sequence_number(), + 6 + ); +} + +fn verify_account_balance(account_state_with_proof: &AccountStateWithProof, f: F) -> Result<()> +where + F: Fn(u64) -> bool, +{ + let balance = get_account_resource_or_default(&account_state_with_proof.blob)?.balance(); + ensure!( + f(balance), + "balance {} doesn't satisfy the condition passed in", + balance + ); + Ok(()) +} + +fn verify_transactions( + txn_list_with_proof: &TransactionListWithProof, + expected_txns: &[SignedTransaction], +) -> Result<()> { + let txns = txn_list_with_proof + .transaction_and_infos + .iter() + .map(|(txn, _)| txn) + .cloned() + .collect::>(); + ensure!( + expected_txns == &txns[..], + "expected txns {:?} doesn't equal to returned txns {:?}", + expected_txns, + txns + ); + Ok(()) +} + +fn verify_committed_txn_status( + signed_txn_with_proof: Option<&SignedTransactionWithProof>, + expected_txn: &SignedTransaction, +) -> Result<()> { + let signed_txn = &signed_txn_with_proof + .ok_or_else(|| format_err!("Transaction is not commited."))? + .signed_transaction; + + ensure!( + expected_txn == signed_txn, + "The two transactions do not match. Expected txn: {:?}, returned txn: {:?}", + expected_txn, + signed_txn, + ); + + Ok(()) +} + +fn verify_uncommitted_txn_status( + signed_txn_with_proof: Option<&SignedTransactionWithProof>, + proof_of_current_sequence_number: Option<&AccountStateWithProof>, + expected_seq_num: u64, +) -> Result<()> { + ensure!( + signed_txn_with_proof.is_none(), + "Transaction is unexpectedly committed." + ); + + let proof_of_current_sequence_number = proof_of_current_sequence_number.ok_or_else(|| { + format_err!( + "proof_of_current_sequence_number should be provided when transaction is not committed." + ) + })?; + let seq_num_in_account = + get_account_resource_or_default(&proof_of_current_sequence_number.blob)?.sequence_number(); + + ensure!( + expected_seq_num == seq_num_in_account, + "expected_seq_num {} doesn't match that in account state \ + in TransactionStatus::Uncommmitted {}", + expected_seq_num, + seq_num_in_account, + ); + Ok(()) +} diff --git a/execution/execution_service/tests/test_helper.rs b/execution/execution_service/tests/test_helper.rs new file mode 100644 index 0000000000000..48a450e0cedbc --- /dev/null +++ b/execution/execution_service/tests/test_helper.rs @@ -0,0 +1,63 @@ +// Copyright (c) The Libra Core Contributors +// SPDX-License-Identifier: Apache-2.0 + +use config::config::NodeConfig; +use crypto::HashValue; +use execution_proto::proto::execution_grpc::create_execution; +use execution_service::ExecutionService; +use grpc_helpers::ServerHandle; +use grpcio::{EnvBuilder, ServerBuilder}; +use std::{collections::HashMap, sync::Arc}; +use storage_client::{StorageReadServiceClient, StorageWriteServiceClient}; +use storage_service::start_storage_service; +use types::ledger_info::{LedgerInfo, LedgerInfoWithSignatures}; + +pub fn gen_block_id(index: u8) -> HashValue { + HashValue::new([index; HashValue::LENGTH]) +} + +pub fn gen_ledger_info_with_sigs( + version: u64, + root_hash: HashValue, + commit_block_id: HashValue, +) -> LedgerInfoWithSignatures { + let ledger_info = LedgerInfo::new( + version, + root_hash, + /* consensus_data_hash = */ HashValue::zero(), + commit_block_id, + 0, + /* timestamp = */ 0, + ); + LedgerInfoWithSignatures::new(ledger_info, /* signatures = */ HashMap::new()) +} + +pub fn create_and_start_server(config: &NodeConfig) -> (ServerHandle, grpcio::Server) { + let storage_server_handle = start_storage_service(config); + + let client_env = Arc::new(EnvBuilder::new().build()); + let storage_read_client = Arc::new(StorageReadServiceClient::new( + Arc::clone(&client_env), + &config.storage.address, + config.storage.port, + )); + let storage_write_client = Arc::new(StorageWriteServiceClient::new( + Arc::clone(&client_env), + &config.storage.address, + config.storage.port, + )); + + let execution_service = create_execution(ExecutionService::new( + storage_read_client, + storage_write_client, + config, + )); + let mut execution_server = ServerBuilder::new(Arc::new(EnvBuilder::new().build())) + .register_service(execution_service) + .bind(config.execution.address.clone(), config.execution.port) + .build() + .expect("Failed to create execution server."); + execution_server.start(); + + (storage_server_handle, execution_server) +} diff --git a/execution/executor/Cargo.toml b/execution/executor/Cargo.toml new file mode 100644 index 0000000000000..c68bdeaa9b509 --- /dev/null +++ b/execution/executor/Cargo.toml @@ -0,0 +1,35 @@ +[package] +name = "executor" +version = "0.1.0" +authors = ["Libra Association "] +license = "Apache-2.0" +publish = false +edition = "2018" + +[dependencies] +backoff = "0.1.5" +futures = { version = "=0.3.0-alpha.16", package = "futures-preview" } +itertools = "0.8.0" +lazy_static = "1.3.0" + +config = { path = "../../config" } +crypto = { path = "../../crypto/legacy_crypto" } +execution_proto = { path = "../execution_proto" } +failure = { path = "../../common/failure_ext", package = "failure_ext" } +logger = { path = "../../common/logger" } +metrics = { path = "../../common/metrics" } +proto_conv = { path = "../../common/proto_conv" } +scratchpad = { path = "../../storage/scratchpad" } +state_view = { path = "../../storage/state_view" } +storage_client = { path = "../../storage/storage_client" } +types = { path = "../../types" } +vm_runtime = { path = "../../language/vm/vm_runtime" } +vm_genesis = { path = "../../language/vm/vm_genesis" } + +[dev-dependencies] +grpcio = "0.4.4" +proptest = "0.9.2" +rusty-fork = "0.2.1" + +storage_proto = { path = "../../storage/storage_proto" } +storage_service = { path = "../../storage/storage_service" } diff --git a/execution/executor/src/block_processor.rs b/execution/executor/src/block_processor.rs new file mode 100644 index 0000000000000..7e45ca2515699 --- /dev/null +++ b/execution/executor/src/block_processor.rs @@ -0,0 +1,860 @@ +// Copyright (c) The Libra Core Contributors +// SPDX-License-Identifier: Apache-2.0 + +use crate::{ + block_tree::{Block, BlockTree}, + transaction_block::{ProcessedVMOutput, TransactionBlock, TransactionData}, + Command, OP_COUNTERS, +}; +use backoff::{ExponentialBackoff, Operation}; +use config::config::VMConfig; +use crypto::{ + hash::{CryptoHash, EventAccumulatorHasher, TransactionAccumulatorHasher}, + HashValue, +}; +use execution_proto::{CommitBlockResponse, ExecuteBlockResponse, ExecuteChunkResponse}; +use failure::prelude::*; +use futures::channel::oneshot; +use logger::prelude::*; +use scratchpad::{Accumulator, ProofRead, SparseMerkleTree}; +use std::{ + collections::{hash_map, BTreeMap, HashMap, HashSet, VecDeque}, + convert::TryFrom, + marker::PhantomData, + rc::Rc, + sync::{mpsc, Arc}, +}; +use storage_client::{StorageRead, StorageWrite, VerifiedStateView}; +use types::{ + account_address::AccountAddress, + account_state_blob::AccountStateBlob, + ledger_info::LedgerInfoWithSignatures, + proof::SparseMerkleProof, + transaction::{ + SignedTransaction, TransactionInfo, TransactionListWithProof, TransactionOutput, + TransactionPayload, TransactionStatus, TransactionToCommit, Version, + }, + write_set::{WriteOp, WriteSet}, +}; +use vm_runtime::VMExecutor; + +#[derive(Debug)] +enum Mode { + Normal, + Syncing, +} + +pub(crate) struct BlockProcessor { + /// Where the processor receives commands. + command_receiver: mpsc::Receiver, + + /// The timestamp of the last committed ledger info. + committed_timestamp_usecs: u64, + + /// The in-memory Sparse Merkle Tree representing last committed state. This tree always has a + /// single Subtree node (or Empty node) whose hash equals the root hash of the newest Sparse + /// Merkle Tree in storage. + committed_state_tree: Rc, + + /// The in-memory Merkle Accumulator representing all the committed transactions. + committed_transaction_accumulator: Rc>, + + /// The main block tree data structure that holds all the uncommitted blocks in memory. + block_tree: BlockTree, + + /// The blocks that are ready to be sent to storage. After pruning `block_tree` we always put + /// the blocks here before sending them to storage, so in the case when storage is temporarily + /// unavailable, we will still prune `block_tree` as normal but blocks will stay here for a bit + /// longer. + blocks_to_store: VecDeque, + + /// Client to storage service. + storage_read_client: Arc, + storage_write_client: Arc, + + /// The current mode. If we are doing state synchronization, we will refuse to serve normal + /// execute_block and commit_block requests. + mode: Mode, + + /// Configuration for the VM. The block processor currently creates a new VM for each block. + vm_config: VMConfig, + + phantom: PhantomData, +} + +impl BlockProcessor +where + V: VMExecutor, +{ + /// Constructs a new `BlockProcessor`. + pub fn new( + command_receiver: mpsc::Receiver, + committed_timestamp_usecs: u64, + previous_state_root_hash: HashValue, + previous_frozen_subtrees_in_accumulator: Vec, + previous_num_elements_in_accumulator: u64, + last_committed_block_id: HashValue, + storage_read_client: Arc, + storage_write_client: Arc, + vm_config: VMConfig, + ) -> Self { + BlockProcessor { + command_receiver, + committed_timestamp_usecs, + committed_state_tree: Rc::new(SparseMerkleTree::new(previous_state_root_hash)), + committed_transaction_accumulator: Rc::new(Accumulator::new( + previous_frozen_subtrees_in_accumulator, + previous_num_elements_in_accumulator, + )), + block_tree: BlockTree::new(last_committed_block_id), + blocks_to_store: VecDeque::new(), + storage_read_client, + storage_write_client, + mode: Mode::Normal, + vm_config, + phantom: PhantomData, + } + } + + /// Keeps processing blocks until the command sender is disconnected. + pub fn run(&mut self) { + loop { + // Fetch and process all commands sent by consensus until there is no more left in the + // channel. + while let Ok(cmd) = self.command_receiver.try_recv() { + self.process_command(cmd); + } + + // Prune the block tree and check if there are eligible blocks ready to be sent to + // storage (the blocks that have finished execution and been marked as committed). This + // will move these blocks from the block tree to `self.blocks_to_store`. + // + // Note: If save_blocks_to_storage below fails, these blocks will stay in + // `self.blocks_to_store`. This is okay because consensus will not retry committing + // these blocks after it receives the errors. Instead it will try to commit a + // descendant block later, which will be found in the block tree and cause the entire + // chain to be saved if storage has recovered. (If consensus retries committing these + // moved blocks, we won't find these blocks in the block tree because we only look up + // the blocks in the block tree, so we will return an error.) + self.blocks_to_store + .extend(self.block_tree.prune().into_iter()); + if !self.blocks_to_store.is_empty() { + let time = std::time::Instant::now(); + let mut save_op = || { + self.save_blocks_to_storage().map_err(|err| { + error!("Failed to save blocks to storage: {}", err); + backoff::Error::Transient(err) + }) + }; + let mut backoff = Self::storage_retry_backoff(); + match save_op.retry(&mut backoff) { + Ok(()) => OP_COUNTERS + .observe("blocks_commit_time_us", time.elapsed().as_micros() as f64), + Err(_err) => crit!( + "Failed to save blocks to storage after trying for {} seconds.", + backoff.get_elapsed_time().as_secs(), + ), + } + } + + // If we do not have anything else to do, check if there is a block pending execution. + // Continue if this function made progress (executed one block). + if self.maybe_execute_block() { + continue; + } + + // In case the previous attempt to send blocks to storage failed, we want to retry + // instead of waiting for new command. + if !self.blocks_to_store.is_empty() { + continue; + } + + // We really have nothing to do. Just block the thread until consensus sends us new + // command. + match self.command_receiver.recv() { + Ok(cmd) => self.process_command(cmd), + Err(mpsc::RecvError) => break, + } + } + } + + /// Processes a single command from consensus. Note that this only modifies the block tree, the + /// actual block execution and commit may happen later. + fn process_command(&mut self, cmd: Command) { + match cmd { + Command::ExecuteBlock { + transactions, + parent_id, + id, + resp_sender, + } => { + if let Mode::Syncing = self.mode { + Self::send_error_when_syncing(resp_sender, id); + return; + } + + // If the block already exists, we simply store the sender via which the response + // will be sent when available. Otherwise construct a block and add to the block + // tree. + match self.block_tree.get_block_mut(id) { + Some(block) => { + warn!("Block {:x} already exists.", id); + block.queue_execute_block_response_sender(resp_sender); + } + None => { + let block = TransactionBlock::new(transactions, parent_id, id, resp_sender); + // If `add_block` errors, we return the error immediately. Otherwise the + // response will be returned once the block is executed. + if let Err(err) = self.block_tree.add_block(block) { + let resp = Err(format_err!("{}", err)); + let mut block = err.into_block(); + block.send_execute_block_response(resp); + } + } + } + } + Command::CommitBlock { + ledger_info_with_sigs, + resp_sender, + } => { + let id = ledger_info_with_sigs.ledger_info().consensus_block_id(); + if let Mode::Syncing = self.mode { + Self::send_error_when_syncing(resp_sender, id); + return; + } + + match self.block_tree.mark_as_committed(id, ledger_info_with_sigs) { + Ok(()) => { + let block = self + .block_tree + .get_block_mut(id) + .expect("Block must exist if mark_as_committed succeeded."); + // We have successfully marked the block as committed, but the real + // response will not be sent to consensus until the block is successfully + // persisted in storage. So we just save the sender in the block. + block.set_commit_response_sender(resp_sender); + } + Err(err) => resp_sender + .send(Err(format_err!("{}", err))) + .expect("Failed to send error message."), + } + } + Command::ExecuteChunk { + txn_list_with_proof, + ledger_info_with_sigs, + resp_sender, + } => { + let res = self + .execute_and_commit_chunk( + txn_list_with_proof.clone(), + ledger_info_with_sigs.clone(), + ) + .map_err(|e| { + security_log(SecurityEvent::InvalidChunkExecutor) + .error(&e) + .data(txn_list_with_proof) + .data(ledger_info_with_sigs) + .log(); + e + }); + resp_sender + .send(res.map(|_| ExecuteChunkResponse {})) + .expect("Failed to send execute chunk response."); + } + } + } + + fn send_error_when_syncing(resp_sender: oneshot::Sender>, id: HashValue) + where + T: std::fmt::Debug, + { + let message = format!("Syncing. Unable to serve request for block {:x}.", id); + warn!("{}", message); + resp_sender + .send(Err(format_err!("{}", message))) + .expect("Failed to send error message."); + } + + /// Verifies the transactions based on the provided proofs and ledger info. If the transactions + /// are valid, executes them and commits immediately if execution results match the proofs. + fn execute_and_commit_chunk( + &mut self, + txn_list_with_proof: TransactionListWithProof, + ledger_info_with_sigs: LedgerInfoWithSignatures, + ) -> Result<()> { + if ledger_info_with_sigs.ledger_info().timestamp_usecs() <= self.committed_timestamp_usecs { + warn!( + "Ledger info is too old: local timestamp: {}, timestamp in request: {}.", + self.committed_timestamp_usecs, + ledger_info_with_sigs.ledger_info().timestamp_usecs(), + ); + return Ok(()); + } + + if let Mode::Normal = self.mode { + self.mode = Mode::Syncing; + info!("Start syncing..."); + } + info!( + "Local version: {}. First transaction version in request: {:?}. \ + Number of transactions in request: {}.", + self.committed_transaction_accumulator.num_elements() - 1, + txn_list_with_proof.first_transaction_version, + txn_list_with_proof.transaction_and_infos.len(), + ); + + let (num_txns_to_skip, first_version) = + self.verify_chunk(&txn_list_with_proof, &ledger_info_with_sigs)?; + info!("Skipping the first {} transactions.", num_txns_to_skip); + let (transactions, infos): (Vec<_>, Vec<_>) = txn_list_with_proof + .transaction_and_infos + .into_iter() + .skip(num_txns_to_skip as usize) + .unzip(); + + // Construct a StateView and pass the transations to VM. + let db_root_hash = self.committed_state_tree.root_hash(); + let state_view = VerifiedStateView::new( + Arc::clone(&self.storage_read_client), + db_root_hash, + &self.committed_state_tree, + ); + let vm_outputs = { + let time = std::time::Instant::now(); + let out = V::execute_block(transactions.clone(), &self.vm_config, &state_view); + OP_COUNTERS.observe( + "vm_execute_block_time_us", + time.elapsed().as_micros() as f64, + ); + out + }; + + // Since other validators have committed these transactions, their status should all be + // TransactionStatus::Keep. + for output in &vm_outputs { + if let TransactionStatus::Discard(_) = output.status() { + bail!("Syncing transactions that should be discarded."); + } + } + + let (account_to_btree, account_to_proof) = state_view.into(); + let output = Self::process_vm_outputs( + account_to_btree, + account_to_proof, + &transactions, + vm_outputs, + Rc::clone(&self.committed_state_tree), + Rc::clone(&self.committed_transaction_accumulator), + )?; + + // Since we have verified the proofs, we just need to verify that each TransactionInfo + // object matches what we have computed locally. + let mut txns_to_commit = vec![]; + for ((txn, txn_data), (i, txn_info)) in itertools::zip_eq( + itertools::zip_eq(transactions, output.transaction_data()), + infos.into_iter().enumerate(), + ) { + ensure!( + txn_info.state_root_hash() == txn_data.state_root_hash(), + "State root hashes do not match for {}-th transaction in chunk.", + i, + ); + ensure!( + txn_info.event_root_hash() == txn_data.event_root_hash(), + "Event root hashes do not match for {}-th transaction in chunk.", + i, + ); + ensure!( + txn_info.gas_used() == txn_data.gas_used(), + "Gas used do not match for {}-th transaction in chunk.", + i, + ); + txns_to_commit.push(TransactionToCommit::new( + txn, + txn_data.account_blobs().clone(), + txn_data.events().to_vec(), + txn_data.gas_used(), + )); + } + + // If this is the last chunk corresponding to this ledger info, send the ledger info to + // storage. + let ledger_info_to_commit = if self.committed_transaction_accumulator.num_elements() + + txns_to_commit.len() as u64 + == ledger_info_with_sigs.ledger_info().version() + 1 + { + // We have constructed the transaction accumulator root and checked that it matches the + // given ledger info in the verification process above, so this check can possibly fail + // only when input transaction list is empty. + ensure!( + ledger_info_with_sigs + .ledger_info() + .transaction_accumulator_hash() + == output.clone_transaction_accumulator().root_hash(), + "Root hash in ledger info does not match local computation." + ); + Some(ledger_info_with_sigs) + } else { + None + }; + self.storage_write_client.save_transactions( + txns_to_commit, + first_version, + ledger_info_to_commit.clone(), + )?; + + self.committed_state_tree = output.clone_state_tree(); + self.committed_transaction_accumulator = output.clone_transaction_accumulator(); + if let Some(ledger_info_with_sigs) = ledger_info_to_commit { + self.committed_timestamp_usecs = ledger_info_with_sigs.ledger_info().timestamp_usecs(); + self.block_tree + .reset(ledger_info_with_sigs.ledger_info().consensus_block_id()); + self.mode = Mode::Normal; + info!( + "Synced to version {}.", + ledger_info_with_sigs.ledger_info().version() + ); + } + + Ok(()) + } + + /// Verifies the proofs using provided ledger info. Also verifies that the version of the first + /// transaction matches the lastest committed transaction. If the first few transaction happens + /// to be older, returns how many need to be skipped and the first version to be committed. + fn verify_chunk( + &self, + txn_list_with_proof: &TransactionListWithProof, + ledger_info_with_sigs: &LedgerInfoWithSignatures, + ) -> Result<(u64, Version)> { + txn_list_with_proof.verify( + ledger_info_with_sigs.ledger_info(), + txn_list_with_proof.first_transaction_version, + )?; + + let num_committed_txns = self.committed_transaction_accumulator.num_elements(); + if txn_list_with_proof.transaction_and_infos.is_empty() { + return Ok((0, num_committed_txns /* first_version */)); + } + + let first_txn_version = txn_list_with_proof + .first_transaction_version + .expect("first_transaction_version should exist."); + + ensure!( + first_txn_version <= num_committed_txns, + "Transaction list too new. Expected version: {}. First transaction version: {}.", + num_committed_txns, + first_txn_version + ); + Ok((num_committed_txns - first_txn_version, num_committed_txns)) + } + + /// If `save_blocks_to_storage` below fails, we retry based on this setting. + fn storage_retry_backoff() -> ExponentialBackoff { + let mut backoff = ExponentialBackoff::default(); + backoff.max_interval = std::time::Duration::from_secs(10); + backoff.max_elapsed_time = Some(std::time::Duration::from_secs(120)); + backoff + } + + /// Saves eligible blocks to persistent storage. If the blocks are successfully persisted, they + /// will be removed from `self.blocks_to_store` and the in-memory Sparse Merkle Trees in these + /// blocks will be pruned. Otherwise nothing happens. + /// + /// If we have multiple blocks and not all of them have signatures, we may send them to storage + /// in a few batches. For example, if we have + /// ```text + /// A <- B <- C <- D <- E + /// ``` + /// and only `C` and `E` have signatures, we will send `A`, `B` and `C` in the first batch, + /// then `D` and `E` later in the another batch. + fn save_blocks_to_storage(&mut self) -> Result<()> { + // The blocks we send to storage in this batch. In the above example, this means block A, B + // and C. + let mut block_batch = vec![]; + for block in &mut self.blocks_to_store { + let should_stop = block.ledger_info_with_sigs().is_some(); + block_batch.push(block); + if should_stop { + break; + } + } + assert!(!block_batch.is_empty()); + + // All transactions that need to go to storage. In the above example, this means all the + // transactions in A, B and C whose status == TransactionStatus::Keep. + let mut txns_to_commit = vec![]; + let mut num_accounts_created = 0; + for block in &block_batch { + for (txn, txn_data) in itertools::zip_eq( + block.transactions(), + block + .output() + .as_ref() + .expect("All blocks in self.blocks_to_store should have finished execution.") + .transaction_data(), + ) { + if let TransactionStatus::Keep(_) = txn_data.status() { + txns_to_commit.push(TransactionToCommit::new( + txn.clone(), + txn_data.account_blobs().clone(), + txn_data.events().to_vec(), + txn_data.gas_used(), + )); + num_accounts_created += txn_data.num_account_created(); + } + } + } + + let last_block = block_batch + .last_mut() + .expect("There must be at least one block with signatures."); + + // Check that the version in ledger info (computed by consensus) matches the version + // computed by us. TODO: we should also verify signatures and check that timestamp is + // strictly increasing. + let ledger_info_with_sigs = last_block + .ledger_info_with_sigs() + .as_ref() + .expect("This block must have signatures."); + let version = ledger_info_with_sigs.ledger_info().version(); + let num_txns_in_accumulator = last_block.clone_transaction_accumulator().num_elements(); + assert_eq!( + version + 1, + num_txns_in_accumulator, + "Number of transactions in ledger info ({}) does not match number of transactions \ + in accumulator ({}).", + version + 1, + num_txns_in_accumulator, + ); + + let num_txns_to_commit = txns_to_commit.len() as u64; + { + let time = std::time::Instant::now(); + self.storage_write_client.save_transactions( + txns_to_commit, + version + 1 - num_txns_to_commit, /* first_version */ + Some(ledger_info_with_sigs.clone()), + )?; + OP_COUNTERS.observe( + "storage_save_transactions_time_us", + time.elapsed().as_micros() as f64, + ); + } + // Only bump the counter when the commit succeeds. + OP_COUNTERS.inc_by("num_accounts", num_accounts_created); + + // Now that the blocks are persisted successfully, we can reply to consensus and update + // in-memory state. + self.committed_timestamp_usecs = ledger_info_with_sigs.ledger_info().timestamp_usecs(); + self.committed_state_tree = last_block.clone_state_tree(); + self.committed_transaction_accumulator = last_block.clone_transaction_accumulator(); + last_block.send_commit_block_response(Ok(CommitBlockResponse::Succeeded)); + + let num_saved = block_batch.len(); + for _i in 0..num_saved { + let block = self + .blocks_to_store + .pop_front() + .expect("self.blocks_to_store must have more blocks."); + let block_data = block + .output() + .as_ref() + .expect("All blocks in self.blocks_to_store should have output."); + for txn_data in block_data.transaction_data() { + txn_data.prune_state_tree(); + } + } + + Ok(()) + } + + /// Checks if there is a block in the tree ready for execution, if so run it by calling the VM. + /// Returns `true` if a block was successfully executed, `false` if there was no block to + /// execute. + fn maybe_execute_block(&mut self) -> bool { + let id = match self.block_tree.get_block_to_execute() { + Some(block_id) => block_id, + None => return false, + }; + + { + let time = std::time::Instant::now(); + self.execute_block(id); + OP_COUNTERS.observe("block_execute_time_us", time.elapsed().as_micros() as f64); + } + + true + } + + fn execute_block(&mut self, id: HashValue) { + let (previous_state_tree, previous_transaction_accumulator) = + self.get_trees_from_parent(id); + + let block_to_execute = self + .block_tree + .get_block_mut(id) + .expect("Block to execute should exist."); + + // Construct a StateView and pass the transations to VM. + let db_root_hash = self.committed_state_tree.root_hash(); + let state_view = VerifiedStateView::new( + Arc::clone(&self.storage_read_client), + db_root_hash, + &previous_state_tree, + ); + let vm_outputs = V::execute_block( + block_to_execute.transactions().to_vec(), + &self.vm_config, + &state_view, + ); + + let status: Vec<_> = vm_outputs + .iter() + .map(TransactionOutput::status) + .cloned() + .collect(); + if !status.is_empty() { + debug!("Execution status: {:?}", status); + } + + let (account_to_btree, account_to_proof) = state_view.into(); + match Self::process_vm_outputs( + account_to_btree, + account_to_proof, + block_to_execute.transactions(), + vm_outputs, + previous_state_tree, + previous_transaction_accumulator, + ) { + Ok(output) => { + let root_hash = output.clone_transaction_accumulator().root_hash(); + block_to_execute.set_output(output); + + // Now that we have the root hash and execution status we can send the response to + // consensus. + // TODO: The VM will support a special transaction to set the validators for the + // next epoch that is part of a block execution. + let execute_block_response = ExecuteBlockResponse::new(root_hash, status, None); + block_to_execute.set_execute_block_response(execute_block_response); + } + Err(err) => { + block_to_execute.send_execute_block_response(Err(format_err!( + "Failed to execute block: {}", + err + ))); + // If we failed to execute this block, remove the block and its descendants from + // the block tree. + self.block_tree.remove_subtree(id); + } + } + } + + /// Given id of the block that is about to be executed, returns the state tree and the + /// transaction accumulator at the end of the parent block. + fn get_trees_from_parent( + &self, + id: HashValue, + ) -> ( + Rc, + Rc>, + ) { + let parent_id = self + .block_tree + .get_block(id) + .expect("Block should exist.") + .parent_id(); + match self.block_tree.get_block(parent_id) { + Some(parent_block) => ( + parent_block.clone_state_tree(), + parent_block.clone_transaction_accumulator(), + ), + None => ( + Rc::clone(&self.committed_state_tree), + Rc::clone(&self.committed_transaction_accumulator), + ), + } + } + + /// Post-processing of what the VM outputs. Returns the entire block's output. + fn process_vm_outputs( + mut account_to_btree: HashMap, Vec>>, + account_to_proof: HashMap, + transactions: &[SignedTransaction], + vm_outputs: Vec, + previous_state_tree: Rc, + previous_transaction_accumulator: Rc>, + ) -> Result { + // The data of each individual transaction. For convenience purpose, even for the + // transactions that will be discarded, we will compute its in-memory Sparse Merkle Tree + // (it will be identical to the previous one). + let mut txn_data = vec![]; + let mut current_state_tree = previous_state_tree; + // The hash of each individual TransactionInfo object. This will not include the + // transactions that will be discarded, since they do not go into the transaction + // accumulator. + let mut txn_info_hashes = vec![]; + + let proof_reader = ProofReader::new(account_to_proof); + for (vm_output, signed_txn) in + itertools::zip_eq(vm_outputs.into_iter(), transactions.iter()) + { + let (blobs, state_tree, num_accounts_created) = Self::process_write_set( + signed_txn, + &mut account_to_btree, + &proof_reader, + vm_output.write_set().clone(), + ¤t_state_tree, + )?; + + let event_tree = Accumulator::::default() + .append(vm_output.events().iter().map(CryptoHash::hash).collect()); + + match vm_output.status() { + TransactionStatus::Keep(_) => { + ensure!( + !vm_output.write_set().is_empty(), + "Transaction with empty write set should be discarded.", + ); + // Compute hash for the TransactionInfo object. We need the hash of the + // transaction itself, the state root hash as well as the event root hash. + let txn_info = TransactionInfo::new( + signed_txn.hash(), + state_tree.root_hash(), + event_tree.root_hash(), + vm_output.gas_used(), + ); + txn_info_hashes.push(txn_info.hash()); + } + TransactionStatus::Discard(_) => { + ensure!( + vm_output.write_set().is_empty(), + "Discarded transaction has non-empty write set.", + ); + ensure!( + vm_output.events().is_empty(), + "Discarded transaction has non-empty events.", + ); + } + } + + txn_data.push(TransactionData::new( + blobs, + vm_output.events().to_vec(), + vm_output.status().clone(), + Rc::clone(&state_tree), + Rc::new(event_tree), + vm_output.gas_used(), + num_accounts_created, + )); + current_state_tree = state_tree; + } + + let current_transaction_accumulator = + previous_transaction_accumulator.append(txn_info_hashes); + Ok(ProcessedVMOutput::new( + txn_data, + Rc::new(current_transaction_accumulator), + current_state_tree, + )) + } + + /// For all accounts modified by this transaction, find the previous blob and update it based + /// on the write set. Returns the blob value of all these accounts as well as the newly + /// constructed state tree. + fn process_write_set( + transaction: &SignedTransaction, + account_to_btree: &mut HashMap, Vec>>, + proof_reader: &ProofReader, + write_set: WriteSet, + previous_state_tree: &SparseMerkleTree, + ) -> Result<( + HashMap, + Rc, + usize, /* num_account_created */ + )> { + let mut updated_blobs = HashMap::new(); + let mut num_accounts_created = 0; + + // Find all addresses this transaction touches while processing each write op. + let mut addrs = HashSet::new(); + for (access_path, write_op) in write_set.into_iter() { + let address = access_path.address; + let path = access_path.path; + match account_to_btree.entry(address) { + hash_map::Entry::Occupied(mut entry) => { + let account_btree = entry.get_mut(); + // TODO(gzh): we check account creation here for now. Will remove it once we + // have a better way. + if account_btree.is_empty() { + num_accounts_created += 1; + } + Self::update_account_btree(account_btree, path, write_op); + } + hash_map::Entry::Vacant(entry) => { + // Before writing to an account, VM should always read that account. So we + // should not reach this code path. The exception is genesis transaction (and + // maybe other FTVM transactions). + match transaction.payload() { + TransactionPayload::Program(_) => { + bail!("Write set should be a subset of read set.") + } + TransactionPayload::WriteSet(_) => (), + } + + let mut account_btree = BTreeMap::new(); + Self::update_account_btree(&mut account_btree, path, write_op); + entry.insert(account_btree); + } + } + addrs.insert(address); + } + + for addr in addrs { + let account_btree = account_to_btree.get(&addr).expect("Address should exist."); + let account_blob = AccountStateBlob::try_from(account_btree)?; + updated_blobs.insert(addr, account_blob); + } + let state_tree = Rc::new( + previous_state_tree + .update( + updated_blobs + .iter() + .map(|(addr, value)| (addr.hash(), value.clone())) + .collect(), + proof_reader, + ) + .expect("Failed to update state tree."), + ); + + Ok((updated_blobs, state_tree, num_accounts_created)) + } + + fn update_account_btree( + account_btree: &mut BTreeMap, Vec>, + path: Vec, + write_op: WriteOp, + ) { + match write_op { + WriteOp::Value(new_value) => account_btree.insert(path, new_value), + WriteOp::Deletion => account_btree.remove(&path), + }; + } +} + +struct ProofReader { + account_to_proof: HashMap, +} + +impl ProofReader { + fn new(account_to_proof: HashMap) -> Self { + ProofReader { account_to_proof } + } +} + +impl ProofRead for ProofReader { + fn get_proof(&self, key: HashValue) -> Option<&SparseMerkleProof> { + self.account_to_proof.get(&key) + } +} diff --git a/execution/executor/src/block_tree/block_tree_test.rs b/execution/executor/src/block_tree/block_tree_test.rs new file mode 100644 index 0000000000000..380064cb3d92f --- /dev/null +++ b/execution/executor/src/block_tree/block_tree_test.rs @@ -0,0 +1,581 @@ +// Copyright (c) The Libra Core Contributors +// SPDX-License-Identifier: Apache-2.0 + +#![allow(clippy::unit_arg)] + +use super::{AddBlockError, Block, CommitBlockError}; +use crypto::HashValue; +use std::collections::HashSet; + +#[derive(Clone, Debug, Eq, PartialEq)] +struct TestBlock { + committed: bool, + id: HashValue, + parent_id: HashValue, + children: HashSet, + output: Option<()>, + signature: Option<()>, +} + +impl Block for TestBlock { + type Output = (); + type Signature = (); + + fn is_committed(&self) -> bool { + self.committed + } + + fn set_committed(&mut self) { + assert!(!self.committed); + self.committed = true; + } + + fn is_executed(&self) -> bool { + self.output.is_some() + } + + fn set_output(&mut self, output: Self::Output) { + assert!(self.output.is_none()); + self.output = Some(output); + } + + fn set_signature(&mut self, signature: Self::Signature) { + assert!(self.signature.is_none()); + self.signature = Some(signature); + } + + fn id(&self) -> HashValue { + self.id + } + + fn parent_id(&self) -> HashValue { + self.parent_id + } + + fn add_child(&mut self, child_id: HashValue) { + assert!(self.children.insert(child_id)); + } + + fn children(&self) -> &HashSet { + &self.children + } +} + +impl TestBlock { + fn new(parent_id: HashValue, id: HashValue) -> Self { + TestBlock { + committed: false, + id, + parent_id, + children: HashSet::new(), + output: None, + signature: None, + } + } +} + +type BlockTree = super::BlockTree; + +fn id(i: u8) -> HashValue { + HashValue::new([i; HashValue::LENGTH]) +} + +fn create_block_tree() -> BlockTree { + // 0 ---> 1 ---> 2 + // | | + // | β””----> 3 ---> 4 + // | | + // | β””----> 5 + // | + // β””----> 6 ---> 7 ---> 8 + // | + // β””----> 9 ---> 10 + // | + // β””----> 11 + let mut block_tree = BlockTree::new(id(0)); + + block_tree.add_block(TestBlock::new(id(0), id(1))).unwrap(); + block_tree.add_block(TestBlock::new(id(1), id(2))).unwrap(); + block_tree.add_block(TestBlock::new(id(1), id(3))).unwrap(); + block_tree.add_block(TestBlock::new(id(3), id(4))).unwrap(); + block_tree.add_block(TestBlock::new(id(3), id(5))).unwrap(); + block_tree.add_block(TestBlock::new(id(0), id(6))).unwrap(); + block_tree.add_block(TestBlock::new(id(6), id(7))).unwrap(); + block_tree.add_block(TestBlock::new(id(7), id(8))).unwrap(); + block_tree.add_block(TestBlock::new(id(6), id(9))).unwrap(); + block_tree.add_block(TestBlock::new(id(9), id(10))).unwrap(); + block_tree.add_block(TestBlock::new(id(9), id(11))).unwrap(); + + block_tree +} + +#[test] +fn test_add_duplicate_block() { + let mut block_tree = create_block_tree(); + let block = TestBlock::new(id(1), id(7)); + let res = block_tree.add_block(block.clone()); + assert_eq!( + res.err().unwrap(), + AddBlockError::BlockAlreadyExists { block } + ); +} + +#[test] +fn test_add_block_missing_parent() { + let mut block_tree = create_block_tree(); + let block = TestBlock::new(id(99), id(200)); + let res = block_tree.add_block(block.clone()); + assert_eq!(res.err().unwrap(), AddBlockError::ParentNotFound { block }); +} + +fn assert_parent_and_children(block: &TestBlock, expected_parent: u8, expected_children: Vec) { + assert_eq!(block.parent_id, id(expected_parent)); + assert_eq!( + block.children, + expected_children + .into_iter() + .map(id) + .collect::>(), + ); +} + +fn assert_heads(block_tree: &BlockTree, expected_heads: Vec) { + assert_eq!(block_tree.heads.len(), expected_heads.len()); + assert_eq!( + block_tree.heads, + expected_heads.into_iter().map(id).collect::>(), + ); +} + +#[test] +fn test_add_block() { + let block_tree = create_block_tree(); + + assert_heads(&block_tree, vec![1, 6]); + assert_eq!(block_tree.last_committed_id, id(0)); + + for i in 1..=11 { + let block = block_tree.get_block(id(i)).unwrap(); + assert_eq!(block.id(), id(i)); + assert!(!block.is_committed()); + assert!(!block.is_executed()); + + match i { + 1 => assert_parent_and_children(block, 0, vec![2, 3]), + 2 => assert_parent_and_children(block, 1, vec![]), + 3 => assert_parent_and_children(block, 1, vec![4, 5]), + 4 => assert_parent_and_children(block, 3, vec![]), + 5 => assert_parent_and_children(block, 3, vec![]), + 6 => assert_parent_and_children(block, 0, vec![7, 9]), + 7 => assert_parent_and_children(block, 6, vec![8]), + 8 => assert_parent_and_children(block, 7, vec![]), + 9 => assert_parent_and_children(block, 6, vec![10, 11]), + 10 => assert_parent_and_children(block, 9, vec![]), + 11 => assert_parent_and_children(block, 9, vec![]), + _ => unreachable!(), + } + } +} + +#[test] +fn test_mark_as_committed_missing_block() { + let mut block_tree = create_block_tree(); + let res = block_tree.mark_as_committed(id(99), ()); + assert_eq!( + res.err().unwrap(), + CommitBlockError::BlockNotFound { id: id(99) } + ); +} + +#[test] +fn test_mark_as_committed_twice() { + let mut block_tree = create_block_tree(); + block_tree.mark_as_committed(id(9), ()).unwrap(); + let res = block_tree.mark_as_committed(id(9), ()); + assert_eq!( + res.err().unwrap(), + CommitBlockError::BlockAlreadyMarkedAsCommitted { id: id(9) } + ); +} + +#[test] +fn test_mark_as_committed_1_2() { + let mut block_tree = create_block_tree(); + + block_tree.mark_as_committed(id(1), ()).unwrap(); + for i in 1..=11 { + let block = block_tree.get_block(id(i)).unwrap(); + match i { + 1 => assert!(block.is_committed()), + _ => assert!(!block.is_committed()), + } + } + + block_tree.mark_as_committed(id(2), ()).unwrap(); + for i in 1..=11 { + let block = block_tree.get_block(id(i)).unwrap(); + match i { + 1 | 2 => assert!(block.is_committed()), + _ => assert!(!block.is_committed()), + } + } +} + +#[test] +fn test_mark_as_committed_2() { + let mut block_tree = create_block_tree(); + block_tree.mark_as_committed(id(2), ()).unwrap(); + for i in 1..=11 { + let block = block_tree.get_block(id(i)).unwrap(); + match i { + 1 | 2 => assert!(block.is_committed()), + _ => assert!(!block.is_committed()), + } + } +} + +#[test] +fn test_mark_as_committed_4() { + let mut block_tree = create_block_tree(); + block_tree.mark_as_committed(id(4), ()).unwrap(); + for i in 1..=11 { + let block = block_tree.get_block(id(i)).unwrap(); + match i { + 1 | 3 | 4 => assert!(block.is_committed()), + _ => assert!(!block.is_committed()), + } + } +} + +#[test] +fn test_mark_as_committed_3_5() { + let mut block_tree = create_block_tree(); + + block_tree.mark_as_committed(id(3), ()).unwrap(); + for i in 1..=11 { + let block = block_tree.get_block(id(i)).unwrap(); + match i { + 1 | 3 => assert!(block.is_committed()), + _ => assert!(!block.is_committed()), + } + } + + block_tree.mark_as_committed(id(5), ()).unwrap(); + for i in 1..=11 { + let block = block_tree.get_block(id(i)).unwrap(); + match i { + 1 | 3 | 5 => assert!(block.is_committed()), + _ => assert!(!block.is_committed()), + } + } +} + +#[test] +fn test_mark_as_committed_8() { + let mut block_tree = create_block_tree(); + block_tree.mark_as_committed(id(8), ()).unwrap(); + for i in 1..=11 { + let block = block_tree.get_block(id(i)).unwrap(); + match i { + 6 | 7 | 8 => assert!(block.is_committed()), + _ => assert!(!block.is_committed()), + } + } +} + +#[test] +fn test_mark_as_committed_11() { + let mut block_tree = create_block_tree(); + block_tree.mark_as_committed(id(11), ()).unwrap(); + for i in 1..=11 { + let block = block_tree.get_block(id(i)).unwrap(); + match i { + 6 | 9 | 11 => assert!(block.is_committed()), + _ => assert!(!block.is_committed()), + } + } +} + +#[test] +fn test_get_committed_head_one_committed() { + let mut block_tree = create_block_tree(); + block_tree.mark_as_committed(id(7), ()).unwrap(); + assert_eq!( + block_tree.get_committed_head(&block_tree.heads), + Some(id(6)), + ); +} + +#[test] +fn test_get_committed_head_all_pending() { + let block_tree = create_block_tree(); + assert_eq!(block_tree.get_committed_head(&block_tree.heads), None); +} + +#[test] +#[should_panic(expected = "Conflicting blocks are both committed.")] +fn test_get_committed_head_two_committed() { + let mut block_tree = create_block_tree(); + block_tree.mark_as_committed(id(2), ()).unwrap(); + block_tree.mark_as_committed(id(7), ()).unwrap(); + let _committed_head = block_tree.get_committed_head(&block_tree.heads); +} + +#[test] +#[should_panic(expected = "Trying to remove a committed block")] +fn test_remove_branch_committed_block() { + let mut block_tree = create_block_tree(); + block_tree.mark_as_committed(id(2), ()).unwrap(); + block_tree.remove_branch(id(1)); +} + +#[test] +fn test_remove_branch_1_7_11() { + let mut block_tree = create_block_tree(); + + block_tree.remove_branch(id(1)); + for i in 1..=11 { + let block = block_tree.get_block(id(i)); + match i { + 1 | 2 | 3 | 4 | 5 => assert!(block.is_none()), + _ => assert!(block.is_some()), + } + } + + block_tree.remove_branch(id(7)); + for i in 1..=11 { + let block = block_tree.get_block(id(i)); + match i { + 1 | 2 | 3 | 4 | 5 | 7 | 8 => assert!(block.is_none()), + _ => assert!(block.is_some()), + } + } + + block_tree.remove_branch(id(11)); + for i in 1..=11 { + let block = block_tree.get_block(id(i)); + match i { + 1 | 2 | 3 | 4 | 5 | 7 | 8 | 11 => assert!(block.is_none()), + _ => assert!(block.is_some()), + } + } +} + +fn set_executed(block_tree: &mut BlockTree, ids: &[HashValue]) { + for id in ids { + let block = block_tree.get_block_mut(*id).unwrap(); + block.set_output(()); + } +} + +fn assert_block_id_in_set(block_id: HashValue, candidates: Vec) { + let ids: Vec<_> = candidates.into_iter().map(id).collect(); + assert!(ids.contains(&block_id)); +} + +#[test] +fn test_get_block_to_execute() { + let mut block_tree = create_block_tree(); + let to_execute = block_tree.get_block_to_execute().unwrap(); + assert_block_id_in_set(to_execute, vec![1, 6]); + + set_executed(&mut block_tree, &[id(6)]); + let to_execute = block_tree.get_block_to_execute().unwrap(); + assert_block_id_in_set(to_execute, vec![1, 7, 9]); + + set_executed( + &mut block_tree, + &[ + id(1), + id(2), + id(3), + id(4), + id(5), + id(7), + id(8), + id(9), + id(10), + ], + ); + let to_execute = block_tree.get_block_to_execute().unwrap(); + assert_block_id_in_set(to_execute, vec![11]); + + set_executed(&mut block_tree, &[id(11)]); + assert!(block_tree.get_block_to_execute().is_none()); +} + +fn assert_to_store(to_store: &[TestBlock], ids: &[HashValue]) { + let committed_ids: Vec<_> = to_store.iter().map(Block::id).collect(); + assert_eq!(committed_ids, ids); +} + +#[test] +fn test_prune_1_executed() { + let mut block_tree = create_block_tree(); + set_executed(&mut block_tree, &[id(1)]); + block_tree.mark_as_committed(id(1), ()).unwrap(); + let to_store = block_tree.prune(); + + assert_to_store(&to_store, &[id(1)]); + assert_heads(&block_tree, vec![2, 3]); + assert_eq!(block_tree.last_committed_id, id(1)); + for i in 1..=11 { + let block = block_tree.get_block(id(i)); + match i { + 2 | 3 | 4 | 5 => assert!(block.is_some()), + _ => assert!(block.is_none()), + } + } +} + +#[test] +fn test_prune_1_not_executed() { + let mut block_tree = create_block_tree(); + block_tree.mark_as_committed(id(1), ()).unwrap(); + let to_store = block_tree.prune(); + + assert!(to_store.is_empty()); + assert_heads(&block_tree, vec![1]); + assert_eq!(block_tree.last_committed_id, id(0)); + for i in 1..=11 { + let block = block_tree.get_block(id(i)); + match i { + 1 | 2 | 3 | 4 | 5 => assert!(block.is_some()), + _ => assert!(block.is_none()), + } + } +} + +#[test] +fn test_prune_2_executed() { + let mut block_tree = create_block_tree(); + set_executed(&mut block_tree, &[id(1), id(2)]); + block_tree.mark_as_committed(id(2), ()).unwrap(); + let to_store = block_tree.prune(); + + assert_to_store(&to_store, &[id(1), id(2)]); + assert_eq!(block_tree.last_committed_id, id(2)); + assert!(block_tree.heads.is_empty()); + assert!(block_tree.id_to_block.is_empty()); +} + +#[test] +fn test_prune_2_not_all_executed() { + let mut block_tree = create_block_tree(); + set_executed(&mut block_tree, &[id(1)]); + block_tree.mark_as_committed(id(2), ()).unwrap(); + let to_store = block_tree.prune(); + + assert_to_store(&to_store, &[id(1)]); + assert_heads(&block_tree, vec![2]); + assert_eq!(block_tree.last_committed_id, id(1)); + for i in 1..=11 { + let block = block_tree.get_block(id(i)); + match i { + 2 => assert!(block.is_some()), + _ => assert!(block.is_none()), + } + } + + set_executed(&mut block_tree, &[id(2)]); + let to_store = block_tree.prune(); + assert_to_store(&to_store, &[id(2)]); + assert!(block_tree.heads.is_empty()); + assert!(block_tree.id_to_block.is_empty()); +} + +#[test] +fn test_prune_7_executed() { + let mut block_tree = create_block_tree(); + set_executed(&mut block_tree, &[id(6), id(7)]); + block_tree.mark_as_committed(id(7), ()).unwrap(); + let to_store = block_tree.prune(); + + assert_to_store(&to_store, &[id(6), id(7)]); + assert_heads(&block_tree, vec![8]); + assert_eq!(block_tree.last_committed_id, id(7)); + for i in 1..=11 { + let block = block_tree.get_block(id(i)); + match i { + 8 => assert!(block.is_some()), + _ => assert!(block.is_none()), + } + } +} + +#[test] +fn test_prune_9_not_executed() { + let mut block_tree = create_block_tree(); + block_tree.mark_as_committed(id(9), ()).unwrap(); + let to_store = block_tree.prune(); + + assert_to_store(&to_store, &[]); + assert_heads(&block_tree, vec![6]); + assert_eq!(block_tree.last_committed_id, id(0)); + for i in 1..=11 { + let block = block_tree.get_block(id(i)); + match i { + 6 | 9 | 10 | 11 => assert!(block.is_some()), + _ => assert!(block.is_none()), + } + } +} + +#[test] +fn test_prune_10_not_executed() { + let mut block_tree = create_block_tree(); + block_tree.mark_as_committed(id(10), ()).unwrap(); + let to_store = block_tree.prune(); + + assert_to_store(&to_store, &[]); + assert_heads(&block_tree, vec![6]); + assert_eq!(block_tree.last_committed_id, id(0)); + for i in 1..=11 { + let block = block_tree.get_block(id(i)); + match i { + 6 | 9 | 10 => assert!(block.is_some()), + _ => assert!(block.is_none()), + } + } +} + +#[test] +fn test_remove_subtree_1() { + let mut block_tree = create_block_tree(); + block_tree.remove_subtree(id(1)); + + assert_heads(&block_tree, vec![6]); + assert_eq!(block_tree.last_committed_id, id(0)); + for i in 1..=11 { + let block = block_tree.get_block(id(i)); + match i { + 1 | 2 | 3 | 4 | 5 => assert!(block.is_none()), + _ => assert!(block.is_some()), + } + } +} + +#[test] +fn test_remove_subtree_3() { + let mut block_tree = create_block_tree(); + block_tree.remove_subtree(id(3)); + + assert_heads(&block_tree, vec![1, 6]); + assert_eq!(block_tree.last_committed_id, id(0)); + for i in 1..=11 { + let block = block_tree.get_block(id(i)); + match i { + 3 | 4 | 5 => assert!(block.is_none()), + _ => assert!(block.is_some()), + } + } +} + +#[test] +fn test_reset() { + let mut block_tree = create_block_tree(); + block_tree.reset(id(100)); + assert!(block_tree.id_to_block.is_empty()); + assert!(block_tree.heads.is_empty()); + assert_eq!(block_tree.last_committed_id, id(100)); +} diff --git a/execution/executor/src/block_tree/mod.rs b/execution/executor/src/block_tree/mod.rs new file mode 100644 index 0000000000000..e5dac8a266339 --- /dev/null +++ b/execution/executor/src/block_tree/mod.rs @@ -0,0 +1,346 @@ +// Copyright (c) The Libra Core Contributors +// SPDX-License-Identifier: Apache-2.0 + +//! In a leader based consensus algorithm, each participant maintains a block tree that looks like +//! the following: +//! ```text +//! Height 5 6 7 ... +//! +//! Committed -> B5 -> B6 -> B7 +//! | +//! β””--> B5' -> B6' -> B7' +//! | +//! β””----> B7" +//! ``` +//! This module implements `BlockTree` that is an in-memory representation of this tree. + +#[cfg(test)] +mod block_tree_test; + +use crypto::HashValue; +use failure::bail_err; +use std::collections::{hash_map, HashMap, HashSet}; + +/// Each block has a unique identifier that is a `HashValue` computed by consensus. It has exactly +/// one parent and zero or more children. +pub trait Block: std::fmt::Debug { + /// The output of executing this block. + type Output; + + /// The signatures on this block. + type Signature; + + /// Whether consensus has decided to commit this block. This kind of blocks are expected to be + /// sent to storage very soon, unless execution is lagging behind. + fn is_committed(&self) -> bool; + + /// Marks this block as committed. + fn set_committed(&mut self); + + /// Whether this block has finished execution. + fn is_executed(&self) -> bool; + + /// Sets the output of this block. + fn set_output(&mut self, output: Self::Output); + + /// Sets the signatures for this block. + fn set_signature(&mut self, signature: Self::Signature); + + /// The id of this block. + fn id(&self) -> HashValue; + + /// The id of the parent block. + fn parent_id(&self) -> HashValue; + + /// Adds a block as its child. + fn add_child(&mut self, child_id: HashValue); + + /// The list of children of this block. + fn children(&self) -> &HashSet; +} + +/// The `BlockTree` implementation. +#[derive(Debug)] +pub struct BlockTree { + /// A map that keeps track of all existing blocks by their ids. + id_to_block: HashMap, + + /// The blocks at the lowest height in the map. B5 and B5' in the following example. + /// ```text + /// Committed(B0..4) -> B5 -> B6 -> B7 + /// | + /// β””--> B5' -> B6' -> B7' + /// | + /// β””----> B7" + /// ``` + heads: HashSet, + + /// Id of the last committed block. B4 in the above example. + last_committed_id: HashValue, +} + +impl BlockTree +where + B: Block, +{ + /// Constructs a new `BlockTree`. + pub fn new(last_committed_id: HashValue) -> Self { + BlockTree { + id_to_block: HashMap::new(), + heads: HashSet::new(), + last_committed_id, + } + } + + /// Adds a new block to the tree. + pub fn add_block(&mut self, block: B) -> Result<(), AddBlockError> { + assert!(!self.id_to_block.contains_key(&self.last_committed_id)); + + let id = block.id(); + if self.id_to_block.contains_key(&id) { + bail_err!(AddBlockError::BlockAlreadyExists { block }); + } + + let parent_id = block.parent_id(); + if parent_id == self.last_committed_id { + assert!(self.heads.insert(id), "Block already existed in heads."); + self.id_to_block.insert(id, block); + return Ok(()); + } + + match self.id_to_block.entry(parent_id) { + hash_map::Entry::Occupied(mut entry) => { + entry.get_mut().add_child(id); + assert!( + self.id_to_block.insert(id, block).is_none(), + "Block {:x} already existed.", + id, + ); + } + hash_map::Entry::Vacant(_) => bail_err!(AddBlockError::ParentNotFound { block }), + } + + Ok(()) + } + + /// Returns a reference to a specific block, if it exists in the tree. + pub fn get_block(&self, id: HashValue) -> Option<&B> { + self.id_to_block.get(&id) + } + + /// Returns a mutable reference to a specific block, if it exists in the tree. + pub fn get_block_mut(&mut self, id: HashValue) -> Option<&mut B> { + self.id_to_block.get_mut(&id) + } + + /// Returns id of a block that is ready to be sent to VM for execution (its parent has finished + /// execution), if such block exists in the tree. + pub fn get_block_to_execute(&mut self) -> Option { + let mut to_visit: Vec = self.heads.iter().cloned().collect(); + + while let Some(id) = to_visit.pop() { + let block = self + .id_to_block + .get(&id) + .expect("Missing block in id_to_block."); + if !block.is_executed() { + return Some(id); + } + to_visit.extend(block.children().iter().cloned()); + } + + None + } + + /// Marks given block and all its uncommitted ancestors as committed. This does not cause these + /// blocks to be sent to storage immediately. + pub fn mark_as_committed( + &mut self, + id: HashValue, + signature: B::Signature, + ) -> Result<(), CommitBlockError> { + // First put the signatures in the block. Note that if this causes multiple blocks to be + // marked as committed, only the last one will have the signatures. + match self.id_to_block.get_mut(&id) { + Some(block) => { + if block.is_committed() { + bail_err!(CommitBlockError::BlockAlreadyMarkedAsCommitted { id }); + } else { + block.set_signature(signature); + } + } + None => bail_err!(CommitBlockError::BlockNotFound { id }), + } + + // Mark the current block as committed. Go to parent block and repeat until a committed + // block is found, or no more blocks. + let mut current_id = id; + while let Some(block) = self.id_to_block.get_mut(¤t_id) { + if block.is_committed() { + break; + } + + block.set_committed(); + current_id = block.parent_id(); + } + + Ok(()) + } + + /// Removes all blocks in the tree that conflict with committed blocks. Returns a list of + /// blocks that are ready to be sent to storage (all the committed blocks that have been + /// executed). + pub fn prune(&mut self) -> Vec { + let mut blocks_to_store = vec![]; + + // First find if there is a committed block in current heads. Since these blocks are at the + // same height, at most one of them can be committed. If all of them are pending we have + // nothing to do here. Otherwise, one of the branches is committed. Throw away the rest of + // them and advance to the next height. + let mut current_heads = self.heads.clone(); + while let Some(committed_head) = self.get_committed_head(¤t_heads) { + assert!( + current_heads.remove(&committed_head), + "committed_head should exist.", + ); + for id in current_heads { + self.remove_branch(id); + } + + match self.id_to_block.entry(committed_head) { + hash_map::Entry::Occupied(entry) => { + current_heads = entry.get().children().clone(); + let current_id = *entry.key(); + let parent_id = entry.get().parent_id(); + if entry.get().is_executed() { + // If this block has been executed, all its proper ancestors must have + // finished execution and present in `blocks_to_store`. + self.heads = current_heads.clone(); + self.last_committed_id = current_id; + blocks_to_store.push(entry.remove()); + } else { + // The current block has not finished execution. If the parent block does + // not exist in the map, that means parent block (also committed) has been + // executed and removed. Otherwise self.heads does not need to be changed. + if !self.id_to_block.contains_key(&parent_id) { + self.heads = HashSet::new(); + self.heads.insert(current_id); + } + } + } + hash_map::Entry::Vacant(_) => unreachable!("committed_head_id should exist."), + } + } + + blocks_to_store + } + + /// Given a list of heads, returns the committed one if it exists. + fn get_committed_head(&self, heads: &HashSet) -> Option { + let mut committed_head = None; + for head in heads { + let block = self + .id_to_block + .get(head) + .expect("Head should exist in id_to_block."); + if block.is_committed() { + assert!( + committed_head.is_none(), + "Conflicting blocks are both committed.", + ); + committed_head = Some(*head); + } + } + committed_head + } + + /// Removes a branch at block `head`. + fn remove_branch(&mut self, head: HashValue) { + let mut remaining = vec![head]; + while let Some(current_block_id) = remaining.pop() { + let block = self + .id_to_block + .remove(¤t_block_id) + .unwrap_or_else(|| { + panic!( + "Trying to remove a non-existing block {:x}.", + current_block_id, + ) + }); + assert!( + !block.is_committed(), + "Trying to remove a committed block {:x}.", + current_block_id, + ); + remaining.extend(block.children().iter()); + } + } + + /// Removes the entire subtree at block `id`. + pub fn remove_subtree(&mut self, id: HashValue) { + self.heads.remove(&id); + self.remove_branch(id); + } + + /// Resets the block tree with a new `last_committed_id`. This removes all the in-memory + /// blocks. + pub fn reset(&mut self, last_committed_id: HashValue) { + let mut new_block_tree = BlockTree::new(last_committed_id); + std::mem::swap(self, &mut new_block_tree); + } +} + +/// An error retured by `add_block`. The error contains the block being added so the caller does +/// not lose it. +#[derive(Debug, Eq, PartialEq)] +pub enum AddBlockError { + ParentNotFound { block: B }, + BlockAlreadyExists { block: B }, +} + +impl AddBlockError +where + B: Block, +{ + pub fn into_block(self) -> B { + match self { + AddBlockError::ParentNotFound { block } => block, + AddBlockError::BlockAlreadyExists { block } => block, + } + } +} + +impl std::fmt::Display for AddBlockError +where + B: Block, +{ + fn fmt(&self, f: &mut std::fmt::Formatter) -> std::fmt::Result { + match self { + AddBlockError::ParentNotFound { block } => { + write!(f, "Parent block {:x} was not found.", block.parent_id()) + } + AddBlockError::BlockAlreadyExists { block } => { + write!(f, "Block {:x} already exists.", block.id()) + } + } + } +} + +/// An error returned by `mark_as_committed`. The error contains id of the block the caller wants +/// to commit. +#[derive(Debug, Eq, PartialEq)] +pub enum CommitBlockError { + BlockNotFound { id: HashValue }, + BlockAlreadyMarkedAsCommitted { id: HashValue }, +} + +impl std::fmt::Display for CommitBlockError { + fn fmt(&self, f: &mut std::fmt::Formatter) -> std::fmt::Result { + match self { + CommitBlockError::BlockNotFound { id } => write!(f, "Block {:x} was not found.", id), + CommitBlockError::BlockAlreadyMarkedAsCommitted { id } => { + write!(f, "Block {:x} was already marked as committed.", id) + } + } + } +} diff --git a/execution/executor/src/executor_test.rs b/execution/executor/src/executor_test.rs new file mode 100644 index 0000000000000..eaeb011998116 --- /dev/null +++ b/execution/executor/src/executor_test.rs @@ -0,0 +1,534 @@ +// Copyright (c) The Libra Core Contributors +// SPDX-License-Identifier: Apache-2.0 + +use crate::{ + mock_vm::{ + encode_mint_transaction, encode_transfer_transaction, MockVM, DISCARD_STATUS, KEEP_STATUS, + }, + Executor, OP_COUNTERS, +}; +use config::config::{NodeConfig, NodeConfigHelpers}; +use crypto::{hash::GENESIS_BLOCK_ID, HashValue}; +use futures::executor::block_on; +use grpcio::{EnvBuilder, ServerBuilder}; +use proptest::prelude::*; +use proto_conv::IntoProtoBytes; +use rusty_fork::{rusty_fork_id, rusty_fork_test, rusty_fork_test_name}; +use std::{ + collections::HashMap, + fs::File, + io::Write, + sync::{mpsc, Arc}, +}; +use storage_client::{StorageRead, StorageReadServiceClient, StorageWriteServiceClient}; +use storage_proto::proto::storage_grpc::create_storage; +use storage_service::StorageService; +use types::{ + account_address::{AccountAddress, ADDRESS_LENGTH}, + ledger_info::{LedgerInfo, LedgerInfoWithSignatures}, + transaction::SignedTransaction, +}; +use vm_genesis::{encode_genesis_transaction, GENESIS_KEYPAIR}; + +fn get_config() -> NodeConfig { + let config = NodeConfigHelpers::get_single_node_test_config(true); + // Write out the genesis blob to the correct location. + // XXX Should this logic live in NodeConfigHelpers? + let genesis_txn = encode_genesis_transaction(&GENESIS_KEYPAIR.0, GENESIS_KEYPAIR.1); + let mut file = File::create(&config.execution.genesis_file_location).unwrap(); + file.write_all(&genesis_txn.into_proto_bytes().unwrap()) + .unwrap(); + + config +} + +fn create_storage_server(config: &mut NodeConfig) -> (grpcio::Server, mpsc::Receiver<()>) { + let (service, shutdown_receiver) = StorageService::new(&config.storage.get_dir()); + let mut server = ServerBuilder::new(Arc::new(EnvBuilder::new().build())) + .register_service(create_storage(service)) + .bind("localhost", 0) + .build() + .expect("Failed to create storage server."); + server.start(); + + assert_eq!(server.bind_addrs().len(), 1); + let (_, port) = server.bind_addrs()[0]; + + // This is a little messy -- technically the config should also be used to set up the storage + // server, but the code currently creates the storage server, binds it to a port, then sets up + // the config. + // XXX Clean this up a little. + config.storage.port = port; + + (server, shutdown_receiver) +} + +fn create_executor(config: &NodeConfig) -> Executor { + let client_env = Arc::new(EnvBuilder::new().build()); + let read_client = Arc::new(StorageReadServiceClient::new( + Arc::clone(&client_env), + "localhost", + config.storage.port, + )); + let write_client = Arc::new(StorageWriteServiceClient::new( + Arc::clone(&client_env), + "localhost", + config.storage.port, + )); + Executor::new(read_client, write_client, config) +} + +fn execute_and_commit_block(executor: &TestExecutor, txn_index: u64) { + let txn = encode_mint_transaction(gen_address(txn_index), 100); + let parent_block_id = match txn_index { + 0 => *GENESIS_BLOCK_ID, + x => gen_block_id(x), + }; + let id = gen_block_id(txn_index + 1); + + let response = block_on(executor.execute_block(vec![txn], parent_block_id, id)) + .unwrap() + .unwrap(); + + let ledger_info = gen_ledger_info(txn_index + 1, response.root_hash(), id, txn_index + 1); + block_on(executor.commit_block(ledger_info)) + .unwrap() + .unwrap(); +} + +struct TestExecutor { + // The config is kept around because it owns the temp dir used in the test. + _config: NodeConfig, + storage_server: Option, + shutdown_receiver: mpsc::Receiver<()>, + executor: Executor, +} + +impl TestExecutor { + fn new() -> TestExecutor { + let mut config = get_config(); + let (storage_server, shutdown_receiver) = create_storage_server(&mut config); + let executor = create_executor(&config); + + TestExecutor { + _config: config, + storage_server: Some(storage_server), + shutdown_receiver, + executor, + } + } +} + +impl std::ops::Deref for TestExecutor { + type Target = Executor; + + fn deref(&self) -> &Self::Target { + &self.executor + } +} + +impl Drop for TestExecutor { + fn drop(&mut self) { + self.storage_server + .take() + .expect("Storage server should exist."); + self.shutdown_receiver.recv().unwrap(); + } +} + +fn gen_address(index: u64) -> AccountAddress { + let bytes = index.to_be_bytes(); + let mut buf = [0; ADDRESS_LENGTH]; + buf[ADDRESS_LENGTH - 8..].copy_from_slice(&bytes); + AccountAddress::new(buf) +} + +fn gen_block_id(index: u64) -> HashValue { + let bytes = index.to_be_bytes(); + let mut buf = [0; HashValue::LENGTH]; + buf[HashValue::LENGTH - 8..].copy_from_slice(&bytes); + HashValue::new(buf) +} + +fn gen_ledger_info( + version: u64, + root_hash: HashValue, + commit_block_id: HashValue, + timestamp_usecs: u64, +) -> LedgerInfoWithSignatures { + let ledger_info = LedgerInfo::new( + version, + root_hash, + /* consensus_data_hash = */ HashValue::zero(), + commit_block_id, + /* epoch_num = */ 0, + timestamp_usecs, + ); + LedgerInfoWithSignatures::new(ledger_info, /* signatures = */ HashMap::new()) +} + +#[test] +fn test_executor_status() { + let executor = TestExecutor::new(); + + let txn0 = encode_mint_transaction(gen_address(0), 100); + let txn1 = encode_mint_transaction(gen_address(1), 100); + let txn2 = encode_transfer_transaction(gen_address(0), gen_address(1), 500); + + let parent_block_id = *GENESIS_BLOCK_ID; + let block_id = gen_block_id(1); + + let response = + block_on(executor.execute_block(vec![txn0, txn1, txn2], parent_block_id, block_id)) + .unwrap() + .unwrap(); + + assert_eq!( + vec![KEEP_STATUS, KEEP_STATUS, DISCARD_STATUS], + response.status() + ); +} + +#[test] +fn test_executor_one_block() { + let executor = TestExecutor::new(); + + let parent_block_id = *GENESIS_BLOCK_ID; + let block_id = gen_block_id(1); + let version = 100; + + let txns = (0..version) + .map(|i| encode_mint_transaction(gen_address(i), 100)) + .collect(); + let execute_block_future = executor.execute_block(txns, parent_block_id, block_id); + let execute_block_response = block_on(execute_block_future).unwrap().unwrap(); + + let ledger_info = gen_ledger_info(version, execute_block_response.root_hash(), block_id, 1); + let commit_block_future = executor.commit_block(ledger_info); + let _commit_block_response = block_on(commit_block_future).unwrap().unwrap(); +} + +#[test] +fn test_executor_multiple_blocks() { + let executor = TestExecutor::new(); + + for i in 0..100 { + execute_and_commit_block(&executor, i) + } +} + +#[test] +fn test_executor_execute_same_block_multiple_times() { + let parent_block_id = *GENESIS_BLOCK_ID; + let block_id = gen_block_id(1); + let version = 100; + + let txns: Vec<_> = (0..version) + .map(|i| encode_mint_transaction(gen_address(i), 100)) + .collect(); + + { + let executor = TestExecutor::new(); + let mut responses = vec![]; + for _i in 0..100 { + let execute_block_future = + executor.execute_block(txns.clone(), parent_block_id, block_id); + let execute_block_response = block_on(execute_block_future).unwrap().unwrap(); + responses.push(execute_block_response); + } + responses.dedup(); + assert_eq!(responses.len(), 1); + } + { + let executor = TestExecutor::new(); + let mut futures = vec![]; + for _i in 0..100 { + let execute_block_future = + executor.execute_block(txns.clone(), parent_block_id, block_id); + futures.push(execute_block_future); + } + let mut responses: Vec<_> = futures + .into_iter() + .map(|fut| block_on(fut).unwrap().unwrap()) + .collect(); + responses.dedup(); + assert_eq!(responses.len(), 1); + } +} + +rusty_fork_test! { + #[test] + fn test_num_accounts_created_counter() { + let executor = TestExecutor::new(); + for i in 0..20 { + execute_and_commit_block(&executor, i); + assert_eq!(OP_COUNTERS.counter("num_accounts").get() as u64, i + 1); + } + } +} + +#[test] +fn test_executor_execute_chunk() { + let first_batch_size = 30; + let second_batch_size = 40; + let third_batch_size = 20; + let overlapping_size = 5; + + // To obtain the two batches of transactions, we first execute and save these transactions in a + // separate DB. Then we call get_transactions to retrieve them. + let (first_batch, second_batch, third_batch, ledger_info) = { + let mut config = get_config(); + let (storage_server, shutdown_receiver) = create_storage_server(&mut config); + let executor = create_executor(&config); + + let mut txns = vec![]; + for i in 0..first_batch_size + second_batch_size + third_batch_size - overlapping_size { + let txn = encode_mint_transaction(gen_address(i), 100); + txns.push(txn); + } + let id = gen_block_id(1); + + let response = block_on(executor.execute_block(txns.clone(), *GENESIS_BLOCK_ID, id)) + .unwrap() + .unwrap(); + let ledger_version = txns.len() as u64; + let ledger_info = gen_ledger_info(ledger_version, response.root_hash(), id, 1); + block_on(executor.commit_block(ledger_info.clone())) + .unwrap() + .unwrap(); + + let storage_client = StorageReadServiceClient::new( + Arc::new(EnvBuilder::new().build()), + "localhost", + config.storage.port, + ); + let first_batch = storage_client + .get_transactions( + /* start_version = */ 1, + first_batch_size, + ledger_version, + false, /* fetch_events */ + ) + .unwrap(); + let second_batch = storage_client + .get_transactions( + /* start_version = */ first_batch_size + 1, + second_batch_size, + ledger_version, + false, /* fetch_events */ + ) + .unwrap(); + let third_batch = storage_client + .get_transactions( + /* start_version = */ + first_batch_size + second_batch_size + 1 - overlapping_size, + third_batch_size, + ledger_version, + false, /* fetch_events */ + ) + .unwrap(); + + drop(storage_server); + shutdown_receiver.recv().unwrap(); + + (first_batch, second_batch, third_batch, ledger_info) + }; + + // Now we execute these two chunks of transactions. + let mut config = get_config(); + let (storage_server, shutdown_receiver) = create_storage_server(&mut config); + let executor = create_executor(&config); + let storage_client = StorageReadServiceClient::new( + Arc::new(EnvBuilder::new().build()), + "localhost", + config.storage.port, + ); + + // Execute the first chunk. After that we should still get the genesis ledger info from DB. + block_on(executor.execute_chunk(first_batch, ledger_info.clone())) + .unwrap() + .unwrap(); + let (_, li, _) = storage_client.update_to_latest_ledger(0, vec![]).unwrap(); + assert_eq!(li.ledger_info().version(), 0); + assert_eq!(li.ledger_info().consensus_block_id(), *GENESIS_BLOCK_ID); + + // Execute the second chunk. After that we should still get the genesis ledger info from DB. + block_on(executor.execute_chunk(second_batch, ledger_info.clone())) + .unwrap() + .unwrap(); + let (_, li, _) = storage_client.update_to_latest_ledger(0, vec![]).unwrap(); + assert_eq!(li.ledger_info().version(), 0); + assert_eq!(li.ledger_info().consensus_block_id(), *GENESIS_BLOCK_ID); + + // Execute the third chunk. After that we should get the new ledger info. + block_on(executor.execute_chunk(third_batch, ledger_info.clone())) + .unwrap() + .unwrap(); + let (_, li, _) = storage_client.update_to_latest_ledger(0, vec![]).unwrap(); + assert_eq!(li, ledger_info); + + drop(storage_server); + shutdown_receiver.recv().unwrap(); +} + +struct TestBlock { + txns: Vec, + parent_id: HashValue, + id: HashValue, +} + +impl TestBlock { + fn new( + addr_index: std::ops::Range, + amount: u32, + parent_id: HashValue, + id: HashValue, + ) -> Self { + TestBlock { + txns: addr_index + .map(|index| encode_mint_transaction(gen_address(index), u64::from(amount))) + .collect(), + parent_id, + id, + } + } +} + +// Executes a list of transactions by executing and immediately commtting one at a time. Returns +// the root hash after all transactions are committed. +fn run_transactions_naive(transactions: Vec) -> HashValue { + let executor = TestExecutor::new(); + let mut iter = transactions.into_iter(); + let first_txn = iter.next(); + let response = block_on(executor.execute_block( + match first_txn { + None => vec![], + Some(txn) => vec![txn], + }, + *GENESIS_BLOCK_ID, + gen_block_id(1), + )) + .unwrap() + .unwrap(); + let mut root_hash = response.root_hash(); + + for (i, txn) in iter.enumerate() { + let parent_block_id = gen_block_id(i as u64 + 1); + // when i = 0, id should be 2. + let id = gen_block_id(i as u64 + 2); + let response = block_on(executor.execute_block(vec![txn], parent_block_id, id)) + .unwrap() + .unwrap(); + + root_hash = response.root_hash(); + let ledger_info = gen_ledger_info(i as u64 + 2, root_hash, id, i as u64 + 1); + block_on(executor.commit_block(ledger_info)) + .unwrap() + .unwrap(); + } + root_hash +} + +proptest! { + #![proptest_config(ProptestConfig::with_cases(10))] + + #[test] + fn test_executor_two_branches( + a_size in 0..30u64, + b_size in 0..30u64, + c_size in 0..30u64, + amount in any::(), + ) { + // Genesis -> A -> B + // | + // β””--> C + let block_a = TestBlock::new(0..a_size, amount, *GENESIS_BLOCK_ID, gen_block_id(1)); + let block_b = TestBlock::new(0..b_size, amount, gen_block_id(1), gen_block_id(2)); + let block_c = TestBlock::new(0..c_size, amount, gen_block_id(1), gen_block_id(3)); + + // Execute block A, B and C. Hold all results in memory. + let executor = TestExecutor::new(); + + let response_a = block_on(executor.execute_block( + block_a.txns.clone(), block_a.parent_id, block_a.id, + )).unwrap().unwrap(); + let response_b = block_on(executor.execute_block( + block_b.txns.clone(), block_b.parent_id, block_b.id, + )).unwrap().unwrap(); + let response_c = block_on(executor.execute_block( + block_c.txns.clone(), block_c.parent_id, block_c.id, + )).unwrap().unwrap(); + + let root_hash_a = response_a.root_hash(); + let root_hash_b = response_b.root_hash(); + let root_hash_c = response_c.root_hash(); + + // Execute block A and B. Execute and commit one transaction at a time. + let expected_root_hash_a = run_transactions_naive(block_a.txns.clone()); + prop_assert_eq!(root_hash_a, expected_root_hash_a); + + let expected_root_hash_b = run_transactions_naive({ + let mut txns = vec![]; + txns.extend(block_a.txns.iter().cloned()); + txns.extend(block_b.txns.iter().cloned()); + txns + }); + prop_assert_eq!(root_hash_b, expected_root_hash_b); + + let expected_root_hash_c = run_transactions_naive({ + let mut txns = vec![]; + txns.extend(block_a.txns.iter().cloned()); + txns.extend(block_c.txns.iter().cloned()); + txns + }); + prop_assert_eq!(root_hash_c, expected_root_hash_c); + } + + #[test] + fn test_executor_restart(a_size in 0..30u64, b_size in 0..30u64, amount in any::()) { + let block_a = TestBlock::new(0..a_size, amount, *GENESIS_BLOCK_ID, gen_block_id(1)); + let block_b = TestBlock::new(0..b_size, amount, gen_block_id(1), gen_block_id(2)); + + let mut config = get_config(); + let (storage_server, shutdown_receiver) = create_storage_server(&mut config); + + // First execute and commit one block, then destroy executor. + { + let executor = create_executor(&config); + let response_a = block_on(executor.execute_block( + block_a.txns.clone(), block_a.parent_id, block_a.id, + )).unwrap().unwrap(); + let root_hash = response_a.root_hash(); + let ledger_info = gen_ledger_info(block_a.txns.len() as u64, root_hash, block_a.id, 1); + block_on(executor.commit_block(ledger_info)).unwrap().unwrap(); + } + + // Now we construt a new executor and run one more block. + let root_hash = { + let executor = create_executor(&config); + let response_b = block_on(executor.execute_block( + block_b.txns.clone(), block_b.parent_id, block_b.id, + )).unwrap().unwrap(); + let root_hash = response_b.root_hash(); + let ledger_info = gen_ledger_info( + (block_a.txns.len() + block_b.txns.len()) as u64, + root_hash, + block_b.id, + 2, + ); + block_on(executor.commit_block(ledger_info)).unwrap().unwrap(); + root_hash + }; + + let expected_root_hash = run_transactions_naive({ + let mut txns = vec![]; + txns.extend(block_a.txns.iter().cloned()); + txns.extend(block_b.txns.iter().cloned()); + txns + }); + prop_assert_eq!(root_hash, expected_root_hash); + + drop(storage_server); + shutdown_receiver.recv().unwrap(); + } +} diff --git a/execution/executor/src/lib.rs b/execution/executor/src/lib.rs new file mode 100644 index 0000000000000..ff98aeee5f2e4 --- /dev/null +++ b/execution/executor/src/lib.rs @@ -0,0 +1,298 @@ +// Copyright (c) The Libra Core Contributors +// SPDX-License-Identifier: Apache-2.0 + +mod block_processor; +mod block_tree; +mod transaction_block; + +#[cfg(test)] +mod executor_test; +#[cfg(test)] +mod mock_vm; + +use crate::block_processor::BlockProcessor; +use config::config::NodeConfig; +use crypto::{ + hash::{GENESIS_BLOCK_ID, PRE_GENESIS_BLOCK_ID, SPARSE_MERKLE_PLACEHOLDER_HASH}, + HashValue, +}; +use execution_proto::{CommitBlockResponse, ExecuteBlockResponse, ExecuteChunkResponse}; +use failure::{format_err, Result}; +use futures::{channel::oneshot, executor::block_on}; +use lazy_static::lazy_static; +use logger::prelude::*; +use std::{ + collections::HashMap, + marker::PhantomData, + sync::{mpsc, Arc, Mutex}, +}; +use storage_client::{StorageRead, StorageWrite}; +use types::{ + ledger_info::{LedgerInfo, LedgerInfoWithSignatures}, + transaction::{SignedTransaction, TransactionListWithProof}, +}; +use vm_runtime::VMExecutor; + +lazy_static! { + static ref OP_COUNTERS: metrics::OpMetrics = metrics::OpMetrics::new_and_registered("executor"); +} + +/// `Executor` implements all functionalities the execution module needs to provide. +pub struct Executor { + /// A thread that keeps processing blocks. + block_processor_thread: Option>, + + /// Where we can send command to the block processor. The block processor sits at the other end + /// of the channel and processes the commands. + command_sender: Mutex>>, + + phantom: PhantomData, +} + +impl Executor +where + V: VMExecutor, +{ + /// Constructs an `Executor`. + pub fn new( + storage_read_client: Arc, + storage_write_client: Arc, + config: &NodeConfig, + ) -> Self { + let startup_info = storage_read_client + .get_executor_startup_info() + .expect("Failed to read startup info from storage."); + + let ( + state_root_hash, + frozen_subtrees_in_accumulator, + num_elements_in_accumulator, + committed_timestamp_usecs, + committed_block_id, + ) = match startup_info { + Some(info) => { + info!("Startup info read from DB: {:?}.", info); + let ledger_info = info.ledger_info; + ( + info.account_state_root_hash, + info.ledger_frozen_subtree_hashes, + ledger_info.version() + 1, + ledger_info.timestamp_usecs(), + ledger_info.consensus_block_id(), + ) + } + None => { + info!("Startup info is empty. Will start from GENESIS."); + ( + *SPARSE_MERKLE_PLACEHOLDER_HASH, + vec![], + 0, + 0, + *PRE_GENESIS_BLOCK_ID, + ) + } + }; + + let (command_sender, command_receiver) = mpsc::channel(); + + let vm_config = config.vm_config.clone(); + let executor = Executor { + block_processor_thread: Some( + std::thread::Builder::new() + .name("block_processor".into()) + .spawn(move || { + let mut block_processor = BlockProcessor::::new( + command_receiver, + committed_timestamp_usecs, + state_root_hash, + frozen_subtrees_in_accumulator, + num_elements_in_accumulator, + committed_block_id, + storage_read_client, + storage_write_client, + vm_config, + ); + block_processor.run(); + }) + .expect("Failed to create block processor thread."), + ), + command_sender: Mutex::new(Some(command_sender)), + phantom: PhantomData, + }; + + if num_elements_in_accumulator == 0 { + let genesis_transaction = config + .execution + .get_genesis_transaction() + .expect("failed to load genesis transaction!"); + executor.init_genesis(genesis_transaction); + } + + executor + } + + /// This is used when we start for the first time and the DB is completely empty. It will write + /// necessary information to DB by committing the genesis transaction. + fn init_genesis(&self, genesis_txn: SignedTransaction) { + // Create a block with genesis_txn being the only transaction. Execute it then commit it + // immediately. + // We create `PRE_GENESIS_BLOCK_ID` as the parent of the genesis block. + let response = block_on(self.execute_block( + vec![genesis_txn], + *PRE_GENESIS_BLOCK_ID, + *GENESIS_BLOCK_ID, + )) + .expect("Response sender was unexpectedly dropped.") + .expect("Failed to execute genesis block."); + + let root_hash = response.root_hash(); + let ledger_info = LedgerInfo::new( + /* version = */ 0, + root_hash, + /* consensus_data_hash = */ HashValue::zero(), + *GENESIS_BLOCK_ID, + /* epoch_num = */ 0, + /* timestamp_usecs = */ 0, + ); + let ledger_info_with_sigs = + LedgerInfoWithSignatures::new(ledger_info, /* signatures = */ HashMap::new()); + block_on(self.commit_block(ledger_info_with_sigs)) + .expect("Response sender was unexpectedly dropped.") + .expect("Failed to commit genesis block."); + info!("GENESIS transaction is committed.") + } + + /// Executes a block. + pub fn execute_block( + &self, + transactions: Vec, + parent_id: HashValue, + id: HashValue, + ) -> oneshot::Receiver> { + debug!( + "Received request to execute block. Parent id: {:x}. Id: {:x}.", + parent_id, id + ); + + let (resp_sender, resp_receiver) = oneshot::channel(); + match self + .command_sender + .lock() + .expect("Failed to lock mutex.") + .as_ref() + { + Some(sender) => sender + .send(Command::ExecuteBlock { + transactions, + parent_id, + id, + resp_sender, + }) + .expect("Did block processor thread panic?"), + None => resp_sender + .send(Err(format_err!("Executor is shutting down."))) + .expect("Failed to send error message."), + } + resp_receiver + } + + /// Commits a block and all its ancestors. + pub fn commit_block( + &self, + ledger_info_with_sigs: LedgerInfoWithSignatures, + ) -> oneshot::Receiver> { + debug!( + "Received request to commit block {:x}.", + ledger_info_with_sigs.ledger_info().consensus_block_id() + ); + + let (resp_sender, resp_receiver) = oneshot::channel(); + match self + .command_sender + .lock() + .expect("Failed to lock mutex.") + .as_ref() + { + Some(sender) => sender + .send(Command::CommitBlock { + ledger_info_with_sigs, + resp_sender, + }) + .expect("Did block processor thread panic?"), + None => resp_sender + .send(Err(format_err!("Executor is shutting down."))) + .expect("Failed to send error message."), + } + resp_receiver + } + + /// Executes and commits a chunk of transactions that are already committed by majority of the + /// validators. + pub fn execute_chunk( + &self, + txn_list_with_proof: TransactionListWithProof, + ledger_info_with_sigs: LedgerInfoWithSignatures, + ) -> oneshot::Receiver> { + debug!( + "Received request to execute chunk. Chunk size: {}. Target version: {}.", + txn_list_with_proof.transaction_and_infos.len(), + ledger_info_with_sigs.ledger_info().version(), + ); + + let (resp_sender, resp_receiver) = oneshot::channel(); + match self + .command_sender + .lock() + .expect("Failed to lock mutex.") + .as_ref() + { + Some(sender) => sender + .send(Command::ExecuteChunk { + txn_list_with_proof, + ledger_info_with_sigs, + resp_sender, + }) + .expect("Did block processor thread panic?"), + None => resp_sender + .send(Err(format_err!("Executor is shutting down."))) + .expect("Failed to send error message."), + } + resp_receiver + } +} + +impl Drop for Executor { + fn drop(&mut self) { + // Drop the sender so the block processor thread will exit. + self.command_sender + .lock() + .expect("Failed to lock mutex.") + .take() + .expect("Command sender should exist."); + self.block_processor_thread + .take() + .expect("Block processor thread should exist.") + .join() + .expect("Did block processor thread panic?"); + } +} + +#[allow(clippy::large_enum_variant)] +#[derive(Debug)] +enum Command { + ExecuteBlock { + transactions: Vec, + parent_id: HashValue, + id: HashValue, + resp_sender: oneshot::Sender>, + }, + CommitBlock { + ledger_info_with_sigs: LedgerInfoWithSignatures, + resp_sender: oneshot::Sender>, + }, + ExecuteChunk { + txn_list_with_proof: TransactionListWithProof, + ledger_info_with_sigs: LedgerInfoWithSignatures, + resp_sender: oneshot::Sender>, + }, +} diff --git a/execution/executor/src/mock_vm/mock_vm_test.rs b/execution/executor/src/mock_vm/mock_vm_test.rs new file mode 100644 index 0000000000000..569afe653394e --- /dev/null +++ b/execution/executor/src/mock_vm/mock_vm_test.rs @@ -0,0 +1,142 @@ +// Copyright (c) The Libra Core Contributors +// SPDX-License-Identifier: Apache-2.0 + +use super::{balance_ap, encode_mint_transaction, encode_transfer_transaction, seqnum_ap, MockVM}; +use config::config::VMConfig; +use failure::Result; +use state_view::StateView; +use types::{ + access_path::AccessPath, + account_address::{AccountAddress, ADDRESS_LENGTH}, + write_set::WriteOp, +}; +use vm_runtime::VMExecutor; + +fn gen_address(index: u8) -> AccountAddress { + AccountAddress::new([index; ADDRESS_LENGTH]) +} + +struct MockStateView; + +impl StateView for MockStateView { + fn get(&self, _access_path: &AccessPath) -> Result>> { + Ok(None) + } + + fn multi_get(&self, _access_paths: &[AccessPath]) -> Result>>> { + unimplemented!(); + } + + fn is_genesis(&self) -> bool { + false + } +} + +#[test] +fn test_mock_vm_different_senders() { + let amount = 100; + let mut txns = vec![]; + for i in 0..10 { + txns.push(encode_mint_transaction(gen_address(i), amount)); + } + + let outputs = MockVM::execute_block( + txns.clone(), + &VMConfig::empty_whitelist_FOR_TESTING(), + &MockStateView, + ); + + for (output, txn) in itertools::zip_eq(outputs.iter(), txns.iter()) { + let sender = txn.sender(); + assert_eq!( + output.write_set().iter().cloned().collect::>(), + vec![ + ( + balance_ap(sender), + WriteOp::Value(amount.to_le_bytes().to_vec()) + ), + ( + seqnum_ap(sender), + WriteOp::Value(1u64.to_le_bytes().to_vec()) + ), + ] + ); + } +} + +#[test] +fn test_mock_vm_same_sender() { + let amount = 100; + let sender = gen_address(1); + let mut txns = vec![]; + for _i in 0..10 { + txns.push(encode_mint_transaction(sender, amount)); + } + + let outputs = MockVM::execute_block( + txns, + &VMConfig::empty_whitelist_FOR_TESTING(), + &MockStateView, + ); + + for (i, output) in outputs.iter().enumerate() { + assert_eq!( + output.write_set().iter().cloned().collect::>(), + vec![ + ( + balance_ap(sender), + WriteOp::Value((amount * (i as u64 + 1)).to_le_bytes().to_vec()) + ), + ( + seqnum_ap(sender), + WriteOp::Value((i as u64 + 1).to_le_bytes().to_vec()) + ), + ] + ); + } +} + +#[test] +fn test_mock_vm_payment() { + let mut txns = vec![]; + txns.push(encode_mint_transaction(gen_address(0), 100)); + txns.push(encode_mint_transaction(gen_address(1), 100)); + txns.push(encode_transfer_transaction( + gen_address(0), + gen_address(1), + 50, + )); + + let output = MockVM::execute_block( + txns, + &VMConfig::empty_whitelist_FOR_TESTING(), + &MockStateView, + ); + + let mut output_iter = output.iter(); + output_iter.next(); + output_iter.next(); + assert_eq!( + output_iter + .next() + .unwrap() + .write_set() + .iter() + .cloned() + .collect::>(), + vec![ + ( + balance_ap(gen_address(0)), + WriteOp::Value(50u64.to_le_bytes().to_vec()) + ), + ( + seqnum_ap(gen_address(0)), + WriteOp::Value(2u64.to_le_bytes().to_vec()) + ), + ( + balance_ap(gen_address(1)), + WriteOp::Value(150u64.to_le_bytes().to_vec()) + ), + ] + ); +} diff --git a/execution/executor/src/mock_vm/mod.rs b/execution/executor/src/mock_vm/mod.rs new file mode 100644 index 0000000000000..cf26125991c37 --- /dev/null +++ b/execution/executor/src/mock_vm/mod.rs @@ -0,0 +1,309 @@ +// Copyright (c) The Libra Core Contributors +// SPDX-License-Identifier: Apache-2.0 + +#[cfg(test)] +mod mock_vm_test; + +use config::config::VMConfig; +use crypto::signing::generate_keypair; +use state_view::StateView; +use std::collections::HashMap; +use types::{ + access_path::AccessPath, + account_address::{AccountAddress, ADDRESS_LENGTH}, + contract_event::ContractEvent, + transaction::{ + Program, RawTransaction, SignedTransaction, TransactionArgument, TransactionOutput, + TransactionPayload, TransactionStatus, + }, + vm_error::{ExecutionStatus, VMStatus}, + write_set::{WriteOp, WriteSet, WriteSetMut}, +}; +use vm_runtime::VMExecutor; + +#[derive(Debug)] +enum Transaction { + Mint { + sender: AccountAddress, + amount: u64, + }, + Payment { + sender: AccountAddress, + recipient: AccountAddress, + amount: u64, + }, +} + +pub const KEEP_STATUS: TransactionStatus = + TransactionStatus::Keep(VMStatus::Execution(ExecutionStatus::Executed)); + +// We use 10 as the assertion error code for insufficient balance within the Libra coin contract. +pub const DISCARD_STATUS: TransactionStatus = + TransactionStatus::Discard(VMStatus::Execution(ExecutionStatus::AssertionFailure(10))); + +pub struct MockVM; + +impl VMExecutor for MockVM { + fn execute_block( + transactions: Vec, + _config: &VMConfig, + state_view: &dyn StateView, + ) -> Vec { + if state_view.is_genesis() { + assert_eq!( + transactions.len(), + 1, + "Genesis block should have only one transaction." + ); + let output = TransactionOutput::new(gen_genesis_writeset(), vec![], 0, KEEP_STATUS); + return vec![output]; + } + + // output_cache is used to store the output of transactions so they are visible to later + // transactions. + let mut output_cache = HashMap::new(); + let mut outputs = vec![]; + + for txn in transactions { + match decode_transaction(&txn) { + Transaction::Mint { sender, amount } => { + let old_balance = read_balance(&output_cache, state_view, sender); + let new_balance = old_balance + amount; + let old_seqnum = read_seqnum(&output_cache, state_view, sender); + let new_seqnum = old_seqnum + 1; + + output_cache.insert(balance_ap(sender), new_balance); + output_cache.insert(seqnum_ap(sender), new_seqnum); + + let write_set = gen_mint_writeset(sender, new_balance, new_seqnum); + let events = gen_events(sender); + outputs.push(TransactionOutput::new(write_set, events, 0, KEEP_STATUS)); + } + Transaction::Payment { + sender, + recipient, + amount, + } => { + let sender_old_balance = read_balance(&output_cache, state_view, sender); + let recipient_old_balance = read_balance(&output_cache, state_view, recipient); + if sender_old_balance < amount { + outputs.push(TransactionOutput::new( + WriteSet::default(), + vec![], + 0, + DISCARD_STATUS, + )); + continue; + } + + let sender_old_seqnum = read_seqnum(&output_cache, state_view, sender); + let sender_new_seqnum = sender_old_seqnum + 1; + let sender_new_balance = sender_old_balance - amount; + let recipient_new_balance = recipient_old_balance + amount; + + output_cache.insert(balance_ap(sender), sender_new_balance); + output_cache.insert(seqnum_ap(sender), sender_new_seqnum); + output_cache.insert(balance_ap(recipient), recipient_new_balance); + + let write_set = gen_payment_writeset( + sender, + sender_new_balance, + sender_new_seqnum, + recipient, + recipient_new_balance, + ); + let events = gen_events(sender); + outputs.push(TransactionOutput::new( + write_set, + events, + 0, + TransactionStatus::Keep(VMStatus::Execution(ExecutionStatus::Executed)), + )); + } + } + } + + outputs + } +} + +fn read_balance( + output_cache: &HashMap, + state_view: &dyn StateView, + account: AccountAddress, +) -> u64 { + let balance_access_path = balance_ap(account); + match output_cache.get(&balance_access_path) { + Some(balance) => *balance, + None => read_balance_from_storage(state_view, &balance_access_path), + } +} + +fn read_seqnum( + output_cache: &HashMap, + state_view: &dyn StateView, + account: AccountAddress, +) -> u64 { + let seqnum_access_path = seqnum_ap(account); + match output_cache.get(&seqnum_access_path) { + Some(seqnum) => *seqnum, + None => read_seqnum_from_storage(state_view, &seqnum_access_path), + } +} + +fn read_balance_from_storage(state_view: &dyn StateView, balance_access_path: &AccessPath) -> u64 { + read_u64_from_storage(state_view, &balance_access_path) +} + +fn read_seqnum_from_storage(state_view: &dyn StateView, seqnum_access_path: &AccessPath) -> u64 { + read_u64_from_storage(state_view, &seqnum_access_path) +} + +fn read_u64_from_storage(state_view: &dyn StateView, access_path: &AccessPath) -> u64 { + match state_view + .get(&access_path) + .expect("Failed to query storage.") + { + Some(bytes) => decode_bytes(&bytes), + None => 0, + } +} + +fn decode_bytes(bytes: &[u8]) -> u64 { + let mut buf = [0; 8]; + buf.copy_from_slice(bytes); + u64::from_le_bytes(buf) +} + +fn balance_ap(account: AccountAddress) -> AccessPath { + AccessPath::new(account, b"balance".to_vec()) +} + +fn seqnum_ap(account: AccountAddress) -> AccessPath { + AccessPath::new(account, b"seqnum".to_vec()) +} + +fn gen_genesis_writeset() -> WriteSet { + let address = AccountAddress::new([0xff; ADDRESS_LENGTH]); + let path = b"hello".to_vec(); + let mut write_set = WriteSetMut::default(); + write_set.push(( + AccessPath::new(address, path), + WriteOp::Value(b"world".to_vec()), + )); + write_set + .freeze() + .expect("genesis writeset should be valid") +} + +fn gen_mint_writeset(sender: AccountAddress, balance: u64, seqnum: u64) -> WriteSet { + let mut write_set = WriteSetMut::default(); + write_set.push(( + balance_ap(sender), + WriteOp::Value(balance.to_le_bytes().to_vec()), + )); + write_set.push(( + seqnum_ap(sender), + WriteOp::Value(seqnum.to_le_bytes().to_vec()), + )); + write_set.freeze().expect("mint writeset should be valid") +} + +fn gen_payment_writeset( + sender: AccountAddress, + sender_balance: u64, + sender_seqnum: u64, + recipient: AccountAddress, + recipient_balance: u64, +) -> WriteSet { + let mut write_set = WriteSetMut::default(); + write_set.push(( + balance_ap(sender), + WriteOp::Value(sender_balance.to_le_bytes().to_vec()), + )); + write_set.push(( + seqnum_ap(sender), + WriteOp::Value(sender_seqnum.to_le_bytes().to_vec()), + )); + write_set.push(( + balance_ap(recipient), + WriteOp::Value(recipient_balance.to_le_bytes().to_vec()), + )); + write_set + .freeze() + .expect("payment write set should be valid") +} + +fn gen_events(sender: AccountAddress) -> Vec { + let access_path = AccessPath::new(sender, b"event".to_vec()); + let event = ContractEvent::new(access_path, 0, b"event_data".to_vec()); + vec![event] +} + +pub fn encode_mint_program(amount: u64) -> Program { + let argument = TransactionArgument::U64(amount); + Program::new(vec![], vec![], vec![argument]) +} + +pub fn encode_transfer_program(recipient: AccountAddress, amount: u64) -> Program { + let argument1 = TransactionArgument::Address(recipient); + let argument2 = TransactionArgument::U64(amount); + Program::new(vec![], vec![], vec![argument1, argument2]) +} + +pub fn encode_mint_transaction(sender: AccountAddress, amount: u64) -> SignedTransaction { + encode_transaction(sender, encode_mint_program(amount)) +} + +pub fn encode_transfer_transaction( + sender: AccountAddress, + recipient: AccountAddress, + amount: u64, +) -> SignedTransaction { + encode_transaction(sender, encode_transfer_program(recipient, amount)) +} + +fn encode_transaction(sender: AccountAddress, program: Program) -> SignedTransaction { + let raw_transaction = + RawTransaction::new(sender, 0, program, 0, 0, std::time::Duration::from_secs(0)); + + let (privkey, pubkey) = generate_keypair(); + raw_transaction + .sign(&privkey, pubkey) + .expect("Failed to sign raw transaction.") +} + +fn decode_transaction(txn: &SignedTransaction) -> Transaction { + let sender = txn.sender(); + match txn.payload() { + TransactionPayload::Program(program) => { + assert!(program.code().is_empty(), "Code should be empty."); + assert!(program.modules().is_empty(), "Modules should be empty."); + match program.args().len() { + 1 => match program.args()[0] { + TransactionArgument::U64(amount) => Transaction::Mint { sender, amount }, + _ => unimplemented!( + "Only one integer argument is allowed for mint transactions." + ), + }, + 2 => match (&program.args()[0], &program.args()[1]) { + (TransactionArgument::Address(recipient), TransactionArgument::U64(amount)) => { + Transaction::Payment { + sender, + recipient: *recipient, + amount: *amount, + } + } + _ => unimplemented!( + "The first argument for payment transaction must be recipient address \ + and the second argument must be amount." + ), + }, + _ => unimplemented!("Transaction must have one or two arguments."), + } + } + TransactionPayload::WriteSet(_) => { + unimplemented!("MockVM does not support WriteSet transaction payload.") + } + } +} diff --git a/execution/executor/src/transaction_block.rs b/execution/executor/src/transaction_block.rs new file mode 100644 index 0000000000000..4da925030ee1e --- /dev/null +++ b/execution/executor/src/transaction_block.rs @@ -0,0 +1,370 @@ +// Copyright (c) The Libra Core Contributors +// SPDX-License-Identifier: Apache-2.0 + +use crate::block_tree::Block; +use crypto::{ + hash::{EventAccumulatorHasher, TransactionAccumulatorHasher}, + HashValue, +}; +use execution_proto::{CommitBlockResponse, ExecuteBlockResponse}; +use failure::{format_err, Result}; +use futures::channel::oneshot; +use logger::prelude::*; +use scratchpad::{Accumulator, SparseMerkleTree}; +use std::{ + collections::{HashMap, HashSet}, + rc::Rc, +}; +use types::{ + account_address::AccountAddress, + account_state_blob::AccountStateBlob, + contract_event::ContractEvent, + ledger_info::LedgerInfoWithSignatures, + transaction::{SignedTransaction, TransactionStatus}, +}; + +/// `TransactionBlock` holds everything about the block of transactions. +#[derive(Debug)] +pub struct TransactionBlock { + /// Whether consensus has decided to commit this block. + committed: bool, + + /// Id of this block. + id: HashValue, + + /// Id of the parent block. + parent_id: HashValue, + + /// The set of children. + children: HashSet, + + /// The transactions themselves. + transactions: Vec, + + /// The result of processing VM's output. + output: Option, + + /// The signatures on this block. Not all committed blocks will have signatures, if multiple + /// blocks are committed at once. + ledger_info_with_sigs: Option, + + /// The response for `execute_block` request. + execute_response: Option, + + /// The senders associated with this block. These senders are like the promises associated with + /// the futures returned by `execute_block` and `commit_block` APIs, which are fulfilled when + /// the responses are ready. + execute_response_senders: Vec>>, + commit_response_sender: Option>>, +} + +impl TransactionBlock { + /// Constructs a new block. A `TransactionBlock` is constructed as soon as consensus gives us a + /// new block. It has not been executed yet so output is `None`. + pub fn new( + transactions: Vec, + parent_id: HashValue, + id: HashValue, + execute_response_sender: oneshot::Sender>, + ) -> Self { + TransactionBlock { + committed: false, + id, + parent_id, + children: HashSet::new(), + transactions, + output: None, + ledger_info_with_sigs: None, + execute_response: None, + execute_response_senders: vec![execute_response_sender], + commit_response_sender: None, + } + } + + /// Returns the list of transactions. + pub fn transactions(&self) -> &[SignedTransaction] { + &self.transactions + } + + /// Returns the output of the block. + pub fn output(&self) -> &Option { + &self.output + } + + /// Returns the signatures on this block. + pub fn ledger_info_with_sigs(&self) -> &Option { + &self.ledger_info_with_sigs + } + + /// Saves the response in the block. If there are any queued senders, send the response. + pub fn set_execute_block_response(&mut self, response: ExecuteBlockResponse) { + assert!(self.execute_response.is_none(), "Response is already set."); + self.execute_response = Some(response.clone()); + // Send the response since it's now available. + self.send_execute_block_response(Ok(response)); + } + + /// Puts a sender in the queue. The response will be sent via the sender once available + /// (possibly as soon as the function is called if the response if already available). + pub fn queue_execute_block_response_sender( + &mut self, + sender: oneshot::Sender>, + ) { + // If the response is already available, just send it. Otherwise store the sender for later + // use. + match self.execute_response { + Some(ref response) => { + if let Err(_err) = sender.send(Ok(response.clone())) { + warn!("Failed to send execute block response."); + } + } + None => self.execute_response_senders.push(sender), + } + } + + /// Sends finished `ExecuteBlockResponse` to consensus. This removes all the existing senders. + pub fn send_execute_block_response(&mut self, response: Result) { + while let Some(sender) = self.execute_response_senders.pop() { + // We need to send the result multiple times, but the error is not cloneable, thus the + // result is not cloneable. This is a bit workaround. + let resp = match &response { + Ok(resp) => Ok(resp.clone()), + Err(err) => Err(format_err!("{}", err)), + }; + if let Err(_err) = sender.send(resp) { + warn!("Failed to send execute block response."); + } + } + } + + /// When the block is created, no one has called `commit_block` on this block yet, so we do not + /// have the sender and `self.commit_response_sender` is initialized to `None`. When consensus + /// calls `commit_block` on a block, we will put the sender in the block. So when this block is + /// persisted in storage later, we will call `send_commit_response` and consensus will receive + /// the response. + pub fn set_commit_response_sender( + &mut self, + commit_response_sender: oneshot::Sender>, + ) { + assert!( + self.commit_response_sender.is_none(), + "CommitBlockResponse sender should not exist." + ); + self.commit_response_sender = Some(commit_response_sender); + } + + /// Sends finished `CommitBlockResponse` to consensus. + pub fn send_commit_block_response(&mut self, response: Result) { + let sender = self + .commit_response_sender + .take() + .expect("CommitBlockResponse sender should exist."); + if let Err(_err) = sender.send(response) { + warn!("Failed to send commit block response:."); + } + } + + /// Returns a pointer to the Sparse Merkle Tree representing the state at the end of the block. + /// Should only be called when the block has finished execution and `set_output` has been + /// called. + pub fn clone_state_tree(&self) -> Rc { + self.output + .as_ref() + .expect("The block has no output yet.") + .clone_state_tree() + } + + /// Returns a pointer to the Merkle Accumulator representing the end of the block. Should only + /// be called when the block has finished execution and `set_output` has been called. + pub fn clone_transaction_accumulator(&self) -> Rc> { + self.output + .as_ref() + .expect("The block has no output yet.") + .clone_transaction_accumulator() + } +} + +impl Block for TransactionBlock { + type Output = ProcessedVMOutput; + type Signature = LedgerInfoWithSignatures; + + fn is_committed(&self) -> bool { + self.committed + } + + fn set_committed(&mut self) { + assert!(!self.committed); + self.committed = true; + } + + fn is_executed(&self) -> bool { + self.output.is_some() + } + + fn set_output(&mut self, output: Self::Output) { + assert!(self.output.is_none(), "Output is already set."); + self.output = Some(output); + } + + fn set_signature(&mut self, signature: Self::Signature) { + assert!( + self.ledger_info_with_sigs.is_none(), + "Signature is already set." + ); + self.ledger_info_with_sigs = Some(signature); + } + + fn id(&self) -> HashValue { + self.id + } + + fn parent_id(&self) -> HashValue { + self.parent_id + } + + fn add_child(&mut self, child_id: HashValue) { + assert!(self.children.insert(child_id)); + } + + fn children(&self) -> &HashSet { + &self.children + } +} + +impl Drop for TransactionBlock { + fn drop(&mut self) { + // It is possible for a block to be discarded before it gets executed, for example, due to + // a parallel block getting committed. In this case we still want to send a response back. + if !self.execute_response_senders.is_empty() { + assert!(self.execute_response.is_none()); + self.send_execute_block_response(Err(format_err!("Block {} is discarded.", self.id))); + } + } +} + +/// The entire set of data associated with a transaction. In addition to the output generated by VM +/// which includes the write set and events, this also has the in-memory trees. +#[derive(Debug)] +pub struct TransactionData { + /// Each entry in this map represents the new blob value of an account touched by this + /// transaction. The blob is obtained by deserializing the previous blob into a BTreeMap, + /// applying relevant portion of write set on the map and serializing the updated map into a + /// new blob. + account_blobs: HashMap, + + /// The list of events emitted during this transaction. + events: Vec, + + /// The execution status set by the VM. + status: TransactionStatus, + + /// The in-memory Sparse Merkle Tree after the write set is applied. This is `Rc` because the + /// tree has uncommitted state and sometimes `StateVersionView` needs to have a pointer to the + /// tree so VM can read it. + state_tree: Rc, + + /// The in-memory Merkle Accumulator that has all events emitted by this transaction. + event_tree: Rc>, + + /// The amount of gas used. + gas_used: u64, + + /// The number of newly created accounts. + num_account_created: usize, +} + +impl TransactionData { + pub fn new( + account_blobs: HashMap, + events: Vec, + status: TransactionStatus, + state_tree: Rc, + event_tree: Rc>, + gas_used: u64, + num_account_created: usize, + ) -> Self { + TransactionData { + account_blobs, + events, + status, + state_tree, + event_tree, + gas_used, + num_account_created, + } + } + + pub fn account_blobs(&self) -> &HashMap { + &self.account_blobs + } + + pub fn events(&self) -> &[ContractEvent] { + &self.events + } + + pub fn status(&self) -> &TransactionStatus { + &self.status + } + + pub fn state_root_hash(&self) -> HashValue { + self.state_tree.root_hash() + } + + pub fn event_root_hash(&self) -> HashValue { + self.event_tree.root_hash() + } + + pub fn gas_used(&self) -> u64 { + self.gas_used + } + + pub fn num_account_created(&self) -> usize { + self.num_account_created + } + + pub fn prune_state_tree(&self) { + self.state_tree.prune() + } +} + +/// Generated by processing VM's output. +#[derive(Debug)] +pub struct ProcessedVMOutput { + /// The entire set of data assoicated with each transaction. + transaction_data: Vec, + + /// The in-memory Merkle Accumulator after appending new `TransactionInfo` objects. + transaction_accumulator: Rc>, + + /// This is the same tree as the state tree in the last transaction's output. When we execute a + /// child block we will need this tree as it stores the output of all previous transactions. It + /// is only for convenience purpose so we do not need to deal with the special case of empty + /// block. + state_tree: Rc, +} + +impl ProcessedVMOutput { + pub fn new( + transaction_data: Vec, + transaction_accumulator: Rc>, + state_tree: Rc, + ) -> Self { + ProcessedVMOutput { + transaction_data, + transaction_accumulator, + state_tree, + } + } + + pub fn transaction_data(&self) -> &[TransactionData] { + &self.transaction_data + } + + pub fn clone_transaction_accumulator(&self) -> Rc> { + Rc::clone(&self.transaction_accumulator) + } + + pub fn clone_state_tree(&self) -> Rc { + Rc::clone(&self.state_tree) + } +} diff --git a/language/README.md b/language/README.md new file mode 100644 index 0000000000000..5b87f570bac50 --- /dev/null +++ b/language/README.md @@ -0,0 +1,53 @@ +--- +id: move-language +title: Move Language +custom_edit_url: https://github.com/libra/libra/edit/master/language/README.md +--- + +# Move + +Move is a new programming language developed to provide a safe and programmable foundation for the Libra Blockchain. + +## Organization + +The Move language directory consists of five parts: + +- The [virtual machine](https://github.com/libra/libra/tree/master/language/vm) (VM), which contains the bytecode format, a bytecode interpreter, and infrastructure for executing a block of transactions. This directory also contains the infrastructure to generate the genesis block. + +- The [bytecode verifier](https://github.com/libra/libra/tree/master/language/bytecode_verifier), which contains a static analysis tool for rejecting invalid Move bytecode. The virtual machine runs the bytecode verifier on any new Move code it encounters before executing it. The compiler runs the bytecode verifier on its output and surfaces the errors to the programmer. + +- The Move intermediate representation (IR) [compiler](https://github.com/libra/libra/tree/master/language/stdlib), which compiles human-readable program text into Move bytecode. *Warning: the IR compiler is a testing tool. It can generate invalid bytecode that will be rejected by the Move bytecode verifier. The IR syntax is a work in progress that will undergo significant changes.* + +- The [standard library](https://github.com/libra/libra/tree/master/language/stdlib), which contains the Move IR code for the core system modules such as `LibraAccount` and `LibraCoin`. + +- The [tests](https://github.com/libra/libra/tree/master/language/functional_tests) for the virtual machine, bytecode verifier, and compiler. These tests are written in Move IR and run by a testing framework that parses the expected result of running a test from special directives encoded in comments. + +## How the Move Language Fits Into Libra Core + +Libra Core components interact with the language component through the VM. Specifically, the [admission control](https://github.com/libra/libra/tree/master/admission_control) component uses a limited, read-only [subset](https://github.com/libra/libra/tree/master/vm_validator) of the VM functionality to discard invalid transactions before they are admitted to the mempool and consensus. The [execution](https://github.com/libra/libra/tree/master/execution) component uses the VM to execute a block of transactions. + +### Exploring Move IR + +* You can find many small Move IR examples in the [tests](https://github.com/libra/libra/tree/master/language/functional_tests/tests/testsuite). The easiest way to experiment with Move IR is to create a new test in this directory and follow the instructions for runnning the tests. +* Some more substantial examples can be found in the [standard library](https://github.com/libra/libra/tree/master/language/stdlib/modules). The two most notable ones are [LibraAccount.mvir](https://github.com/libra/libra/blob/master/language/stdlib/modules/libra_account.mvir), which implements accounts on the Libra blockchain, and [LibraCoin.mvir](https://github.com/libra/libra/blob/master/language/stdlib/modules/libra_coin.mvir), which implements Libra coin. +* The four transaction scripts supported in the Libra testnet are also in the standard library directiory. They are [peer-to-peer transfer](https://github.com/libra/libra/blob/master/language/stdlib/transaction_scripts/peer_to_peer_transfer.mvir), [account creation](https://github.com/libra/libra/blob/master/language/stdlib/transaction_scripts/create_account.mvir), [minting new Libra](https://github.com/libra/libra/blob/master/language/stdlib/transaction_scripts/mint.mvir) (will only work for an account with proper privileges), and [key rotation](https://github.com/libra/libra/blob/master/language/stdlib/transaction_scripts/rotate_authentication_key.mvir). +* The most complete documention of the Move IR syntax is the [grammar](https://github.com/libra/libra/blob/master/language/compiler/src/parser/mod.rs). You can also take a look at the [parser for the Move IR](https://github.com/libra/libra/blob/master/language/compiler/src/parser/syntax.lalrpop). +* Check out the [IR compiler README](https://github.com/libra/libra/blob/master/language/compiler/README.md) for more details on writing Move IR code. + +### Directory Organization + +``` +β”œβ”€β”€ README.md # This README +β”œβ”€β”€ bytecode_verifier # The bytecode verifier +β”œβ”€β”€ functional_tests # Testing framework for the Move language +β”œβ”€β”€ compiler # The IR to Move bytecode compiler +β”œβ”€β”€ stdlib # Core Move modules and transaction scripts +β”œβ”€β”€ test.sh # Script for running all the language tests +└── vm + β”œβ”€β”€ cost_synthesis # Cost synthesis for bytecode instructions + β”œβ”€β”€ src # Bytecode language definitions, serializer, and deserializer + β”œβ”€β”€ tests # VM tests + β”œβ”€β”€ vm_genesis # The genesis state creation, and blockchain genesis writeset + └── vm_runtime # The bytecode interpreter +``` + diff --git a/language/bytecode_verifier/Cargo.toml b/language/bytecode_verifier/Cargo.toml new file mode 100644 index 0000000000000..39b42448ece5c --- /dev/null +++ b/language/bytecode_verifier/Cargo.toml @@ -0,0 +1,18 @@ +[package] +name = "bytecode_verifier" +version = "0.1.0" +authors = ["Libra Association "] +license = "Apache-2.0" +publish = false +edition = "2018" + +[dependencies] +mirai-annotations = "0.0.1" +petgraph = "0.4" + +vm = { path = "../vm" } +types = { path = "../../types" } + +[dev-dependencies] +invalid_mutations = { path = "invalid_mutations" } +proptest = "0.9" diff --git a/language/bytecode_verifier/README.md b/language/bytecode_verifier/README.md new file mode 100644 index 0000000000000..1422680c6ba8d --- /dev/null +++ b/language/bytecode_verifier/README.md @@ -0,0 +1,157 @@ +--- +id: bytecode-verifier +title: Bytecode Verifier +custom_edit_url: https://github.com/libra/libra/edit/master/language/bytecode_verifier/README.md +--- + +### Bytecode Verifier: Checking Safety of Stack Usage, Types, Resources, and References + +The body of each function in a compiled module is verified separately while trusting the correctness of function signatures in the module. Checking that each function signature matches its definition is a separate responsibility. The body of a function is a sequence of bytecode instructions. This instruction sequence is checked in several phases described below. + +## CFG Construction + +A control-flow graph is constructed by decomposing the instruction sequence into a collection of basic blocks. Each basic block contains a contiguous sequence of instructions; the set of all instructions is partitioned among the blocks. Each block ends with a branch or return instruction. The decomposition into blocks guarantees that branch targets land only at the beginning of some block. The decomposition also attempts to ensure that the generated blocks are maximal. However, the soundness of the analysis does not depend on maximality. + +## Stack Safety + +The execution of a block happens in the context of an array of local variables and a stack. The parameters of the function are a prefix of the array of local variables. Passing arguments and return values across function calls is done via the stack. When a function starts executing, its arguments are already loaded into its parameters. Suppose the stack height is *n* when a function starts executing; then valid bytecode must enforce the invariant that when execution lands at the beginning of a basic block, the stack height is *n*. Furthermore, at a return instruction, the stack height must be *n*+*k* where *k*, s.t. *k*>=0 is the number of return values. The first phase of the analysis checks that this invariant is maintained by analyzing each block separately, calculating the effect of each instruction in the block on the stack height, checking that the height does not go below *n*, and that is left either at *n* or *n*+*k* (depending on the final instruction of the block and the return type of the function) at the end of the block. + +## Type Safety + +The second phase of the analysis checks that each operation, primitive or defined function, is invoked with arguments of appropriate types. The operands of an operation are values located either in a local variable or on the stack. The types of local variables of a function are already provided in the bytecode. However, the types of stack values are inferred. This inference and the type checking of each operation can be done separately for each block. Since the stack height at the beginning of each block is *n* and does not go below *n* during the execution of the block, we only need to model the suffix of the stack starting at *n* for type checking the block instructions. We model this suffix using a stack of types on which types are pushed and popped as the instruction stream in a block is processed. The type stack and the statically-known types of local variables are enough to type check each instruction. + +## Resource Safety + +Resources represent assets of the blockchain. As such, there are certain restrictions over these types that do not apply to normal values. Intuitively, resource values cannot be copied and must be used by the end of the transaction (this means moved to global storage or destroyed). Concretely, the following restrictions apply: + +* `CopyLoc` and `StLoc` require the type of local is not of resource kind. +* `WriteRef`, `Eq`, and `Neq` require the type of the reference is not of resource kind. +* At the end of a function (when `Ret` is reached), no any local whose type is of resource kind must be empty, i.e., the value must have been moved out of the local. + +As mentioned above, this last rule around `Ret` implies that the resource *must* have been either: + +* Moved to global storage via `MoveToSender`. +* Destroyed via `Unpack`. + +Both of `MoveToSender` and `Unpack` are internal to the module in which the resource is declared. + +## Reference Safety + +References are first-class in the bytecode language. Fresh references become available to a function in several ways: + +* Input parameters. +* Taking address of the value in a local variable. +* Taking address of the globally published value in an address. +* Taking address of a field from a reference to the containing struct. +* Return value from a function. + +The goal of reference safety checking is to ensure that there are no dangling references. Here are some examples of dangling references: + +* Local variable `y` contains a reference to the value in a local variable `x`; `x` is then moved. +* Local variable `y` contains a reference to the value in a local variable `x`; `x` is then bound to a new value. +* Reference is taken to a local variable that has not been initialized. +* Reference to a value in a local variable is returned from a function. +* Reference `r` is taken to a globally published value `v`; `v` is then unpublished. + +References can be either exclusive or shared; the latter allow only read access. A secondary goal of reference safety checking is to ensure that in the execution context of the bytecode program β€” including the entire evaluation stack and all function frames β€” if there are two distinct storage locations containing references `r1` and `r2` such that `r2` extends `r1`, then both of the following conditions hold: + +* If `r1` is tagged exclusive, then it must be inactive, i.e. it is impossible to reach a control location where `r1` is dereferenced or mutated. +* If `r1` is shared, then `r2` is shared. + +The two conditions above establish the property of referential transparency, important for scalable program verification, which looks roughly as follows: consider the piece of code `v1 = *r; S; v2 = *r`, where `S` is an arbitrary computation that does not perform any write through the syntactic reference `r` (and no writes to any `r'` that extends `r`). Then `v1 == v2`. + +**Analysis Setup.** The reference safety analysis is set up as a flow analysis (or abstract interpretation for those that are familiar with the concept). An abstract state is defined for abstractly executing the code of a basic block. A map is maintained from basic blocks to abstract states. Given an abstract state *S* at the beginning of a basic block *B*, the abstract execution of *B* results in state *S'*. This state *S'* is propagated to all successors of *B* and recorded in the map. If a state already existed for a block, the freshly propagated state is β€œjoined” with the existing state. The join might fail in which case an error is reported. If the join succeeds but the abstract state remains unchanged, no further propagation is done. Otherwise, the state is updated and propagated again through the block. An error may also be reported when an instruction is processed during propagation of abstract state through a block. This propagation terminates because ... + +**Abstract State.** The abstract state has three components: + +* A partial map from locals to abstract values. Locals not in the domain of this map are unavailable. Availability is a generalization of the concept of being initialized. A local variable may become unavailable subsequent to initialization as a result of being moved. An abstract value is either *Reference*(*n*) (for variables of reference type) or *Value*(*ns*) (for variables of value type), where *n* is a nonce and *ns* is a set of nonces. A nonce is a constant used to represent a reference. Let *Nonce* represent the set of all nonces. If a local variable *l* is mapped to *Value*(*ns*), it means that there are outstanding borrowed references pointing into the value stored in *l*. For each member *n* of *ns*, there must be a local variable *l* mapped to *Reference*(*n*). If a local variable *x* is mapped to *Reference*(*n*) and there are local variables *y* and *z* mapped to *Value*(*ns1*) and *Value*(*ns2*) respectively, then it is possible that *n* is a member of both *ns1* and *ns2*. This simply means that the analysis is lossy. The special case when *l* is mapped to *Value*({}) means that there are no borrowed references to *l*, and, therefore, *l* may be destroyed or moved. +* The partial map from locals to abstract values is not enough by itself to check bytecode programs because values manipulated by the bytecode can be large nested structures with references pointing into the middle. A reference pointing into the middle of a value could be extended to get another reference. Some extensions should be allowed but others should not. To keep track of relative extensions among references, we have a second component to the abstract state. This component is a map from nonces to one of two kinds of borrow information: either a set of nonces or a map from fields to sets of nonces. The current implementation stores this information as two separate maps with disjoint domains: + 1. *borrowed_by* maps from *Nonce* to *Set*<*Nonce*>. + 2. *fields_borrowed_by* maps from *Nonce* to *Map*<*Field*, *Set*<*Nonce*>>. + * If *n2* in *borrowed_by*[*n1*], then it means that the reference represented by *n2* is an extension of the reference represented by *n1*. + * If *n2* in *fields_borrowed_by*[*n1*][*f*], it means that the reference represented by *n2* is an extension of the *f*-extension of the reference represented by *n1*. Based on this intuition, it is a sound overapproximation to move a nonce *n* from the domain of *fields_borrowed_by* to the domain of *borrowed_by* by taking the union of all nonce sets corresponding to all fields in the domain of *fields_borrowed_by*[*n*]. +* To propagate an abstract state across the instructions in a block, the values and references on the stack must also be modeled. We had earlier described how we model the usable stack suffix as a stack of types. We now augment the contents of this stack to be a structure containing a type and an abstract value. We maintain the invariant that non-reference values on the stack cannot have pending borrows on them. Therefore, if there is an abstract value *Value*(*ns*) on the stack, then *ns* is empty. + +**Values and References.** Let us take a closer look at how values and references, shared and exclusive, are modeled. + +* A non-reference value is modeled as *Value*(*ns*) where *ns* is a set of nonces representing borrowed references. Destruction/move/copy of this value is deemed safe only if *ns* is empty. Values on the stack trivially satisfy this property, but values in local variables may not. +* A reference is modeled as *Reference*(*n*), where *n* is a nonce. If the reference is tagged shared, then read accesses are always allowed and write accesses are never allowed. If a reference *Reference*(*n*) is tagged exclusive, write access is allowed only if *n* does not have a borrow, and read access is allowed if all nonces that borrow from *n* reside in references that are tagged as shared. Furthermore, the rules for constructing references guarantee that an extension of a reference tagged shared must also be tagged shared. Together, these checks provide the property of referential transparency mentioned earlier. + +At the moment, the bytecode language does not contain any direct constructors for shared references. `BorrowLoc` and `BorrowGlobal` create exclusive references. `BorrowField` creates a reference that inherits its tag from the source reference. Move (when applied to a local containing a reference) moves the reference from a local variable to the stack. `FreezeRef` is used to convert an existing exclusive reference to a shared reference. In the future, we may add a version of `BorrowGlobal` that generates a shared reference + +**Errors.** As mentioned before, an error is reported by the checker in one of the following situations: + +* An instruction cannot be proved safe during propagating of abstract state through a block. +* Join of abstract states propagated via different incoming edges into a block fails. + +Let us take a closer look at the second reason for error reporting above. Note that the stack of type and abstract value pairs representing the usable stack suffix is empty at the beginning of a block. So, the join occurs only over the abstract state representing the available local variables and the borrow information. The join fails only in the situation when the set of available local variables is different on the two edges. If the set of available variables is identical, the join itself is straightforward---the borrow sets are unioned point-wise. There are two subtleties worth mentioning though: + +* The set of nonces used in the abstract states along the two edges may not have any connection with each other. Since the actual nonce values are immaterial, the nonces are canonically mapped to fixed integers (indices of local variables containing the nonces) before performing the join. +* During the join, if a nonce *n* is in the domain of borrowed_by on one side and in the domain of fields_borrowed_by on the other side, *n* is moved from fields_borrowed_by to borrowed_by before doing the join. + +**Borrowing References.** Each of the reference constructors ---`BorrowLoc`, `BorrowField`, `BorrowGlobal`, `FreezeRef`, and `CopyLoc`--- is modeled via the generation of a fresh nonce. While `BorrowLoc` borrows from a value in a local variable, `BorrowGlobal` borrows from the global pool of values. `BorrowField`, `FreezeRef`, and `CopyLoc` (when applied to a local containing a reference) borrow from the source reference. Since each fresh nonce is distinct from all previously-generated nonces, the analysis maintains the invariant that all available local variables and stack locations of reference type have distinct nonces representing their abstract value. Another important invariant is that every nonce referred to in the borrow information must reside in some abstract value representing a local variable or a stack location. + +**Releasing References.** References, both global and local, are released by the `ReleaseRef` operation. References must be explicitly released. It is an error to return from a function with unreleased references in a local variable of the function. All references must be explicitly released. Therefore, it is an error to overwrite an available reference using the `StLoc` operation. + +References are implicitly released when consumed by the operations `ReadRef`, `WriteRef`, `Eq`, `Neq`, and `EmitEvent`. + +**Global References.** The safety of global references depends on a combination of static and dynamic analysis. The static analysis does not distinguish between global and local references. But the dynamic analysis distinguishes between them and performs reference counting on the global references as follows: the bytecode interpreter maintains a map `M` from a pair of Address and fully-qualified resource type to a union (Rust enum) comprising the following values: + +* `Empty` +* `RefCount(n)` for some `n` >= 0 + +Extra state updates and checks are performed by the interpreter for the following operations. In the code below, assert failure indicates programmer error, and panic failure indicates internal error in interpreter. + +```text +MoveFrom(addr) { + assert M[addr, T] == RefCount(0); + M[addr, T] := Empty; +} + +MoveToSender(addr) { + assert M[addr, T] == Empty; + M[addr, T] := RefCount(0); +} + +BorrowGlobal(addr) { + if let RefCount(n) = M[addr, T] then { + assert n == 0; + M[addr, T] := RefCount(n+1); + } else { + assert false; + } +} + +CopyLoc(ref) { + if let Global(addr, T) = ref { + if let RefCount(n) = M[addr, T] then { + assert n > 0; + M[addr, T] := RefCount(n+1); + } else { + panic false; + } + } +} + +ReleaseRef(ref) { + if let Global(addr, T) = ref { + if let RefCount(n) = M[addr, T] then { + assert n > 0; + M[addr, T] := RefCount(n-1); + } else { + panic false; + } + } +} +``` + +A subtle point not explicated by the rules above is that `BorrowField` and `FreezeRef`, when applied to a global reference, leave the reference count unchanged. The reason is because these instructions consume the reference at the top of the stack while producing an extension of it at the top of the stack. Similarly, since `ReadRef`, `WriteRef`, `Eq`, `Neq`, and `EmitEvent` consume the reference at the top of the stack, they will reduce the reference count by 1. + +## Folder Structure + +```text +* +β”œβ”€β”€ invalid_mutations # Library used by proptests +β”œβ”€β”€ src # Core bytecode verifier files +β”œβ”€β”€ tests # Proptests +``` diff --git a/language/bytecode_verifier/invalid_mutations/Cargo.toml b/language/bytecode_verifier/invalid_mutations/Cargo.toml new file mode 100644 index 0000000000000..1366c67047d1d --- /dev/null +++ b/language/bytecode_verifier/invalid_mutations/Cargo.toml @@ -0,0 +1,12 @@ +[package] +name = "invalid_mutations" +version = "0.1.0" +edition = "2018" +authors = ["Libra Association "] +license = "Apache-2.0" +publish = false + +[dependencies] +proptest = "0.9" +proptest_helpers = { path = "../../../common/proptest_helpers" } +vm = { path = "../../vm" } diff --git a/language/bytecode_verifier/invalid_mutations/src/bounds.rs b/language/bytecode_verifier/invalid_mutations/src/bounds.rs new file mode 100644 index 0000000000000..164d3f63319a2 --- /dev/null +++ b/language/bytecode_verifier/invalid_mutations/src/bounds.rs @@ -0,0 +1,455 @@ +// Copyright (c) The Libra Core Contributors +// SPDX-License-Identifier: Apache-2.0 + +use proptest::{ + prelude::*, + sample::{self, Index as PropIndex}, +}; +use proptest_helpers::pick_slice_idxs; +use std::collections::BTreeMap; +use vm::{ + errors::{VMStaticViolation, VerificationError}, + file_format::{ + AddressPoolIndex, CompiledModule, FieldDefinitionIndex, FunctionHandleIndex, + FunctionSignatureIndex, LocalsSignatureIndex, ModuleHandleIndex, StringPoolIndex, + StructHandleIndex, TableIndex, TypeSignatureIndex, + }, + internals::ModuleIndex, + views::{ModuleView, SignatureTokenView}, + IndexKind, +}; + +mod code_unit; +pub use code_unit::{ApplyCodeUnitBoundsContext, CodeUnitBoundsMutation}; + +/// Represents the number of pointers that exist out from a node of a particular kind. +#[derive(Clone, Copy, Debug, Eq, Hash, PartialEq)] +pub enum PointerKind { + /// Exactly one pointer out with this index kind as its destination. + One(IndexKind), + /// Zero or one pointer out with this index kind as its destination. Like the `?` operator in + /// regular expressions. + Optional(IndexKind), + /// Zero or more pointers out with this index kind as its destination. Like the `*` operator + /// in regular expressions. + Star(IndexKind), +} + +impl PointerKind { + /// A list of what pointers (indexes) exist out from a particular kind of node within the + /// module. + /// + /// The only special case is `FunctionDefinition`, which contains a `CodeUnit` that can contain + /// one of several kinds of pointers out. That is not represented in this table. + #[inline] + pub fn pointers_from(src_kind: IndexKind) -> &'static [PointerKind] { + use IndexKind::*; + use PointerKind::*; + + match src_kind { + ModuleHandle => &[One(AddressPool), One(StringPool)], + StructHandle => &[One(ModuleHandle), One(StringPool)], + FunctionHandle => &[One(ModuleHandle), One(StringPool), One(FunctionSignature)], + StructDefinition => &[One(StructHandle), One(FieldDefinition)], + FieldDefinition => &[One(StructHandle), One(StringPool), One(TypeSignature)], + FunctionDefinition => &[One(FunctionHandle), One(LocalsSignature)], + TypeSignature => &[Optional(StructHandle)], + FunctionSignature => &[Star(StructHandle)], + LocalsSignature => &[Star(StructHandle)], + StringPool => &[], + ByteArrayPool => &[], + AddressPool => &[], + // LocalPool and CodeDefinition are function-local, and this only works for + // module-scoped indexes. + // XXX maybe don't treat LocalPool and CodeDefinition the same way as the others? + LocalPool => &[], + CodeDefinition => &[], + } + } + + #[inline] + pub fn to_index_kind(self) -> IndexKind { + match self { + PointerKind::One(idx) | PointerKind::Optional(idx) | PointerKind::Star(idx) => idx, + } + } +} + +pub static VALID_POINTER_SRCS: &[IndexKind] = &[ + IndexKind::ModuleHandle, + IndexKind::StructHandle, + IndexKind::FunctionHandle, + IndexKind::StructDefinition, + IndexKind::FieldDefinition, + IndexKind::FunctionDefinition, + IndexKind::TypeSignature, + IndexKind::FunctionSignature, + IndexKind::LocalsSignature, +]; + +#[cfg(test)] +mod test { + use super::*; + + #[test] + fn pointer_kind_sanity() { + for variant in IndexKind::variants() { + if VALID_POINTER_SRCS.iter().any(|x| x == variant) { + assert!( + !PointerKind::pointers_from(*variant).is_empty(), + "expected variant {:?} to be a valid pointer source", + variant, + ); + } else { + assert!( + PointerKind::pointers_from(*variant).is_empty(), + "expected variant {:?} to not be a valid pointer source", + variant, + ); + } + } + } +} + +/// Represents a single mutation to a `CompiledModule` to produce an out-of-bounds situation. +/// +/// Use `OutOfBoundsMutation::strategy()` to generate them, preferably using `Vec` to generate +/// many at a time. Then use `ApplyOutOfBoundsContext` to apply those mutations. +#[derive(Debug)] +pub struct OutOfBoundsMutation { + src_kind: IndexKind, + src_idx: PropIndex, + dst_kind: IndexKind, + offset: usize, +} + +impl OutOfBoundsMutation { + pub fn strategy() -> impl Strategy { + ( + Self::src_kind_strategy(), + any::(), + any::(), + 0..16 as usize, + ) + .prop_map(|(src_kind, src_idx, dst_kind_idx, offset)| { + let dst_kind = Self::dst_kind(src_kind, dst_kind_idx); + Self { + src_kind, + src_idx, + dst_kind, + offset, + } + }) + } + + // Not all source kinds can be made to be out of bounds (e.g. inherent types can't.) + fn src_kind_strategy() -> impl Strategy { + sample::select(VALID_POINTER_SRCS) + } + + fn dst_kind(src_kind: IndexKind, dst_kind_idx: PropIndex) -> IndexKind { + dst_kind_idx + .get(PointerKind::pointers_from(src_kind)) + .to_index_kind() + } +} + +/// This is used for source indexing, to work with pick_slice_idxs. +impl AsRef for OutOfBoundsMutation { + #[inline] + fn as_ref(&self) -> &PropIndex { + &self.src_idx + } +} + +pub struct ApplyOutOfBoundsContext<'a> { + module: &'a mut CompiledModule, + // This is an Option because it gets moved out in apply before apply_one is called. Rust + // doesn't let you call another con-consuming method after a partial move out. + mutations: Option>, + + // Some precomputations done for signatures. + type_sig_structs: Vec, + function_sig_structs: Vec, + locals_sig_structs: Vec<(LocalsSignatureIndex, usize)>, +} + +impl<'a> ApplyOutOfBoundsContext<'a> { + pub fn new(module: &'a mut CompiledModule, mutations: Vec) -> Self { + let type_sig_structs: Vec<_> = Self::type_sig_structs(module).collect(); + let function_sig_structs: Vec<_> = Self::function_sig_structs(module).collect(); + let locals_sig_structs: Vec<_> = Self::locals_sig_structs(module).collect(); + + Self { + module, + mutations: Some(mutations), + type_sig_structs, + function_sig_structs, + locals_sig_structs, + } + } + + pub fn apply(mut self) -> Vec { + // This is a map from (source kind, dest kind) to the actual mutations -- this is done to + // figure out how many mutations to do for a particular pair, which is required for + // pick_slice_idxs below. + let mut mutation_map = BTreeMap::new(); + for mutation in self + .mutations + .take() + .expect("mutations should always be present") + { + mutation_map + .entry((mutation.src_kind, mutation.dst_kind)) + .or_insert_with(|| vec![]) + .push(mutation); + } + + let mut results = vec![]; + + for ((src_kind, dst_kind), mutations) in mutation_map { + // It would be cool to use an iterator here, if someone could figure out exactly how + // to get the lifetimes right :) + results.extend(self.apply_one(src_kind, dst_kind, mutations)); + } + results + } + + fn apply_one( + &mut self, + src_kind: IndexKind, + dst_kind: IndexKind, + mutations: Vec, + ) -> Vec { + let src_count = match src_kind { + // Only the signature indexes that have structs in them (i.e. are in *_sig_structs) + // are going to be modifiable, so pick among them. + IndexKind::TypeSignature => self.type_sig_structs.len(), + IndexKind::FunctionSignature => self.function_sig_structs.len(), + IndexKind::LocalsSignature => self.locals_sig_structs.len(), + // For the other sorts it's always possible to change an index. + src_kind => self.module.kind_count(src_kind), + }; + // Any signature can be a destination, not just the ones that have structs in them. + let dst_count = self.module.kind_count(dst_kind); + let to_mutate = pick_slice_idxs(src_count, &mutations); + + mutations + .iter() + .zip(to_mutate) + .map(move |(mutation, src_idx)| { + self.set_index( + src_kind, + src_idx, + dst_kind, + dst_count, + (dst_count + mutation.offset) as TableIndex, + ) + }) + .collect() + } + + /// Sets the particular index in the table + /// + /// For example, with `src_kind` set to `ModuleHandle` and `dst_kind` set to `AddressPool`, + /// this will set self.module_handles[src_idx].address to new_idx. + /// + /// This is mainly used for test generation. + fn set_index( + &mut self, + src_kind: IndexKind, + src_idx: usize, + dst_kind: IndexKind, + dst_count: usize, + new_idx: TableIndex, + ) -> VerificationError { + use IndexKind::*; + + // These are default values, but some of the match arms below mutate them. + let mut src_idx = src_idx; + let mut err = VMStaticViolation::IndexOutOfBounds(dst_kind, dst_count, new_idx as usize); + + // A dynamic type system would be able to express this next block of code far more + // concisely. A static type system would require some sort of complicated dependent type + // structure that Rust doesn't have. As things stand today, every possible case needs to + // be listed out. + + match (src_kind, dst_kind) { + (ModuleHandle, AddressPool) => { + self.module.module_handles[src_idx].address = AddressPoolIndex::new(new_idx); + } + (ModuleHandle, StringPool) => { + self.module.module_handles[src_idx].name = StringPoolIndex::new(new_idx) + } + (StructHandle, ModuleHandle) => { + self.module.struct_handles[src_idx].module = ModuleHandleIndex::new(new_idx) + } + (StructHandle, StringPool) => { + self.module.struct_handles[src_idx].name = StringPoolIndex::new(new_idx) + } + (FunctionHandle, ModuleHandle) => { + self.module.function_handles[src_idx].module = ModuleHandleIndex::new(new_idx) + } + (FunctionHandle, StringPool) => { + self.module.function_handles[src_idx].name = StringPoolIndex::new(new_idx) + } + (FunctionHandle, FunctionSignature) => { + self.module.function_handles[src_idx].signature = + FunctionSignatureIndex::new(new_idx) + } + (StructDefinition, StructHandle) => { + self.module.struct_defs[src_idx].struct_handle = StructHandleIndex::new(new_idx) + } + (StructDefinition, FieldDefinition) => { + // Consider a situation with 3 fields, and with first field = 1 and count = 2. + // A graphical representation of that might be: + // + // |___|___|___| + // idx 0 1 2 + // ^ ^ + // | | + // first field = 1 (first field + count) = 3 + // + // Given that the lowest value for new_idx is 3 (offset 0), the goal is to make + // (first field + count) at least 4, or (new_idx + 1). This means that the first + // field would be new_idx + 1 - count. + let end_idx = new_idx + 1; + let first_new_idx = end_idx - self.module.struct_defs[src_idx].field_count; + self.module.struct_defs[src_idx].fields = FieldDefinitionIndex::new(first_new_idx); + err = VMStaticViolation::RangeOutOfBounds( + dst_kind, + dst_count, + first_new_idx as usize, + end_idx as usize, + ); + } + (FieldDefinition, StructHandle) => { + self.module.field_defs[src_idx].struct_ = StructHandleIndex::new(new_idx) + } + (FieldDefinition, StringPool) => { + self.module.field_defs[src_idx].name = StringPoolIndex::new(new_idx) + } + (FieldDefinition, TypeSignature) => { + self.module.field_defs[src_idx].signature = TypeSignatureIndex::new(new_idx) + } + (FunctionDefinition, FunctionHandle) => { + self.module.function_defs[src_idx].function = FunctionHandleIndex::new(new_idx) + } + (FunctionDefinition, LocalsSignature) => { + self.module.function_defs[src_idx].code.locals = LocalsSignatureIndex::new(new_idx) + } + (TypeSignature, StructHandle) => { + // For this and the other signatures, the source index will be picked from + // only the ones that have struct handles in them. + src_idx = self.type_sig_structs[src_idx].into_index(); + self.module.type_signatures[src_idx] + .0 + .debug_set_sh_idx(StructHandleIndex::new(new_idx)); + } + (FunctionSignature, StructHandle) => match &self.function_sig_structs[src_idx] { + FunctionSignatureTokenIndex::ReturnType(actual_src_idx, ret_idx) => { + src_idx = actual_src_idx.into_index(); + self.module.function_signatures[src_idx].return_types[*ret_idx] + .debug_set_sh_idx(StructHandleIndex::new(new_idx)); + } + FunctionSignatureTokenIndex::ArgType(actual_src_idx, arg_idx) => { + src_idx = actual_src_idx.into_index(); + self.module.function_signatures[src_idx].arg_types[*arg_idx] + .debug_set_sh_idx(StructHandleIndex::new(new_idx)); + } + }, + (LocalsSignature, StructHandle) => { + let (actual_src_idx, arg_idx) = self.locals_sig_structs[src_idx]; + src_idx = actual_src_idx.into_index(); + self.module.locals_signatures[src_idx].0[arg_idx] + .debug_set_sh_idx(StructHandleIndex::new(new_idx)); + } + _ => panic!("Invalid pointer kind: {:?} -> {:?}", src_kind, dst_kind), + } + + VerificationError { + kind: src_kind, + idx: src_idx, + err, + } + } + + /// Returns the indexes of type signatures that contain struct handles inside them. + fn type_sig_structs<'b>( + module: &'b CompiledModule, + ) -> impl Iterator + 'b { + let module_view = ModuleView::new(module); + module_view + .type_signatures() + .enumerate() + .filter_map(|(idx, signature)| { + signature + .token() + .struct_handle() + .map(|_| TypeSignatureIndex::new(idx as u16)) + }) + } + + /// Returns the indexes of function signatures that contain struct handles inside them. + fn function_sig_structs<'b>( + module: &'b CompiledModule, + ) -> impl Iterator + 'b { + let module_view = ModuleView::new(module); + let return_tokens = module_view + .function_signatures() + .enumerate() + .map(|(idx, signature)| { + let idx = FunctionSignatureIndex::new(idx as u16); + Self::find_struct_tokens(signature.return_tokens(), move |arg_idx| { + FunctionSignatureTokenIndex::ReturnType(idx, arg_idx) + }) + }) + .flatten(); + let arg_tokens = module_view + .function_signatures() + .enumerate() + .map(|(idx, signature)| { + let idx = FunctionSignatureIndex::new(idx as u16); + Self::find_struct_tokens(signature.arg_tokens(), move |arg_idx| { + FunctionSignatureTokenIndex::ArgType(idx, arg_idx) + }) + }) + .flatten(); + return_tokens.chain(arg_tokens) + } + + /// Returns the indexes of locals signatures that contain struct handles inside them. + fn locals_sig_structs<'b>( + module: &'b CompiledModule, + ) -> impl Iterator + 'b { + let module_view = ModuleView::new(module); + module_view + .locals_signatures() + .enumerate() + .map(|(idx, signature)| { + let idx = LocalsSignatureIndex::new(idx as u16); + Self::find_struct_tokens(signature.tokens(), move |arg_idx| (idx, arg_idx)) + }) + .flatten() + } + + #[inline] + fn find_struct_tokens<'b, F, T>( + tokens: impl IntoIterator> + 'b, + map_fn: F, + ) -> impl Iterator + 'b + where + F: Fn(usize) -> T + 'b, + { + tokens + .into_iter() + .enumerate() + .filter_map(move |(arg_idx, token)| token.struct_handle().map(|_| map_fn(arg_idx))) + } +} + +#[derive(Copy, Clone, Debug, Eq, Hash, PartialEq)] +enum FunctionSignatureTokenIndex { + ReturnType(FunctionSignatureIndex, usize), + ArgType(FunctionSignatureIndex, usize), +} diff --git a/language/bytecode_verifier/invalid_mutations/src/bounds/code_unit.rs b/language/bytecode_verifier/invalid_mutations/src/bounds/code_unit.rs new file mode 100644 index 0000000000000..0c944b25ba8bf --- /dev/null +++ b/language/bytecode_verifier/invalid_mutations/src/bounds/code_unit.rs @@ -0,0 +1,230 @@ +// Copyright (c) The Libra Core Contributors +// SPDX-License-Identifier: Apache-2.0 + +use proptest::{prelude::*, sample::Index as PropIndex}; +use proptest_helpers::pick_slice_idxs; +use std::collections::BTreeMap; +use vm::{ + errors::{VMStaticViolation, VerificationError}, + file_format::{ + AddressPoolIndex, ByteArrayPoolIndex, Bytecode, CodeOffset, CompiledModule, + FieldDefinitionIndex, FunctionHandleIndex, LocalIndex, StringPoolIndex, + StructDefinitionIndex, TableIndex, + }, + internals::ModuleIndex, + IndexKind, +}; + +/// Represents a single mutation onto a code unit to make it out of bounds. +#[derive(Debug)] +pub struct CodeUnitBoundsMutation { + function_def: PropIndex, + bytecode: PropIndex, + offset: usize, +} + +impl CodeUnitBoundsMutation { + pub fn strategy() -> impl Strategy { + (any::(), any::(), 0..16 as usize).prop_map( + |(function_def, bytecode, offset)| Self { + function_def, + bytecode, + offset, + }, + ) + } +} + +impl AsRef for CodeUnitBoundsMutation { + #[inline] + fn as_ref(&self) -> &PropIndex { + &self.bytecode + } +} + +pub struct ApplyCodeUnitBoundsContext<'a> { + module: &'a mut CompiledModule, + // This is so apply_one can be called after mutations has been iterated on. + mutations: Option>, +} + +macro_rules! new_bytecode { + ($dst_len: expr, $offset: expr, $idx_type: ident, $bytecode_ident: tt) => {{ + let dst_len = $dst_len; + let new_idx = (dst_len + $offset) as TableIndex; + ( + $bytecode_ident($idx_type::new(new_idx)), + VMStaticViolation::IndexOutOfBounds($idx_type::KIND, dst_len, new_idx as usize), + ) + }}; +} + +macro_rules! code_bytecode { + ($code_len: expr, $offset: expr, $bytecode_ident: tt) => {{ + let code_len = $code_len; + let new_idx = code_len + $offset; + ( + $bytecode_ident(new_idx as CodeOffset), + VMStaticViolation::IndexOutOfBounds(IndexKind::CodeDefinition, code_len, new_idx), + ) + }}; +} + +macro_rules! locals_bytecode { + ($locals_len: expr, $offset: expr, $bytecode_ident: tt) => {{ + let locals_len = $locals_len; + let new_idx = locals_len + $offset; + ( + $bytecode_ident(new_idx as LocalIndex), + VMStaticViolation::IndexOutOfBounds(IndexKind::LocalPool, locals_len, new_idx), + ) + }}; +} + +impl<'a> ApplyCodeUnitBoundsContext<'a> { + pub fn new(module: &'a mut CompiledModule, mutations: Vec) -> Self { + Self { + module, + mutations: Some(mutations), + } + } + + pub fn apply(mut self) -> Vec { + let function_def_len = self.module.function_defs.len(); + + let mut mutation_map = BTreeMap::new(); + for mutation in self + .mutations + .take() + .expect("mutations should always be present") + { + let picked_idx = mutation.function_def.index(function_def_len); + mutation_map + .entry(picked_idx) + .or_insert_with(|| vec![]) + .push(mutation); + } + + let mut results = vec![]; + + for (idx, mutations) in mutation_map { + results.extend(self.apply_one(idx, mutations)); + } + results + } + + fn apply_one( + &mut self, + idx: usize, + mutations: Vec, + ) -> Vec { + // For this function def, find all the places where a bounds mutation can be applied. + let (code_len, locals_len) = { + let code = &mut self.module.function_defs[idx].code; + ( + code.code.len(), + self.module.locals_signatures[code.locals.into_index()].len(), + ) + }; + + let mut interesting: Vec<&mut Bytecode> = self.module.function_defs[idx] + .code + .code + .iter_mut() + .filter(|bytecode| is_interesting(*bytecode)) + .collect(); + let to_mutate = pick_slice_idxs(interesting.len(), &mutations); + + // These have to be computed upfront because self.module is being mutated below. + let address_pool_len = self.module.address_pool.len(); + let string_pool_len = self.module.string_pool.len(); + let byte_array_pool_len = self.module.byte_array_pool.len(); + let function_handles_len = self.module.function_handles.len(); + let field_defs_len = self.module.field_defs.len(); + let struct_defs_len = self.module.struct_defs.len(); + + mutations + .iter() + .zip(to_mutate) + .map(|(mutation, bytecode_idx)| { + let offset = mutation.offset; + use Bytecode::*; + + let (new_bytecode, err) = match interesting[bytecode_idx] { + LdAddr(_) => new_bytecode!(address_pool_len, offset, AddressPoolIndex, LdAddr), + LdStr(_) => new_bytecode!(string_pool_len, offset, StringPoolIndex, LdStr), + LdByteArray(_) => { + new_bytecode!(byte_array_pool_len, offset, ByteArrayPoolIndex, LdByteArray) + } + BorrowField(_) => { + new_bytecode!(field_defs_len, offset, FieldDefinitionIndex, BorrowField) + } + Call(_) => { + new_bytecode!(function_handles_len, offset, FunctionHandleIndex, Call) + } + Pack(_) => new_bytecode!(struct_defs_len, offset, StructDefinitionIndex, Pack), + Unpack(_) => { + new_bytecode!(struct_defs_len, offset, StructDefinitionIndex, Unpack) + } + Exists(_) => { + new_bytecode!(struct_defs_len, offset, StructDefinitionIndex, Exists) + } + BorrowGlobal(_) => { + new_bytecode!(struct_defs_len, offset, StructDefinitionIndex, BorrowGlobal) + } + MoveFrom(_) => { + new_bytecode!(struct_defs_len, offset, StructDefinitionIndex, MoveFrom) + } + MoveToSender(_) => { + new_bytecode!(struct_defs_len, offset, StructDefinitionIndex, MoveToSender) + } + BrTrue(_) => code_bytecode!(code_len, offset, BrTrue), + BrFalse(_) => code_bytecode!(code_len, offset, BrFalse), + Branch(_) => code_bytecode!(code_len, offset, Branch), + CopyLoc(_) => locals_bytecode!(locals_len, offset, CopyLoc), + MoveLoc(_) => locals_bytecode!(locals_len, offset, MoveLoc), + StLoc(_) => locals_bytecode!(locals_len, offset, StLoc), + BorrowLoc(_) => locals_bytecode!(locals_len, offset, BorrowLoc), + + // List out the other options explicitly so there's a compile error if a new + // bytecode gets added. + FreezeRef | ReleaseRef | Pop | Ret | LdConst(_) | LdTrue | LdFalse + | ReadRef | WriteRef | Add | Sub | Mul | Mod | Div | BitOr | BitAnd | Xor + | Or | And | Not | Eq | Neq | Lt | Gt | Le | Ge | Assert + | GetTxnGasUnitPrice | GetTxnMaxGasUnits | GetGasRemaining + | GetTxnSenderAddress | CreateAccount | EmitEvent | GetTxnSequenceNumber + | GetTxnPublicKey => panic!( + "Bytecode has no internal index: {:?}", + interesting[bytecode_idx] + ), + }; + + *interesting[bytecode_idx] = new_bytecode; + + VerificationError { + kind: IndexKind::FunctionDefinition, + idx, + err, + } + }) + .collect() + } +} + +fn is_interesting(bytecode: &Bytecode) -> bool { + use Bytecode::*; + + match bytecode { + LdAddr(_) | LdStr(_) | LdByteArray(_) | BorrowField(_) | Call(_) | Pack(_) | Unpack(_) + | Exists(_) | BorrowGlobal(_) | MoveFrom(_) | MoveToSender(_) | BrTrue(_) | BrFalse(_) + | Branch(_) | CopyLoc(_) | MoveLoc(_) | StLoc(_) | BorrowLoc(_) => true, + + // List out the other options explicitly so there's a compile error if a new + // bytecode gets added. + FreezeRef | ReleaseRef | Pop | Ret | LdConst(_) | LdTrue | LdFalse | ReadRef | WriteRef + | Add | Sub | Mul | Mod | Div | BitOr | BitAnd | Xor | Or | And | Not | Eq | Neq | Lt + | Gt | Le | Ge | Assert | GetTxnGasUnitPrice | GetTxnMaxGasUnits | GetGasRemaining + | GetTxnSenderAddress | CreateAccount | EmitEvent | GetTxnSequenceNumber + | GetTxnPublicKey => false, + } +} diff --git a/language/bytecode_verifier/invalid_mutations/src/lib.rs b/language/bytecode_verifier/invalid_mutations/src/lib.rs new file mode 100644 index 0000000000000..81ed355898ec5 --- /dev/null +++ b/language/bytecode_verifier/invalid_mutations/src/lib.rs @@ -0,0 +1,5 @@ +// Copyright (c) The Libra Core Contributors +// SPDX-License-Identifier: Apache-2.0 + +pub mod bounds; +pub mod signature; diff --git a/language/bytecode_verifier/invalid_mutations/src/signature.rs b/language/bytecode_verifier/invalid_mutations/src/signature.rs new file mode 100644 index 0000000000000..99698d1c2c6fa --- /dev/null +++ b/language/bytecode_verifier/invalid_mutations/src/signature.rs @@ -0,0 +1,248 @@ +// Copyright (c) The Libra Core Contributors +// SPDX-License-Identifier: Apache-2.0 + +use proptest::{ + prelude::*, + sample::{select, Index as PropIndex}, +}; +use proptest_helpers::{pick_slice_idxs, RepeatVec}; +use std::collections::BTreeMap; +use vm::{ + errors::{VMStaticViolation, VerificationError}, + file_format::{CompiledModule, SignatureToken}, + internals::ModuleIndex, + IndexKind, SignatureTokenKind, +}; + +/// Represents a mutation that wraps a signature token up in a double reference (or an array of +/// references. +#[derive(Clone, Debug)] +pub struct DoubleRefMutation { + idx: PropIndex, + kind: DoubleRefMutationKind, +} + +impl DoubleRefMutation { + pub fn strategy() -> impl Strategy { + (any::(), DoubleRefMutationKind::strategy()) + .prop_map(|(idx, kind)| Self { idx, kind }) + } +} + +impl AsRef for DoubleRefMutation { + #[inline] + fn as_ref(&self) -> &PropIndex { + &self.idx + } +} + +/// Context for applying a list of `DoubleRefMutation` instances. +pub struct ApplySignatureDoubleRefContext<'a> { + module: &'a mut CompiledModule, + mutations: Vec, +} + +impl<'a> ApplySignatureDoubleRefContext<'a> { + pub fn new(module: &'a mut CompiledModule, mutations: Vec) -> Self { + Self { module, mutations } + } + + pub fn apply(self) -> Vec { + // Apply double refs before field refs -- XXX is this correct? + let sig_indexes = self.all_sig_indexes(); + let picked = sig_indexes.pick_uniform(&self.mutations); + + let mut errs = vec![]; + + for (double_ref, (sig_idx, idx2)) in self.mutations.iter().zip(picked) { + // When there's one level of indexing (e.g. Type), idx2 represents that level. + // When there's two levels of indexing (e.g. FunctionArg), idx1 represents the outer + // level (signature index) and idx2 the inner level (token index). + let (token, kind, error_idx) = match sig_idx { + SignatureIndex::Type => ( + &mut self.module.type_signatures[idx2].0, + IndexKind::TypeSignature, + idx2, + ), + SignatureIndex::FunctionReturn(idx1) => ( + &mut self.module.function_signatures[*idx1].return_types[idx2], + IndexKind::FunctionSignature, + *idx1, + ), + SignatureIndex::FunctionArg(idx1) => ( + &mut self.module.function_signatures[*idx1].arg_types[idx2], + IndexKind::FunctionSignature, + *idx1, + ), + SignatureIndex::Locals(idx1) => ( + &mut self.module.locals_signatures[*idx1].0[idx2], + IndexKind::LocalsSignature, + *idx1, + ), + }; + + *token = double_ref.kind.wrap(token.clone()); + errs.push(VerificationError { + kind, + idx: error_idx, + err: VMStaticViolation::InvalidSignatureToken( + token.clone(), + double_ref.kind.outer, + double_ref.kind.inner, + ), + }); + } + + errs + } + + fn all_sig_indexes(&self) -> RepeatVec { + let mut res = RepeatVec::new(); + res.extend(SignatureIndex::Type, self.module.type_signatures.len()); + for (idx, sig) in self.module.function_signatures.iter().enumerate() { + res.extend(SignatureIndex::FunctionReturn(idx), sig.return_types.len()); + } + for (idx, sig) in self.module.function_signatures.iter().enumerate() { + res.extend(SignatureIndex::FunctionArg(idx), sig.arg_types.len()); + } + for (idx, sig) in self.module.locals_signatures.iter().enumerate() { + res.extend(SignatureIndex::Locals(idx), sig.0.len()); + } + res + } +} + +/// Represents a mutation that turns a field definition's type into a reference. +#[derive(Clone, Debug)] +pub struct FieldRefMutation { + idx: PropIndex, + is_mutable: bool, +} + +impl FieldRefMutation { + pub fn strategy() -> impl Strategy { + (any::(), any::()).prop_map(|(idx, is_mutable)| Self { idx, is_mutable }) + } +} + +impl AsRef for FieldRefMutation { + #[inline] + fn as_ref(&self) -> &PropIndex { + &self.idx + } +} + +/// Context for applying a list of `FieldRefMutation` instances. +pub struct ApplySignatureFieldRefContext<'a> { + module: &'a mut CompiledModule, + mutations: Vec, +} + +impl<'a> ApplySignatureFieldRefContext<'a> { + pub fn new(module: &'a mut CompiledModule, mutations: Vec) -> Self { + Self { module, mutations } + } + + #[inline] + pub fn apply(self) -> Vec { + // One field definition might be associated with more than one signature, so collect all + // the interesting ones in a map of type_sig_idx => field_def_idx. + let mut interesting_idxs = BTreeMap::new(); + for (field_def_idx, field_def) in self.module.field_defs.iter().enumerate() { + interesting_idxs + .entry(field_def.signature) + .or_insert_with(|| vec![]) + .push(field_def_idx); + } + // Convert into a Vec of pairs to allow pick_slice_idxs return vvalues to work. + let interesting_idxs: Vec<_> = interesting_idxs.into_iter().collect(); + + let picked = pick_slice_idxs(interesting_idxs.len(), &self.mutations); + let mut errs = vec![]; + for (mutation, picked_idx) in self.mutations.iter().zip(picked) { + let (type_sig_idx, field_def_idxs) = &interesting_idxs[picked_idx]; + let token = &mut self.module.type_signatures[type_sig_idx.into_index()].0; + let (new_token, token_kind) = if mutation.is_mutable { + ( + SignatureToken::MutableReference(Box::new(token.clone())), + SignatureTokenKind::MutableReference, + ) + } else { + ( + SignatureToken::Reference(Box::new(token.clone())), + SignatureTokenKind::Reference, + ) + }; + + *token = new_token; + + let violation = VMStaticViolation::InvalidFieldDefReference(token.clone(), token_kind); + errs.extend( + field_def_idxs + .iter() + .map(|field_def_idx| VerificationError { + kind: IndexKind::FieldDefinition, + idx: *field_def_idx, + err: violation.clone(), + }), + ); + } + + errs + } +} + +#[derive(Copy, Clone, Eq, PartialEq)] +pub enum SignatureIndex { + Type, + FunctionReturn(usize), + FunctionArg(usize), + Locals(usize), +} + +#[derive(Clone, Debug)] +struct DoubleRefMutationKind { + outer: SignatureTokenKind, + inner: SignatureTokenKind, +} + +impl DoubleRefMutationKind { + fn strategy() -> impl Strategy { + (Self::outer_strategy(), Self::inner_strategy()) + .prop_map(|(outer, inner)| Self { outer, inner }) + } + + fn wrap(&self, token: SignatureToken) -> SignatureToken { + let token = Self::wrap_one(token, self.inner); + Self::wrap_one(token, self.outer) + } + + fn wrap_one(token: SignatureToken, kind: SignatureTokenKind) -> SignatureToken { + match kind { + SignatureTokenKind::Reference => SignatureToken::Reference(Box::new(token)), + SignatureTokenKind::MutableReference => { + SignatureToken::MutableReference(Box::new(token)) + } + SignatureTokenKind::Value => panic!("invalid wrapping kind: {}", kind), + } + } + + #[inline] + fn outer_strategy() -> impl Strategy { + static VALID_OUTERS: &[SignatureTokenKind] = &[ + SignatureTokenKind::Reference, + SignatureTokenKind::MutableReference, + ]; + select(VALID_OUTERS) + } + + #[inline] + fn inner_strategy() -> impl Strategy { + static VALID_INNERS: &[SignatureTokenKind] = &[ + SignatureTokenKind::Reference, + SignatureTokenKind::MutableReference, + ]; + + select(VALID_INNERS) + } +} diff --git a/language/bytecode_verifier/src/abstract_interpreter.rs b/language/bytecode_verifier/src/abstract_interpreter.rs new file mode 100644 index 0000000000000..727e2e61a6ffa --- /dev/null +++ b/language/bytecode_verifier/src/abstract_interpreter.rs @@ -0,0 +1,868 @@ +// Copyright (c) The Libra Core Contributors +// SPDX-License-Identifier: Apache-2.0 + +//! This module defines the abstract interpretater for verifying type and memory safety on a +//! function body. +use crate::{ + abstract_state::{AbstractState, AbstractValue, JoinResult}, + code_unit_verifier::VerificationPass, + control_flow_graph::{BlockId, ControlFlowGraph, VMControlFlowGraph}, + nonce::Nonce, +}; +use std::collections::{BTreeMap, BTreeSet}; +use vm::{ + access::{BaseAccess, ModuleAccess}, + errors::VMStaticViolation, + file_format::{ + Bytecode, CompiledModule, FieldDefinitionIndex, FunctionDefinition, LocalIndex, + SignatureToken, StructHandleIndex, + }, + views::{ + FieldDefinitionView, FunctionDefinitionView, FunctionSignatureView, LocalsSignatureView, + SignatureTokenView, StructDefinitionView, ViewInternals, + }, +}; + +#[derive(Clone, Debug, Eq, PartialEq)] +struct StackAbstractValue { + signature: SignatureToken, + value: AbstractValue, +} + +pub struct AbstractInterpreter<'a> { + module: &'a CompiledModule, + function_definition_view: FunctionDefinitionView<'a, CompiledModule>, + locals_signature_view: LocalsSignatureView<'a, CompiledModule>, + cfg: &'a VMControlFlowGraph, + block_id_to_state: BTreeMap, + erroneous_blocks: BTreeSet, + work_list: Vec, + stack: Vec, + next_nonce: usize, +} + +impl<'a> VerificationPass<'a> for AbstractInterpreter<'a> { + fn new( + module: &'a CompiledModule, + function_definition: &'a FunctionDefinition, + cfg: &'a VMControlFlowGraph, + ) -> Self { + let function_definition_view = FunctionDefinitionView::new(module, function_definition); + let locals_signature_view = function_definition_view.locals_signature(); + let function_signature_view = function_definition_view.signature(); + let mut block_id_to_state = BTreeMap::new(); + let erroneous_blocks = BTreeSet::new(); + let mut locals = BTreeMap::new(); + for (arg_idx, arg_type_view) in function_signature_view.arg_tokens().enumerate() { + if arg_type_view.is_reference() { + locals.insert( + arg_idx as LocalIndex, + AbstractValue::Reference(Nonce::new(arg_idx)), + ); + } else { + locals.insert( + arg_idx as LocalIndex, + AbstractValue::full_value(arg_type_view.is_resource()), + ); + } + } + block_id_to_state.insert(0, AbstractState::new(locals)); + let next_nonce = function_signature_view.arg_count(); + Self { + module, + function_definition_view, + locals_signature_view, + cfg, + block_id_to_state, + erroneous_blocks, + work_list: vec![0], + stack: vec![], + next_nonce, + } + } + + fn verify(mut self) -> Vec { + let mut errors = vec![]; + while !self.work_list.is_empty() { + let block_id = self.work_list.pop().unwrap(); + errors.append(&mut self.propagate(block_id)); + } + errors + } +} + +impl<'a> AbstractInterpreter<'a> { + fn propagate(&mut self, block_id: BlockId) -> Vec { + match self.compute(block_id) { + Ok(flow_state) => { + let state = flow_state.construct_canonical_state(); + let mut errors = vec![]; + let block = &self + .cfg + .block_of_id(block_id) + .expect("block_id is not the start offset of a block"); + for next_block_id in &block.successors { + if self.erroneous_blocks.contains(next_block_id) { + continue; + } + if !self.block_id_to_state.contains_key(next_block_id) { + self.work_list.push(*next_block_id); + self.block_id_to_state + .entry(*next_block_id) + .or_insert_with(|| state.clone()); + } else { + let curr_state = &self.block_id_to_state[next_block_id]; + let join_result = curr_state.join(&state); + match join_result { + JoinResult::Unchanged => {} + JoinResult::Changed(next_state) => { + self.block_id_to_state + .entry(*next_block_id) + .and_modify(|entry| *entry = next_state); + self.work_list.push(*next_block_id); + } + JoinResult::Error => { + errors.append(&mut vec![VMStaticViolation::JoinFailure( + *next_block_id as usize, + )]); + } + } + } + } + errors + } + Err(es) => { + self.erroneous_blocks.insert(block_id); + es + } + } + } + + fn compute(&mut self, block_id: BlockId) -> Result> { + let mut state = self.block_id_to_state[&block_id].clone(); + let block = &self.cfg.block_of_id(block_id).unwrap(); + let mut offset = block.entry; + while offset <= block.exit { + let result = self.next( + state, + offset as usize, + &self.function_definition_view.code().code[offset as usize], + ); + match result { + Ok(next_state) => state = next_state, + Err(errors) => return Err(errors), + } + offset += 1; + } + Ok(state) + } + + fn get_field_signature(&self, field_definition_index: FieldDefinitionIndex) -> SignatureToken { + let field_definition = self.module.field_def_at(field_definition_index); + let field_definition_view = FieldDefinitionView::new(self.module, field_definition); + field_definition_view + .type_signature() + .token() + .as_inner() + .clone() + } + + fn is_field_in_struct( + &self, + field_definition_index: FieldDefinitionIndex, + struct_handle_index: StructHandleIndex, + ) -> bool { + let field_definition = self.module.field_def_at(field_definition_index); + struct_handle_index == field_definition.struct_ + } + + fn get_nonce(&mut self, state: &mut AbstractState) -> Nonce { + let nonce = Nonce::new(self.next_nonce); + state.add_nonce(nonce.clone()); + self.next_nonce += 1; + nonce + } + + fn freeze_ok(&self, state: &AbstractState, existing_borrows: BTreeSet) -> bool { + for (arg_idx, arg_type_view) in self.locals_signature_view.tokens().enumerate() { + if arg_type_view.as_inner().is_mutable_reference() + && state.is_available(arg_idx as LocalIndex) + { + if let AbstractValue::Reference(nonce) = state.local(arg_idx as LocalIndex) { + if existing_borrows.contains(nonce) { + return false; + } + } + } + } + for stack_value in &self.stack { + if stack_value.signature.is_mutable_reference() { + if let AbstractValue::Reference(nonce) = &stack_value.value { + if existing_borrows.contains(nonce) { + return false; + } + } + } + } + true + } + + fn write_borrow_ok(existing_borrows: BTreeSet) -> bool { + existing_borrows.is_empty() + } + + fn extract_nonce(value: &AbstractValue) -> Option<&Nonce> { + match value { + AbstractValue::Reference(nonce) => Some(nonce), + AbstractValue::Value(_, _) => None, + } + } + + fn is_safe_to_destroy(&self, state: &AbstractState, idx: LocalIndex) -> bool { + match state.local(idx) { + AbstractValue::Reference(_) => false, + AbstractValue::Value(is_resource, borrowed_nonces) => { + !is_resource && borrowed_nonces.is_empty() + } + } + } + + fn next( + &mut self, + state: AbstractState, + offset: usize, + bytecode: &Bytecode, + ) -> Result> { + match bytecode { + Bytecode::Pop => { + let operand = self.stack.pop().unwrap(); + if SignatureTokenView::new(self.module, &operand.signature).is_resource() { + Err(vec![VMStaticViolation::PopResourceError(offset)]) + } else if operand.value.is_reference() { + Err(vec![VMStaticViolation::PopReferenceError(offset)]) + } else { + Ok(state) + } + } + + Bytecode::ReleaseRef => { + let operand = self.stack.pop().unwrap(); + if let AbstractValue::Reference(nonce) = operand.value { + let mut next_state = state; + next_state.destroy_nonce(nonce); + Ok(next_state) + } else { + Err(vec![VMStaticViolation::ReleaseRefTypeMismatchError(offset)]) + } + } + + Bytecode::BrTrue(_) | Bytecode::BrFalse(_) => { + let operand = self.stack.pop().unwrap(); + if operand.signature == SignatureToken::Bool { + Ok(state) + } else { + Err(vec![VMStaticViolation::BrTypeMismatchError(offset)]) + } + } + + Bytecode::Assert => { + let condition = self.stack.pop().unwrap(); + let error_code = self.stack.pop().unwrap(); + if condition.signature == SignatureToken::Bool + && error_code.signature == SignatureToken::U64 + { + Ok(state) + } else { + Err(vec![VMStaticViolation::AssertTypeMismatchError(offset)]) + } + } + + Bytecode::StLoc(idx) => { + let operand = self.stack.pop().unwrap(); + if operand.signature != *self.locals_signature_view.token_at(*idx).as_inner() { + return Err(vec![VMStaticViolation::StLocTypeMismatchError(offset)]); + } + let mut next_state = state; + if next_state.is_available(*idx) { + if self.is_safe_to_destroy(&next_state, *idx) { + next_state.destroy_local(*idx); + } else { + return Err(vec![VMStaticViolation::StLocUnsafeToDestroyError(offset)]); + } + } + next_state.insert_local(*idx, operand.value); + Ok(next_state) + } + + Bytecode::Ret => { + for arg_idx in 0..self.locals_signature_view.len() { + let idx = arg_idx as LocalIndex; + if state.is_available(idx) && !self.is_safe_to_destroy(&state, idx) { + return Err(vec![VMStaticViolation::RetUnsafeToDestroyError(offset)]); + } + } + for return_type_view in self + .function_definition_view + .signature() + .return_tokens() + .rev() + { + let operand = self.stack.pop().unwrap(); + if operand.signature != *return_type_view.as_inner() { + return Err(vec![VMStaticViolation::RetTypeMismatchError(offset)]); + } + } + Ok(AbstractState::new(BTreeMap::new())) + } + + Bytecode::Branch(_) => Ok(state), + + Bytecode::FreezeRef => { + let operand = self.stack.pop().unwrap(); + if let SignatureToken::MutableReference(signature) = operand.signature { + let operand_nonce = Self::extract_nonce(&operand.value).unwrap().clone(); + let borrowed_nonces = state.borrowed_nonces(operand_nonce.clone()); + if self.freeze_ok(&state, borrowed_nonces) { + self.stack.push(StackAbstractValue { + signature: SignatureToken::Reference(signature), + value: operand.value, + }); + return Ok(state); + } else { + Err(vec![VMStaticViolation::FreezeRefExistsMutableBorrowError( + offset, + )]) + } + } else { + Err(vec![VMStaticViolation::FreezeRefTypeMismatchError(offset)]) + } + } + + Bytecode::BorrowField(field_definition_index) => { + let operand = self.stack.pop().unwrap(); + if let Some(struct_handle_index) = + SignatureToken::get_struct_handle_from_reference(&operand.signature) + { + if self.is_field_in_struct(*field_definition_index, struct_handle_index) { + let field_signature = self.get_field_signature(*field_definition_index); + let operand_nonce = Self::extract_nonce(&operand.value).unwrap().clone(); + let mut next_state = state; + let nonce = self.get_nonce(&mut next_state); + if operand.signature.is_mutable_reference() { + let borrowed_nonces = next_state.borrowed_nonces_for_field( + *field_definition_index, + operand_nonce.clone(), + ); + if Self::write_borrow_ok(borrowed_nonces) { + self.stack.push(StackAbstractValue { + signature: SignatureToken::MutableReference(Box::new( + field_signature, + )), + value: AbstractValue::Reference(nonce.clone()), + }); + next_state.borrow_field_from_nonce( + *field_definition_index, + operand_nonce.clone(), + nonce, + ); + next_state.destroy_nonce(operand_nonce); + Ok(next_state) + } else { + Err(vec![ + VMStaticViolation::BorrowFieldExistsMutableBorrowError(offset), + ]) + } + } else { + self.stack.push(StackAbstractValue { + signature: SignatureToken::Reference(Box::new(field_signature)), + value: AbstractValue::Reference(nonce.clone()), + }); + next_state.borrow_field_from_nonce( + *field_definition_index, + operand_nonce.clone(), + nonce, + ); + next_state.destroy_nonce(operand_nonce); + Ok(next_state) + } + } else { + Err(vec![VMStaticViolation::BorrowFieldBadFieldError(offset)]) + } + } else { + Err(vec![VMStaticViolation::BorrowFieldTypeMismatchError( + offset, + )]) + } + } + + Bytecode::LdConst(_) => { + self.stack.push(StackAbstractValue { + signature: SignatureToken::U64, + value: AbstractValue::full_value(false), + }); + Ok(state) + } + + Bytecode::LdAddr(_) => { + self.stack.push(StackAbstractValue { + signature: SignatureToken::Address, + value: AbstractValue::full_value(false), + }); + Ok(state) + } + + Bytecode::LdStr(_) => { + self.stack.push(StackAbstractValue { + signature: SignatureToken::String, + value: AbstractValue::full_value(false), + }); + Ok(state) + } + + Bytecode::LdByteArray(_) => { + self.stack.push(StackAbstractValue { + signature: SignatureToken::ByteArray, + value: AbstractValue::full_value(false), + }); + Ok(state) + } + + Bytecode::LdTrue | Bytecode::LdFalse => { + self.stack.push(StackAbstractValue { + signature: SignatureToken::Bool, + value: AbstractValue::full_value(false), + }); + Ok(state) + } + + Bytecode::CopyLoc(idx) => { + let signature_view = self.locals_signature_view.token_at(*idx); + if !state.is_available(*idx) { + Err(vec![VMStaticViolation::CopyLocUnavailableError(offset)]) + } else if signature_view.is_reference() { + let mut next_state = state; + let nonce = self.get_nonce(&mut next_state); + next_state.borrow_from_local_reference(*idx, nonce.clone()); + self.stack.push(StackAbstractValue { + signature: signature_view.as_inner().clone(), + value: AbstractValue::Reference(nonce), + }); + Ok(next_state) + } else if signature_view.is_resource() { + Err(vec![VMStaticViolation::CopyLocResourceError(offset)]) + } else if state.is_full(state.local(*idx)) { + self.stack.push(StackAbstractValue { + signature: signature_view.as_inner().clone(), + value: AbstractValue::full_value(false), + }); + Ok(state) + } else { + Err(vec![VMStaticViolation::CopyLocExistsBorrowError(offset)]) + } + } + + Bytecode::MoveLoc(idx) => { + let signature = self.locals_signature_view.token_at(*idx).as_inner().clone(); + if !state.is_available(*idx) { + Err(vec![VMStaticViolation::MoveLocUnavailableError(offset)]) + } else if signature.is_reference() || state.is_full(state.local(*idx)) { + let mut next_state = state; + let value = next_state.remove_local(*idx); + self.stack.push(StackAbstractValue { signature, value }); + Ok(next_state) + } else { + Err(vec![VMStaticViolation::MoveLocExistsBorrowError(offset)]) + } + } + + Bytecode::BorrowLoc(idx) => { + let signature = self.locals_signature_view.token_at(*idx).as_inner().clone(); + if signature.is_reference() { + Err(vec![VMStaticViolation::BorrowLocReferenceError(offset)]) + } else if !state.is_available(*idx) { + Err(vec![VMStaticViolation::BorrowLocUnavailableError(offset)]) + } else if state.is_full(state.local(*idx)) { + let mut next_state = state; + let nonce = self.get_nonce(&mut next_state); + next_state.borrow_from_local_value(*idx, nonce.clone()); + self.stack.push(StackAbstractValue { + signature: SignatureToken::MutableReference(Box::new(signature)), + value: AbstractValue::Reference(nonce), + }); + Ok(next_state) + } else { + Err(vec![VMStaticViolation::BorrowLocExistsBorrowError(offset)]) + } + } + + Bytecode::Call(idx) => { + let function_handle = self.module.function_handle_at(*idx); + let function_signature = + self.module.function_signature_at(function_handle.signature); + let function_signature_view = + FunctionSignatureView::new(self.module, function_signature); + let mut all_references_to_borrow_from = BTreeSet::new(); + let mut mutable_references_to_borrow_from = BTreeSet::new(); + for arg_type in function_signature.arg_types.iter().rev() { + let arg = self.stack.pop().unwrap(); + if arg.signature != *arg_type { + return Err(vec![VMStaticViolation::CallTypeMismatchError(offset)]); + } + if arg_type.is_mutable_reference() && !state.is_full(&arg.value) { + return Err(vec![VMStaticViolation::CallBorrowedMutableReferenceError( + offset, + )]); + } + if let AbstractValue::Reference(nonce) = arg.value { + all_references_to_borrow_from.insert(nonce.clone()); + if arg_type.is_mutable_reference() { + mutable_references_to_borrow_from.insert(nonce.clone()); + } + } + } + let mut next_state = state; + for return_type_view in function_signature_view.return_tokens() { + if return_type_view.is_reference() { + let nonce = self.get_nonce(&mut next_state); + if return_type_view.is_mutable_reference() { + next_state.borrow_from_nonces( + &mutable_references_to_borrow_from, + nonce.clone(), + ); + } else { + next_state + .borrow_from_nonces(&all_references_to_borrow_from, nonce.clone()); + } + self.stack.push(StackAbstractValue { + signature: return_type_view.as_inner().clone(), + value: AbstractValue::Reference(nonce), + }); + } else { + self.stack.push(StackAbstractValue { + signature: return_type_view.as_inner().clone(), + value: AbstractValue::full_value(return_type_view.is_resource()), + }); + } + } + for x in all_references_to_borrow_from { + next_state.destroy_nonce(x); + } + Ok(next_state) + } + + Bytecode::Pack(idx) => { + let struct_definition = self.module.struct_def_at(*idx); + let struct_definition_view = + StructDefinitionView::new(self.module, struct_definition); + for field_definition_view in struct_definition_view.fields().rev() { + let field_signature_view = field_definition_view.type_signature(); + let arg = self.stack.pop().unwrap(); + if arg.signature != *field_signature_view.token().as_inner() { + return Err(vec![VMStaticViolation::PackTypeMismatchError(offset)]); + } + } + self.stack.push(StackAbstractValue { + signature: SignatureToken::Struct(struct_definition.struct_handle), + value: AbstractValue::full_value(struct_definition_view.is_resource()), + }); + Ok(state) + } + + Bytecode::Unpack(idx) => { + let struct_definition = self.module.struct_def_at(*idx); + let struct_arg = self.stack.pop().unwrap(); + if struct_arg.signature != SignatureToken::Struct(struct_definition.struct_handle) { + return Err(vec![VMStaticViolation::UnpackTypeMismatchError(offset)]); + } + let struct_definition_view = + StructDefinitionView::new(self.module, struct_definition); + for field_definition_view in struct_definition_view.fields() { + let field_signature_view = field_definition_view.type_signature(); + self.stack.push(StackAbstractValue { + signature: field_signature_view.token().as_inner().clone(), + value: AbstractValue::full_value(field_signature_view.is_resource()), + }) + } + Ok(state) + } + + Bytecode::ReadRef => { + let operand = self.stack.pop().unwrap(); + match operand.signature { + SignatureToken::Reference(signature) => { + let operand_nonce = Self::extract_nonce(&operand.value).unwrap().clone(); + if SignatureTokenView::new(self.module, &signature).is_resource() { + Err(vec![VMStaticViolation::ReadRefResourceError(offset)]) + } else { + self.stack.push(StackAbstractValue { + signature: *signature, + value: AbstractValue::full_value(false), + }); + let mut next_state = state; + next_state.destroy_nonce(operand_nonce); + Ok(next_state) + } + } + SignatureToken::MutableReference(signature) => { + let operand_nonce = Self::extract_nonce(&operand.value).unwrap().clone(); + if SignatureTokenView::new(self.module, &signature).is_resource() { + Err(vec![VMStaticViolation::ReadRefResourceError(offset)]) + } else { + let borrowed_nonces = state.borrowed_nonces(operand_nonce.clone()); + if self.freeze_ok(&state, borrowed_nonces) { + self.stack.push(StackAbstractValue { + signature: *signature, + value: AbstractValue::full_value(false), + }); + let mut next_state = state; + next_state.destroy_nonce(operand_nonce); + Ok(next_state) + } else { + Err(vec![VMStaticViolation::ReadRefExistsMutableBorrowError( + offset, + )]) + } + } + } + _ => Err(vec![VMStaticViolation::ReadRefTypeMismatchError(offset)]), + } + } + + Bytecode::WriteRef => { + let ref_operand = self.stack.pop().unwrap(); + let val_operand = self.stack.pop().unwrap(); + if let SignatureToken::MutableReference(signature) = ref_operand.signature { + if SignatureTokenView::new(self.module, &signature).is_resource() { + Err(vec![VMStaticViolation::WriteRefResourceError(offset)]) + } else if val_operand.signature != *signature { + Err(vec![VMStaticViolation::WriteRefTypeMismatchError(offset)]) + } else if state.is_full(&ref_operand.value) { + let ref_operand_nonce = + Self::extract_nonce(&ref_operand.value).unwrap().clone(); + let mut next_state = state; + next_state.destroy_nonce(ref_operand_nonce); + Ok(next_state) + } else { + Err(vec![VMStaticViolation::WriteRefExistsBorrowError(offset)]) + } + } else { + Err(vec![VMStaticViolation::WriteRefNoMutableReferenceError( + offset, + )]) + } + } + + Bytecode::Add + | Bytecode::Sub + | Bytecode::Mul + | Bytecode::Mod + | Bytecode::Div + | Bytecode::BitOr + | Bytecode::BitAnd + | Bytecode::Xor => { + let operand1 = self.stack.pop().unwrap(); + let operand2 = self.stack.pop().unwrap(); + if operand1.signature == SignatureToken::U64 + && operand2.signature == SignatureToken::U64 + { + self.stack.push(StackAbstractValue { + signature: SignatureToken::U64, + value: AbstractValue::full_value(false), + }); + Ok(state) + } else { + Err(vec![VMStaticViolation::IntegerOpTypeMismatchError(offset)]) + } + } + + Bytecode::Or | Bytecode::And => { + let operand1 = self.stack.pop().unwrap(); + let operand2 = self.stack.pop().unwrap(); + if operand1.signature == SignatureToken::Bool + && operand2.signature == SignatureToken::Bool + { + self.stack.push(StackAbstractValue { + signature: SignatureToken::Bool, + value: AbstractValue::full_value(false), + }); + Ok(state) + } else { + Err(vec![VMStaticViolation::BooleanOpTypeMismatchError(offset)]) + } + } + + Bytecode::Not => { + let operand = self.stack.pop().unwrap(); + if operand.signature == SignatureToken::Bool { + self.stack.push(StackAbstractValue { + signature: SignatureToken::Bool, + value: AbstractValue::full_value(false), + }); + Ok(state) + } else { + Err(vec![VMStaticViolation::BooleanOpTypeMismatchError(offset)]) + } + } + + Bytecode::Eq | Bytecode::Neq => { + let operand1 = self.stack.pop().unwrap(); + let operand2 = self.stack.pop().unwrap(); + if operand1.signature.allows_equality() && operand1.signature == operand2.signature + { + let mut next_state = state; + if let AbstractValue::Reference(nonce) = operand1.value { + next_state.destroy_nonce(nonce); + } + if let AbstractValue::Reference(nonce) = operand2.value { + next_state.destroy_nonce(nonce); + } + self.stack.push(StackAbstractValue { + signature: SignatureToken::Bool, + value: AbstractValue::full_value(false), + }); + Ok(next_state) + } else { + Err(vec![VMStaticViolation::EqualityOpTypeMismatchError(offset)]) + } + } + + Bytecode::Lt | Bytecode::Gt | Bytecode::Le | Bytecode::Ge => { + let operand1 = self.stack.pop().unwrap(); + let operand2 = self.stack.pop().unwrap(); + if operand1.signature == SignatureToken::U64 + && operand2.signature == SignatureToken::U64 + { + self.stack.push(StackAbstractValue { + signature: SignatureToken::Bool, + value: AbstractValue::full_value(false), + }); + Ok(state) + } else { + Err(vec![VMStaticViolation::IntegerOpTypeMismatchError(offset)]) + } + } + + Bytecode::Exists(_) => { + let operand = self.stack.pop().unwrap(); + if operand.signature == SignatureToken::Address { + self.stack.push(StackAbstractValue { + signature: SignatureToken::Bool, + value: AbstractValue::full_value(false), + }); + Ok(state) + } else { + Err(vec![VMStaticViolation::ExistsResourceTypeMismatchError( + offset, + )]) + } + } + + Bytecode::BorrowGlobal(idx) => { + let struct_definition = self.module.struct_def_at(*idx); + if !StructDefinitionView::new(self.module, struct_definition).is_resource() { + return Err(vec![VMStaticViolation::BorrowGlobalNoResourceError(offset)]); + } + + let operand = self.stack.pop().unwrap(); + if operand.signature == SignatureToken::Address { + let mut next_state = state; + let nonce = self.get_nonce(&mut next_state); + self.stack.push(StackAbstractValue { + signature: SignatureToken::MutableReference(Box::new( + SignatureToken::Struct(struct_definition.struct_handle), + )), + value: AbstractValue::Reference(nonce), + }); + Ok(next_state) + } else { + Err(vec![VMStaticViolation::BorrowGlobalTypeMismatchError( + offset, + )]) + } + } + + Bytecode::MoveFrom(idx) => { + let struct_definition = self.module.struct_def_at(*idx); + if !StructDefinitionView::new(self.module, struct_definition).is_resource() { + return Err(vec![VMStaticViolation::MoveFromNoResourceError(offset)]); + } + + let operand = self.stack.pop().unwrap(); + if operand.signature == SignatureToken::Address { + self.stack.push(StackAbstractValue { + signature: SignatureToken::Struct(struct_definition.struct_handle), + value: AbstractValue::full_value(true), + }); + Ok(state) + } else { + Err(vec![VMStaticViolation::MoveFromTypeMismatchError(offset)]) + } + } + + Bytecode::MoveToSender(idx) => { + let struct_definition = self.module.struct_def_at(*idx); + if !StructDefinitionView::new(self.module, struct_definition).is_resource() { + return Err(vec![VMStaticViolation::MoveToSenderNoResourceError(offset)]); + } + + let value_operand = self.stack.pop().unwrap(); + if value_operand.signature + == SignatureToken::Struct(struct_definition.struct_handle) + { + Ok(state) + } else { + Err(vec![VMStaticViolation::MoveToSenderTypeMismatchError( + offset, + )]) + } + } + + Bytecode::GetTxnGasUnitPrice + | Bytecode::GetTxnMaxGasUnits + | Bytecode::GetGasRemaining + | Bytecode::GetTxnSequenceNumber => { + self.stack.push(StackAbstractValue { + signature: SignatureToken::U64, + value: AbstractValue::full_value(false), + }); + Ok(state) + } + + Bytecode::GetTxnSenderAddress => { + self.stack.push(StackAbstractValue { + signature: SignatureToken::Address, + value: AbstractValue::full_value(false), + }); + Ok(state) + } + + Bytecode::GetTxnPublicKey => { + self.stack.push(StackAbstractValue { + signature: SignatureToken::ByteArray, + value: AbstractValue::full_value(false), + }); + Ok(state) + } + + Bytecode::CreateAccount => { + let operand = self.stack.pop().unwrap(); + if operand.signature == SignatureToken::Address { + Ok(state) + } else { + Err(vec![VMStaticViolation::CreateAccountTypeMismatchError( + offset, + )]) + } + } + + Bytecode::EmitEvent => { + // TODO: EmitEvent is currently unimplemented + // following is a workaround to skip the check + self.stack.pop(); + self.stack.pop(); + self.stack.pop(); + Ok(state) + } + } + } +} diff --git a/language/bytecode_verifier/src/abstract_state.rs b/language/bytecode_verifier/src/abstract_state.rs new file mode 100644 index 0000000000000..42c7713ed629d --- /dev/null +++ b/language/bytecode_verifier/src/abstract_state.rs @@ -0,0 +1,603 @@ +// Copyright (c) The Libra Core Contributors +// SPDX-License-Identifier: Apache-2.0 + +//! This module defines the abstract state over which abstract interpretation is executed. +use crate::{nonce::Nonce, partition::Partition}; +use mirai_annotations::checked_verify; +use std::collections::{BTreeMap, BTreeSet}; +use vm::file_format::{FieldDefinitionIndex, LocalIndex}; + +#[derive(Clone, Debug, Eq, PartialEq)] +pub enum AbstractValue { + Reference(Nonce), + Value(bool, BTreeSet), +} + +impl AbstractValue { + pub fn is_reference(&self) -> bool { + match self { + AbstractValue::Reference(_) => true, + AbstractValue::Value(_, _) => false, + } + } + + pub fn is_value(&self) -> bool { + !self.is_reference() + } + + pub fn is_unrestricted_value(&self) -> bool { + match self { + AbstractValue::Reference(_) => false, + AbstractValue::Value(is_resource, _) => !*is_resource, + } + } + + pub fn is_borrowed_value(&self) -> bool { + match self { + AbstractValue::Reference(_) => false, + AbstractValue::Value(_, nonce_set) => !nonce_set.is_empty(), + } + } + + pub fn full_value(is_resource: bool) -> Self { + AbstractValue::Value(is_resource, BTreeSet::new()) + } +} + +#[derive(Clone, Debug, Eq, PartialEq)] +enum BorrowInfo { + BorrowedBy(BTreeSet), + FieldsBorrowedBy(BTreeMap>), +} + +#[derive(Clone, Debug, Eq, PartialEq)] +pub struct AbstractState { + locals: BTreeMap, + borrows: BTreeMap, + partition: Partition, +} + +#[derive(Debug)] +pub enum JoinResult { + Unchanged, + Changed(AbstractState), + Error, +} + +impl AbstractState { + /// create a new abstract state + pub fn new(locals: BTreeMap) -> Self { + let borrows = BTreeMap::new(); + let mut partition = Partition::default(); + for value in locals.values() { + if let AbstractValue::Reference(nonce) = value { + partition.add_nonce(nonce.clone()); + } + } + AbstractState { + locals, + borrows, + partition, + } + } + + /// checks if local@idx is available + pub fn is_available(&self, idx: LocalIndex) -> bool { + self.locals.contains_key(&idx) + } + + /// returns local@idx + pub fn local(&self, idx: LocalIndex) -> &AbstractValue { + &self.locals[&idx] + } + + /// removes local@idx + pub fn remove_local(&mut self, idx: LocalIndex) -> AbstractValue { + self.locals.remove(&idx).unwrap() + } + + /// inserts local@idx + pub fn insert_local(&mut self, idx: LocalIndex, value: AbstractValue) { + self.locals.insert(idx, value); + } + + /// checks if local@idx is a reference + pub fn is_reference(&self, idx: LocalIndex) -> bool { + self.locals[&idx].is_reference() + } + + /// checks if local@idx is a value + pub fn is_value(&self, idx: LocalIndex) -> bool { + self.locals[&idx].is_value() + } + + /// destroys local@idx + /// call only if self.is_safe_to_destroy(idx) returns true + pub fn destroy_local(&mut self, idx: LocalIndex) { + let local = self.locals.remove(&idx).unwrap(); + match local { + AbstractValue::Reference(nonce) => self.destroy_nonce(nonce), + AbstractValue::Value(is_resource, borrowed_nonces) => { + checked_verify!(!is_resource && borrowed_nonces.is_empty()); + } + } + } + + /// nonce must be fresh + pub fn add_nonce(&mut self, nonce: Nonce) { + self.partition.add_nonce(nonce); + } + + /// destroys nonce + /// borrows of nonce become borrows of any nonce' such that nonce borrows from nonce' + pub fn destroy_nonce(&mut self, nonce: Nonce) { + let mut nonce_set = BTreeSet::new(); + let mut new_locals = BTreeMap::new(); + let mut new_borrows = BTreeMap::new(); + + if let Some(borrow_info) = self.borrows.remove(&nonce) { + new_borrows = self.strong_propagate(nonce.clone(), &borrow_info); + match borrow_info { + BorrowInfo::BorrowedBy(x) => { + nonce_set = x.clone(); + } + BorrowInfo::FieldsBorrowedBy(y) => { + nonce_set = Self::get_union_of_sets(&BTreeSet::new(), &y); + } + } + } + + for (x, value) in &self.locals { + if let AbstractValue::Value(is_resource, y) = value { + if y.contains(&nonce) { + let mut y_restrict = y.clone(); + y_restrict.remove(&nonce); + new_locals.insert( + x.clone(), + AbstractValue::Value( + *is_resource, + y_restrict.union(&nonce_set).cloned().collect(), + ), + ); + } else { + new_locals.insert(x.clone(), AbstractValue::Value(*is_resource, y.clone())); + } + } else { + new_locals.insert(x.clone(), value.clone()); + } + } + + for (x, borrow_info) in &self.borrows { + if new_borrows.contains_key(x) { + continue; + } + match borrow_info { + BorrowInfo::BorrowedBy(y) => { + if y.contains(&nonce) { + let mut y_restrict = y.clone(); + y_restrict.remove(&nonce); + let y_update = y_restrict + .union(&nonce_set) + .cloned() + .collect::>(); + if !y_update.is_empty() { + new_borrows.insert(x.clone(), BorrowInfo::BorrowedBy(y_update)); + } + } else { + new_borrows.insert(x.clone(), BorrowInfo::BorrowedBy(y.clone())); + } + } + BorrowInfo::FieldsBorrowedBy(w) => { + let mut new_index_to_nonce_set = BTreeMap::new(); + for (idx, y) in w { + if y.contains(&nonce) { + let mut y_restrict = y.clone(); + y_restrict.remove(&nonce); + let y_update = y_restrict + .union(&nonce_set) + .cloned() + .collect::>(); + if !y_update.is_empty() { + new_index_to_nonce_set.insert(idx.clone(), y_update); + } + } else { + new_index_to_nonce_set.insert(idx.clone(), y.clone()); + } + } + if !new_index_to_nonce_set.is_empty() { + new_borrows.insert( + x.clone(), + BorrowInfo::FieldsBorrowedBy(new_index_to_nonce_set), + ); + } + } + } + } + + self.locals = new_locals; + self.borrows = new_borrows; + self.partition.remove_nonce(nonce); + } + + /// checks if there are any pending borrows on value + pub fn is_full(&self, value: &AbstractValue) -> bool { + match value { + AbstractValue::Reference(nonce) => !self.borrows.contains_key(&nonce), + AbstractValue::Value(_, nonce_set) => nonce_set.is_empty(), + } + } + + /// returns the set of nonces borrowing from nonce that might alias some idx-extension of nonce + pub fn borrowed_nonces_for_field( + &self, + idx: FieldDefinitionIndex, + nonce: Nonce, + ) -> BTreeSet { + if self.borrows.contains_key(&nonce) { + match &self.borrows[&nonce] { + BorrowInfo::BorrowedBy(x) => x.clone(), + BorrowInfo::FieldsBorrowedBy(y) => { + if y.contains_key(&idx) { + y[&idx].clone() + } else { + BTreeSet::new() + } + } + } + } else { + BTreeSet::new() + } + } + + /// returns the set of nonces borrowing from nonce that might alias some extension of nonce + pub fn borrowed_nonces(&self, nonce: Nonce) -> BTreeSet { + if self.borrows.contains_key(&nonce) { + match &self.borrows[&nonce] { + BorrowInfo::BorrowedBy(x) => x.clone(), + BorrowInfo::FieldsBorrowedBy(y) => { + let empty_set = BTreeSet::new(); + Self::get_union_of_sets(&empty_set, y) + } + } + } else { + BTreeSet::new() + } + } + + /// update self to reflect a borrow of idx from nonce by new_nonce + pub fn borrow_field_from_nonce( + &mut self, + idx: FieldDefinitionIndex, + nonce: Nonce, + new_nonce: Nonce, + ) { + self.borrows + .entry(nonce.clone()) + .and_modify(|borrow_info| match borrow_info { + BorrowInfo::BorrowedBy(nonce_set) => { + nonce_set.insert(new_nonce.clone()); + } + BorrowInfo::FieldsBorrowedBy(index_to_nonce_set) => { + index_to_nonce_set + .entry(idx) + .and_modify(|nonce_set| { + nonce_set.insert(new_nonce.clone()); + }) + .or_insert({ + let mut x = BTreeSet::new(); + x.insert(new_nonce.clone()); + x + }); + } + }) + .or_insert({ + let mut x = BTreeSet::new(); + x.insert(new_nonce.clone()); + let mut y = BTreeMap::new(); + y.insert(idx, x); + BorrowInfo::FieldsBorrowedBy(y) + }); + } + + /// update self to reflect a borrow of a value local@idx by new_nonce + pub fn borrow_from_local_value(&mut self, idx: LocalIndex, new_nonce: Nonce) { + checked_verify!(self.locals[&idx].is_value()); + self.locals.entry(idx).and_modify(|value| { + if let AbstractValue::Value(_, nonce_set) = value { + nonce_set.insert(new_nonce); + } + }); + } + + /// update self to reflect a borrow of a reference local@idx by new_nonce + pub fn borrow_from_local_reference(&mut self, idx: LocalIndex, new_nonce: Nonce) { + checked_verify!(self.locals[&idx].is_reference()); + if let AbstractValue::Reference(borrowee) = &self.locals[&idx] { + if let Some(info) = self.borrows.remove(borrowee) { + self.borrows.insert(new_nonce.clone(), info); + } + self.borrows.entry(borrowee.clone()).or_insert({ + let mut x = BTreeSet::new(); + x.insert(new_nonce.clone()); + BorrowInfo::BorrowedBy(x) + }); + self.partition.add_equality(new_nonce, borrowee.clone()); + } + } + + /// update self to reflect a borrow from each nonce in to_borrow_from by new_nonce + pub fn borrow_from_nonces(&mut self, to_borrow_from: &BTreeSet, new_nonce: Nonce) { + for x in to_borrow_from { + self.borrow_from_nonce(x.clone(), new_nonce.clone()); + } + } + + /// checks if self is canonical + pub fn is_canonical(&self) -> bool { + let mut values = BTreeMap::new(); + let mut references = BTreeMap::new(); + Self::split_locals(&self.locals, &mut values, &mut references); + references.iter().all(|(x, y)| y.is(*x as usize)) + } + + /// returns the canonical representation of self + pub fn construct_canonical_state(self) -> Self { + let mut values = BTreeMap::new(); + let mut references = BTreeMap::new(); + Self::split_locals(&self.locals, &mut values, &mut references); + + let mut locals = BTreeMap::new(); + let mut nonce_map = BTreeMap::new(); + for (x, y) in references { + nonce_map.insert(y, Nonce::new(x as usize)); + locals.insert(x, AbstractValue::Reference(Nonce::new(x as usize))); + } + for (x, (is_resource, nonce_set)) in values { + locals.insert( + x, + AbstractValue::Value(is_resource, Self::map_nonce_set(&nonce_map, &nonce_set)), + ); + } + let mut borrows = BTreeMap::new(); + for (x, borrow_info) in &self.borrows { + match borrow_info { + BorrowInfo::BorrowedBy(y) => { + borrows.insert( + nonce_map[&x].clone(), + BorrowInfo::BorrowedBy(Self::map_nonce_set(&nonce_map, &y)), + ); + } + BorrowInfo::FieldsBorrowedBy(w) => { + let mut index_to_nonce_set = BTreeMap::new(); + for (idx, y) in w { + index_to_nonce_set.insert(idx.clone(), Self::map_nonce_set(&nonce_map, &y)); + } + borrows.insert( + nonce_map[&x].clone(), + BorrowInfo::FieldsBorrowedBy(index_to_nonce_set), + ); + } + } + } + let partition = self.partition.construct_canonical_partition(&nonce_map); + + AbstractState { + locals, + borrows, + partition, + } + } + + /// attempts to join state to self and returns the result + /// both self.is_canonical() and state.is_canonical() must be true + pub fn join(&self, state: &AbstractState) -> JoinResult { + // A join failure occurs in each of the following situations: + // - an unrestricted value is borrowed along one path but unavailable along the other + // - a value that is not unrestricted, i.e., either reference or resource, is available + // along one path but not the other + if Self::unrestricted_borrowed_value_unavailable(self, state) + || Self::unrestricted_borrowed_value_unavailable(state, self) + { + return JoinResult::Error; + } + if self + .locals + .keys() + .filter(|x| !self.locals[x].is_unrestricted_value()) + .collect::>() + != state + .locals + .keys() + .filter(|x| !state.locals[x].is_unrestricted_value()) + .collect::>() + { + return JoinResult::Error; + } + + let mut values1 = BTreeMap::new(); + let mut references1 = BTreeMap::new(); + Self::split_locals(&self.locals, &mut values1, &mut references1); + let mut values2 = BTreeMap::new(); + let mut references2 = BTreeMap::new(); + Self::split_locals(&state.locals, &mut values2, &mut references2); + checked_verify!(references1 == references2); + + let mut locals = BTreeMap::new(); + for (x, y) in &references1 { + locals.insert(x.clone(), AbstractValue::Reference(y.clone())); + } + for (x, (is_resource1, nonce_set1)) in &values1 { + if let Some((is_resource2, nonce_set2)) = values2.get(x) { + checked_verify!(is_resource1 == is_resource2); + locals.insert( + x.clone(), + AbstractValue::Value( + *is_resource1, + nonce_set1.union(nonce_set2).cloned().collect(), + ), + ); + } + } + + let mut borrows = BTreeMap::new(); + for (x, borrow_info) in &self.borrows { + if state.borrows.contains_key(x) { + match borrow_info { + BorrowInfo::BorrowedBy(y1) => match &state.borrows[x] { + BorrowInfo::BorrowedBy(y2) => { + borrows.insert( + x.clone(), + BorrowInfo::BorrowedBy(y1.union(y2).cloned().collect()), + ); + } + BorrowInfo::FieldsBorrowedBy(w2) => { + borrows.insert( + x.clone(), + BorrowInfo::BorrowedBy(Self::get_union_of_sets(y1, w2)), + ); + } + }, + BorrowInfo::FieldsBorrowedBy(w1) => match &state.borrows[x] { + BorrowInfo::BorrowedBy(y2) => { + borrows.insert( + x.clone(), + BorrowInfo::BorrowedBy(Self::get_union_of_sets(y2, w1)), + ); + } + BorrowInfo::FieldsBorrowedBy(w2) => { + borrows.insert( + x.clone(), + BorrowInfo::FieldsBorrowedBy(Self::get_union_of_maps(w1, w2)), + ); + } + }, + } + } else { + borrows.insert(x.clone(), borrow_info.clone()); + } + } + for (x, borrow_info) in &state.borrows { + if !borrows.contains_key(x) { + borrows.insert(x.clone(), borrow_info.clone()); + } + } + + let partition = self.partition.join(&state.partition); + + let next_state = AbstractState { + locals, + borrows, + partition, + }; + if next_state == *self { + JoinResult::Unchanged + } else { + JoinResult::Changed(next_state) + } + } + + fn unrestricted_borrowed_value_unavailable( + state1: &AbstractState, + state2: &AbstractState, + ) -> bool { + state1.locals.keys().any(|x| { + state1.locals[x].is_unrestricted_value() + && state1.locals[x].is_borrowed_value() + && !state2.locals.contains_key(x) + }) + } + + fn strong_propagate( + &self, + nonce: Nonce, + borrow_info: &BorrowInfo, + ) -> BTreeMap { + let mut new_borrows = BTreeMap::new(); + let mut singleton_nonce_set = BTreeSet::new(); + singleton_nonce_set.insert(nonce.clone()); + for (x, y) in &self.borrows { + if self.partition.is_equal(x.clone(), nonce.clone()) { + if let BorrowInfo::BorrowedBy(nonce_set) = y { + if nonce_set == &singleton_nonce_set { + new_borrows.insert(x.clone(), borrow_info.clone()); + } + } + } + } + new_borrows + } + + fn borrow_from_nonce(&mut self, nonce: Nonce, new_nonce: Nonce) { + self.borrows.entry(nonce.clone()).or_insert({ + let mut x = BTreeSet::new(); + x.insert(new_nonce); + BorrowInfo::BorrowedBy(x) + }); + } + + fn map_nonce_set( + nonce_map: &BTreeMap, + nonce_set: &BTreeSet, + ) -> BTreeSet { + let mut mapped_nonce_set = BTreeSet::new(); + for x in nonce_set { + mapped_nonce_set.insert(nonce_map[x].clone()); + } + mapped_nonce_set + } + + fn split_locals( + locals: &BTreeMap, + values: &mut BTreeMap)>, + references: &mut BTreeMap, + ) { + for (x, y) in locals { + match y { + AbstractValue::Reference(nonce) => { + references.insert(x.clone(), nonce.clone()); + } + AbstractValue::Value(is_resource, nonces) => { + values.insert(x.clone(), (*is_resource, nonces.clone())); + } + } + } + } + + fn get_union_of_sets( + nonce_set: &BTreeSet, + index_to_nonce_set: &BTreeMap>, + ) -> BTreeSet { + index_to_nonce_set + .values() + .fold(nonce_set.clone(), |mut acc, set| { + for x in set { + acc.insert(x.clone()); + } + acc + }) + } + + fn get_union_of_maps( + index_to_nonce_set1: &BTreeMap>, + index_to_nonce_set2: &BTreeMap>, + ) -> BTreeMap> { + let mut index_to_nonce_set = BTreeMap::new(); + for (x, y) in index_to_nonce_set1 { + if index_to_nonce_set2.contains_key(x) { + index_to_nonce_set.insert( + x.clone(), + y.union(&index_to_nonce_set2[x]).cloned().collect(), + ); + } else { + index_to_nonce_set.insert(x.clone(), y.clone()); + } + } + for (x, y) in index_to_nonce_set2 { + if index_to_nonce_set.contains_key(x) { + continue; + } + index_to_nonce_set.insert(x.clone(), y.clone()); + } + index_to_nonce_set + } +} diff --git a/language/bytecode_verifier/src/check_duplication.rs b/language/bytecode_verifier/src/check_duplication.rs new file mode 100644 index 0000000000000..d72d3f713246a --- /dev/null +++ b/language/bytecode_verifier/src/check_duplication.rs @@ -0,0 +1,236 @@ +// Copyright (c) The Libra Core Contributors +// SPDX-License-Identifier: Apache-2.0 + +//! This module implements a checker for verifying that each vector in a CompiledModule contains +//! distinct values. Successful verification implies that an index in vector can be used to +//! uniquely name the entry at that index. Additionally, the checker also verifies the +//! following: +//! - struct and field definitions are consistent +//! - the handles in struct and function definitions point to IMPLEMENTED_MODULE_INDEX +//! - all struct and function handles pointing to IMPLEMENTED_MODULE_INDEX have a definition +use std::{collections::HashSet, hash::Hash}; +use vm::{ + access::{BaseAccess, ModuleAccess}, + errors::{VMStaticViolation, VerificationError}, + file_format::{ + CompiledModule, FieldDefinitionIndex, FunctionHandleIndex, ModuleHandleIndex, + StructHandleIndex, + }, + IndexKind, +}; + +pub struct DuplicationChecker<'a> { + module: &'a CompiledModule, +} + +impl<'a> DuplicationChecker<'a> { + pub fn new(module: &'a CompiledModule) -> Self { + Self { module } + } + + pub fn verify(self) -> Vec { + let mut errors = vec![]; + + if let Some(idx) = Self::first_duplicate_element(self.module.string_pool()) { + errors.push(VerificationError { + kind: IndexKind::StringPool, + idx, + err: VMStaticViolation::DuplicateElement, + }) + } + if let Some(idx) = Self::first_duplicate_element(self.module.byte_array_pool()) { + errors.push(VerificationError { + kind: IndexKind::ByteArrayPool, + idx, + err: VMStaticViolation::DuplicateElement, + }) + } + if let Some(idx) = Self::first_duplicate_element(self.module.address_pool()) { + errors.push(VerificationError { + kind: IndexKind::AddressPool, + idx, + err: VMStaticViolation::DuplicateElement, + }) + } + if let Some(idx) = Self::first_duplicate_element(self.module.type_signatures()) { + errors.push(VerificationError { + kind: IndexKind::TypeSignature, + idx, + err: VMStaticViolation::DuplicateElement, + }) + } + if let Some(idx) = Self::first_duplicate_element(self.module.function_signatures()) { + errors.push(VerificationError { + kind: IndexKind::FunctionSignature, + idx, + err: VMStaticViolation::DuplicateElement, + }) + } + if let Some(idx) = Self::first_duplicate_element(self.module.locals_signatures()) { + errors.push(VerificationError { + kind: IndexKind::LocalsSignature, + idx, + err: VMStaticViolation::DuplicateElement, + }) + } + if let Some(idx) = Self::first_duplicate_element(self.module.module_handles()) { + errors.push(VerificationError { + kind: IndexKind::ModuleHandle, + idx, + err: VMStaticViolation::DuplicateElement, + }) + } + if let Some(idx) = + Self::first_duplicate_element(self.module.struct_handles().map(|x| (x.module, x.name))) + { + errors.push(VerificationError { + kind: IndexKind::StructHandle, + idx, + err: VMStaticViolation::DuplicateElement, + }) + } + if let Some(idx) = Self::first_duplicate_element( + self.module.function_handles().map(|x| (x.module, x.name)), + ) { + errors.push(VerificationError { + kind: IndexKind::FunctionHandle, + idx, + err: VMStaticViolation::DuplicateElement, + }) + } + if let Some(idx) = + Self::first_duplicate_element(self.module.struct_defs().map(|x| x.struct_handle)) + { + errors.push(VerificationError { + kind: IndexKind::StructDefinition, + idx, + err: VMStaticViolation::DuplicateElement, + }) + } + if let Some(idx) = + Self::first_duplicate_element(self.module.function_defs().map(|x| x.function)) + { + errors.push(VerificationError { + kind: IndexKind::FunctionDefinition, + idx, + err: VMStaticViolation::DuplicateElement, + }) + } + if let Some(idx) = + Self::first_duplicate_element(self.module.field_defs().map(|x| (x.struct_, x.name))) + { + errors.push(VerificationError { + kind: IndexKind::FieldDefinition, + idx, + err: VMStaticViolation::DuplicateElement, + }) + } + + // Check that: + // (1) the order of struct definitions matches the order of field definitions, + // (2) each struct definition and its field definitions point to the same struct handle, + // (3) there are no unused fields. + let mut start_field_index: usize = 0; + let mut idx_opt = None; + for (idx, struct_def) in self.module.struct_defs().enumerate() { + if FieldDefinitionIndex::new(start_field_index as u16) != struct_def.fields { + idx_opt = Some(idx); + break; + } + let next_start_field_index = start_field_index + struct_def.field_count as usize; + if !(start_field_index..next_start_field_index) + .all(|i| struct_def.struct_handle == self.module.field_defs[i].struct_) + { + idx_opt = Some(idx); + break; + } + start_field_index = next_start_field_index; + } + if let Some(idx) = idx_opt { + errors.push(VerificationError { + kind: IndexKind::StructDefinition, + idx, + err: VMStaticViolation::InconsistentFields, + }); + } else if start_field_index != self.module.field_defs.len() { + errors.push(VerificationError { + kind: IndexKind::FieldDefinition, + idx: start_field_index, + err: VMStaticViolation::UnusedFields, + }); + } + + // Check that each struct definition is pointing to module handle with index + // IMPLEMENTED_MODULE_INDEX. + if let Some(idx) = self.module.struct_defs().position(|x| { + self.module.struct_handle_at(x.struct_handle).module + != ModuleHandleIndex::new(CompiledModule::IMPLEMENTED_MODULE_INDEX) + }) { + errors.push(VerificationError { + kind: IndexKind::StructDefinition, + idx, + err: VMStaticViolation::InvalidModuleHandle, + }) + } + // Check that each function definition is pointing to module handle with index + // IMPLEMENTED_MODULE_INDEX. + if let Some(idx) = self.module.function_defs().position(|x| { + self.module.function_handle_at(x.function).module + != ModuleHandleIndex::new(CompiledModule::IMPLEMENTED_MODULE_INDEX) + }) { + errors.push(VerificationError { + kind: IndexKind::FunctionDefinition, + idx, + err: VMStaticViolation::InvalidModuleHandle, + }) + } + // Check that each struct handle with module handle index IMPLEMENTED_MODULE_INDEX is + // implemented. + let implemented_struct_handles: HashSet = + self.module.struct_defs().map(|x| x.struct_handle).collect(); + if let Some(idx) = (0..self.module.struct_handles.len()).position(|x| { + let y = StructHandleIndex::new(x as u16); + self.module.struct_handle_at(y).module + == ModuleHandleIndex::new(CompiledModule::IMPLEMENTED_MODULE_INDEX) + && !implemented_struct_handles.contains(&y) + }) { + errors.push(VerificationError { + kind: IndexKind::StructHandle, + idx, + err: VMStaticViolation::UnimplementedHandle, + }) + } + // Check that each function handle with module handle index IMPLEMENTED_MODULE_INDEX is + // implemented. + let implemented_function_handles: HashSet = + self.module.function_defs().map(|x| x.function).collect(); + if let Some(idx) = (0..self.module.function_handles.len()).position(|x| { + let y = FunctionHandleIndex::new(x as u16); + self.module.function_handle_at(y).module + == ModuleHandleIndex::new(CompiledModule::IMPLEMENTED_MODULE_INDEX) + && !implemented_function_handles.contains(&y) + }) { + errors.push(VerificationError { + kind: IndexKind::FunctionHandle, + idx, + err: VMStaticViolation::UnimplementedHandle, + }) + } + + errors + } + + fn first_duplicate_element(iter: T) -> Option + where + T: Iterator, + T::Item: Eq + Hash, + { + let mut uniq = HashSet::new(); + for (i, x) in iter.enumerate() { + if !uniq.insert(x) { + return Some(i); + } + } + None + } +} diff --git a/language/bytecode_verifier/src/code_unit_verifier.rs b/language/bytecode_verifier/src/code_unit_verifier.rs new file mode 100644 index 0000000000000..ee4858c8e8890 --- /dev/null +++ b/language/bytecode_verifier/src/code_unit_verifier.rs @@ -0,0 +1,79 @@ +// Copyright (c) The Libra Core Contributors +// SPDX-License-Identifier: Apache-2.0 + +//! This module implements the checker for verifying correctness of function bodies. +//! The overall verification is split between stack_usage_verifier.rs and +//! abstract_interpreter.rs. CodeUnitVerifier simply orchestrates calls into these two files. +use crate::control_flow_graph::{ControlFlowGraph, VMControlFlowGraph}; +use vm::{ + access::ModuleAccess, + errors::{VMStaticViolation, VerificationError}, + file_format::{CompiledModule, FunctionDefinition}, + IndexKind, +}; + +use crate::{abstract_interpreter::AbstractInterpreter, stack_usage_verifier::StackUsageVerifier}; + +pub trait VerificationPass<'a> { + fn new( + module: &'a CompiledModule, + function_definition: &'a FunctionDefinition, + cfg: &'a VMControlFlowGraph, + ) -> Self; + + fn verify(self) -> Vec; +} + +pub struct CodeUnitVerifier<'a> { + module: &'a CompiledModule, +} + +impl<'a> CodeUnitVerifier<'a> { + pub fn new(module: &'a CompiledModule) -> Self { + Self { module } + } + + pub fn verify(&self) -> Vec { + self.module + .function_defs() + .enumerate() + .map(move |(idx, function_definition)| { + self.verify_function(function_definition) + .into_iter() + .map(move |err| VerificationError { + kind: IndexKind::FunctionDefinition, + idx, + err, + }) + }) + .flatten() + .collect() + } + + fn verify_function( + &self, + function_definition: &'a FunctionDefinition, + ) -> Vec { + if function_definition.is_native() { + return vec![]; + } + let result: Result = + VMControlFlowGraph::new(&function_definition.code.code); + match result { + Ok(cfg) => self.verify_function_inner(function_definition, &cfg), + Err(e) => vec![e], + } + } + + fn verify_function_inner( + &self, + function_definition: &'a FunctionDefinition, + cfg: &'a VMControlFlowGraph, + ) -> Vec { + let errors = StackUsageVerifier::new(self.module, function_definition, cfg).verify(); + if !errors.is_empty() { + return errors; + } + AbstractInterpreter::new(self.module, function_definition, cfg).verify() + } +} diff --git a/language/bytecode_verifier/src/control_flow_graph.rs b/language/bytecode_verifier/src/control_flow_graph.rs new file mode 100644 index 0000000000000..960a85891a139 --- /dev/null +++ b/language/bytecode_verifier/src/control_flow_graph.rs @@ -0,0 +1,248 @@ +// Copyright (c) The Libra Core Contributors +// SPDX-License-Identifier: Apache-2.0 + +//! This module defines the control-flow graph uses for bytecode verification. +use std::{ + collections::{BTreeMap, BTreeSet}, + marker::Sized, + result::Result, +}; +use vm::{ + errors::VMStaticViolation, + file_format::{Bytecode, CodeOffset}, +}; + +// BTree/Hash agnostic type wrappers +type Map = BTreeMap; +type Set = BTreeSet; + +pub type BlockId = CodeOffset; + +/// A trait that specifies the basic requirements for a CFG +pub trait ControlFlowGraph: Sized { + /// Given a vector of bytecodes, constructs the control flow graph for it. + /// Return a VMStaticViolation if we were unable to construct a control flow graph, or + /// if we encounter an invalid jump instruction. + fn new(code: &[Bytecode]) -> Result; + + /// Given a block ID, return the reachable blocks from that block + /// including the block itself. + fn reachable_from(&self, block_id: BlockId) -> Vec<&BasicBlock>; + + /// Given an offset into the bytecode return the basic block ID that contains + /// that offset + fn block_id_of_offset(&self, code_offset: CodeOffset) -> Option; + + /// Given a block ID, return the corresponding basic block. Return None if + /// the block ID is invalid. + fn block_of_id(&self, block_id: BlockId) -> Option<&BasicBlock>; + + /// Return the number of blocks (vertices) in the control flow graph + fn num_blocks(&self) -> u16; +} + +/// A basic block +pub struct BasicBlock { + /// Start index into bytecode vector + pub entry: CodeOffset, + + /// End index into bytecode vector + pub exit: CodeOffset, + + /// Flows-to + pub successors: Set, +} + +/// The control flow graph that we build from the bytecode. +pub struct VMControlFlowGraph { + /// The basic blocks + pub blocks: Map, +} + +impl BasicBlock { + pub fn display(&self) { + println!("+=======================+"); + println!("| Enter: {} |", self.entry); + println!("+-----------------------+"); + println!("==> Children: {:?}", self.successors); + println!("+-----------------------+"); + println!("| Exit: {} |", self.exit); + println!("+=======================+"); + } +} + +impl VMControlFlowGraph { + pub fn display(&self) { + for block in self.blocks.values() { + block.display(); + } + } + + fn is_end_of_block(pc: CodeOffset, code: &[Bytecode], block_ids: &Set) -> bool { + pc + 1 == (code.len() as CodeOffset) || block_ids.contains(&(pc + 1)) + } + + fn record_block_ids(pc: CodeOffset, code: &[Bytecode], block_ids: &mut Set) { + let bytecode = &code[pc as usize]; + + if let Some(offset) = VMControlFlowGraph::offset(bytecode) { + block_ids.insert(*offset); + } + + if VMControlFlowGraph::is_branch(bytecode) && pc + 1 < (code.len() as CodeOffset) { + block_ids.insert(pc + 1); + } + } + + fn is_unconditional_branch(bytecode: &Bytecode) -> bool { + match bytecode { + Bytecode::Ret | Bytecode::Branch(_) => true, + _ => false, + } + } + + fn is_conditional_branch(bytecode: &Bytecode) -> bool { + match bytecode { + Bytecode::BrFalse(_) | Bytecode::BrTrue(_) => true, + _ => false, + } + } + + fn is_branch(bytecode: &Bytecode) -> bool { + VMControlFlowGraph::is_conditional_branch(bytecode) + || VMControlFlowGraph::is_unconditional_branch(bytecode) + } + + fn offset(bytecode: &Bytecode) -> Option<&CodeOffset> { + match bytecode { + Bytecode::BrFalse(offset) | Bytecode::BrTrue(offset) | Bytecode::Branch(offset) => { + Some(offset) + } + _ => None, + } + } + + fn get_successors(pc: CodeOffset, code: &[Bytecode]) -> Set { + let bytecode = &code[pc as usize]; + let mut v = Set::new(); + + if let Some(offset) = VMControlFlowGraph::offset(bytecode) { + v.insert(*offset); + } + + if pc + 1 >= code.len() as CodeOffset { + return v; + } + + if !VMControlFlowGraph::is_branch(bytecode) + || VMControlFlowGraph::is_conditional_branch(bytecode) + { + v.insert(pc + 1); + } + + v + } + + /// A utility function that implements BFS-reachability from block_id with + /// respect to get_targets function + fn traverse_by( + &self, + get_targets: fn(&BasicBlock) -> &Set, + block_id: BlockId, + ) -> Vec<&BasicBlock> { + let mut ret = Vec::new(); + // We use this index to keep track of our frontier. + let mut index = 0; + // Guard against cycles + let mut seen = Set::new(); + + let block = &self.blocks[&block_id]; + ret.push(block); + seen.insert(&block_id); + + while index < ret.len() { + let block = ret[index]; + index += 1; + let successors = get_targets(&block); + for block_id in successors.iter() { + if !seen.contains(&block_id) { + ret.push(&self.blocks[&block_id]); + seen.insert(block_id); + } + } + } + + ret + } +} + +impl ControlFlowGraph for VMControlFlowGraph { + fn num_blocks(&self) -> u16 { + self.blocks.len() as u16 + } + + fn block_id_of_offset(&self, code_offset: CodeOffset) -> Option { + let mut index = None; + + for (block_id, block) in &self.blocks { + if block.entry >= code_offset && block.exit <= code_offset { + index = Some(*block_id); + } + } + + index + } + + fn block_of_id(&self, block_id: BlockId) -> Option<&BasicBlock> { + if self.blocks.contains_key(&block_id) { + Some(&self.blocks[&block_id]) + } else { + None + } + } + + fn reachable_from(&self, block_id: BlockId) -> Vec<&BasicBlock> { + self.traverse_by(|block: &BasicBlock| &block.successors, block_id) + } + + fn new(code: &[Bytecode]) -> Result { + // Check to make sure that the bytecode vector ends with a branching instruction. + if let Some(bytecode) = code.last() { + if !VMControlFlowGraph::is_branch(bytecode) { + return Err(VMStaticViolation::InvalidFallThrough); + } + } else { + return Err(VMStaticViolation::InvalidFallThrough); + } + + // First go through and collect block ids, i.e., offsets that begin basic blocks. + // Need to do this first in order to handle backwards edges. + let mut block_ids = Set::new(); + block_ids.insert(0); + for pc in 0..code.len() { + VMControlFlowGraph::record_block_ids(pc as CodeOffset, code, &mut block_ids); + } + + // Create basic blocks + let mut ret = VMControlFlowGraph { blocks: Map::new() }; + let mut entry = 0; + for pc in 0..code.len() { + let co_pc: CodeOffset = pc as CodeOffset; + + // Create a basic block + if VMControlFlowGraph::is_end_of_block(co_pc, code, &block_ids) { + let successors = VMControlFlowGraph::get_successors(co_pc, code); + let bb = BasicBlock { + entry, + exit: co_pc, + successors, + }; + ret.blocks.insert(entry, bb); + entry = co_pc + 1; + } + } + + assert!(entry == code.len() as CodeOffset); + Ok(ret) + } +} diff --git a/language/bytecode_verifier/src/lib.rs b/language/bytecode_verifier/src/lib.rs new file mode 100644 index 0000000000000..e477c86a99604 --- /dev/null +++ b/language/bytecode_verifier/src/lib.rs @@ -0,0 +1,32 @@ +// Copyright (c) The Libra Core Contributors +// SPDX-License-Identifier: Apache-2.0 + +//! Verifies bytecode sanity. + +#![feature(exhaustive_patterns)] +#![feature(never_type)] + +// Bounds checks are implemented in the `vm` crate. +pub mod abstract_interpreter; +pub mod abstract_state; +pub mod check_duplication; +pub mod code_unit_verifier; +pub mod control_flow_graph; +pub mod nonce; +pub mod partition; +pub mod resources; +pub mod signature; +pub mod stack_usage_verifier; +pub mod struct_defs; +pub mod verifier; + +pub use check_duplication::DuplicationChecker; +pub use code_unit_verifier::CodeUnitVerifier; +pub use resources::ResourceTransitiveChecker; +pub use signature::SignatureChecker; +pub use stack_usage_verifier::StackUsageVerifier; +pub use struct_defs::RecursiveStructDefChecker; +pub use verifier::{ + verify_main_signature, verify_module, verify_module_dependencies, verify_script, + verify_script_dependencies, +}; diff --git a/language/bytecode_verifier/src/nonce.rs b/language/bytecode_verifier/src/nonce.rs new file mode 100644 index 0000000000000..ef71eb2f6cb43 --- /dev/null +++ b/language/bytecode_verifier/src/nonce.rs @@ -0,0 +1,23 @@ +// Copyright (c) The Libra Core Contributors +// SPDX-License-Identifier: Apache-2.0 + +//! This module implements the Nonce type used for borrow checking in the abstract interpreter. +//! A Nonce instance represents an arbitrary reference or access path. +//! The integer inside a Nonce is meaningless; only equality and borrow relationships are +//! meaningful. +#[derive(Clone, Debug, Hash, PartialEq, Eq, Ord, PartialOrd)] +pub struct Nonce(usize); + +impl Nonce { + pub fn new(n: usize) -> Self { + Self(n) + } + + pub fn is(&self, n: usize) -> bool { + self.0 == n + } + + pub fn inner(&self) -> usize { + self.0 + } +} diff --git a/language/bytecode_verifier/src/partition.rs b/language/bytecode_verifier/src/partition.rs new file mode 100644 index 0000000000000..abeb90e1c39af --- /dev/null +++ b/language/bytecode_verifier/src/partition.rs @@ -0,0 +1,129 @@ +// Copyright (c) The Libra Core Contributors +// SPDX-License-Identifier: Apache-2.0 + +//! This module defines the partition data structure used for tracking equality among nonces. +use crate::nonce::Nonce; +use mirai_annotations::checked_verify; +use std::{ + collections::{BTreeMap, BTreeSet}, + usize::MAX, +}; + +#[derive(Clone, Debug, Default, Eq, PartialEq)] +pub struct Partition { + nonce_to_id: BTreeMap, + id_to_nonce_set: BTreeMap>, +} + +impl Partition { + // adds a nonce to the partition; new_nonce must be a fresh nonce + pub fn add_nonce(&mut self, new_nonce: Nonce) { + let nonce_const = new_nonce.inner(); + self.nonce_to_id.insert(new_nonce.clone(), nonce_const); + let mut singleton_set = BTreeSet::new(); + singleton_set.insert(new_nonce); + self.id_to_nonce_set.insert(nonce_const, singleton_set); + } + + // removes a nonce that already exists in the partition + pub fn remove_nonce(&mut self, nonce: Nonce) { + let id = self.nonce_to_id.remove(&nonce).unwrap(); + self.id_to_nonce_set.entry(id).and_modify(|x| { + x.remove(&nonce); + }); + } + + // adds an equality between nonce1 an nonce2 + pub fn add_equality(&mut self, nonce1: Nonce, nonce2: Nonce) { + let id1 = self.nonce_to_id[&nonce1]; + let id2 = self.nonce_to_id.remove(&nonce2).unwrap(); + self.nonce_to_id.insert(nonce2, id1); + let mut nonce_set = self.id_to_nonce_set.remove(&id2).unwrap(); + self.id_to_nonce_set.entry(id1).and_modify(|x| { + x.append(&mut nonce_set); + }); + } + + // checks if nonce1 and nonce2 are known to be equal + pub fn is_equal(&self, nonce1: Nonce, nonce2: Nonce) -> bool { + self.nonce_to_id[&nonce1] == self.nonce_to_id[&nonce2] + } + + // returns a canonical version of self in which an id of a set is determined + // to be the least element of the set. + // the choice of returned id is arbitrary but it must be a function on nonce sets. + pub fn construct_canonical_partition(self, nonce_map: &BTreeMap) -> Self { + let mut id_to_nonce_set = BTreeMap::new(); + for nonce_set in self.id_to_nonce_set.values() { + let canonical_nonce_set: BTreeSet = nonce_set + .iter() + .map(|nonce| nonce_map[nonce].clone()) + .collect(); + let canonical_id = Self::canonical_id(&canonical_nonce_set); + id_to_nonce_set.insert(canonical_id, canonical_nonce_set); + } + let nonce_to_id = Self::compute_nonce_to_id(&id_to_nonce_set); + Self { + nonce_to_id, + id_to_nonce_set, + } + } + + pub fn nonces(&self) -> BTreeSet { + self.nonce_to_id.keys().cloned().collect() + } + + // both self and partition must be canonical and over the same set of nonces + pub fn join(&self, partition: &Partition) -> Self { + checked_verify!(self.nonces() == partition.nonces()); + // The join algorithm exploits the property that both self and partition are partitions over + // the same set of nonces. The algorithm does partition refinement by constructing + // for each nonce the intersection of the two sets containing it in self and partition. + // In the resulting partition, the nonce is mapped to this intersection set. + let mut nonce_to_id_pair = BTreeMap::new(); + let mut id_pair_to_nonce_set = BTreeMap::new(); + for (nonce, id) in self.nonce_to_id.iter() { + let id_pair = (id, partition.nonce_to_id[nonce]); + nonce_to_id_pair.insert(nonce.clone(), id_pair); + id_pair_to_nonce_set.entry(id_pair).or_insert({ + let nonce_set_for_id_pair: BTreeSet = self.id_to_nonce_set[&id_pair.0] + .intersection(&partition.id_to_nonce_set[&id_pair.1]) + .cloned() + .collect(); + nonce_set_for_id_pair + }); + } + let id_to_nonce_set: BTreeMap> = id_pair_to_nonce_set + .into_iter() + .map(|(_, nonce_set)| (Self::canonical_id(&nonce_set), nonce_set)) + .collect(); + let nonce_to_id = Self::compute_nonce_to_id(&id_to_nonce_set); + Self { + nonce_to_id, + id_to_nonce_set, + } + } + + fn canonical_id(nonce_set: &BTreeSet) -> usize { + let mut minimum_id = MAX; + for nonce in nonce_set { + let id = nonce.inner(); + if minimum_id > id { + minimum_id = id; + } + } + minimum_id + } + + fn compute_nonce_to_id( + id_to_nonce_set: &BTreeMap>, + ) -> BTreeMap { + let mut nonce_to_id = BTreeMap::new(); + for (id, nonce_set) in id_to_nonce_set.iter() { + for nonce in nonce_set { + nonce_to_id.insert(nonce.clone(), id.clone()); + } + } + nonce_to_id + } +} diff --git a/language/bytecode_verifier/src/resources.rs b/language/bytecode_verifier/src/resources.rs new file mode 100644 index 0000000000000..494d3c6da7319 --- /dev/null +++ b/language/bytecode_verifier/src/resources.rs @@ -0,0 +1,42 @@ +// Copyright (c) The Libra Core Contributors +// SPDX-License-Identifier: Apache-2.0 + +//! This module implements a checker for verifying that a non-resource struct does not +//! have resource fields inside it. +use vm::{ + errors::{VMStaticViolation, VerificationError}, + file_format::CompiledModule, + views::ModuleView, + IndexKind, +}; + +pub struct ResourceTransitiveChecker<'a> { + module_view: ModuleView<'a, CompiledModule>, +} + +impl<'a> ResourceTransitiveChecker<'a> { + pub fn new(module: &'a CompiledModule) -> Self { + Self { + module_view: ModuleView::new(module), + } + } + + pub fn verify(self) -> Vec { + let mut errors = vec![]; + for (idx, struct_def) in self.module_view.structs().enumerate() { + let def_is_resource = struct_def.is_resource(); + if !def_is_resource { + let mut fields = struct_def.fields(); + let any_resource_field = fields.any(|field| field.type_signature().is_resource()); + if any_resource_field { + errors.push(VerificationError { + kind: IndexKind::StructDefinition, + idx, + err: VMStaticViolation::InvalidResourceField, + }); + } + } + } + errors + } +} diff --git a/language/bytecode_verifier/src/signature.rs b/language/bytecode_verifier/src/signature.rs new file mode 100644 index 0000000000000..7e7a25cb87226 --- /dev/null +++ b/language/bytecode_verifier/src/signature.rs @@ -0,0 +1,72 @@ +// Copyright (c) The Libra Core Contributors +// SPDX-License-Identifier: Apache-2.0 + +//! This module implements a checker for verifying signature tokens used in types of function +//! parameters, locals, and fields of structs are well-formed. References can only occur at the +//! top-level in all tokens. Additionally, references cannot occur at all in field types. +use vm::{ + checks::SignatureCheck, errors::VerificationError, file_format::CompiledModule, + views::ModuleView, IndexKind, +}; + +pub struct SignatureChecker<'a> { + module_view: ModuleView<'a, CompiledModule>, +} + +impl<'a> SignatureChecker<'a> { + pub fn new(module: &'a CompiledModule) -> Self { + Self { + module_view: ModuleView::new(module), + } + } + + pub fn verify(self) -> Vec { + let mut errors: Vec> = vec![]; + + errors.push(Self::verify_impl( + IndexKind::TypeSignature, + self.module_view.type_signatures(), + )); + errors.push(Self::verify_impl( + IndexKind::FunctionSignature, + self.module_view.function_signatures(), + )); + errors.push(Self::verify_impl( + IndexKind::LocalsSignature, + self.module_view.locals_signatures(), + )); + + let signature_ref_errors = self + .module_view + .fields() + .enumerate() + .filter_map(move |(idx, view)| { + view.check_signature_refs() + .map(move |err| VerificationError { + kind: IndexKind::FieldDefinition, + idx, + err, + }) + }) + .collect(); + errors.push(signature_ref_errors); + + errors.into_iter().flatten().collect() + } + + #[inline] + fn verify_impl( + kind: IndexKind, + views: impl Iterator, + ) -> Vec { + views + .enumerate() + .map(move |(idx, view)| { + view.check_signatures() + .into_iter() + .map(move |err| VerificationError { kind, idx, err }) + }) + .flatten() + .collect() + } +} diff --git a/language/bytecode_verifier/src/stack_usage_verifier.rs b/language/bytecode_verifier/src/stack_usage_verifier.rs new file mode 100644 index 0000000000000..232a8203f9158 --- /dev/null +++ b/language/bytecode_verifier/src/stack_usage_verifier.rs @@ -0,0 +1,156 @@ +// Copyright (c) The Libra Core Contributors +// SPDX-License-Identifier: Apache-2.0 + +//! This module implements a checker for verifying that basic blocks in the bytecode instruction +//! sequence of a function use the evaluation stack in a balanced manner. Every basic block, +//! except those that end in Ret (return to caller) opcode, must leave the stack height the +//! same as at the begining of the block. A basic block that ends in Ret opcode must increase +//! the stack height by the number of values returned by the function as indicated in its +//! signature. Additionally, the stack height must not dip below that at the beginning of the +//! block for any basic block. +use crate::{ + code_unit_verifier::VerificationPass, + control_flow_graph::{BasicBlock, VMControlFlowGraph}, +}; +use vm::{ + access::{BaseAccess, ModuleAccess}, + errors::VMStaticViolation, + file_format::{Bytecode, CompiledModule, FunctionDefinition}, + views::FunctionDefinitionView, +}; + +pub struct StackUsageVerifier<'a> { + module: &'a CompiledModule, + function_definition_view: FunctionDefinitionView<'a, CompiledModule>, + cfg: &'a VMControlFlowGraph, +} + +impl<'a> VerificationPass<'a> for StackUsageVerifier<'a> { + fn new( + module: &'a CompiledModule, + function_definition: &'a FunctionDefinition, + cfg: &'a VMControlFlowGraph, + ) -> Self { + let function_definition_view = FunctionDefinitionView::new(module, function_definition); + Self { + module, + function_definition_view, + cfg, + } + } + + fn verify(self) -> Vec { + let mut errors = vec![]; + for (_, block) in self.cfg.blocks.iter() { + errors.append(&mut self.verify_block(&block)); + } + errors + } +} + +impl<'a> StackUsageVerifier<'a> { + fn verify_block(&self, block: &BasicBlock) -> Vec { + let code = &self.function_definition_view.code().code; + let mut stack_size_increment = 0; + for i in block.entry..=block.exit { + stack_size_increment += self.instruction_effect(&code[i as usize]); + if stack_size_increment < 0 { + return vec![VMStaticViolation::NegativeStackSizeInsideBlock( + block.entry as usize, + i as usize, + )]; + } + } + + if stack_size_increment == 0 { + vec![] + } else { + vec![VMStaticViolation::PositiveStackSizeAtBlockEnd( + block.entry as usize, + )] + } + } + + fn instruction_effect(&self, instruction: &Bytecode) -> i32 { + match instruction { + Bytecode::Pop | Bytecode::BrTrue(_) | Bytecode::BrFalse(_) | Bytecode::StLoc(_) => -1, + + Bytecode::Ret => { + let return_count = self.function_definition_view.signature().return_count() as i32; + -return_count + } + + Bytecode::Branch(_) | Bytecode::BorrowField(_) => 0, + + Bytecode::LdConst(_) + | Bytecode::LdAddr(_) + | Bytecode::LdStr(_) + | Bytecode::LdTrue + | Bytecode::LdFalse + | Bytecode::CopyLoc(_) + | Bytecode::MoveLoc(_) + | Bytecode::BorrowLoc(_) => 1, + + Bytecode::Call(idx) => { + let function_handle = self.module.function_handle_at(*idx); + let signature = self.module.function_signature_at(function_handle.signature); + let arg_count = signature.arg_types.len() as i32; + let return_count = signature.return_types.len() as i32; + return_count - arg_count + } + + Bytecode::Pack(idx) => { + let struct_definition = self.module.struct_def_at(*idx); + let num_fields = i32::from(struct_definition.field_count); + 1 - num_fields + } + + Bytecode::Unpack(idx) => { + let struct_definition = self.module.struct_def_at(*idx); + let num_fields = i32::from(struct_definition.field_count); + num_fields - 1 + } + + Bytecode::ReadRef => 0, + + Bytecode::WriteRef | Bytecode::Assert => -2, + + Bytecode::Add + | Bytecode::Sub + | Bytecode::Mul + | Bytecode::Mod + | Bytecode::Div + | Bytecode::BitOr + | Bytecode::BitAnd + | Bytecode::Xor + | Bytecode::Or + | Bytecode::And + | Bytecode::Eq + | Bytecode::Neq + | Bytecode::Lt + | Bytecode::Gt + | Bytecode::Le + | Bytecode::Ge => -1, + + Bytecode::Not => 0, + + Bytecode::FreezeRef => 0, + Bytecode::Exists(_) => 0, + Bytecode::BorrowGlobal(_) => 0, + Bytecode::ReleaseRef => -1, + Bytecode::MoveFrom(_) => 0, + Bytecode::MoveToSender(_) => -1, + + Bytecode::GetTxnGasUnitPrice + | Bytecode::GetTxnMaxGasUnits + | Bytecode::GetGasRemaining + | Bytecode::GetTxnPublicKey + | Bytecode::GetTxnSequenceNumber + | Bytecode::GetTxnSenderAddress => 1, + Bytecode::CreateAccount => -1, + Bytecode::EmitEvent => -3, + + Bytecode::LdByteArray(_) => 1, + } + } +} diff --git a/language/bytecode_verifier/src/struct_defs.rs b/language/bytecode_verifier/src/struct_defs.rs new file mode 100644 index 0000000000000..1570936db511b --- /dev/null +++ b/language/bytecode_verifier/src/struct_defs.rs @@ -0,0 +1,121 @@ +// Copyright (c) The Libra Core Contributors +// SPDX-License-Identifier: Apache-2.0 + +//! This module provides a checker for verifing that struct definitions in a module are not +//! recursive. Since the module dependency graph is acylic by construction, applying this checker to +//! each module in isolation guarantees that there is no structural recursion globally. +use petgraph::{algo::toposort, Directed, Graph}; +use std::collections::BTreeMap; +use vm::{ + access::ModuleAccess, + errors::{VMStaticViolation, VerificationError}, + file_format::{CompiledModule, StructDefinitionIndex, StructHandleIndex, TableIndex}, + internals::ModuleIndex, + views::StructDefinitionView, + IndexKind, +}; + +pub struct RecursiveStructDefChecker<'a> { + module: &'a CompiledModule, +} + +impl<'a> RecursiveStructDefChecker<'a> { + pub fn new(module: &'a CompiledModule) -> Self { + Self { module } + } + + pub fn verify(self) -> Vec { + let graph_builder = StructDefGraphBuilder::new(self.module); + + let graph = graph_builder.build(); + + // toposort is iterative while petgraph::algo::is_cyclic_directed is recursive. Prefer + // the iterative solution here as this code may be dealing with untrusted data. + match toposort(&graph, None) { + Ok(_) => { + // Is the result of this useful elsewhere? + vec![] + } + Err(cycle) => { + let sd_idx = graph[cycle.node_id()]; + vec![VerificationError { + kind: IndexKind::StructDefinition, + idx: sd_idx.into_index(), + err: VMStaticViolation::RecursiveStructDef, + }] + } + } + } +} + +/// Given a module, build a graph of struct definitions. This is useful when figuring out whether +/// the struct definitions in module form a cycle. +pub struct StructDefGraphBuilder<'a> { + module: &'a CompiledModule, + /// Used to follow field definitions' signatures' struct handles to their struct definitions. + handle_to_def: BTreeMap, +} + +impl<'a> StructDefGraphBuilder<'a> { + pub fn new(module: &'a CompiledModule) -> Self { + let mut handle_to_def = BTreeMap::new(); + // the mapping from struct definitions to struct handles is already checked to be 1-1 by + // DuplicationChecker + for (idx, struct_def) in module.struct_defs.iter().enumerate() { + let sh_idx = struct_def.struct_handle; + handle_to_def.insert(sh_idx, StructDefinitionIndex::new(idx as TableIndex)); + } + + Self { + module, + handle_to_def, + } + } + + pub fn build(self) -> Graph { + let mut graph = Graph::new(); + + let struct_def_count = self.module.struct_defs.len(); + + let nodes: Vec<_> = (0..struct_def_count) + .map(|idx| graph.add_node(StructDefinitionIndex::new(idx as TableIndex))) + .collect(); + + for idx in 0..struct_def_count { + let sd_idx = StructDefinitionIndex::new(idx as TableIndex); + for followed_idx in self.member_struct_defs(sd_idx) { + graph.add_edge(nodes[idx], nodes[followed_idx.into_index()], ()); + } + } + + graph + } + + fn member_struct_defs( + &'a self, + idx: StructDefinitionIndex, + ) -> impl Iterator + 'a { + let struct_def = self.module.struct_def_at(idx); + let struct_def = StructDefinitionView::new(self.module, struct_def); + let fields = struct_def.fields(); + let handle_to_def = &self.handle_to_def; + + fields.filter_map(move |field| { + let type_signature = field.type_signature(); + let sh_idx = match type_signature.token().struct_index() { + Some(sh_idx) => sh_idx, + None => { + // This field doesn't refer to a struct. + return None; + } + }; + match handle_to_def.get(&sh_idx) { + Some(sd_idx) => Some(*sd_idx), + None => { + // This field refers to a struct in another module. + None + } + } + }) + } +} diff --git a/language/bytecode_verifier/src/verifier.rs b/language/bytecode_verifier/src/verifier.rs new file mode 100644 index 0000000000000..8bbaa42a27db8 --- /dev/null +++ b/language/bytecode_verifier/src/verifier.rs @@ -0,0 +1,237 @@ +// Copyright (c) The Libra Core Contributors +// SPDX-License-Identifier: Apache-2.0 + +//! This module contains the public APIs supported by the bytecode verifier. +use crate::{ + check_duplication::DuplicationChecker, code_unit_verifier::CodeUnitVerifier, + resources::ResourceTransitiveChecker, signature::SignatureChecker, + struct_defs::RecursiveStructDefChecker, +}; +use std::collections::BTreeMap; +use types::language_storage::CodeKey; +use vm::{ + checks::BoundsChecker, + errors::{VMStaticViolation, VerificationError}, + file_format::{CompiledModule, CompiledScript}, + resolver::Resolver, + views::{ModuleView, ViewInternals}, + IndexKind, +}; + +/// Verification of a module is performed through a sequnence of checks. +/// There is a partial order on the checs. For example, bounds checking must precede all other +/// checks and duplication check must precede the structural recursion check. In general, later +/// checks are more expensive. +pub fn verify_module(module: CompiledModule) -> (CompiledModule, Vec) { + let mut errors = BoundsChecker::new(&module).verify(); + if errors.is_empty() { + errors.append(&mut DuplicationChecker::new(&module).verify()); + } + if errors.is_empty() { + errors.append(&mut SignatureChecker::new(&module).verify()); + errors.append(&mut ResourceTransitiveChecker::new(&module).verify()); + } + if errors.is_empty() { + errors.append(&mut RecursiveStructDefChecker::new(&module).verify()); + } + if errors.is_empty() { + errors.append(&mut CodeUnitVerifier::new(&module).verify()); + } + (module, errors) +} + +/// Verification of a script is done in two steps: +/// - Convert the script into a module and run all the usual verification performed on a module +/// - Check the signature of the main function of the script +/// This approach works because critical operations such as MoveFrom, MoveToSender, and BorrowGlobal +/// that are not allowed in the script function take a StructDefinitionIndex as an argument. +/// Since the module constructed from a script is guaranteed to have an empty vector of struct +/// definitions, the bounds checker will catch any occurrences of these illegal operations. +pub fn verify_script(script: CompiledScript) -> (CompiledScript, Vec) { + let fake_module = script.into_module(); + let (fake_module, mut errors) = verify_module(fake_module); + let script = fake_module.into_script(); + errors.append( + &mut verify_main_signature(&script) + .into_iter() + .map(move |err| VerificationError { + kind: IndexKind::FunctionDefinition, + idx: 0, + err, + }) + .collect(), + ); + (script, errors) +} + +/// This function checks the extra requirements on the signature of the main function of a script. +pub fn verify_main_signature(script: &CompiledScript) -> Vec { + let function_handle = &script.function_handles[script.main.function.0 as usize]; + let function_signature = &script.function_signatures[function_handle.signature.0 as usize]; + if !function_signature.return_types.is_empty() { + return vec![VMStaticViolation::InvalidMainFunctionSignature]; + } + for arg_type in &function_signature.arg_types { + if !arg_type.is_primitive() { + return vec![VMStaticViolation::InvalidMainFunctionSignature]; + } + } + vec![] +} + +/// Verification of a module in isolation (using verify_module) trusts that struct and function +/// handles not implemented in the module are declared correctly. The following procedure justifies +/// this trust by checking that these declarations match the definitions in the module dependencies. +/// Each dependency of 'module' is looked up in 'dependencies'. If not found, an error is included +/// in the returned list of errors. If found, usage of types and functions of the dependency in +/// 'module' is checked against the declarations in the found module and mismatch errors are +/// returned. +pub fn verify_module_dependencies( + module: CompiledModule, + dependencies: &[CompiledModule], +) -> (CompiledModule, Vec) { + let module_code_key = module.self_code_key(); + let mut dependency_map = BTreeMap::new(); + for dependency in dependencies { + let dependency_code_key = dependency.self_code_key(); + if module_code_key != dependency_code_key { + dependency_map.insert(dependency_code_key, dependency); + } + } + let mut errors = vec![]; + let module_view = ModuleView::new(&module); + errors.append(&mut verify_struct_kind(&module_view, &dependency_map)); + errors.append(&mut verify_function_visibility_and_type( + &module_view, + &dependency_map, + )); + errors.append(&mut verify_all_dependencies_provided( + &module_view, + &dependency_map, + )); + (module, errors) +} + +/// Verifying the dependencies of a script follows the same recipe as verify_script---convert to a +/// module and invoke verify_module_dependencies. Each dependency of 'script' is looked up in +/// 'dependencies'. If not found, an error is included in the returned list of errors. If found, +/// usage of types and functions of the dependency in 'script' is checked against the +/// declarations in the found module and mismatch errors are returned. +pub fn verify_script_dependencies( + script: CompiledScript, + dependencies: &[CompiledModule], +) -> (CompiledScript, Vec) { + let fake_module = script.into_module(); + let (fake_module, errors) = verify_module_dependencies(fake_module, dependencies); + let script = fake_module.into_script(); + (script, errors) +} + +fn verify_all_dependencies_provided( + module_view: &ModuleView, + dependency_map: &BTreeMap, +) -> Vec { + let mut errors = vec![]; + for (idx, module_handle_view) in module_view.module_handles().enumerate() { + let module_id = module_handle_view.module_code_key(); + if idx != CompiledModule::IMPLEMENTED_MODULE_INDEX as usize + && !dependency_map.contains_key(&module_id) + { + errors.push(VerificationError { + kind: IndexKind::ModuleHandle, + idx, + err: VMStaticViolation::MissingDependency, + }); + } + } + errors +} + +fn verify_struct_kind( + module_view: &ModuleView, + dependency_map: &BTreeMap, +) -> Vec { + let mut errors = vec![]; + for (idx, struct_handle_view) in module_view.struct_handles().enumerate() { + let owner_module_id = struct_handle_view.module_code_key(); + if !dependency_map.contains_key(&owner_module_id) { + continue; + } + let struct_name = struct_handle_view.name(); + let owner_module = &dependency_map[&owner_module_id]; + let owner_module_view = ModuleView::new(*owner_module); + if let Some(struct_definition_view) = owner_module_view.struct_definition(struct_name) { + if struct_handle_view.is_resource() != struct_definition_view.is_resource() { + errors.push(VerificationError { + kind: IndexKind::StructHandle, + idx, + err: VMStaticViolation::TypeMismatch, + }); + } + } else { + errors.push(VerificationError { + kind: IndexKind::StructHandle, + idx, + err: VMStaticViolation::LookupFailed, + }); + } + } + errors +} + +fn verify_function_visibility_and_type( + module_view: &ModuleView, + dependency_map: &BTreeMap, +) -> Vec { + let resolver = Resolver::new(module_view.as_inner()); + let mut errors = vec![]; + for (idx, function_handle_view) in module_view.function_handles().enumerate() { + let owner_module_id = function_handle_view.module_code_key(); + if !dependency_map.contains_key(&owner_module_id) { + continue; + } + let function_name = function_handle_view.name(); + let owner_module = &dependency_map[&owner_module_id]; + let owner_module_view = ModuleView::new(*owner_module); + if let Some(function_definition_view) = owner_module_view.function_definition(function_name) + { + if function_definition_view.is_public() { + let function_definition_signature = function_definition_view.signature().as_inner(); + match resolver + .import_function_signature(owner_module, &function_definition_signature) + { + Ok(imported_function_signature) => { + let function_handle_signature = function_handle_view.signature().as_inner(); + if imported_function_signature != *function_handle_signature { + errors.push(VerificationError { + kind: IndexKind::FunctionHandle, + idx, + err: VMStaticViolation::TypeMismatch, + }); + } + } + Err(err) => { + errors.push(VerificationError { + kind: IndexKind::FunctionHandle, + idx, + err, + }); + } + } + } else { + errors.push(VerificationError { + kind: IndexKind::FunctionHandle, + idx, + err: VMStaticViolation::VisibilityMismatch, + }); + } + } else { + errors.push(VerificationError { + kind: IndexKind::FunctionHandle, + idx, + err: VMStaticViolation::LookupFailed, + }); + } + } + errors +} diff --git a/language/bytecode_verifier/tests/bounds_tests.proptest-regressions b/language/bytecode_verifier/tests/bounds_tests.proptest-regressions new file mode 100644 index 0000000000000..489fbbdff4871 --- /dev/null +++ b/language/bytecode_verifier/tests/bounds_tests.proptest-regressions @@ -0,0 +1,10 @@ +# Seeds for failure cases proptest has generated in the past. It is +# automatically read and these particular cases re-run before any +# novel cases are generated. +# +# It is recommended to check this file in to source control so that +# everyone who runs the test benefits from these saved cases. +cc 2beb0a0e65962432af560e626fa109d269b07db8807968413425f0bb14bb3667 # shrinks to module = CompiledModule: { struct_handles: [ StructHandle { module: ModuleHandleIndex(0), name: StringPoolIndex(0), is_resource: false }, StructHandle { module: ModuleHandleIndex(0), name: StringPoolIndex(0), is_resource: false },] function_handles: [ FunctionHandle { module: ModuleHandleIndex(0), name: StringPoolIndex(0), signature: FunctionSignatureIndex(0) }, FunctionHandle { module: ModuleHandleIndex(0), name: StringPoolIndex(0), signature: FunctionSignatureIndex(1) },] struct_defs: [ StructDefinition { struct_handle: 1, access: 0x4, field_count: 0, fields: 0 },] field_defs: [] function_defs: [ FunctionDefinition { function: 1, access: 0x2, code: CodeUnit { max_stack_size: 0, locals: 0 code: [] } },] type_signatures: [ TypeSignature(Unit),] function_signatures: [ FunctionSignature { return_type: Unit, arg_types: [] }, FunctionSignature { return_type: Unit, arg_types: [] },] locals_signatures: [ LocalsSignature([]),] string_pool: [ "",] address_pool: [ Address([0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0]),] } +cc c14ae393a6eefae82c0f4ede2acaa0aa0e993c1bba3fe3e5958e6e31cb5d2957 # shrinks to module = CompiledModule: { module_handles: [ ModuleHandle { address: AddressPoolIndex(0), name: StringPoolIndex(0) },] struct_handles: [ StructHandle { module: ModuleHandleIndex(0), name: StringPoolIndex(0), is_resource: false }, StructHandle { module: ModuleHandleIndex(0), name: StringPoolIndex(0), is_resource: false },] function_handles: [ FunctionHandle { module: ModuleHandleIndex(0), name: StringPoolIndex(0), signature: FunctionSignatureIndex(0) }, FunctionHandle { module: ModuleHandleIndex(0), name: StringPoolIndex(0), signature: FunctionSignatureIndex(1) },] struct_defs: [ StructDefinition { struct_handle: 1, access: 0x4, field_count: 0, fields: 0 },] field_defs: [] function_defs: [ FunctionDefinition { function: 1, access: 0x2, code: CodeUnit { max_stack_size: 0, locals: 0 code: [] } },] type_signatures: [ TypeSignature(Unit),] function_signatures: [ FunctionSignature { return_type: Unit, arg_types: [] }, FunctionSignature { return_type: Unit, arg_types: [] },] locals_signatures: [ LocalsSignature([]),] string_pool: [ "",] address_pool: [ Address([0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0]),] } , oob_mutations = [] +cc 88615e15ef42d29405cd91d6d0a573ccbeb833d0c7471f718ee794bc5ba399ca # shrinks to module = CompiledModule: { module_handles: [ ModuleHandle { address: AddressPoolIndex(0), name: StringPoolIndex(0) },] struct_handles: [ StructHandle { module: ModuleHandleIndex(0), name: StringPoolIndex(0), is_resource: false }, StructHandle { module: ModuleHandleIndex(0), name: StringPoolIndex(0), is_resource: false }, StructHandle { module: ModuleHandleIndex(0), name: StringPoolIndex(0), is_resource: false },] function_handles: [ FunctionHandle { module: ModuleHandleIndex(0), name: StringPoolIndex(0), signature: FunctionSignatureIndex(0) }, FunctionHandle { module: ModuleHandleIndex(0), name: StringPoolIndex(0), signature: FunctionSignatureIndex(1) },] struct_defs: [ StructDefinition { struct_handle: 1, access: 0x4, field_count: 0, fields: 0 }, StructDefinition { struct_handle: 2, access: 0x4, field_count: 0, fields: 0 },] field_defs: [] function_defs: [ FunctionDefinition { function: 1, access: 0x2, code: CodeUnit { max_stack_size: 0, locals: 0 code: [] } },] type_signatures: [ TypeSignature(Unit),] function_signatures: [ FunctionSignature { return_type: Unit, arg_types: [] }, FunctionSignature { return_type: Unit, arg_types: [] },] locals_signatures: [ LocalsSignature([]),] string_pool: [ "",] address_pool: [ Address([0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0]),] } , oob_mutations = [OutOfBoundsMutation { src_kind: StructDefinition, src_idx: Index(0), dst_kind: FieldDefinition, offset: 0 }] +cc a34039f5d57751762a6eacf3ca3a2857781fb0bd0af0b7a06a9427f896f29aa9 # shrinks to module = CompiledModule: { module_handles: [ ModuleHandle { address: AddressPoolIndex(0), name: StringPoolIndex(0) },] struct_handles: [ StructHandle { module: ModuleHandleIndex(0), name: StringPoolIndex(0), is_resource: false }, StructHandle { module: ModuleHandleIndex(0), name: StringPoolIndex(0), is_resource: false },] function_handles: [ FunctionHandle { module: ModuleHandleIndex(0), name: StringPoolIndex(0), signature: FunctionSignatureIndex(0) }, FunctionHandle { module: ModuleHandleIndex(0), name: StringPoolIndex(0), signature: FunctionSignatureIndex(1) },] struct_defs: [ StructDefinition { struct_handle: 1, access: 0x2, field_count: 0, fields: 0 },] field_defs: [] function_defs: [ FunctionDefinition { function: 1, access: 0x0, code: CodeUnit { max_stack_size: 0, locals: 0 code: [ BrTrue(1),] } },] type_signatures: [ TypeSignature(Unit), TypeSignature(Unit),] function_signatures: [ FunctionSignature { return_type: Unit, arg_types: [] }, FunctionSignature { return_type: Unit, arg_types: [] },] locals_signatures: [ LocalsSignature([]),] string_pool: [ "",] address_pool: [ Address([0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0]),] } , oob_mutations = [] diff --git a/language/bytecode_verifier/tests/bounds_tests.rs b/language/bytecode_verifier/tests/bounds_tests.rs new file mode 100644 index 0000000000000..670efea0c461f --- /dev/null +++ b/language/bytecode_verifier/tests/bounds_tests.rs @@ -0,0 +1,115 @@ +// Copyright (c) The Libra Core Contributors +// SPDX-License-Identifier: Apache-2.0 + +use invalid_mutations::bounds::{ + ApplyCodeUnitBoundsContext, ApplyOutOfBoundsContext, CodeUnitBoundsMutation, + OutOfBoundsMutation, +}; +use proptest::{collection::vec, prelude::*}; +use types::{account_address::AccountAddress, byte_array::ByteArray}; +use vm::{ + checks::BoundsChecker, + errors::{VMStaticViolation, VerificationError}, + proptest_types::CompiledModuleStrategyGen, + CompiledModule, IndexKind, +}; + +proptest! { + #[test] + fn valid_bounds(module in CompiledModule::valid_strategy(20)) { + let bounds_checker = BoundsChecker::new(&module); + prop_assert_eq!(bounds_checker.verify(), vec![]); + } +} + +/// Ensure that valid modules that don't have any members (e.g. function args, struct fields) pass +/// bounds checks. +/// +/// There are some potentially tricky edge cases around ranges that are captured here. +#[test] +fn valid_bounds_no_members() { + let mut gen = CompiledModuleStrategyGen::new(20); + gen.member_count(0); + proptest!(|(module in gen.generate())| { + let bounds_checker = BoundsChecker::new(&module); + prop_assert_eq!(bounds_checker.verify(), vec![]) + }); +} + +proptest! { + #[test] + fn invalid_out_of_bounds( + module in CompiledModule::valid_strategy(20), + oob_mutations in vec(OutOfBoundsMutation::strategy(), 0..40), + ) { + let mut module = module; + let mut expected_violations = { + let oob_context = ApplyOutOfBoundsContext::new(&mut module, oob_mutations); + oob_context.apply() + }; + expected_violations.sort(); + + let bounds_checker = BoundsChecker::new(&module); + let mut actual_violations = bounds_checker.verify(); + actual_violations.sort(); + prop_assert_eq!(expected_violations, actual_violations); + } + + #[test] + fn code_unit_out_of_bounds( + module in CompiledModule::valid_strategy(20), + mutations in vec(CodeUnitBoundsMutation::strategy(), 0..40), + ) { + let mut module = module; + let mut expected_violations = { + let context = ApplyCodeUnitBoundsContext::new(&mut module, mutations); + context.apply() + }; + expected_violations.sort(); + + let bounds_checker = BoundsChecker::new(&module); + let mut actual_violations = bounds_checker.verify(); + actual_violations.sort(); + prop_assert_eq!(expected_violations, actual_violations); + } + + #[test] + fn no_module_handles( + string_pool in vec(".*", 0..20), + address_pool in vec(any::(), 0..20), + byte_array_pool in vec(any::(), 0..20), + ) { + // If there are no module handles, the only other things that can be stored are intrinsic + // data. + let mut module = CompiledModule::default(); + module.string_pool = string_pool; + module.address_pool = address_pool; + module.byte_array_pool = byte_array_pool; + + let bounds_checker = BoundsChecker::new(&module); + let actual_violations = bounds_checker.verify(); + prop_assert_eq!( + actual_violations, + vec![ + VerificationError { + kind: IndexKind::ModuleHandle, + idx: 0, + err: VMStaticViolation::NoModuleHandles, + }, + ] + ); + } +} + +proptest! { + // Generating arbitrary compiled modules is really slow, possibly because of + // https://github.com/AltSysrq/proptest/issues/143. + #![proptest_config(ProptestConfig::with_cases(16))] + + /// Make sure that garbage inputs don't crash the bounds checker. + #[test] + fn garbage_inputs(module in any_with::(16)) { + let bounds_checker = BoundsChecker::new(&module); + bounds_checker.verify(); + } +} diff --git a/language/bytecode_verifier/tests/duplication_tests.rs b/language/bytecode_verifier/tests/duplication_tests.rs new file mode 100644 index 0000000000000..1e329d56bddf8 --- /dev/null +++ b/language/bytecode_verifier/tests/duplication_tests.rs @@ -0,0 +1,15 @@ +// Copyright (c) The Libra Core Contributors +// SPDX-License-Identifier: Apache-2.0 + +use bytecode_verifier::DuplicationChecker; +use proptest::prelude::*; +use vm::{checks::BoundsChecker, file_format::CompiledModule}; + +proptest! { + #[test] + fn valid_duplication(module in CompiledModule::valid_strategy(20)) { + prop_assert!(BoundsChecker::new(&module).verify().is_empty()); + let duplication_checker = DuplicationChecker::new(&module); + prop_assert!(!duplication_checker.verify().is_empty()); + } +} diff --git a/language/bytecode_verifier/tests/resources_tests.rs b/language/bytecode_verifier/tests/resources_tests.rs new file mode 100644 index 0000000000000..43da53dea6d3b --- /dev/null +++ b/language/bytecode_verifier/tests/resources_tests.rs @@ -0,0 +1,15 @@ +// Copyright (c) The Libra Core Contributors +// SPDX-License-Identifier: Apache-2.0 + +use bytecode_verifier::ResourceTransitiveChecker; +use proptest::prelude::*; +use vm::{checks::BoundsChecker, file_format::CompiledModule}; + +proptest! { + #[test] + fn valid_resource_transitivity(module in CompiledModule::valid_strategy(20)) { + prop_assert!(BoundsChecker::new(&module).verify().is_empty()); + let resource_checker = ResourceTransitiveChecker::new(&module); + prop_assert!(resource_checker.verify().is_empty()); + } +} diff --git a/language/bytecode_verifier/tests/signature_tests.rs b/language/bytecode_verifier/tests/signature_tests.rs new file mode 100644 index 0000000000000..8b6082f077267 --- /dev/null +++ b/language/bytecode_verifier/tests/signature_tests.rs @@ -0,0 +1,71 @@ +// Copyright (c) The Libra Core Contributors +// SPDX-License-Identifier: Apache-2.0 + +use bytecode_verifier::SignatureChecker; +use invalid_mutations::signature::{ + ApplySignatureDoubleRefContext, ApplySignatureFieldRefContext, DoubleRefMutation, + FieldRefMutation, +}; +use proptest::{collection::vec, prelude::*}; +use vm::{checks::BoundsChecker, errors::VMStaticViolation, file_format::CompiledModule}; + +proptest! { + #[test] + fn valid_signatures(module in CompiledModule::valid_strategy(20)) { + prop_assert!(BoundsChecker::new(&module).verify().is_empty()); + let signature_checker = SignatureChecker::new(&module); + prop_assert_eq!(signature_checker.verify(), vec![]); + } + + #[test] + fn double_refs( + module in CompiledModule::valid_strategy(20), + mutations in vec(DoubleRefMutation::strategy(), 0..40), + ) { + let mut module = module; + let mut expected_violations = { + let context = ApplySignatureDoubleRefContext::new(&mut module, mutations); + context.apply() + }; + expected_violations.sort(); + + prop_assert!(BoundsChecker::new(&module).verify().is_empty()); + let signature_checker = SignatureChecker::new(&module); + + let actual_violations = signature_checker.verify(); + // Since some type signatures are field definition references as well, actual_violations + // will also contain VMStaticViolation::InvalidFieldDefReference errors -- filter those + // out. + let mut actual_violations: Vec<_> = actual_violations + .into_iter() + .filter(|err| match &err.err { + VMStaticViolation::InvalidFieldDefReference(..) => false, + _ => true, + }) + .collect(); + actual_violations.sort(); + prop_assert_eq!(expected_violations, actual_violations); + } + + #[test] + fn field_def_references( + module in CompiledModule::valid_strategy(20), + mutations in vec(FieldRefMutation::strategy(), 0..40), + ) { + let mut module = module; + let mut expected_violations = { + let context = ApplySignatureFieldRefContext::new(&mut module, mutations); + context.apply() + }; + expected_violations.sort(); + + prop_assert!(BoundsChecker::new(&module).verify().is_empty()); + let signature_checker = SignatureChecker::new(&module); + + let mut actual_violations = signature_checker.verify(); + // Note that this shouldn't cause any InvalidSignatureToken errors because there are no + // double references involved. So no filtering is required here. + actual_violations.sort(); + prop_assert_eq!(expected_violations, actual_violations); + } +} diff --git a/language/bytecode_verifier/tests/struct_defs_tests.rs b/language/bytecode_verifier/tests/struct_defs_tests.rs new file mode 100644 index 0000000000000..987343529395b --- /dev/null +++ b/language/bytecode_verifier/tests/struct_defs_tests.rs @@ -0,0 +1,15 @@ +// Copyright (c) The Libra Core Contributors +// SPDX-License-Identifier: Apache-2.0 + +use bytecode_verifier::RecursiveStructDefChecker; +use proptest::prelude::*; +use vm::{checks::BoundsChecker, file_format::CompiledModule}; + +proptest! { + #[test] + fn valid_recursive_struct_defs(module in CompiledModule::valid_strategy(20)) { + prop_assert!(BoundsChecker::new(&module).verify().is_empty()); + let recursive_checker = RecursiveStructDefChecker::new(&module); + prop_assert!(recursive_checker.verify().is_empty()); + } +} diff --git a/language/compiler/Cargo.toml b/language/compiler/Cargo.toml new file mode 100644 index 0000000000000..0c480a15d4351 --- /dev/null +++ b/language/compiler/Cargo.toml @@ -0,0 +1,26 @@ +[package] +name = "compiler" +version = "0.1.0" +authors = ["Libra Association "] +license = "Apache-2.0" +publish = false +edition = "2018" +build = "build.rs" + +[dependencies] +bytecode_verifier = { path = "../bytecode_verifier" } +failure = { path = "../../common/failure_ext", package = "failure_ext" } +types = { path = "../../types" } +vm = { path = "../vm" } +lalrpop-util = "0.16.3" +log = "0.4.6" +codespan = "0.1.3" +codespan-reporting = "0.1.4" +serde = { version = "1.0.89", features = ["derive"] } +serde_json = "1.0.38" +hex = "0.3.2" +regex = "1.1.6" +structopt = "0.2.15" + +[build-dependencies] +lalrpop = "0.16.3" diff --git a/language/compiler/README.md b/language/compiler/README.md new file mode 100644 index 0000000000000..48a6bf0200865 --- /dev/null +++ b/language/compiler/README.md @@ -0,0 +1,79 @@ +--- +id: ir-to-bytecode +title: Move IR Compiler +custom_edit_url: https://github.com/libra/libra/edit/master/language/compiler/README.md +--- + +# Move IR Compiler + +## Summary + +The Move IR compiler compiles the Move IR down to its bytecode representation. + +## Overview + +The Move IR compiler compiles modules and scripts written in Move down to +their respective bytecode representations. The two data types used to +represent these outputs are `CompiledModule` and `CompiledScript`. These +data types are defined in [file_format.rs](https://github.com/libra/libra/blob/master/language/vm/src/file_format.rs). + +Beyond translating Move IR to Move bytecode, the compiler's purpose is as a +testing tool for the bytecode verifier. Because of this, its job is to +output bytecode programs that correspond as closely as possible to the +input IR; optimizations and advanced semantic checks are specifically not +performed during the compilation process. In fact, the compiler goes out of +its way to push these semantic checks into the bytecode, and compile +semantically invalid code in the Move IR to equivalent---semantically +invalid---bytecode programs. The semantics of the compiled bytecode is +then verified by the [bytecode verifier](https://github.com/libra/libra/blob/master/language/bytecode_verifier/README.md). The compiler command line +automatically calls the bytecode verifer at the end of compilation. + +## Command-line options + +```text +USAGE: + compiler [FLAGS] [OPTIONS] + +FLAGS: + -h, --help Prints help information + --no-stdlib Do not automatically compile stdlib dependencies + --no-verify Do not automatically run the bytecode verifier + -s, --script Treat input file as a script (default is to treat file as a module) + -V, --version Prints version information + +OPTIONS: + -o, --output Serialize and write the compiled output to this file + +ARGS: + Path to the Move IR source to compile +``` + +### Example Usage + +To compile a `*.mvir` file: + +> cargo build -β€”bin compiler + +* This will build the compiler+verifier binary. +* The binary can be found at `libra/target/debug/compiler`. +* Alternatively, the binary can be run directly with `cargo run -p compiler`. + +To compile and verify a `*.mvir` module file: +> `compiler a.mvir` + +To compile and verify a `*.mvir` transaction script file: +> `compiler -s *.mvir` + +## Folder Structure + +```text + * + |- parser/ + |- ast.rs # Contains all the data structures used to build the AST representing the parsed Move IR input. + |- syntax.lalrpop # Description of the Move IR language, used by lalrpop to generate a parser. + |- syntax.rs # Parser generated by lalrpop using the description in `syntax.lalrpop` - a clean checkout won't contain this file. + |- compiler.rs # Main compiler logic - converts an AST generated by `syntax.rs` to a `CompiledModule` or `CompiledScript`. + |- main.rs # Compiler driver - parses command line options and calls the parser, compiler, and bytecode verifier. + |- util.rs # Misc compiler utilities. +``` + diff --git a/language/compiler/build.rs b/language/compiler/build.rs new file mode 100644 index 0000000000000..f4592e5c8f1af --- /dev/null +++ b/language/compiler/build.rs @@ -0,0 +1,9 @@ +// Copyright (c) The Libra Core Contributors +// SPDX-License-Identifier: Apache-2.0 + +fn main() { + lalrpop::Configuration::new() + .generate_in_source_tree() + .process() + .unwrap(); +} diff --git a/language/compiler/src/compiler.rs b/language/compiler/src/compiler.rs new file mode 100644 index 0000000000000..6a488963d7983 --- /dev/null +++ b/language/compiler/src/compiler.rs @@ -0,0 +1,2237 @@ +// Copyright (c) The Libra Core Contributors +// SPDX-License-Identifier: Apache-2.0 + +use crate::parser::ast::{ + BinOp, Block, Builtin, Cmd, CopyableVal, Exp, Field, Fields, Function, FunctionBody, + FunctionCall, FunctionSignature as AstFunctionSignature, FunctionVisibility, IfElse, Kind, + Loop, ModuleDefinition, ModuleIdent, ModuleName, Program, Statement, + StructDefinition as MoveStruct, Tag, Type, UnaryOp, Var, Var_, While, +}; +use bytecode_verifier::verifier::{verify_module, verify_module_dependencies}; +use failure::*; +use std::{ + clone::Clone, + collections::{ + hash_map::Entry::{Occupied, Vacant}, + HashMap, VecDeque, + }, +}; +use types::{account_address::AccountAddress, byte_array::ByteArray}; +use vm::{ + errors::VerificationError, + file_format::{ + AddressPoolIndex, ByteArrayPoolIndex, Bytecode, CodeUnit, CompiledModule, CompiledProgram, + CompiledScript, FieldDefinition, FieldDefinitionIndex, FunctionDefinition, + FunctionDefinitionIndex, FunctionHandle, FunctionHandleIndex, FunctionSignature, + FunctionSignatureIndex, LocalsSignature, LocalsSignatureIndex, MemberCount, ModuleHandle, + ModuleHandleIndex, SignatureToken, StringPoolIndex, StructDefinition, + StructDefinitionIndex, StructHandle, StructHandleIndex, TableIndex, TypeSignature, + TypeSignatureIndex, SELF_MODULE_NAME, + }, + printers::TableAccess, +}; + +#[cfg(test)] +#[path = "unit_tests/branch_tests.rs"] +mod branch_tests; +#[cfg(test)] +#[path = "unit_tests/cfg_tests.rs"] +mod cfg_tests; +#[cfg(test)] +#[path = "unit_tests/expression_tests.rs"] +mod expression_tests; +#[cfg(test)] +#[path = "unit_tests/function_tests.rs"] +mod function_tests; +#[cfg(test)] +#[path = "unit_tests/import_tests.rs"] +mod import_tests; +#[cfg(test)] +#[path = "unit_tests/serializer_tests.rs"] +mod serializer_tests; +#[cfg(test)] +#[path = "unit_tests/stdlib_scripts.rs"] +mod stdlib_scripts; + +#[derive(Debug, Default)] +struct LoopInfo { + start_loc: usize, + breaks: Vec, +} + +// Ideally, we should capture all of this info into a CFG, but as we only have structured control +// flow currently, it would be a bit overkill. It will be a necessity if we add arbitrary branches +// in the IR, as is expressible in the bytecode +struct ControlFlowInfo { + // A `break` is reachable iff it was used before a terminal node + reachable_break: bool, + // A terminal node is an infinite loop or a path that always returns + terminal_node: bool, +} + +impl ControlFlowInfo { + fn join(f1: ControlFlowInfo, f2: ControlFlowInfo) -> ControlFlowInfo { + ControlFlowInfo { + reachable_break: f1.reachable_break || f2.reachable_break, + terminal_node: f1.terminal_node && f2.terminal_node, + } + } + fn successor(prev: ControlFlowInfo, next: ControlFlowInfo) -> ControlFlowInfo { + if prev.terminal_node { + prev + } else { + ControlFlowInfo { + reachable_break: prev.reachable_break || next.reachable_break, + terminal_node: next.terminal_node, + } + } + } +} + +// Inferred representation of SignatureToken's +// In essence, it's a signature token with a "bottom" type added +enum InferredType { + // Result of the compiler failing to infer the type of an expression + // Not translatable to a signature token + Anything, + + // Signature tokens + Bool, + U64, + String, + ByteArray, + Address, + Struct(StructHandleIndex), + Reference(Box), + MutableReference(Box), +} + +impl InferredType { + fn from_signature_token(sig_token: &SignatureToken) -> Self { + use InferredType as I; + use SignatureToken as S; + match sig_token { + S::Bool => I::Bool, + S::U64 => I::U64, + S::String => I::String, + S::ByteArray => I::ByteArray, + S::Address => I::Address, + S::Struct(si) => I::Struct(*si), + S::Reference(s_inner) => { + let i_inner = Self::from_signature_token(&*s_inner); + I::Reference(Box::new(i_inner)) + } + S::MutableReference(s_inner) => { + let i_inner = Self::from_signature_token(&*s_inner); + I::MutableReference(Box::new(i_inner)) + } + } + } + + fn get_struct_handle(&self) -> Result { + match self { + InferredType::Anything => bail!("could not infer struct type"), + InferredType::Bool => bail!("no struct type for Bool"), + InferredType::U64 => bail!("no struct type for U64"), + InferredType::String => bail!("no struct type for String"), + InferredType::ByteArray => bail!("no struct type for ByteArray"), + InferredType::Address => bail!("no struct type for Address"), + InferredType::Reference(inner) | InferredType::MutableReference(inner) => { + inner.get_struct_handle() + } + InferredType::Struct(idx) => Ok(*idx), + } + } +} + +// Holds information about a function being compiled. +#[derive(Debug, Default)] +struct FunctionFrame { + local_count: u8, + locals: HashMap, + local_types: LocalsSignature, + // i64 to allow the bytecode verifier to catch errors of + // - negative stack sizes + // - excessivley large stack sizes + // The max stack depth of the file_format is set as u16 + // Theoritically, we could use a BigInt here, but that is probably overkill for any testing + max_stack_depth: i64, + cur_stack_depth: i64, + loops: Vec, +} + +impl FunctionFrame { + fn new() -> FunctionFrame { + FunctionFrame::default() + } + + // Manage the stack info for the function + fn push(&mut self) -> Result<()> { + if self.cur_stack_depth == i64::max_value() { + bail!("ICE Stack depth accounting overflow. The compiler can only support a maximum stack depth of up to i64::max_value") + } + self.cur_stack_depth += 1; + self.max_stack_depth = std::cmp::max(self.max_stack_depth, self.cur_stack_depth); + Ok(()) + } + + fn pop(&mut self) -> Result<()> { + if self.cur_stack_depth == i64::min_value() { + bail!("ICE Stack depth accounting underflow. The compiler can only support a minimum stack depth of up to i64::min_value") + } + self.cur_stack_depth -= 1; + Ok(()) + } + + fn get_local(&self, var: &Var) -> Result { + match self.locals.get(var) { + None => bail!("variable {} undefined", var), + Some(idx) => Ok(*idx), + } + } + + fn get_local_type(&self, idx: u8) -> Result<&SignatureToken> { + match self.local_types.0.get(idx as usize) { + None => bail!("variable {} undefined", idx), + Some(sig_token) => Ok(sig_token), + } + } + + fn define_local(&mut self, var: &Var, type_: SignatureToken) -> Result { + if self.local_count >= u8::max_value() { + bail!("Max number of locals reached"); + } + + let cur_loc_idx = self.local_count; + let loc = var.clone(); + let entry = self.locals.entry(loc); + match entry { + Occupied(_) => bail!("variable redefinition {}", var), + Vacant(e) => { + e.insert(cur_loc_idx); + self.local_types.0.push(type_); + self.local_count += 1; + } + } + Ok(cur_loc_idx) + } + + fn push_loop(&mut self, start_loc: usize) -> Result<()> { + self.loops.push(LoopInfo { + start_loc, + breaks: Vec::new(), + }); + Ok(()) + } + + fn pop_loop(&mut self) -> Result<()> { + match self.loops.pop() { + Some(_) => Ok(()), + None => bail!("Impossible: failed to pop loop!"), + } + } + + fn get_loop_start(&self) -> Result { + match self.loops.last() { + Some(loop_) => Ok(loop_.start_loc), + None => bail!("continue outside loop"), + } + } + + fn push_loop_break(&mut self, loc: usize) -> Result<()> { + match self.loops.last_mut() { + Some(loop_) => { + loop_.breaks.push(loc); + Ok(()) + } + None => bail!("break outside loop"), + } + } + + fn get_loop_breaks(&self) -> Result<&Vec> { + match self.loops.last() { + Some(loop_) => Ok(&loop_.breaks), + None => bail!("Impossible: failed to get loop breaks (no loops in stack)"), + } + } +} + +type ModuleIndex = u8; // 2^8 max number of modules per compilation +type FieldMap = HashMap; + +/// Global scope where all modules are imported. +/// This is a read only scope and holds the compilation references. +/// The handles are in the scope of the compilation unit, the def in the scope of the imported unit. +/// Those maps also help in resolution of fields and functions, which is name based in the IR +/// (as opposed to signature based in the VM - field type, function signature) +#[derive(Debug)] +struct CompilationScope<'a> { + // maps from handles in the compilation unit (reference) to the definitions + // in the imported unit (definitions). + imported_modules: HashMap, + function_definitions: + HashMap>, + // imported modules (external contracts you compile + pub modules: &'a [CompiledModule], +} + +impl<'a> CompilationScope<'a> { + fn new(modules: &[CompiledModule]) -> CompilationScope { + CompilationScope { + imported_modules: HashMap::new(), + function_definitions: HashMap::new(), + modules, + } + } + + // TODO: Change `module_name` to be something like a ModuleIdent when we have better data + // structure for dependency modules input. + fn link_module( + &mut self, + import_name: &str, + module_name: &str, + mh_idx: ModuleHandleIndex, + ) -> Result<()> { + for idx in 0..self.modules.len() { + if self.modules[idx].name() == module_name { + self.imported_modules + .insert(import_name.to_string(), (idx as u8, mh_idx)); + return Ok(()); + } + } + bail!("can't find module {} in dependency list", import_name); + } + + fn link_function( + &mut self, + module_name: &str, + function_name: &str, + fd_idx: FunctionDefinitionIndex, + ) -> Result<()> { + let (module_index, mh_idx) = self.imported_modules[module_name]; + let func_map = self + .function_definitions + .entry(mh_idx) + .or_insert_with(HashMap::new); + func_map.insert(function_name.to_string(), (module_index, fd_idx)); + Ok(()) + } + + fn get_imported_module_impl(&self, name: &str) -> Result<(&CompiledModule, ModuleHandleIndex)> { + match self.imported_modules.get(name) { + None => bail!("no module named {}", name), + Some((module_index, mh_idx)) => Ok((&self.modules[*module_index as usize], *mh_idx)), + } + } + + fn get_imported_module(&self, name: &str) -> Result<&CompiledModule> { + let (module, _) = self.get_imported_module_impl(name)?; + Ok(module) + } + + fn get_imported_module_handle(&self, name: &str) -> Result { + let (_, mh_idx) = self.get_imported_module_impl(name)?; + Ok(mh_idx) + } + + fn get_function_signature( + &self, + mh_idx: ModuleHandleIndex, + name: &str, + ) -> Result<&FunctionSignature> { + let func_map = match self.function_definitions.get(&mh_idx) { + None => bail!( + "no module handle index {} in function definition table", + mh_idx + ), + Some(func_map) => func_map, + }; + + let (module_index, fd_idx) = match func_map.get(name) { + None => bail!("no function {} in module {}", name, mh_idx), + Some(res) => res, + }; + + let module = &self.modules[*module_index as usize]; + + let fh_idx = match module.function_defs.get(fd_idx.0 as usize) { + None => bail!( + "No function definition index {} in function definition table", + fd_idx + ), + Some(function_def) => function_def.function, + }; + + let fh = module.get_function_at(fh_idx)?; + module.get_function_signature_at(fh.signature) + } +} + +#[derive(Debug)] +struct ModuleScope<'a> { + // parent scope, the global module scope + pub compilation_scope: CompilationScope<'a>, + // builds a struct map based on handles and signatures + struct_definitions: HashMap, + field_definitions: HashMap, + function_definitions: HashMap, + // the module being compiled + pub module: CompiledModule, +} + +impl<'a> ModuleScope<'a> { + fn new(module: CompiledModule, modules: &[CompiledModule]) -> ModuleScope { + ModuleScope { + compilation_scope: CompilationScope::new(modules), + struct_definitions: HashMap::new(), + field_definitions: HashMap::new(), + function_definitions: HashMap::new(), + module, + } + } + + fn add_item(item: T, table: &mut Vec) -> Result { + let size = table.len(); + if size >= TABLE_MAX_SIZE { + bail!("Max table size reached!") + } + table.push(item); + Ok(size as TableIndex) + } +} + +#[derive(Debug)] +struct ScriptScope<'a> { + // parent scope, the global module scope + compilation_scope: CompilationScope<'a>, + pub script: CompiledScript, +} + +impl<'a> ScriptScope<'a> { + fn new(script: CompiledScript, modules: &[CompiledModule]) -> ScriptScope { + ScriptScope { + compilation_scope: CompilationScope::new(modules), + script, + } + } + + fn add_item(item: T, table: &mut Vec) -> Result { + let size = table.len(); + if size >= TABLE_MAX_SIZE { + bail!("Max table size reached!") + } + table.push(item); + Ok(size as TableIndex) + } +} + +trait Scope { + fn make_string(&mut self, s: String) -> Result; + fn make_byte_array(&mut self, buf: ByteArray) -> Result; + fn make_address(&mut self, s: AccountAddress) -> Result; + fn make_type_signature(&mut self, s: TypeSignature) -> Result; + fn make_function_signature(&mut self, s: FunctionSignature) -> Result; + fn make_locals_signature(&mut self, s: LocalsSignature) -> Result; + + fn make_module_handle( + &mut self, + addr_idx: AddressPoolIndex, + name_idx: StringPoolIndex, + ) -> Result; + fn make_struct_handle( + &mut self, + module_idx: ModuleHandleIndex, + name_idx: StringPoolIndex, + is_resource: bool, + ) -> Result; + fn make_function_handle( + &mut self, + mh_idx: ModuleHandleIndex, + name_idx: StringPoolIndex, + sig_idx: FunctionSignatureIndex, + ) -> Result; + + fn publish_struct_def( + &mut self, + name: &str, + is_resource: bool, + struct_def: StructDefinition, + ) -> Result; + fn publish_function_def( + &mut self, + function_def: FunctionDefinition, + ) -> Result; + fn publish_field_def(&mut self, field_def: FieldDefinition) -> Result; + fn publish_code(&mut self, name: &str, code: CodeUnit) -> Result<()>; + + fn link_module( + &mut self, + import_name: &str, + module_name: &str, + mh_idx: ModuleHandleIndex, + ) -> Result<()>; + fn link_field( + &mut self, + sh_idx: StructHandleIndex, + name: &str, + fd_idx: FieldDefinitionIndex, + ) -> Result<()>; + fn link_function( + &mut self, + module_name: &str, + function_name: &str, + fd_idx: FunctionDefinitionIndex, + ) -> Result<()>; + + fn get_imported_module(&self, name: &str) -> Result<&CompiledModule>; + fn get_imported_module_handle(&self, name: &str) -> Result; + fn get_next_field_definition_index(&mut self) -> Result; + fn get_struct_def(&self, name: &str) -> Result<(bool, StructDefinitionIndex)>; + fn get_field_def(&self, sh_idx: StructHandleIndex, name: &str) -> Result; + fn get_field_type(&self, sh_idx: StructHandleIndex, name: &str) -> Result<&TypeSignature>; + fn get_function_signature( + &self, + mh_idx: ModuleHandleIndex, + name: &str, + ) -> Result<&FunctionSignature>; + + fn get_name(&self) -> Result; +} + +impl<'a> Scope for ModuleScope<'a> { + fn make_string(&mut self, s: String) -> Result { + ModuleScope::add_item(s, &mut self.module.string_pool).map(StringPoolIndex::new) + } + + fn make_byte_array(&mut self, buf: ByteArray) -> Result { + ModuleScope::add_item(buf, &mut self.module.byte_array_pool).map(ByteArrayPoolIndex::new) + } + + fn make_address(&mut self, addr: AccountAddress) -> Result { + ModuleScope::add_item(addr, &mut self.module.address_pool).map(AddressPoolIndex::new) + } + + fn make_type_signature(&mut self, sig: TypeSignature) -> Result { + ModuleScope::add_item(sig, &mut self.module.type_signatures).map(TypeSignatureIndex::new) + } + + fn make_function_signature( + &mut self, + sig: FunctionSignature, + ) -> Result { + ModuleScope::add_item(sig, &mut self.module.function_signatures) + .map(FunctionSignatureIndex::new) + } + + fn make_locals_signature(&mut self, sig: LocalsSignature) -> Result { + ModuleScope::add_item(sig, &mut self.module.locals_signatures) + .map(LocalsSignatureIndex::new) + } + + fn make_module_handle( + &mut self, + addr_idx: AddressPoolIndex, + name_idx: StringPoolIndex, + ) -> Result { + let mh = ModuleHandle { + address: addr_idx, + name: name_idx, + }; + let size = self.module.module_handles.len(); + if size >= STRUCTS_MAX_SIZE { + bail!("Max table size reached!") + } + self.module.module_handles.push(mh); + Ok(ModuleHandleIndex::new(size as u16)) + } + + fn make_struct_handle( + &mut self, + module_idx: ModuleHandleIndex, + name_idx: StringPoolIndex, + is_resource: bool, + ) -> Result { + let sh = StructHandle { + module: module_idx, + name: name_idx, + is_resource, + }; + let size = self.module.struct_handles.len(); + if size >= TABLE_MAX_SIZE { + bail!("Max table size reached!") + } + self.module.struct_handles.push(sh); + Ok(StructHandleIndex::new(size as u16)) + } + + fn make_function_handle( + &mut self, + mh_idx: ModuleHandleIndex, + name_idx: StringPoolIndex, + sig_idx: FunctionSignatureIndex, + ) -> Result { + let fh = FunctionHandle { + module: mh_idx, + name: name_idx, + signature: sig_idx, + }; + ModuleScope::add_item(fh, &mut self.module.function_handles).map(FunctionHandleIndex::new) + } + + fn publish_struct_def( + &mut self, + name: &str, + is_resource: bool, + struct_def: StructDefinition, + ) -> Result { + let idx = self.module.struct_defs.len(); + if idx >= STRUCTS_MAX_SIZE { + bail!("Max number of structs reached") + } + let sd_idx = StructDefinitionIndex::new(idx as TableIndex); + self.module.struct_defs.push(struct_def); + self.struct_definitions + .insert(name.to_string(), (is_resource, sd_idx)); + Ok(sd_idx) + } + + fn publish_function_def( + &mut self, + function_def: FunctionDefinition, + ) -> Result { + let idx = self.module.function_defs.len(); + if idx >= FUNCTIONS_MAX_SIZE { + bail!("Max number of functions reached") + } + let fd_idx = FunctionDefinitionIndex::new(idx as TableIndex); + self.module.function_defs.push(function_def); + Ok(fd_idx) + } + + /// Compute the index of the next field definition + fn get_next_field_definition_index(&mut self) -> Result { + let idx = self.module.field_defs.len(); + if idx >= FIELDS_MAX_SIZE { + bail!("Max number of fields reached") + } + let fd_idx = FieldDefinitionIndex::new(idx as TableIndex); + Ok(fd_idx) + } + + fn publish_field_def(&mut self, field_def: FieldDefinition) -> Result { + let fd_idx = self.get_next_field_definition_index()?; + self.module.field_defs.push(field_def); + Ok(fd_idx) + } + + fn publish_code(&mut self, name: &str, code: CodeUnit) -> Result<()> { + let fd_idx = match self.function_definitions.get(name) { + None => bail!("Cannot find function {}", name), + Some(def_idx) => def_idx, + }; + let func_def = match self.module.function_defs.get_mut(fd_idx.0 as usize) { + None => bail!("Cannot find function def for {}", name), + Some(func_def) => func_def, + }; + func_def.code = code; + Ok(()) + } + + fn link_module( + &mut self, + import_name: &str, + module_name: &str, + mh_idx: ModuleHandleIndex, + ) -> Result<()> { + self.compilation_scope + .link_module(import_name, module_name, mh_idx) + } + + fn link_field( + &mut self, + sh_idx: StructHandleIndex, + name: &str, + fd_idx: FieldDefinitionIndex, + ) -> Result<()> { + let field_map = self + .field_definitions + .entry(sh_idx) + .or_insert_with(HashMap::new); + field_map.insert(name.to_string(), fd_idx); + Ok(()) + } + + fn link_function( + &mut self, + module_name: &str, + function_name: &str, + fd_idx: FunctionDefinitionIndex, + ) -> Result<()> { + if module_name.is_empty() { + self.function_definitions + .insert(function_name.to_string(), fd_idx); + Ok(()) + } else { + self.compilation_scope + .link_function(module_name, function_name, fd_idx) + } + } + + fn get_imported_module(&self, name: &str) -> Result<&CompiledModule> { + self.compilation_scope.get_imported_module(name) + } + + fn get_imported_module_handle(&self, name: &str) -> Result { + self.compilation_scope.get_imported_module_handle(name) + } + + fn get_struct_def(&self, name: &str) -> Result<(bool, StructDefinitionIndex)> { + match self.struct_definitions.get(name) { + None => bail!("No struct definition for name {}", name), + Some((is_resource, def_idx)) => Ok((*is_resource, *def_idx)), + } + } + + fn get_field_def(&self, sh_idx: StructHandleIndex, name: &str) -> Result { + let field_map = match self.field_definitions.get(&sh_idx) { + None => bail!("no struct handle index {}", sh_idx), + Some(map) => map, + }; + match field_map.get(name) { + None => bail!("no field {} in struct handle index {}", name, sh_idx), + Some(def_idx) => Ok(*def_idx), + } + } + + fn get_field_type(&self, sh_idx: StructHandleIndex, name: &str) -> Result<&TypeSignature> { + let fd_idx = self.get_field_def(sh_idx, name)?; + let sig_idx = match self.module.field_defs.get(fd_idx.0 as usize) { + None => bail!( + "No field definition index {} in field definition table", + fd_idx + ), + Some(field_def) => field_def.signature, + }; + match self.module.type_signatures.get(sig_idx.0 as usize) { + None => bail!("missing type signature index {}", sig_idx), + Some(sig) => Ok(sig), + } + } + + fn get_function_signature( + &self, + mh_idx: ModuleHandleIndex, + name: &str, + ) -> Result<&FunctionSignature> { + // Call into an external module. + if mh_idx.0 != 0 { + return self.compilation_scope.get_function_signature(mh_idx, name); + } + + let fd_idx = match self.function_definitions.get(name) { + None => bail!("no function {} in module {}", name, mh_idx), + Some(def_idx) => def_idx, + }; + + let fh_idx = match self.module.function_defs.get(fd_idx.0 as usize) { + None => bail!( + "No function definition index {} in function definition table", + fd_idx + ), + Some(function_def) => function_def.function, + }; + + let fh = self.module.get_function_at(fh_idx)?; + self.module.get_function_signature_at(fh.signature) + } + + fn get_name(&self) -> Result { + let mh = self.module.get_module_at(ModuleHandleIndex::new(0))?; + let name_ref = self.module.get_string_at(mh.name)?; + Ok(name_ref.clone()) + } +} + +impl<'a> Scope for ScriptScope<'a> { + fn make_string(&mut self, s: String) -> Result { + ScriptScope::add_item(s, &mut self.script.string_pool).map(StringPoolIndex::new) + } + + fn make_byte_array(&mut self, buf: ByteArray) -> Result { + ScriptScope::add_item(buf, &mut self.script.byte_array_pool).map(ByteArrayPoolIndex::new) + } + + fn make_address(&mut self, addr: AccountAddress) -> Result { + ScriptScope::add_item(addr, &mut self.script.address_pool).map(AddressPoolIndex::new) + } + + fn make_type_signature(&mut self, sig: TypeSignature) -> Result { + ScriptScope::add_item(sig, &mut self.script.type_signatures).map(TypeSignatureIndex::new) + } + + fn make_function_signature( + &mut self, + sig: FunctionSignature, + ) -> Result { + ModuleScope::add_item(sig, &mut self.script.function_signatures) + .map(FunctionSignatureIndex::new) + } + + fn make_locals_signature(&mut self, sig: LocalsSignature) -> Result { + ModuleScope::add_item(sig, &mut self.script.locals_signatures) + .map(LocalsSignatureIndex::new) + } + + fn make_module_handle( + &mut self, + addr_idx: AddressPoolIndex, + name_idx: StringPoolIndex, + ) -> Result { + let mh = ModuleHandle { + address: addr_idx, + name: name_idx, + }; + let size = self.script.module_handles.len(); + if size >= STRUCTS_MAX_SIZE { + bail!("Max table size reached!") + } + self.script.module_handles.push(mh); + Ok(ModuleHandleIndex::new(size as u16)) + } + + fn make_struct_handle( + &mut self, + module_idx: ModuleHandleIndex, + name_idx: StringPoolIndex, + is_resource: bool, + ) -> Result { + let sh = StructHandle { + module: module_idx, + name: name_idx, + is_resource, + }; + let size = self.script.struct_handles.len(); + if size >= TABLE_MAX_SIZE { + bail!("Max table size reached!") + } + self.script.struct_handles.push(sh); + Ok(StructHandleIndex::new(size as u16)) + } + + fn make_function_handle( + &mut self, + mh_idx: ModuleHandleIndex, + name_idx: StringPoolIndex, + sig_idx: FunctionSignatureIndex, + ) -> Result { + let fh = FunctionHandle { + module: mh_idx, + name: name_idx, + signature: sig_idx, + }; + ScriptScope::add_item(fh, &mut self.script.function_handles).map(FunctionHandleIndex::new) + } + + fn publish_struct_def( + &mut self, + _name: &str, + _is_resource: bool, + _struct_def: StructDefinition, + ) -> Result { + bail!("Cannot publish structs in scripts") + } + + fn publish_function_def( + &mut self, + _function_def: FunctionDefinition, + ) -> Result { + bail!("Cannot publish functions in scripts") + } + + fn publish_field_def(&mut self, _field_def: FieldDefinition) -> Result { + bail!("Cannot publish fields in scripts") + } + + fn publish_code(&mut self, _name: &str, _code: CodeUnit) -> Result<()> { + bail!("No function definitions in scripts") + } + + fn link_module( + &mut self, + import_name: &str, + module_name: &str, + mh_idx: ModuleHandleIndex, + ) -> Result<()> { + self.compilation_scope + .link_module(import_name, module_name, mh_idx) + } + + fn link_field( + &mut self, + _sh_idx: StructHandleIndex, + _name: &str, + _fd_idx: FieldDefinitionIndex, + ) -> Result<()> { + bail!("no field linking in scripts"); + } + + fn link_function( + &mut self, + module_name: &str, + function_name: &str, + fd_idx: FunctionDefinitionIndex, + ) -> Result<()> { + self.compilation_scope + .link_function(module_name, function_name, fd_idx) + } + + fn get_imported_module(&self, name: &str) -> Result<&CompiledModule> { + self.compilation_scope.get_imported_module(name) + } + + fn get_imported_module_handle(&self, name: &str) -> Result { + self.compilation_scope.get_imported_module_handle(name) + } + + fn get_next_field_definition_index(&mut self) -> Result { + bail!("no field definition referencing in scripts") + } + + fn get_struct_def(&self, _name: &str) -> Result<(bool, StructDefinitionIndex)> { + bail!("no struct definition referencing in scripts") + } + + fn get_field_def( + &self, + _sh_idx: StructHandleIndex, + _name: &str, + ) -> Result { + bail!("no field definition referencing in scripts") + } + + fn get_field_type(&self, _sh_idx: StructHandleIndex, _name: &str) -> Result<&TypeSignature> { + bail!("no field type referencing in scripts") + } + + fn get_function_signature( + &self, + mh_idx: ModuleHandleIndex, + name: &str, + ) -> Result<&FunctionSignature> { + self.compilation_scope.get_function_signature(mh_idx, name) + } + + fn get_name(&self) -> Result { + bail!("no name for scripts") + } +} + +struct Compiler { + // identity maps + // Map a handle to its positon in its table + // TODO: those could be expressed as references and it would make for better code. + // For now this is easier to do and those are intended to be "primitive" values so we'll get + // back to this... + modules: HashMap, + structs: HashMap, + functions: HashMap, + strings: HashMap, + byte_arrays: HashMap, + addresses: HashMap, + type_signatures: HashMap, + function_signatures: HashMap, + locals_signatures: HashMap, + // resolution scope + scope: S, +} + +const STRUCTS_MAX_SIZE: usize = TABLE_MAX_SIZE; +const FIELDS_MAX_SIZE: usize = TABLE_MAX_SIZE; +const FUNCTIONS_MAX_SIZE: usize = TABLE_MAX_SIZE; +const TABLE_MAX_SIZE: usize = u16::max_value() as usize; + +// +// Module/Contract compilation +// + +/// Compile a module +pub fn compile_module( + address: &AccountAddress, + module: &ModuleDefinition, + modules: &[CompiledModule], +) -> Result { + let compiled_module = CompiledModule::default(); + let scope = ModuleScope::new(compiled_module, modules); + let mut compiler = Compiler::new(scope); + let addr_idx = compiler.make_address(&address)?; + let name_idx = compiler.make_string(module.name.name_ref())?; + let mh_idx = compiler.make_module_handle(addr_idx, name_idx)?; + assert!(mh_idx.0 == 0); + for import in &module.imports { + compiler.import_module( + match &import.ident { + ModuleIdent::Transaction(_) => address, + ModuleIdent::Qualified(id) => &id.address, + }, + &import.ident.get_name().name_ref(), + &import.alias, + )?; + } + for struct_ in &module.structs { + compiler.define_struct(mh_idx, &struct_)?; + } + for (name, function) in &module.functions { + compiler.define_function(name.name_ref(), &function)?; + } + for (name, function) in &module.functions { + match &function.body { + FunctionBody::Move { locals, code } => { + debug!("compile move function: {} {}", name, &function.signature); + let compiled_code = + compiler.compile_function(&function.signature.formals, locals, code)?; + compiler + .scope + .publish_code(name.name_ref(), compiled_code)?; + } + FunctionBody::Native => (), + } + } + + Ok(compiler.scope.module) +} + +/// Compile a module and invoke the bytecode verifier on it +pub fn compile_and_verify_module( + address: &AccountAddress, + module: &ModuleDefinition, + modules: &[CompiledModule], +) -> Result<(CompiledModule, Vec)> { + let compiled_module = compile_module(address, module, modules)?; + let (compiled_module, verification_errors) = verify_module(compiled_module); + if verification_errors.is_empty() { + let (compiled_module, verification_errors) = + verify_module_dependencies(compiled_module, modules); + Ok((compiled_module, verification_errors)) + } else { + Ok((compiled_module, verification_errors)) + } +} + +// +// Transaction/Script compilation +// + +/// Compile a transaction program +pub fn compile_program( + address: &AccountAddress, + program: &Program, + deps: &[CompiledModule], +) -> Result { + // Compile modules in the program + let mut deps: Vec = deps.to_vec(); + let n_external_deps = deps.len(); + for m in &program.modules { + deps.push(compile_module(address, &m, &deps)?); + } + + // Compile transaction script + let func_def: FunctionDefinition; + let compiled_script = CompiledScript::default(); + let scope = ScriptScope::new(compiled_script, &deps); + let mut compiler = Compiler::new(scope); + let addr_idx = compiler.make_address(&address)?; + let name_idx = compiler.make_string(SELF_MODULE_NAME)?; + let mh_idx = compiler.make_module_handle(addr_idx, name_idx)?; + assert!(mh_idx.0 == 0); + + for import in &program.script.imports { + compiler.import_module( + match &import.ident { + ModuleIdent::Transaction(_) => address, + ModuleIdent::Qualified(id) => &id.address, + }, + &import.ident.get_name().name_ref(), + &import.alias, + )?; + } + + func_def = compiler.compile_main(&program.script.main)?; + + let mut script = compiler.scope.script; + script.main = func_def; + + Ok(CompiledProgram::new( + deps[n_external_deps..].to_vec(), + script, + )) +} + +impl Compiler { + fn new(scope: S) -> Self { + Compiler { + modules: HashMap::new(), + structs: HashMap::new(), + functions: HashMap::new(), + strings: HashMap::new(), + byte_arrays: HashMap::new(), + addresses: HashMap::new(), + type_signatures: HashMap::new(), + function_signatures: HashMap::new(), + locals_signatures: HashMap::new(), + // resolution scope + scope, + } + } + + fn import_module( + &mut self, + address: &AccountAddress, + name: &str, + module_alias: &ModuleName, + ) -> Result<()> { + let addr_idx = self.make_address(address)?; + let name_idx = self.make_string(&name)?; + let mh_idx = self.make_module_handle(addr_idx, name_idx)?; + self.scope + .link_module(module_alias.name_ref(), name, mh_idx) + } + + fn import_signature_token( + &mut self, + module_name: &str, + sig_token: SignatureToken, + ) -> Result { + match sig_token { + SignatureToken::Bool + | SignatureToken::U64 + | SignatureToken::String + | SignatureToken::ByteArray + | SignatureToken::Address => Ok(sig_token), + SignatureToken::Struct(sh_idx) => { + let (defining_module_name, name, is_resource) = { + let module = self.scope.get_imported_module(module_name)?; + let struct_handle = module.get_struct_at(sh_idx)?; + let defining_module_handle = module.get_module_at(struct_handle.module)?; + ( + module.get_string_at(defining_module_handle.name)?, + module.get_string_at(struct_handle.name)?.clone(), + struct_handle.is_resource, + ) + }; + let mh_idx = self + .scope + .get_imported_module_handle(defining_module_name)?; + let name_idx = self.make_string(&name)?; + let local_sh_idx = self.make_struct_handle(mh_idx, name_idx, is_resource)?; + Ok(SignatureToken::Struct(local_sh_idx)) + } + SignatureToken::Reference(sub_sig_token) => Ok(SignatureToken::Reference(Box::new( + self.import_signature_token(module_name, *sub_sig_token)?, + ))), + SignatureToken::MutableReference(sub_sig_token) => { + Ok(SignatureToken::MutableReference(Box::new( + self.import_signature_token(module_name, *sub_sig_token)?, + ))) + } + } + } + + fn import_function_signature( + &mut self, + module_name: &str, + func_sig: FunctionSignature, + ) -> Result { + if module_name == ModuleName::SELF { + Ok(func_sig) + } else { + let mut return_types = Vec::::new(); + let mut arg_types = Vec::::new(); + for e in func_sig.return_types { + return_types.push(self.import_signature_token(module_name, e)?); + } + for e in func_sig.arg_types { + arg_types.push(self.import_signature_token(module_name, e)?); + } + Ok(FunctionSignature { + return_types, + arg_types, + }) + } + } + + fn define_struct(&mut self, module_idx: ModuleHandleIndex, struct_: &MoveStruct) -> Result<()> { + let name = struct_.name.name_ref(); + let name_idx = self.make_string(name)?; + let sh_idx = self.make_struct_handle(module_idx, name_idx, struct_.resource_kind)?; + let struct_def = self.define_fields(sh_idx, &struct_.fields)?; + self.scope + .publish_struct_def(name, struct_.resource_kind, struct_def)?; + Ok(()) + } + + fn define_fields( + &mut self, + sh_idx: StructHandleIndex, + fields: &Fields, + ) -> Result { + let field_count = fields.len(); + let struct_def = StructDefinition { + struct_handle: sh_idx, + field_count: (field_count as MemberCount), + fields: self.scope.get_next_field_definition_index()?, + }; + + if field_count > FIELDS_MAX_SIZE { + bail!("too many fields {}", struct_def.struct_handle) + } + + for field in fields { + let field_name = field.0.name(); + let field_type = field.1; + self.publish_field(struct_def.struct_handle, field_name, field_type)?; + } + Ok(struct_def) + } + + // Compile a main function in a Script. + fn compile_main(&mut self, main: &Function) -> Result { + // make main entry point + let main_name = "main".to_string(); + let main_name_idx = self.make_string(&main_name)?; + let signature = self.build_function_signature(&main.signature)?; + let sig_idx = self.make_function_signature(&signature)?; + let fh_idx = + self.make_function_handle(ModuleHandleIndex::new(0), main_name_idx, sig_idx)?; + // compile script + let code = match &main.body { + FunctionBody::Move { code, locals } => { + self.compile_function(&main.signature.formals, locals, code)? + } + FunctionBody::Native => bail!("main() cannot be a native function"), + }; + Ok(FunctionDefinition { + function: fh_idx, + flags: CodeUnit::PUBLIC, + code, + }) + } + + // + // Reference tables filling: string, byte_array, address, signature, *handles + // + + fn make_string(&mut self, s: &str) -> Result { + let mut empty = false; + let idx; + { + let str_idx = self.strings.get(s); + idx = match str_idx { + None => { + empty = true; + self.scope.make_string(s.to_string())? + } + Some(idx) => *idx, + }; + } + if empty { + self.strings.insert(s.to_string(), idx); + } + Ok(idx) + } + + fn make_byte_array(&mut self, buf: &ByteArray) -> Result { + let mut empty = false; + let idx; + { + let byte_array_idx = self.byte_arrays.get(buf); + idx = match byte_array_idx { + None => { + empty = true; + self.scope.make_byte_array(buf.clone())? + } + Some(idx) => *idx, + }; + } + if empty { + self.byte_arrays.insert(buf.clone(), idx); + } + Ok(idx) + } + + fn make_address(&mut self, addr: &AccountAddress) -> Result { + let mut empty = false; + let idx; + { + let addr_idx = self.addresses.get(addr); + idx = match addr_idx { + None => { + empty = true; + self.scope.make_address(addr.clone())? + } + Some(idx) => *idx, + }; + } + if empty { + self.addresses.insert(addr.clone(), idx); + } + Ok(idx) + } + + fn make_type_signature(&mut self, _type: &Type) -> Result { + let signature = self.build_type_signature(_type)?; + let mut empty = false; + let idx; + { + let sig_idx = self.type_signatures.get(&signature); + idx = match sig_idx { + None => { + empty = true; + self.scope.make_type_signature(signature.clone())? + } + Some(idx) => *idx, + }; + } + if empty { + self.type_signatures.insert(signature.clone(), idx); + } + Ok(idx) + } + + fn make_function_signature( + &mut self, + signature: &FunctionSignature, + ) -> Result { + let mut empty = false; + let idx; + { + let sig_idx = self.function_signatures.get(&signature); + idx = match sig_idx { + None => { + empty = true; + self.scope.make_function_signature(signature.clone())? + } + Some(idx) => *idx, + }; + } + if empty { + self.function_signatures.insert(signature.clone(), idx); + } + Ok(idx) + } + + fn make_locals_signature( + &mut self, + signature: &LocalsSignature, + ) -> Result { + let mut empty = false; + let idx; + { + let sig_idx = self.locals_signatures.get(signature); + idx = match sig_idx { + None => { + empty = true; + self.scope.make_locals_signature(signature.clone())? + } + Some(idx) => *idx, + }; + } + if empty { + self.locals_signatures.insert(signature.clone(), idx); + } + Ok(idx) + } + + fn make_module_handle( + &mut self, + addr_idx: AddressPoolIndex, + name_idx: StringPoolIndex, + ) -> Result { + let mh = ModuleHandle { + address: addr_idx, + name: name_idx, + }; + let mut empty = false; + let idx; + { + let mh_idx = self.modules.get(&mh); + idx = match mh_idx { + None => { + empty = true; + self.scope.make_module_handle(addr_idx, name_idx)? + } + Some(idx) => *idx, + }; + } + if empty { + self.modules.insert(mh, idx); + } + Ok(idx) + } + + fn make_struct_handle( + &mut self, + module_idx: ModuleHandleIndex, + name_idx: StringPoolIndex, + is_resource: bool, + ) -> Result { + let sh = StructHandle { + module: module_idx, + name: name_idx, + is_resource, + }; + Ok(match self.structs.get(&sh) { + None => { + let idx = self + .scope + .make_struct_handle(module_idx, name_idx, is_resource)?; + self.structs.insert(sh, idx); + idx + } + Some(idx) => *idx, + }) + } + + fn make_function_handle( + &mut self, + mh_idx: ModuleHandleIndex, + name_idx: StringPoolIndex, + sig_idx: FunctionSignatureIndex, + ) -> Result { + let fh = FunctionHandle { + module: mh_idx, + name: name_idx, + signature: sig_idx, + }; + let mut empty = false; + let idx; + { + let fh_idx = self.functions.get(&fh); + idx = match fh_idx { + None => { + empty = true; + self.scope.make_function_handle(mh_idx, name_idx, sig_idx)? + } + Some(idx) => *idx, + }; + } + if empty { + self.functions.insert(fh, idx); + } + Ok(idx) + } + + // + // Create definitions, this is effectively only used when compiling modules + // + + fn publish_field( + &mut self, + sh_idx: StructHandleIndex, + name: &str, + sig: &Type, + ) -> Result { + let name_idx = self.make_string(name)?; + let sig_idx = self.make_type_signature(sig)?; + let field_def = FieldDefinition { + struct_: sh_idx, + name: name_idx, + signature: sig_idx, + }; + let fd_idx = self.scope.publish_field_def(field_def)?; + self.scope.link_field(sh_idx, name, fd_idx)?; + Ok(fd_idx) + } + + fn define_function( + &mut self, + name: &str, + function: &Function, + ) -> Result { + // Use ModuleHandleIndex::new(0) here because module 0 refers to the module being currently + // compiled. + let mh = ModuleHandleIndex::new(0); + + let name_idx = self.make_string(name)?; + let sig = self.build_function_signature(&function.signature)?; + let sig_idx = self.make_function_signature(&sig)?; + let fh_idx = self.make_function_handle(mh, name_idx, sig_idx)?; + + let flags = match function.visibility { + FunctionVisibility::Internal => 0, + FunctionVisibility::Public => CodeUnit::PUBLIC, + } | match function.body { + FunctionBody::Move { .. } => 0, + FunctionBody::Native => CodeUnit::NATIVE, + }; + + let func_def = FunctionDefinition { + function: fh_idx, + flags, + code: CodeUnit::default(), // TODO: eliminate usage of default + }; + + let fd_idx = self.scope.publish_function_def(func_def)?; + self.scope.link_function("", name, fd_idx)?; + Ok(fd_idx) + } + + // + // Signatue building methods + // + + fn build_type_signature(&mut self, type_: &Type) -> Result { + let signature_token = self.build_signature_token(type_)?; + Ok(TypeSignature(signature_token)) + } + + fn build_function_signature( + &mut self, + signature: &AstFunctionSignature, + ) -> Result { + let mut ret_sig: Vec = Vec::new(); + for t in &signature.return_type { + ret_sig.push(self.build_signature_token(&t)?); + } + let mut arg_sig: Vec = Vec::new(); + for formal in &signature.formals { + arg_sig.push(self.build_signature_token(&formal.1)?) + } + Ok(FunctionSignature { + return_types: ret_sig, + arg_types: arg_sig, + }) + } + + fn build_signature_token(&mut self, type_: &Type) -> Result { + match type_ { + Type::Normal(kind, tag) => self.build_normal_signature_token(&kind, &tag), + Type::Reference { + is_mutable, + kind, + tag, + } => { + let inner_token = Box::new(self.build_normal_signature_token(kind, tag)?); + if *is_mutable { + Ok(SignatureToken::MutableReference(inner_token)) + } else { + Ok(SignatureToken::Reference(inner_token)) + } + } + } + } + + fn build_normal_signature_token(&mut self, kind: &Kind, tag: &Tag) -> Result { + match (kind, tag) { + (Kind::Value, Tag::Address) => Ok(SignatureToken::Address), + (Kind::Value, Tag::U64) => Ok(SignatureToken::U64), + (Kind::Value, Tag::Bool) => Ok(SignatureToken::Bool), + (Kind::Value, Tag::ByteArray) => Ok(SignatureToken::ByteArray), + (kind, Tag::Struct(ctype)) => { + let module_name = &ctype.module().name(); + let module_idx = if self.scope.get_name().is_ok() && module_name == ModuleName::SELF + { + ModuleHandleIndex::new(0) + } else { + self.scope.get_imported_module_handle(module_name)? + }; + let name_idx = self.make_string(&ctype.name().name_ref())?; + let is_resource = match kind { + Kind::Value => false, + Kind::Resource => true, + }; + let sh_idx = self.make_struct_handle(module_idx, name_idx, is_resource)?; + Ok(SignatureToken::Struct(sh_idx)) + } + (Kind::Value, _) => bail!("unknown value type {:?}", tag), + (Kind::Resource, _) => bail!("unknown resource type {:?}", tag), + } + } + + // + // Code compilation functions, walk the IR and generate bytecodes + // + fn compile_function( + &mut self, + formals: &[(Var, Type)], + locals: &[(Var_, Type)], + body: &Block, + ) -> Result { + let mut code = CodeUnit::default(); + let mut function_frame = FunctionFrame::new(); + for (var, t) in formals { + let type_sig = self.build_signature_token(t)?; + function_frame.define_local(var, type_sig)?; + } + for (var_, t) in locals { + let type_sig = self.build_signature_token(t)?; + function_frame.define_local(&var_.value, type_sig)?; + } + self.compile_block(body, &mut code, &mut function_frame)?; + let sig_idx = self.make_locals_signature(&function_frame.local_types)?; + code.locals = sig_idx; + code.max_stack_size = if function_frame.max_stack_depth < 0 { + 0 + } else if function_frame.max_stack_depth > i64::from(u16::max_value()) { + u16::max_value() + } else { + function_frame.max_stack_depth as u16 + }; + Ok(code) + } + + fn compile_block( + &mut self, + body: &Block, + code: &mut CodeUnit, + function_frame: &mut FunctionFrame, + ) -> Result { + let mut cf_info = ControlFlowInfo { + reachable_break: false, + terminal_node: false, + }; + for stmt in &body.stmts { + debug!("{}", stmt); + let stmt_info; + match stmt { + Statement::CommandStatement(command) => { + stmt_info = self.compile_command(&command, code, function_frame)?; + debug!("{:?}", code); + } + Statement::WhileStatement(while_) => { + // always assume the loop might not be taken + stmt_info = self.compile_while(&while_, code, function_frame)?; + debug!("{:?}", code); + } + Statement::LoopStatement(loop_) => { + stmt_info = self.compile_loop(&loop_, code, function_frame)?; + debug!("{:?}", code); + } + Statement::IfElseStatement(if_else) => { + stmt_info = self.compile_if_else(&if_else, code, function_frame)?; + debug!("{:?}", code); + } + Statement::VerifyStatement(_) | Statement::AssumeStatement(_) => continue, + Statement::EmptyStatement => continue, + }; + cf_info = ControlFlowInfo::successor(cf_info, stmt_info); + } + Ok(cf_info) + } + + fn compile_if_else( + &mut self, + if_else: &IfElse, + code: &mut CodeUnit, + function_frame: &mut FunctionFrame, + ) -> Result { + self.compile_expression(&if_else.cond, code, function_frame)?; + + let brfalse_ins_loc = code.code.len(); + code.code.push(Bytecode::BrFalse(0)); // placeholder, final branch target replaced later + function_frame.pop()?; + let if_cf_info = self.compile_block(&if_else.if_block, code, function_frame)?; + + let mut else_block_location = code.code.len(); + + let else_cf_info = match if_else.else_block { + None => ControlFlowInfo { + reachable_break: false, + terminal_node: false, + }, + Some(ref else_block) => { + let branch_ins_loc = code.code.len(); + if !if_cf_info.terminal_node { + code.code.push(Bytecode::Branch(0)); // placeholder, final branch target replaced later + else_block_location += 1; + } + let else_cf_info = self.compile_block(else_block, code, function_frame)?; + if !if_cf_info.terminal_node { + code.code[branch_ins_loc] = Bytecode::Branch(code.code.len() as u16); + } + else_cf_info + } + }; + + code.code[brfalse_ins_loc] = Bytecode::BrFalse(else_block_location as u16); + + let cf_info = ControlFlowInfo::join(if_cf_info, else_cf_info); + Ok(cf_info) + } + + fn compile_while( + &mut self, + while_: &While, + code: &mut CodeUnit, + function_frame: &mut FunctionFrame, + ) -> Result { + let loop_start_loc = code.code.len(); + function_frame.push_loop(loop_start_loc)?; + self.compile_expression(&while_.cond, code, function_frame)?; + + let brfalse_loc = code.code.len(); + code.code.push(Bytecode::BrFalse(0)); // placeholder, final branch target replaced later + function_frame.pop()?; + + self.compile_block(&while_.block, code, function_frame)?; + code.code.push(Bytecode::Branch(loop_start_loc as u16)); + + let loop_end_loc = code.code.len() as u16; + code.code[brfalse_loc] = Bytecode::BrFalse(loop_end_loc); + let breaks = function_frame.get_loop_breaks()?; + for i in breaks { + code.code[*i] = Bytecode::Branch(loop_end_loc); + } + + function_frame.pop_loop()?; + Ok(ControlFlowInfo { + // this `reachable_break` break is for any outer loop + // not the loop that was just compiled + reachable_break: false, + // While always has the ability to break. + // Conceptually we treat + // `while (cond) { body }` + // as ` + // `loop { if (cond) { body; continue; } else { break; } }` + // So a `break` is always reachable + terminal_node: false, + }) + } + + fn compile_loop( + &mut self, + loop_: &Loop, + code: &mut CodeUnit, + function_frame: &mut FunctionFrame, + ) -> Result { + let loop_start_loc = code.code.len(); + function_frame.push_loop(loop_start_loc)?; + + let body_cf_info = self.compile_block(&loop_.block, code, function_frame)?; + code.code.push(Bytecode::Branch(loop_start_loc as u16)); + + let loop_end_loc = code.code.len() as u16; + let breaks = function_frame.get_loop_breaks()?; + for i in breaks { + code.code[*i] = Bytecode::Branch(loop_end_loc); + } + + function_frame.pop_loop()?; + // this `reachable_break` break is for any outer loop + // not the loop that was just compiled + let reachable_break = false; + // If the body of the loop does not have a break, it will loop forever + // and thus is a terminal node + let terminal_node = !body_cf_info.reachable_break; + Ok(ControlFlowInfo { + reachable_break, + terminal_node, + }) + } + + fn compile_command( + &mut self, + cmd: &Cmd, + code: &mut CodeUnit, + function_frame: &mut FunctionFrame, + ) -> Result { + debug!("compile command {}", cmd); + match cmd { + Cmd::Return(exps) => { + for exp in exps { + self.compile_expression(exp, code, function_frame)?; + } + code.code.push(Bytecode::Ret); + } + Cmd::Assign(lhs_variable, rhs_expression) => { + let _expr_type = self.compile_expression(rhs_expression, code, function_frame)?; + let loc_idx = function_frame.get_local(&lhs_variable.value)?; + let st_loc = Bytecode::StLoc(loc_idx); + code.code.push(st_loc); + function_frame.pop()?; + } + Cmd::Unpack(name, bindings, e) => { + self.compile_expression(e, code, function_frame)?; + + let (_is_resource, def_idx) = self.scope.get_struct_def(name.name_ref())?; + code.code.push(Bytecode::Unpack(def_idx)); + function_frame.pop()?; + + for lhs_variable in bindings.values().rev() { + let loc_idx = function_frame.get_local(&lhs_variable.value)?; + let st_loc = Bytecode::StLoc(loc_idx); + code.code.push(st_loc); + } + } + Cmd::Call { + ref return_bindings, + ref call, + ref actuals, + } => { + let mut actuals_tys = VecDeque::new(); + for exp in actuals.iter() { + actuals_tys.push_back(self.compile_expression(exp, code, function_frame)?); + } + let _ret_types = + self.compile_call(&call.value, code, function_frame, actuals_tys)?; + for return_var in return_bindings.iter().rev() { + let loc_idx = function_frame.get_local(&return_var.value)?; + let st_loc = Bytecode::StLoc(loc_idx); + code.code.push(st_loc); + } + } + Cmd::Mutate(exp, op) => { + self.compile_expression(op, code, function_frame)?; + self.compile_expression(exp, code, function_frame)?; + code.code.push(Bytecode::WriteRef); + function_frame.pop()?; + function_frame.pop()?; + } + Cmd::Assert(ref condition, ref error_code) => { + self.compile_expression(error_code, code, function_frame)?; + self.compile_expression(condition, code, function_frame)?; + code.code.push(Bytecode::Assert); + function_frame.pop()?; + function_frame.pop()?; + } + Cmd::Continue => { + let loc = function_frame.get_loop_start()?; + code.code.push(Bytecode::Branch(loc as u16)); + } + Cmd::Break => { + function_frame.push_loop_break(code.code.len())?; + // placeholder, to be replaced when the enclosing while is compiled + code.code.push(Bytecode::Branch(0)); + } + } + let (reachable_break, terminal_node) = match cmd { + // If we are in a loop, `continue` makes a terminal node + // Conceptually we treat + // `while (cond) { body }` + // as ` + // `loop { if (cond) { body; continue; } else { break; } }` + Cmd::Continue | + // `return` always makes a terminal node + Cmd::Return(_) => (false, true), + Cmd::Break => (true, false), + _ => (false, false), + }; + Ok(ControlFlowInfo { + reachable_break, + terminal_node, + }) + } + + fn compile_expression( + &mut self, + exp: &Exp, + code: &mut CodeUnit, + function_frame: &mut FunctionFrame, + ) -> Result { + debug!("compile expression {}", exp); + match exp { + Exp::Move(ref x) => self.compile_move_local(&x.value, code, function_frame), + Exp::Copy(ref x) => self.compile_copy_local(&x.value, code, function_frame), + Exp::BorrowLocal(ref is_mutable, ref x) => { + self.compile_borrow_local(&x.value, *is_mutable, code, function_frame) + } + Exp::Value(cv) => match cv.as_ref() { + CopyableVal::Address(address) => { + let addr_idx = self.make_address(&address)?; + code.code.push(Bytecode::LdAddr(addr_idx)); + function_frame.push()?; + Ok(InferredType::Address) + } + CopyableVal::U64(i) => { + code.code.push(Bytecode::LdConst(*i)); + function_frame.push()?; + Ok(InferredType::U64) + } + CopyableVal::ByteArray(buf) => { + let buf_idx = self.make_byte_array(buf)?; + code.code.push(Bytecode::LdByteArray(buf_idx)); + function_frame.push()?; + Ok(InferredType::ByteArray) + } + CopyableVal::Bool(b) => { + if *b { + code.code.push(Bytecode::LdTrue); + } else { + code.code.push(Bytecode::LdFalse); + } + function_frame.push()?; + Ok(InferredType::Bool) + } + CopyableVal::String(_) => bail!("nice try! come back later {:?}", cv), + }, + Exp::Pack(name, fields) => { + let module_idx = ModuleHandleIndex::new(0); + let name_idx = self.make_string(name.name_ref())?; + let (is_resource, def_idx) = self.scope.get_struct_def(name.name_ref())?; + let sh = self.make_struct_handle(module_idx, name_idx, is_resource)?; + for (_, exp) in fields.iter() { + self.compile_expression(exp, code, function_frame)?; + } + + code.code.push(Bytecode::Pack(def_idx)); + for _ in fields.iter() { + function_frame.pop()?; + } + function_frame.push()?; + Ok(InferredType::Struct(sh)) + } + Exp::UnaryExp(op, e) => { + self.compile_expression(e, code, function_frame)?; + match op { + UnaryOp::Not => { + code.code.push(Bytecode::Not); + Ok(InferredType::Bool) + } + } + } + Exp::BinopExp(e1, op, e2) => { + self.compile_expression(e1, code, function_frame)?; + self.compile_expression(e2, code, function_frame)?; + function_frame.pop()?; + match op { + BinOp::Add => { + code.code.push(Bytecode::Add); + Ok(InferredType::U64) + } + BinOp::Sub => { + code.code.push(Bytecode::Sub); + Ok(InferredType::U64) + } + BinOp::Mul => { + code.code.push(Bytecode::Mul); + Ok(InferredType::U64) + } + BinOp::Mod => { + code.code.push(Bytecode::Mod); + Ok(InferredType::U64) + } + BinOp::Div => { + code.code.push(Bytecode::Div); + Ok(InferredType::U64) + } + BinOp::BitOr => { + code.code.push(Bytecode::BitOr); + Ok(InferredType::U64) + } + BinOp::BitAnd => { + code.code.push(Bytecode::BitAnd); + Ok(InferredType::U64) + } + BinOp::Xor => { + code.code.push(Bytecode::Xor); + Ok(InferredType::U64) + } + BinOp::Or => { + code.code.push(Bytecode::Or); + Ok(InferredType::Bool) + } + BinOp::And => { + code.code.push(Bytecode::And); + Ok(InferredType::Bool) + } + BinOp::Eq => { + code.code.push(Bytecode::Eq); + Ok(InferredType::Bool) + } + BinOp::Neq => { + code.code.push(Bytecode::Neq); + Ok(InferredType::Bool) + } + BinOp::Lt => { + code.code.push(Bytecode::Lt); + Ok(InferredType::Bool) + } + BinOp::Gt => { + code.code.push(Bytecode::Gt); + Ok(InferredType::Bool) + } + BinOp::Le => { + code.code.push(Bytecode::Le); + Ok(InferredType::Bool) + } + BinOp::Ge => { + code.code.push(Bytecode::Ge); + Ok(InferredType::Bool) + } + } + } + Exp::Dereference(e) => { + let loc_type = self.compile_expression(e, code, function_frame)?; + code.code.push(Bytecode::ReadRef); + match loc_type { + InferredType::MutableReference(sig_ref_token) => Ok(*sig_ref_token), + InferredType::Reference(sig_ref_token) => Ok(*sig_ref_token), + _ => Ok(InferredType::Anything), + } + } + Exp::Borrow { + ref is_mutable, + ref exp, + ref field, + } => { + let this_type = self.compile_expression(exp, code, function_frame)?; + self.compile_load_field_reference( + this_type, + field, + *is_mutable, + code, + function_frame, + ) + } + } + } + + fn compile_load_field_reference( + &mut self, + struct_type: InferredType, + struct_field: &Field, + is_mutable: bool, + code: &mut CodeUnit, + function_frame: &mut FunctionFrame, + ) -> Result { + let sh_idx = struct_type.get_struct_handle()?; + // TODO: the clone is to avoid the problem with mut/immut references, review... + let field_type = self.scope.get_field_type(sh_idx, struct_field.name())?; + let fd_idx = self.scope.get_field_def(sh_idx, struct_field.name())?; + function_frame.pop()?; + code.code.push(Bytecode::BorrowField(fd_idx)); + function_frame.push()?; + let input_is_mutable = match struct_type { + InferredType::Reference(_) => false, + _ => true, + }; + let inner_token = Box::new(InferredType::from_signature_token(&field_type.0)); + Ok(if is_mutable { + if !input_is_mutable { + bail!("Unsupported Syntax: Cannot take a mutable field reference in an immutable reference. It is not expressible in the bytecode"); + } + InferredType::MutableReference(inner_token) + } else { + if input_is_mutable { + code.code.push(Bytecode::FreezeRef); + } + InferredType::Reference(inner_token) + }) + } + + fn compile_copy_local( + &self, + v: &Var, + code: &mut CodeUnit, + function_frame: &mut FunctionFrame, + ) -> Result { + let loc_idx = function_frame.get_local(&v)?; + let load_loc = Bytecode::CopyLoc(loc_idx); + code.code.push(load_loc); + function_frame.push()?; + let loc_type = function_frame.get_local_type(loc_idx)?; + Ok(InferredType::from_signature_token(loc_type)) + } + + fn compile_move_local( + &self, + v: &Var, + code: &mut CodeUnit, + function_frame: &mut FunctionFrame, + ) -> Result { + let loc_idx = function_frame.get_local(&v)?; + let load_loc = Bytecode::MoveLoc(loc_idx); + code.code.push(load_loc); + function_frame.push()?; + let loc_type = function_frame.get_local_type(loc_idx)?; + Ok(InferredType::from_signature_token(loc_type)) + } + + fn compile_borrow_local( + &self, + v: &Var, + is_mutable: bool, + code: &mut CodeUnit, + function_frame: &mut FunctionFrame, + ) -> Result { + let loc_idx = function_frame.get_local(&v)?; + code.code.push(Bytecode::BorrowLoc(loc_idx)); + function_frame.push()?; + let loc_type = function_frame.get_local_type(loc_idx)?; + let inner_token = Box::new(InferredType::from_signature_token(loc_type)); + Ok(if is_mutable { + InferredType::MutableReference(inner_token) + } else { + code.code.push(Bytecode::FreezeRef); + InferredType::Reference(inner_token) + }) + } + + fn compile_call( + &mut self, + call: &FunctionCall, + code: &mut CodeUnit, + function_frame: &mut FunctionFrame, + mut argument_types: VecDeque, + ) -> Result> { + match call { + FunctionCall::Builtin(function) => { + match function { + Builtin::GetTxnGasUnitPrice => { + code.code.push(Bytecode::GetTxnGasUnitPrice); + function_frame.push()?; + Ok(vec![InferredType::U64]) + } + Builtin::GetTxnMaxGasUnits => { + code.code.push(Bytecode::GetTxnMaxGasUnits); + function_frame.push()?; + Ok(vec![InferredType::U64]) + } + Builtin::GetGasRemaining => { + code.code.push(Bytecode::GetGasRemaining); + function_frame.push()?; + Ok(vec![InferredType::U64]) + } + Builtin::GetTxnSender => { + code.code.push(Bytecode::GetTxnSenderAddress); + function_frame.push()?; + Ok(vec![InferredType::Address]) + } + Builtin::Exists(name) => { + let (_, def_idx) = self.scope.get_struct_def(name.name_ref())?; + code.code.push(Bytecode::Exists(def_idx)); + function_frame.pop()?; + function_frame.push()?; + Ok(vec![InferredType::Bool]) + } + Builtin::BorrowGlobal(name) => { + let (is_resource, def_idx) = self.scope.get_struct_def(name.name_ref())?; + code.code.push(Bytecode::BorrowGlobal(def_idx)); + function_frame.pop()?; + function_frame.push()?; + + let module_idx = ModuleHandleIndex::new(0); + let name_idx = self.make_string(name.name_ref())?; + let sh = self.make_struct_handle(module_idx, name_idx, is_resource)?; + Ok(vec![InferredType::MutableReference(Box::new( + InferredType::Struct(sh), + ))]) + } + Builtin::Release => { + code.code.push(Bytecode::ReleaseRef); + function_frame.pop()?; + function_frame.push()?; + Ok(vec![]) + } + Builtin::CreateAccount => { + code.code.push(Bytecode::CreateAccount); + function_frame.pop()?; + function_frame.push()?; + Ok(vec![]) + } + Builtin::EmitEvent => { + code.code.push(Bytecode::EmitEvent); + function_frame.pop()?; + function_frame.pop()?; + function_frame.pop()?; + Ok(vec![]) + } + Builtin::MoveFrom(name) => { + let (is_resource, def_idx) = self.scope.get_struct_def(name.name_ref())?; + code.code.push(Bytecode::MoveFrom(def_idx)); + function_frame.pop()?; // pop the address + function_frame.push()?; // push the return value + + let module_idx = ModuleHandleIndex::new(0); + let name_idx = self.make_string(name.name_ref())?; + let sh = self.make_struct_handle(module_idx, name_idx, is_resource)?; + Ok(vec![InferredType::Struct(sh)]) + } + Builtin::MoveToSender(name) => { + let (_, def_idx) = self.scope.get_struct_def(name.name_ref())?; + code.code.push(Bytecode::MoveToSender(def_idx)); + function_frame.push()?; + Ok(vec![]) + } + Builtin::GetTxnSequenceNumber => { + code.code.push(Bytecode::GetTxnSequenceNumber); + function_frame.push()?; + Ok(vec![InferredType::U64]) + } + Builtin::GetTxnPublicKey => { + code.code.push(Bytecode::GetTxnPublicKey); + function_frame.push()?; + Ok(vec![InferredType::ByteArray]) + } + Builtin::Freeze => { + code.code.push(Bytecode::FreezeRef); + function_frame.pop()?; // pop mut ref + function_frame.push()?; // push imm ref + let inner_token = match argument_types.pop_front() { + Some(InferredType::Reference(inner_token)) + | Some(InferredType::MutableReference(inner_token)) => inner_token, + // Incorrect call + _ => Box::new(InferredType::Anything), + }; + Ok(vec![InferredType::Reference(inner_token)]) + } + _ => bail!("unsupported builtin function: {}", function), + } + } + FunctionCall::ModuleFunctionCall { module, name } => { + let scope_name = self.scope.get_name(); + + let mh = if scope_name.is_ok() && module.name() == ModuleName::SELF { + ModuleHandleIndex::new(0) + } else { + let target_module = self.scope.get_imported_module(module.name_ref())?; + + let mut idx = 0; + while idx < target_module.function_defs.len() { + let fh_idx = target_module.function_defs[idx].function; + let fh = target_module.get_function_at(fh_idx)?; + let func_name = target_module.get_string_at(fh.name)?; + if func_name == name.name_ref() { + break; + } + idx += 1; + } + if idx == target_module.function_defs.len() { + bail!( + "Cannot find function `{}' in module `{}'", + name.name_ref(), + module.name_ref() + ); + } + + self.scope.link_function( + module.name_ref(), + name.name_ref(), + FunctionDefinitionIndex::new(idx as u16), + )?; + self.scope.get_imported_module_handle(module.name_ref())? + }; + + let name_ref = name.name_ref(); + let name_idx = self.make_string(name_ref)?; + let mut func_sig = self.scope.get_function_signature(mh, name_ref)?.clone(); + func_sig = self.import_function_signature(module.name_ref(), func_sig)?; + let return_types = func_sig + .return_types + .iter() + .map(InferredType::from_signature_token) + .collect(); + let args_count = func_sig.arg_types.len(); + let sig_idx = self.make_function_signature(&func_sig)?; + let fh_idx = self.make_function_handle(mh, name_idx, sig_idx)?; + let call = Bytecode::Call(fh_idx); + code.code.push(call); + for _ in 0..args_count { + function_frame.pop()?; + } + // Return value of current function is pushed onto the stack. + function_frame.push()?; + Ok(return_types) + } + } + } +} diff --git a/language/compiler/src/lib.rs b/language/compiler/src/lib.rs new file mode 100644 index 0000000000000..de2ccf69a9506 --- /dev/null +++ b/language/compiler/src/lib.rs @@ -0,0 +1,9 @@ +// Copyright (c) The Libra Core Contributors +// SPDX-License-Identifier: Apache-2.0 + +#[macro_use] +extern crate log; + +pub mod compiler; +pub mod parser; +pub mod util; diff --git a/language/compiler/src/logging.rs b/language/compiler/src/logging.rs new file mode 100644 index 0000000000000..33b9a6a1797c2 --- /dev/null +++ b/language/compiler/src/logging.rs @@ -0,0 +1,20 @@ +// Copyright (c) The Libra Core Contributors +// SPDX-License-Identifier: Apache-2.0 + +use log::{Record, Level, Metadata}; + +struct SimpleLogger; + +impl log::Log for SimpleLogger { + fn enabled(&self, metadata: &Metadata) -> bool { + metadata.level() <= Level::Info + } + + fn log(&self, record: &Record) { + if self.enabled(record.metadata()) { + println!("{} - {}", record.level(), record.args()); + } + } + + fn flush(&self) {} +} diff --git a/language/compiler/src/main.rs b/language/compiler/src/main.rs new file mode 100644 index 0000000000000..b338529f18735 --- /dev/null +++ b/language/compiler/src/main.rs @@ -0,0 +1,134 @@ +// Copyright (c) The Libra Core Contributors +// SPDX-License-Identifier: Apache-2.0 + +use bytecode_verifier::verifier::{ + verify_module, verify_module_dependencies, verify_script, verify_script_dependencies, +}; +use compiler::{ + compiler::compile_program, + parser::parse_program, + util::{build_stdlib, do_compile_module}, +}; +use std::{fs, io::Write, path::PathBuf}; +use structopt::StructOpt; +use types::account_address::AccountAddress; +use vm::{ + errors::VerificationError, + file_format::{CompiledModule, CompiledScript}, +}; + +#[derive(Debug, StructOpt)] +#[structopt( + name = "IR Compiler", + author = "Libra", + about = "Move IR to bytecode compiler." +)] +struct Args { + /// Serialize and write the compiled output to this file + #[structopt(short = "o", long = "output")] + pub output_path: Option, + /// Treat input file as a module (default is to treat file as a program) + #[structopt(short = "m", long = "module")] + pub module_input: bool, + /// Do not automatically compile stdlib dependencies + #[structopt(long = "no-stdlib")] + pub no_stdlib: bool, + /// Do not automatically run the bytecode verifier + #[structopt(long = "no-verify")] + pub no_verify: bool, + /// Path to the Move IR source to compile + #[structopt(parse(from_os_str))] + pub source_path: PathBuf, +} + +fn check_verification_results(verification_errors: &[VerificationError]) { + if !verification_errors.is_empty() { + println!("Verification failed. Errors below:"); + for e in verification_errors { + println!("{:?}", e); + } + std::process::exit(1); + } +} + +fn do_verify_module(module: &CompiledModule, dependencies: &[CompiledModule]) { + let (verified_module, verification_errors) = verify_module(module.clone()); + check_verification_results(&verification_errors); + let (_verified_module, verification_errors) = + verify_module_dependencies(verified_module, dependencies); + check_verification_results(&verification_errors); +} + +fn do_verify_script(script: &CompiledScript, dependencies: &[CompiledModule]) { + let (verified_script, verification_errors) = verify_script(script.clone()); + check_verification_results(&verification_errors); + let (_verified_script, verification_errors) = + verify_script_dependencies(verified_script, dependencies); + check_verification_results(&verification_errors); +} + +fn write_output(path: &str, buf: &[u8]) { + let mut f = fs::File::create(path) + .unwrap_or_else(|err| panic!("Unable to open output file {}: {}", path, err)); + f.write_all(&buf) + .unwrap_or_else(|err| panic!("Unable to write to output file {}: {}", path, err)); +} + +fn main() { + let args = Args::from_args(); + + let address = AccountAddress::default(); + let mut dependencies = if args.no_stdlib { + vec![] + } else { + build_stdlib() + }; + + if !args.module_input { + let source = fs::read_to_string(args.source_path).expect("Unable to read file"); + let parsed_program = parse_program(&source).unwrap(); + + let compiled_program = compile_program(&address, &parsed_program, &dependencies).unwrap(); + + // TODO: Make this a do_verify_program helper function. + if !args.no_verify { + for m in &compiled_program.modules { + do_verify_module(m, &dependencies); + dependencies.push(m.clone()); + } + do_verify_script(&compiled_program.script, &dependencies); + } + + match args.output_path { + Some(path) => { + // TODO: Only the script is serialized. Shall we also serialize the modules? + let mut out = vec![]; + compiled_program + .script + .serialize(&mut out) + .expect("Unable to serialize script"); + write_output(&path, &out); + } + None => { + println!("{}", compiled_program); + } + } + } else { + let compiled_module = do_compile_module(&args.source_path, &address, &dependencies); + if !args.no_verify { + do_verify_module(&compiled_module, &dependencies); + } + match args.output_path { + Some(path) => { + let mut out = vec![]; + compiled_module + .serialize(&mut out) + .expect("Unable to serialize module"); + write_output(&path, &out); + } + None => { + println!("{}", compiled_module); + } + } + } +} diff --git a/language/compiler/src/parser/ast.rs b/language/compiler/src/parser/ast.rs new file mode 100644 index 0000000000000..543d67c9fd5cf --- /dev/null +++ b/language/compiler/src/parser/ast.rs @@ -0,0 +1,1367 @@ +// Copyright (c) The Libra Core Contributors +// SPDX-License-Identifier: Apache-2.0 + +use codespan::{ByteIndex, Span}; +use serde::{Deserialize, Serialize}; +use std::{ + collections::{BTreeMap, VecDeque}, + fmt, + ops::Deref, +}; +use types::{account_address::AccountAddress, byte_array::ByteArray}; + +/// Generic wrapper that keeps file locations for any ast-node +#[derive(Serialize, Deserialize, Debug, Copy, Clone, Eq, PartialEq, Default)] +pub struct Spanned { + #[serde(skip_serializing, skip_deserializing)] + /// The file location + pub span: Loc, + /// The value being wrapped + pub value: T, +} + +/// The file location type +pub type Loc = Span; + +//************************************************************************************************** +// Program +//************************************************************************************************** +#[derive(Serialize, Deserialize, Debug, Clone)] +/// A set of move modules and a Move transaction script + +pub struct Program { + /// The modules to publish + pub modules: Vec, + /// The transaction script to execute + pub script: Script, +} + +//************************************************************************************************** +// Script +//************************************************************************************************** + +#[derive(Serialize, Deserialize, Debug, Clone)] +/// The move transaction script to be executed +pub struct Script { + /// The dependencies of `main`, i.e. of the transaction script + pub imports: Vec, + /// The transaction script's `main` procedure + pub main: Function, +} + +//************************************************************************************************** +// Modules +//************************************************************************************************** + +/// Newtype for a name of a module +#[derive(Serialize, Deserialize, Clone, Debug, Eq, Hash, PartialEq, PartialOrd, Ord)] +pub struct ModuleName(String); + +/// Newtype of the address + the module name +/// `addr.m` +#[derive(Serialize, Deserialize, Clone, Debug, Eq, Hash, PartialEq, PartialOrd, Ord)] +pub struct QualifiedModuleIdent { + /// Name for the module. Will be unique among modules published under the same address + pub name: ModuleName, + /// Address that this module is published under + pub address: AccountAddress, +} + +/// A Move module +#[derive(Clone, Debug, Serialize, Deserialize, PartialEq)] +pub struct ModuleDefinition { + /// name of the module + pub name: ModuleName, + /// the module's dependencies + pub imports: Vec, + /// the structs (including resources) that the module defines + pub structs: Vec, + /// the procedure that the module defines + pub functions: Vec<(FunctionName, Function)>, +} + +/// Either a qualified module name like `addr.m` or `Transaction.m`, which refers to a module in +/// the same transaction. +#[derive(Serialize, Deserialize, Clone, Debug, Eq, Hash, PartialEq, PartialOrd, Ord)] +pub enum ModuleIdent { + Transaction(ModuleName), + Qualified(QualifiedModuleIdent), +} + +//************************************************************************************************** +// Imports +//************************************************************************************************** + +/// A dependency/import declaration +#[derive(Clone, Debug, Serialize, Deserialize, PartialEq)] +pub struct ImportDefinition { + /// the dependency + /// `addr.m` or `Transaction.m` + pub ident: ModuleIdent, + /// the alias for that dependency + /// `m` + pub alias: ModuleName, +} + +//************************************************************************************************** +// Structs +//************************************************************************************************** + +/// The file newtype +pub type Field = types::access_path::Field; +/// A field map +pub type Fields = BTreeMap; + +/// Newtype for the name of a struct +#[derive(Serialize, Deserialize, Clone, Debug, Eq, Hash, PartialEq, PartialOrd, Ord)] +pub struct StructName(String); + +/// A Move struct +#[derive(Clone, Debug, Serialize, Deserialize, PartialEq)] +pub struct StructDefinition { + /// The struct will have kind resource if `resource_kind` is true + /// and a value otherwise + pub resource_kind: bool, + /// Human-readable name for the struct that also serves as a nominal type + pub name: StructName, + /// the fields each instance has + pub fields: Fields, +} + +//************************************************************************************************** +// Functions +//************************************************************************************************** + +/// Newtype for the name of a function +#[derive(Debug, Deserialize, Eq, Hash, Ord, PartialEq, PartialOrd, Clone, Serialize)] +pub struct FunctionName(String); + +/// The signature of a function +#[derive(Serialize, Deserialize, PartialEq, Debug, Clone)] +pub struct FunctionSignature { + /// Possibly-empty list of (formal name, formal type) pairs. Names are unique. + pub formals: Vec<(Var, Type)>, + /// Optional return types + pub return_type: Vec, +} + +/// Public or internal modifier for a procedure +#[derive(Serialize, Deserialize, PartialEq, Debug, Clone)] +pub enum FunctionVisibility { + /// The procedure can be invoked anywhere + /// `public` + Public, + /// The procedure can be invoked only internally + /// `` + Internal, +} + +#[derive(Serialize, Deserialize, PartialEq, Debug, Clone)] +pub enum FunctionAnnotation { + Requires(String), + Ensures(String), +} + +/// The body of a Move function +#[derive(Serialize, Deserialize, PartialEq, Debug, Clone)] +pub enum FunctionBody { + /// The body is declared + /// `locals` are all of the declared locals + /// `code` is the code that defines the procedure + Move { + locals: Vec<(Var_, Type)>, + code: Block, + }, + /// The body is provided by the runtime + Native, +} + +/// A Move function/procedure +#[derive(Serialize, Deserialize, PartialEq, Debug, Clone)] +pub struct Function { + /// The visibility (public or internal) + pub visibility: FunctionVisibility, + /// The type signature + pub signature: FunctionSignature, + /// Annotations on the function + pub annotations: Vec, + /// The code for the procedure + pub body: FunctionBody, +} + +//************************************************************************************************** +// Types +//************************************************************************************************** + +/// Used to annotate struct types as a resource or value +#[derive(Serialize, Deserialize, Debug, PartialEq, Clone)] +pub enum Kind { + /// `R` + Resource, + /// `V` + Value, +} + +/// Identifier for a struct definition. Tells us where to look in the storage layer to find the +/// code associated with the interface +#[derive(Serialize, Deserialize, Clone, Debug, Eq, Hash, PartialEq, PartialOrd, Ord)] +pub struct StructType { + /// Module name and address in which the struct is contained + pub module: ModuleName, + /// Name for the struct class. Should be unique among structs published under the same + /// module+address + pub name: StructName, +} + +/// Type "name" of the type +#[derive(Serialize, Deserialize, Debug, PartialEq, Clone)] +pub enum Tag { + /// `address` + Address, + /// `u64` + U64, + /// `bool` + Bool, + /// `bytearray` + ByteArray, + /// `string` + String, + /// A module defined struct + /// `n` + Struct(StructType), +} + +/// The type of a single value +#[derive(Serialize, Deserialize, Debug, PartialEq, Clone)] +pub enum Type { + /// A non reference type + /// `g` or `k#d.n` + Normal(Kind, Tag), + /// A reference type + /// `&t` or `&mut t` + Reference { + /// true if `&mut` and false if `&` + is_mutable: bool, + /// the kind, value or resource + kind: Kind, + /// the "name" of the type + tag: Tag, + }, +} + +//************************************************************************************************** +// Statements +//************************************************************************************************** + +/// Newtype for a variable/local +#[derive(Serialize, Deserialize, Debug, PartialEq, Hash, Eq, Clone, Ord, PartialOrd)] +pub struct Var(String); +/// The type of a variable with a location +pub type Var_ = Spanned; + +/// Builtin "function"-like operators that often have a signature not expressable in the +/// type system and/or have access to some runtime/storage context +#[derive(Serialize, Deserialize, Debug, PartialEq, Clone)] +pub enum Builtin { + /// Intentionally destroy a resource (i.e., the inverse of `new`). + Release, + /// Check if there is a struct object (`StructName` resolved by current module) associated with + /// the given address + Exists(StructName), + /// Get the struct object (`StructName` resolved by current module) associated with the given + /// address + BorrowGlobal(StructName), + /// Returns the height of the current transaction. + GetHeight, + /// Returns the price per gas unit the current transaction is willing to pay + GetTxnGasUnitPrice, + /// Returns the maximum units of gas the current transaction is willing to use + GetTxnMaxGasUnits, + /// Returns the public key of the current transaction's sender + GetTxnPublicKey, + /// Returns the address of the current transaction's sender + GetTxnSender, + /// Returns the sequence number of the current transaction. + GetTxnSequenceNumber, + /// Returns the unit of gas remain to be used for now. + GetGasRemaining, + /// Emit an event + EmitEvent, + + /// Publishing, + /// Initialize a previously empty address by publishing a resource of type Account + CreateAccount, + /// Remove a resource of the given type from the account with the given address + MoveFrom(StructName), + /// Publish an instantiated struct object into sender's account. + MoveToSender(StructName), + + /// Convert a mutable reference into an immutable one + Freeze, +} + +/// Enum for different function calls +#[derive(Serialize, Deserialize, Debug, PartialEq, Clone)] +pub enum FunctionCall { + /// functions defined in the host environment + Builtin(Builtin), + /// The call of a module defined procedure + ModuleFunctionCall { + module: ModuleName, + name: FunctionName, + }, +} +/// The type for a function call and its location +pub type FunctionCall_ = Spanned; + +/// Enum for Move commands +#[derive(Serialize, Deserialize, Debug, Clone, PartialEq)] +pub enum Cmd { + /// `x_1, ..., x_j = call` + Call { + return_bindings: Vec, + call: FunctionCall_, + actuals: Vec, + }, + /// `x = e` + Assign(Var_, Exp_), + /// `n { f_1: x_1, ... , f_j: x_j } = e` + Unpack(StructName, Fields, Exp_), + /// `*e_1 = e_2` + Mutate(Exp_, Exp_), + /// `assert(e1, e2)` + Assert(Exp_, Exp_), + /// `return e_1, ... , e_j` + Return(Vec), + /// `break` + Break, + /// `continue` + Continue, +} +/// The type of a command with its location +pub type Cmd_ = Spanned; + +/// Struct defining an if statement +#[derive(Serialize, Deserialize, Debug, PartialEq, Clone)] +pub struct IfElse { + /// the if's condition + pub cond: Exp_, + /// the block taken if the condition is `true` + pub if_block: Block, + /// the block taken if the condition is `false` + pub else_block: Option, +} + +/// Struct defining a while statement +#[derive(Serialize, Deserialize, Debug, PartialEq, Clone)] +pub struct While { + /// The condition for a while statement + pub cond: Exp_, + /// The block taken if the condition is `true` + pub block: Block, +} + +/// Struct defining a loop statement +#[derive(Serialize, Deserialize, Debug, PartialEq, Clone)] +pub struct Loop { + /// The body of the loop + pub block: Block, +} + +#[derive(Serialize, Deserialize, Debug, PartialEq, Clone)] +#[allow(clippy::large_enum_variant)] +pub enum Statement { + /// `c;` + CommandStatement(Cmd_), + /// `if (e) { s_1 } else { s_2 }` + IfElseStatement(IfElse), + /// `while (e) { s }` + WhileStatement(While), + /// `loop { s }` + LoopStatement(Loop), + VerifyStatement(String), + AssumeStatement(String), + /// no-op that eases parsing in some places + EmptyStatement, +} + +#[derive(Serialize, Deserialize, Debug, PartialEq, Clone)] +/// `{ s }` +pub struct Block { + /// The statements that make up the block + pub stmts: VecDeque, +} + +//************************************************************************************************** +// Expressions +//************************************************************************************************** + +/// Bottom of the value hierarchy. These values can be trivially copyable and stored in statedb as a +/// single entry. +#[derive(Serialize, Deserialize, Debug, PartialEq, Clone)] +pub enum CopyableVal { + /// An address in the global storage + Address(AccountAddress), + /// An unsigned 64-bit integer + U64(u64), + /// true or false + Bool(bool), + /// `b""` + ByteArray(ByteArray), + /// Not yet supported in the parser + String(String), +} +/// The type of a value and its location +pub type CopyableVal_ = Spanned; +/// The type for fields and their bound expressions +pub type ExpFields = Fields; + +/// Enum for unary operators +#[derive(Serialize, Deserialize, Debug, Clone, PartialEq)] +pub enum UnaryOp { + /// Boolean negation + Not, +} + +/// Enum for binary operators +#[derive(Serialize, Deserialize, Debug, Clone, PartialEq)] +pub enum BinOp { + // u64 ops + /// `+` + Add, + /// `-` + Sub, + /// `*` + Mul, + /// `%` + Mod, + /// `/` + Div, + /// `|` + BitOr, + /// `&` + BitAnd, + /// `^` + Xor, + + // Bool ops + /// `&&` + And, + /// `||` + Or, + + // Compare Ops + /// `==` + Eq, + /// `!=` + Neq, + /// `<` + Lt, + /// `>` + Gt, + /// `<=` + Le, + /// `>=` + Ge, +} + +/// Enum for all expressions +#[derive(Serialize, Deserialize, Debug, Clone, PartialEq)] +pub enum Exp { + /// `*e` + Dereference(Box), + /// `op e` + UnaryExp(UnaryOp, Box), + /// `e_1 op e_2` + BinopExp(Box, BinOp, Box), + /// Wrapper to lift `CopyableVal` into `Exp` + /// `v` + Value(CopyableVal_), + /// Takes the given field values and instantiates the struct + /// Returns a fresh `StructInstance` whose type and kind (resource or otherwise) + /// as the current struct class (i.e., the class of the method we're currently executing). + /// `n { f_1: e_1, ... , f_j: e_j }` + Pack(StructName, ExpFields), + /// `&e.f`, `&mut e.f` + Borrow { + /// mutable or not + is_mutable: bool, + /// the expression containing the reference + exp: Box, + /// the field being borrowed + field: Field, + }, + /// `move(x)` + Move(Var_), + /// `copy(x)` + Copy(Var_), + /// `&x` or `&mut x` + BorrowLocal(bool, Var_), +} + +/// The type for a `Exp` and it's location +pub type Exp_ = Spanned; + +//************************************************************************************************** +// impls +//************************************************************************************************** + +impl Program { + /// Create a new `Program` from modules and transaction script + pub fn new(modules: Vec, script: Script) -> Self { + Program { modules, script } + } +} + +impl Script { + /// Create a new `Script` from the imports and the main function + pub fn new(imports: Vec, main: Function) -> Self { + Script { imports, main } + } + + /// Accessor for the body of the 'main' procedure + pub fn body(&self) -> &Block { + match self.main.body { + FunctionBody::Move { ref code, .. } => &code, + FunctionBody::Native => panic!("main() can't be native"), + } + } +} + +impl ModuleName { + /// Create a new `ModuleName` identifier from a string + pub fn new(name: String) -> Self { + assert!(name != ""); + ModuleName(name) + } + + /// String value for the current module handle + pub const SELF: &'static str = "Self"; + + /// Create a new `ModuleName` for the `SELF` constant + pub fn module_self() -> Self { + ModuleName::new(ModuleName::SELF.to_string()) + } + + /// Returns the raw bytes of the module name's string value + pub fn as_bytes(&self) -> Vec { + self.0.as_bytes().to_vec() + } + + /// Returns a cloned copy of the module name's string value + pub fn name(&self) -> String { + self.0.clone() + } + + /// Accessor for the module name's string value + pub fn name_ref(&self) -> &String { + &self.0 + } +} + +impl QualifiedModuleIdent { + /// Creates a new fully qualified module identifier from the module name and the address at + /// which it is published + pub fn new(name: ModuleName, address: AccountAddress) -> Self { + QualifiedModuleIdent { address, name } + } + + /// Accessor for the name of the fully qualified module identifier + pub fn get_name(&self) -> &ModuleName { + &self.name + } + + /// Accessor for the address at which the module is published + pub fn get_address(&self) -> &AccountAddress { + &self.address + } +} + +impl ModuleIdent { + pub fn get_name(&self) -> &ModuleName { + match self { + ModuleIdent::Transaction(name) => &name, + ModuleIdent::Qualified(id) => &id.name, + } + } +} + +impl ModuleDefinition { + /// Creates a new `ModuleDefinition` from its string name, dependencies, structs+resources, + /// and procedures + /// Does not verify the correctness of any internal properties of its elements + pub fn new( + name: String, + imports: Vec, + structs: Vec, + functions: Vec<(FunctionName, Function)>, + ) -> Self { + ModuleDefinition { + name: ModuleName::new(name), + imports, + structs, + functions, + } + } +} + +impl Type { + /// Creates a new non-reference type from the type's kind and tag + pub fn nonreference(kind: Kind, tag: Tag) -> Type { + Type::Normal(kind, tag) + } + + /// Creates a new reference type from its mutability and underlying type + pub fn reference(is_mutable: bool, annot: Type) -> Type { + match annot { + Type::Normal(kind, tag) => Type::Reference { + is_mutable, + kind, + tag, + }, + _ => panic!("ICE expected Normal annotation"), + } + } + + /// Creates a new address type + pub fn address() -> Type { + Type::Normal(Kind::Value, Tag::Address) + } + + /// Creates a new u64 type + pub fn u64() -> Type { + Type::Normal(Kind::Value, Tag::U64) + } + + /// Creates a new bool type + pub fn bool() -> Type { + Type::Normal(Kind::Value, Tag::Bool) + } + + /// Creates a new bytearray type + pub fn bytearray() -> Type { + Type::Normal(Kind::Value, Tag::ByteArray) + } +} + +impl StructType { + /// Creates a new StructType handle from the name of the module alias and the name of the struct + pub fn new(module: ModuleName, name: StructName) -> Self { + StructType { module, name } + } + + /// Accessor for the module alias + pub fn module(&self) -> &ModuleName { + &self.module + } + + /// Accessor for the struct name + pub fn name(&self) -> &StructName { + &self.name + } +} + +impl ImportDefinition { + /// Creates a new import definition from a module identifier and an optional alias + /// If the alias is `None`, the alias will be a cloned copy of the identifiers module name + pub fn new(ident: ModuleIdent, alias_opt: Option) -> Self { + let alias = match alias_opt { + Some(alias) => alias, + None => ident.get_name().clone(), + }; + ImportDefinition { ident, alias } + } +} + +impl StructName { + /// Create a new `StructName` identifier from a string + pub fn new(name: String) -> Self { + StructName(name) + } + + /// Returns the raw bytes of the struct name's string value + pub fn as_bytes(&self) -> Vec { + self.0.as_bytes().to_vec() + } + + /// Returns a cloned copy of the struct name's string value + pub fn name(&self) -> String { + self.0.clone() + } + + /// Accessor for the name of the struct + pub fn name_ref(&self) -> &String { + &self.0 + } +} + +impl StructDefinition { + /// Creates a new StructDefinition from the resource kind (true if resource), the string + /// representation of the name, and the field names with their types + /// Does not verify the correctness of any internal properties, e.g. doesn't check that the + /// fields do not have reference types + pub fn new(resource_kind: bool, name: String, fields: Fields) -> Self { + StructDefinition { + resource_kind, + name: StructName::new(name), + fields, + } + } +} + +impl FunctionName { + /// Create a new `FunctionName` identifier from a string + pub fn new(name: String) -> Self { + FunctionName(name) + } + + /// Returns a cloned copy of the function name's string value + pub fn name(&self) -> String { + self.0.clone() + } + + /// Accessor for the name of the function + pub fn name_ref(&self) -> &String { + &self.0 + } +} + +impl FunctionSignature { + /// Creates a new function signature from the parameters and the return types + pub fn new(formals: Vec<(Var, Type)>, return_type: Vec) -> Self { + FunctionSignature { + formals, + return_type, + } + } +} + +impl Function { + /// Creates a new function declaration from the components of the function + /// See the declaration of the struct `Function` for more details + pub fn new( + visibility: FunctionVisibility, + formals: Vec<(Var, Type)>, + return_type: Vec, + annotations: Vec, + body: FunctionBody, + ) -> Self { + let signature = FunctionSignature::new(formals, return_type); + Function { + visibility, + signature, + annotations, + body, + } + } +} + +impl Var { + /// Create a new `Var` identifier from a string + pub fn new(s: &str) -> Self { + Var(s.to_string()) + } + + /// Create a new `Var_` identifier from a string with an empty location + pub fn new_(s: &str) -> Var_ { + Spanned::no_loc(Var::new(s)) + } + + /// Accessor for the name of the var + pub fn name(&self) -> &str { + &self.0 + } +} + +impl FunctionCall { + /// Creates a `FunctionCall::ModuleFunctionCall` variant + pub fn module_call(module: ModuleName, name: FunctionName) -> Self { + FunctionCall::ModuleFunctionCall { module, name } + } + + /// Creates a `FunctionCall::Builtin` variant with no location information + pub fn builtin(bif: Builtin) -> FunctionCall_ { + Spanned::no_loc(FunctionCall::Builtin(bif)) + } +} + +impl Cmd { + /// Creates a command that returns no values + pub fn return_empty() -> Self { + Cmd::Return(vec![]) + } + + /// Creates a command that returns a single value + pub fn return_(op: Exp_) -> Self { + Cmd::Return(vec![op]) + } +} + +impl IfElse { + /// Creates an if-statement with no else branch + pub fn if_block(cond: Exp_, if_block: Block) -> Self { + IfElse { + cond, + if_block, + else_block: None, + } + } + + /// Creates an if-statement with an else branch + pub fn if_else(cond: Exp_, if_block: Block, else_block: Block) -> Self { + IfElse { + cond, + if_block, + else_block: Some(else_block), + } + } +} + +impl Statement { + /// Lifts a command into a statement + pub fn cmd(c: Cmd_) -> Self { + Statement::CommandStatement(c) + } + + /// Creates an `Statement::IfElseStatement` variant with no else branch + pub fn if_block(cond: Exp_, if_block: Block) -> Self { + Statement::IfElseStatement(IfElse::if_block(cond, if_block)) + } + + /// Creates an `Statement::IfElseStatement` variant with an else branch + pub fn if_else(cond: Exp_, if_block: Block, else_block: Block) -> Self { + Statement::IfElseStatement(IfElse::if_else(cond, if_block, else_block)) + } +} + +impl Block { + /// Creates a new block from the vector of statements + pub fn new(stmts: Vec) -> Self { + Block { + stmts: VecDeque::from(stmts), + } + } + + /// Creates an empty block + pub fn empty() -> Self { + Block { + stmts: VecDeque::new(), + } + } +} + +impl Exp { + /// Creates a new address `Exp` with no location information + pub fn address(addr: AccountAddress) -> Exp_ { + Spanned::no_loc(Exp::Value(Spanned::no_loc(CopyableVal::Address(addr)))) + } + + /// Creates a new value `Exp` with no location information + pub fn value(b: CopyableVal) -> Exp_ { + Spanned::no_loc(Exp::Value(Spanned::no_loc(b))) + } + + /// Creates a new u64 `Exp` with no location information + pub fn u64(i: u64) -> Exp_ { + Exp::value(CopyableVal::U64(i)) + } + + /// Creates a new bool `Exp` with no location information + pub fn bool(b: bool) -> Exp_ { + Exp::value(CopyableVal::Bool(b)) + } + + /// Creates a new bytearray `Exp` with no location information + pub fn byte_array(buf: ByteArray) -> Exp_ { + Exp::value(CopyableVal::ByteArray(buf)) + } + + /// Creates a new pack/struct-instantiation `Exp` with no location information + pub fn instantiate(n: StructName, s: ExpFields) -> Exp_ { + Spanned::no_loc(Exp::Pack(n, s)) + } + + /// Creates a new binary operator `Exp` with no location information + pub fn binop(lhs: Exp_, op: BinOp, rhs: Exp_) -> Exp_ { + Spanned::no_loc(Exp::BinopExp(Box::new(lhs), op, Box::new(rhs))) + } + + /// Creates a new `e+e` `Exp` with no location information + pub fn add(lhs: Exp_, rhs: Exp_) -> Exp_ { + Exp::binop(lhs, BinOp::Add, rhs) + } + + /// Creates a new `e-e` `Exp` with no location information + pub fn sub(lhs: Exp_, rhs: Exp_) -> Exp_ { + Exp::binop(lhs, BinOp::Sub, rhs) + } + + /// Creates a new `*e` `Exp` with no location information + pub fn dereference(e: Exp_) -> Exp_ { + Spanned::no_loc(Exp::Dereference(Box::new(e))) + } + + /// Creates a new borrow field `Exp` with no location information + pub fn borrow(is_mutable: bool, exp: Box, field: Field) -> Exp_ { + Spanned::no_loc(Exp::Borrow { + is_mutable, + exp, + field, + }) + } + + /// Creates a new copy-local `Exp` with no location information + pub fn copy(v: Var_) -> Exp_ { + Spanned::no_loc(Exp::Copy(v)) + } + + /// Creates a new move-local `Exp` with no location information + pub fn move_(v: Var_) -> Exp_ { + Spanned::no_loc(Exp::Move(v)) + } +} + +//************************************************************************************************** +// Trait impls +//************************************************************************************************** + +impl Iterator for Script { + type Item = Statement; + + fn next(&mut self) -> Option { + match self.main.body { + FunctionBody::Move { ref mut code, .. } => code.stmts.pop_front(), + FunctionBody::Native => panic!("main() cannot be native code"), + } + } +} + +impl PartialEq for Script { + fn eq(&self, other: &Script) -> bool { + self.imports == other.imports && self.main.body == other.main.body + } +} + +impl Deref for Spanned { + type Target = T; + + fn deref(&self) -> &T { + &self.value + } +} + +impl AsRef for Spanned { + fn as_ref(&self) -> &T { + &self.value + } +} + +impl Spanned { + pub fn no_loc(value: T) -> Spanned { + Spanned { + value, + span: Span::default(), + } + } +} + +impl Iterator for Block { + type Item = Statement; + + fn next(&mut self) -> Option { + self.stmts.pop_front() + } +} + +impl Into for CopyableVal { + fn into(self) -> Field { + Field::new(self.to_string().as_ref()) + } +} + +//************************************************************************************************** +// Display +//************************************************************************************************** + +impl fmt::Display for Spanned +where + T: fmt::Display, +{ + fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { + write!(f, "{}", self.value) + } +} + +impl fmt::Display for ModuleName { + fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { + write!(f, "{}", self.0) + } +} + +impl fmt::Display for QualifiedModuleIdent { + fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { + write!(f, "{:?}.{}", self.address, self.name) + } +} + +impl fmt::Display for ModuleDefinition { + fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { + writeln!(f, "Module({}, ", self.name.name())?; + write!(f, "Structs(")?; + for struct_def in &self.structs { + write!(f, "{}, ", struct_def)?; + } + write!(f, "Functions(")?; + for (fun_name, fun) in &self.functions { + write!(f, "({}, {}), ", fun_name, fun)?; + } + write!(f, ")") + } +} + +impl fmt::Display for StructDefinition { + fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { + writeln!(f, "Struct({}, ", self.name)?; + writeln!(f, "{}", format_fields(&self.fields))?; + write!(f, ")") + } +} + +impl fmt::Display for Function { + fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { + write!(f, "{} ({})", self.signature, self.body) + } +} + +impl fmt::Display for StructName { + fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { + write!(f, "{}", self.0) + } +} + +impl fmt::Display for FunctionName { + fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { + write!(f, "{}", self.0) + } +} + +impl fmt::Display for FunctionBody { + fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { + match self { + FunctionBody::Move { + ref locals, + ref code, + } => { + for (local, ty) in locals { + write!(f, "let {}: {};", local, ty)?; + } + writeln!(f, "{}", code) + } + FunctionBody::Native => write!(f, "native"), + } + } +} + +fn intersperse(items: &[T], join: &str) -> String { + items.iter().fold(String::new(), |acc, v| { + format!("{acc}{join}{v}", acc = acc, join = join, v = v) + }) +} + +fn format_fields(fields: &Fields) -> String { + fields.iter().fold(String::new(), |acc, (field, val)| { + format!("{} {}: {},", acc, field, val) + }) +} + +impl fmt::Display for FunctionSignature { + fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { + write!(f, "(")?; + for (v, ty) in self.formals.iter() { + write!(f, "{}: {}, ", v, ty)?; + } + write!(f, ")")?; + Ok(()) + } +} + +impl fmt::Display for Kind { + fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { + match self { + Kind::Resource => write!(f, "R"), + Kind::Value => write!(f, "V"), + } + } +} + +impl fmt::Display for StructType { + fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { + write!(f, "{}.{}", self.module, self.name.name()) + } +} + +impl fmt::Display for Tag { + fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { + match self { + Tag::U64 => write!(f, "u64"), + Tag::Bool => write!(f, "bool"), + Tag::Address => write!(f, "address"), + Tag::ByteArray => write!(f, "bytearray"), + Tag::String => write!(f, "string"), + Tag::Struct(ty) => write!(f, "{}", ty), + } + } +} + +fn write_kind_tag(f: &mut fmt::Formatter<'_>, k: &Kind, t: &Tag) -> fmt::Result { + match t { + Tag::Struct(_) => write!(f, "{}#{}", k, t), + _ => write!(f, "{}", t), + } +} + +impl fmt::Display for Type { + fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { + match self { + Type::Normal(k, t) => write_kind_tag(f, k, t), + Type::Reference { + kind, + tag, + is_mutable, + } => { + write!(f, "&{}", if *is_mutable { "mut " } else { "" })?; + write_kind_tag(f, kind, tag) + } + } + } +} + +impl fmt::Display for Var { + fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { + write!(f, "{}", self.0) + } +} + +impl fmt::Display for Builtin { + fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { + match self { + Builtin::CreateAccount => write!(f, "create_account"), + Builtin::Release => write!(f, "release"), + Builtin::EmitEvent => write!(f, "log"), + Builtin::Exists(t) => write!(f, "exists<{}>", t), + Builtin::BorrowGlobal(t) => write!(f, "borrow_global<{}>", t), + Builtin::GetHeight => write!(f, "get_height"), + Builtin::GetTxnMaxGasUnits => write!(f, "get_txn_max_gas_units"), + Builtin::GetTxnGasUnitPrice => write!(f, "get_txn_gas_unit_price"), + Builtin::GetTxnPublicKey => write!(f, "get_txn_public_key"), + Builtin::GetTxnSender => write!(f, "get_txn_sender"), + Builtin::GetTxnSequenceNumber => write!(f, "get_txn_sequence_number"), + Builtin::GetGasRemaining => write!(f, "get_gas_remaining"), + Builtin::MoveFrom(t) => write!(f, "move_from<{}>", t), + Builtin::MoveToSender(t) => write!(f, "move_to_sender<{}>", t), + Builtin::Freeze => write!(f, "freeze"), + } + } +} + +impl fmt::Display for FunctionCall { + fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { + match self { + FunctionCall::Builtin(fun) => write!(f, "{}", fun), + FunctionCall::ModuleFunctionCall { module, name } => write!(f, "{}.{}", module, name), + } + } +} + +impl fmt::Display for Cmd { + fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { + match self { + Cmd::Call { + return_bindings, + call, + actuals, + } => { + let args = intersperse(actuals, ", "); + if return_bindings.is_empty() { + write!(f, "{}({});", call, args) + } else { + write!( + f, + "let {} = {}({});", + intersperse(return_bindings, ", "), + call, + args + ) + } + } + Cmd::Assign(v, e) => write!(f, "{} = {};", v, e,), + Cmd::Unpack(n, bindings, e) => write!( + f, + "{} {{ {} }} = {}", + n, + bindings + .iter() + .fold(String::new(), |acc, (field, var)| format!( + "{} {} : {},", + acc, field, var + )), + e + ), + Cmd::Mutate(e, o) => write!(f, "*({}) = {};", e, o), + Cmd::Assert(e, err) => write!(f, "assert({}, {});", e, err), + Cmd::Return(exps) => write!(f, "return {};", intersperse(exps, ", ")), + Cmd::Break => write!(f, "break;"), + Cmd::Continue => write!(f, "continue;"), + } + } +} + +impl fmt::Display for IfElse { + fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { + write!( + f, + "if ({}) {{\n{:indent$}\n}}", + self.cond, + self.if_block, + indent = 4 + )?; + match self.else_block { + None => Ok(()), + Some(ref block) => write!(f, " else {{\n{:indent$}\n}}", block, indent = 4), + } + } +} + +impl fmt::Display for While { + fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { + write!( + f, + "while ({}) {{\n{:indent$}\n}}", + self.cond, + self.block, + indent = 4 + )?; + Ok(()) + } +} + +impl fmt::Display for Loop { + fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { + write!(f, "loop {{\n{:indent$}\n}}", self.block, indent = 4)?; + Ok(()) + } +} + +impl fmt::Display for Statement { + fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { + match self { + Statement::CommandStatement(cmd) => write!(f, "{}", cmd), + Statement::IfElseStatement(if_else) => write!(f, "{}", if_else), + Statement::WhileStatement(while_) => write!(f, "{}", while_), + Statement::LoopStatement(loop_) => write!(f, "{}", loop_), + Statement::VerifyStatement(cond) => write!(f, "verify<{}>)", cond), + Statement::AssumeStatement(cond) => write!(f, "assume<{}>", cond), + Statement::EmptyStatement => write!(f, ""), + } + } +} + +impl fmt::Display for Block { + fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { + for stmt in self.stmts.iter() { + writeln!(f, "{}", stmt)?; + } + Ok(()) + } +} + +impl fmt::Display for CopyableVal { + fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { + match self { + CopyableVal::U64(v) => write!(f, "{}", v), + CopyableVal::Bool(v) => write!(f, "{}", v), + CopyableVal::ByteArray(v) => write!(f, "{}", v), + CopyableVal::Address(v) => write!(f, "0x{}", hex::encode(&v)), + CopyableVal::String(v) => write!(f, "{}", v), + } + } +} + +impl fmt::Display for UnaryOp { + fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { + write!( + f, + "{}", + match self { + UnaryOp::Not => "!", + } + ) + } +} + +impl fmt::Display for BinOp { + fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { + write!( + f, + "{}", + match self { + BinOp::Add => "+", + BinOp::Sub => "-", + BinOp::Mul => "*", + BinOp::Mod => "%", + BinOp::Div => "/", + BinOp::BitOr => "|", + BinOp::BitAnd => "&", + BinOp::Xor => "^", + + // Bool ops + BinOp::Or => "||", + BinOp::And => "&&", + + // Compare Ops + BinOp::Eq => "==", + BinOp::Neq => "!=", + BinOp::Lt => "<", + BinOp::Gt => ">", + BinOp::Le => "<=", + BinOp::Ge => ">=", + } + ) + } +} + +impl fmt::Display for Exp { + fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { + match self { + Exp::Dereference(e) => write!(f, "*({})", e), + Exp::UnaryExp(o, e) => write!(f, "({}{})", o, e), + Exp::BinopExp(e1, o, e2) => write!(f, "({} {} {})", o, e1, e2), + Exp::Value(v) => write!(f, "{}", v), + Exp::Pack(n, s) => write!( + f, + "{}{{{}}}", + n, + s.iter().fold(String::new(), |acc, (field, op)| format!( + "{} {} : {},", + acc, field, op + )) + ), + Exp::Borrow { + is_mutable, + exp, + field, + } => write!( + f, + "&{}{}.{}", + if *is_mutable { "mut " } else { "" }, + exp, + field + ), + Exp::Move(v) => write!(f, "move({})", v), + Exp::Copy(v) => write!(f, "copy({})", v), + Exp::BorrowLocal(is_mutable, v) => { + write!(f, "&{}{}", if *is_mutable { "mut " } else { "" }, v) + } + } + } +} diff --git a/language/compiler/src/parser/mod.rs b/language/compiler/src/parser/mod.rs new file mode 100644 index 0000000000000..63dfb5985899d --- /dev/null +++ b/language/compiler/src/parser/mod.rs @@ -0,0 +1,319 @@ +// Copyright (c) The Libra Core Contributors +// SPDX-License-Identifier: Apache-2.0 + +#[rustfmt::skip] +#[allow(clippy::all)] +#[allow(deprecated)] +mod syntax; + +/// # Grammar +/// ## Identifiers +/// ```text +/// f ∈ FieldName // [a-zA-Z$_][a-zA-Z0-9$_]* +/// p ∈ ProcedureName // [a-zA-Z$_][a-zA-Z0-9$_]* +/// m ∈ ModuleName // [a-zA-Z$_][a-zA-Z0-9$_]* +/// n ∈ StructName // [a-zA-Z$_][a-zA-Z0-9$_]* +/// x ∈ Var // [a-zA-Z$_][a-zA-Z0-9$_]* +/// ``` +/// +/// ## Types +/// ```text +/// k ∈ Kind ::= +/// | R // Linear resource struct value. Must be used, cannot be copied +/// | V // Non-resource struct value. Can be silently discarded, can be copied +/// +/// g ∈ GroundType ::= +/// | bool +/// | u64 // unsigned 64 bit integer +/// | address // 32 byte account address +/// | bytearray // immutable, arbitrarily sized array of bytes +/// +/// d ∈ ModuleAlias ::= +/// | m // module name that is an alias to a declared module, addr.m +/// | Self // current module +/// +/// t ∈ BaseType ::= +/// | g // ground type +/// | k#d.n // struct 'n' declared in the module referenced by 'd' with kind 'k' +/// // the kind 'k' cannot differ from the declared kind +/// +/// 𝛕 ∈ Type ::= +/// | t // base type +/// | &t // immutable reference to a base type +/// | &mut t // mutable reference to a base type +/// +/// 𝛕-list ∈ [Type] ::= +/// | unit // empty type list. +/// // in the actual syntax, it is represented by the abscense of a type +/// | 𝛕_1 * ... * 𝛕_j // 'j' >= 1. list of multiple types. used for multiple return values +/// ``` +/// +/// ## Values +/// ```text +/// u ∈ Unsigned64 // Unsigned, 64-bit Integer +/// addr ∈ AccountAddress // addresses of blockchain accounts +/// bytes ∈ ByteArray // byte array of arbitrary length +/// v ∈ Value ::= +/// | true +/// | false +/// | u // u64 literal +/// | 0xaddr // 32 byte address literal +/// | b"bytes" // arbitrary length bytearray literal +/// ``` +/// +/// ## Expressions +/// ```text +/// o ∈ VarOp ::= +/// | copy(x) // returns value bound to 'x' +/// | move(x) // moves the value out of 'x', i.e. returns the value and makes 'x' unusable +/// +/// r ∈ ReferenceOp ::= +/// | &x // type: 't -> &mut t' +/// // creates an exclusive, mutable reference to a local +/// | &e.f // type: '&t_1 -> &t_2' or '&mut t_1 -> &mut t_2' +/// // borrows a new reference to field 'f' of the struct 't_1'. inherits exclusive or shared from parent +/// // 't_1' must be a struct declared in the current module, i.e. 'f' is "private" +/// | *e // type: '&t -> t' or '&mut t -> t'. Dereferencing. Not valid for resources +/// +/// e ∈ Exp ::= +/// | v +/// | o +/// | r +/// | n { f_1: e_1, ... , f_j: e_j } // type: '𝛕-list -> k#Self.n' +/// // "constructor" for 'n' +/// // "packs" the values, binding them to the fields, and creates a new instance of 'n' +/// // 'n' must be declared in the current module +/// // boolean operators +/// | !e_1 +/// | e_1 || e_2 +/// | e_1 && e_2 +/// // u64 operators +/// | e_1 >= e_2 +/// | e_1 <= e_2 +/// | e_1 > e_2 +/// | e_1 < e_2 +/// | e_1 + e_2 +/// | e_1 - e_2 +/// | e_1 * e_2 +/// | e_1 / e_2 +/// | e_1 % e_2 +/// | e_1 ^ e_2 +/// | e_1 | e_2 +/// | e_g & e_2 +/// // operators over any ground type +/// | e_1 == e_2 +/// | e_1 != e_2 +/// ``` +/// ## Commands +/// ```text +/// // module operators are available only inside the module that declares n. +/// mop ∈ ModuleOp ::= +/// | move_to_sender(e) // type: 'R#Self.n -> unit' +/// // publishes resource struct 'n' under sender's address +/// // fails if there is already a resource present for 'Self.n' +/// | move_from(e) // type: 'address -> R#Self.n' +/// // removes the resource struct 'n' at the specified address +/// // fails if there is no resource present for 'Self.n' +/// | borrow_global(e) // type: 'address -> &mut R#Self.n' +/// // borrows a mutable reference to the resource struct 'n' at the specified address +/// // fails if there is no resource +/// // fails if it is already borrowed in this transaction's execution +/// | exists(e) // type: 'address -> bool', s.t. 'n' is a resource struct +/// // returns 'true' if the resource struct 'n' at the specified address exists +/// // returns 'false' otherwise +/// +/// builtin ∈ Builtin ::= +/// | create_account(e) // type: 'addr -> unit' +/// // creates new account at the specified address, failing if it already exists +/// | release(e) // type: '&t -> unit' or '&mut t -> unit' +/// // releases the reference given +/// | freeze(x) // type: '&mut t -> &t' +/// // coerce a mutable reference to an immutable reference +/// | get_txn_gas_unit_price() // type: 'unit -> u64' +/// // gives the price specified per gas unit +/// | get_txn_max_gas_units() // type: 'unit -> u64' +/// // gives the tranaction's maximum amount of usable gas units +/// | get_txn_public_key() // type: 'unit -> bytearray' +/// // gives the transaction's public key +/// | get_txn_sender() // type: 'unit -> address' +/// // gives the transaction's sender's account address +/// | get_txn_sequence_number() // type: 'unit -> u64' +/// // gives the sequence number for this transaction +/// | get_gas_remaining() // type: 'unit -> u64' +/// // gives the amount of gas gas units remaining before the transaction execution will be forced to halt execution +/// +/// call ∈ Call ::= +/// | mop +/// | builtin +/// | d.p(e_1, ..., e_j) // procedure 'p' defined in the module referenced by 'd' +/// +/// c ∈ Cmd ::= +/// | x = e // assign the result of evaluating 'e' to 'x' +/// | x_1, ..., x_j = call // Invokes 'call', assigns result to 'x_1' to 'x_j' +/// | call // Invokes 'call' that has a return type of 'unit' +/// | *x = e // mutation, s.t. 'x: &mut t' and 'e: t' and 't' is not of resource kind +/// | assert(e_1, e_2) // type: 'bool * u64 -> unit' +/// // halts execution with error code 'e_2' if 'e_1' evaluates to 'false' +/// | break // exit a loop +/// | continue // return to the top of a loop +/// | return e_1, ..., e_n // return values from procedure +/// | n { f_1: x_1, ... , f_j: x_j } = e // "de-constructor" for 'n' +/// // "unpacks" a struct value 'e: _#Self.n' +/// // value for 'f_i' is bound to local 'x_i' +/// ``` +/// +/// ## Statements +/// ```text +/// s ∈ Stmt ::= +/// | if (e) { s_1 } else { s_2 } // conditional +/// | if (e) { s } // conditional without else branch +/// | while (e) { s } // while loop +/// | loop { s } // loops forever +/// | c; // command +/// | s_1 s_2 // sequencing +/// ``` +/// +/// ## Imports +///```text +/// idecl ∈ Import ::= +/// | import addr.m_1 as m_2; // imports 'addr.m_1' with the alias 'm_2' +/// | import addr.m_1; // imports 'addr.m_1' with the alias 'm_1' +/// ``` +/// ## Modules +/// ```text +/// sdecl ∈ StructDecl ::= +/// | resource n { f_1: t_1, ..., f_j: t_j } // declaration of a resource struct +/// | struct n { f_1: t_1, ..., f_j: t_j } // declaration of a non-resource (value) struct +/// // s.t. any 't_i' is not of resource kind +/// +/// body ∈ ProcedureBody ::= +/// | let x_1; ... let x_j; s // The locals declared in this procedure, and the code for that procedure +/// +/// pdecl ∈ ProcedureDecl ::= +/// | (public?) p(x_1: 𝛕_1, ..., x_j: 𝛕_j): 𝛕-list { body } // declaration of a defined procedure +/// // the procedure may be public, or internal to the module +/// | native (public?) p(x_1: 𝛕_1, ..., x_j: 𝛕_j): 𝛕-list; // declaration of a native procedure +/// // the implementation is provided by the VM +/// // the procedure may be public, or internal to the module +/// +/// mdecl ∈ ModuleDecl ::= +/// | module m { idecl_1 ... idecl_i sdecl_1 ... sdecl_j pdecl_1 ... pdecl_k } +/// ``` +/// +/// ## Transaction Scripts +/// ```text +/// TransactionScript ::= +/// // declaration of the transaction scripts procedure +/// // the 'main' procedure must be 'public' and any parameters must have a ground type +/// | idecl_1 ... idecl_i public main(x_1: g_1, ..., x_j: g_j) { s } +/// ``` +pub mod ast; + +use codespan::{ByteIndex, CodeMap, Span}; +use codespan_reporting::{emit, termcolor::Buffer, Diagnostic, Label, Severity}; +use failure::*; +use lalrpop_util::ParseError; +use regex::Regex; +use std::{ + collections::hash_map::DefaultHasher, + hash::{Hash, Hasher}, +}; +use types::account_address::AccountAddress; + +// Since lalrpop can't handle comments without a custom lexer, we somewhat hackily remove all the +// comments from the input string before passing it off to lalrpop. We only support single line +// comments for now. Will later on add in other comment types. +fn strip_comments(string: &str) -> String { + // Remove line comments + let line_comments = Regex::new(r"//.*(\r\n|\n|\r)").unwrap(); + line_comments.replace_all(string, "$1").into_owned() +} + +/// Given the raw input of a file, creates a `Program` struct +/// Fails with `Err(_)` if the text cannot be parsed +pub fn parse_program(program_str: &str) -> Result { + let stripped_string = &strip_comments(program_str); + let parser = syntax::ProgramParser::new(); + match parser.parse(stripped_string) { + Ok(program) => Ok(program), + Err(e) => handle_error(e, program_str), + } +} + +/// Given the raw input of a file, creates a `Script` struct +/// Fails with `Err(_)` if the text cannot be parsed +pub fn parse_script(script_str: &str) -> Result { + let stripped_string = &strip_comments(script_str); + let parser = syntax::ScriptParser::new(); + match parser.parse(stripped_string) { + Ok(script) => Ok(script), + Err(e) => handle_error(e, script_str), + } +} + +/// Given the raw input of a file, creates a single `ModuleDefinition` struct +/// Fails with `Err(_)` if the text cannot be parsed +pub fn parse_module(modules_str: &str) -> Result { + let stripped_string = &strip_comments(modules_str); + let parser = syntax::ModuleParser::new(); + match parser.parse(stripped_string) { + Ok(module) => Ok(module), + Err(e) => handle_error(e, modules_str), + } +} + +/// Given the raw input of a file, creates a single `Cmd` struct +/// Fails with `Err(_)` if the text cannot be parsed +pub fn parse_cmd(cmd_str: &str, _sender_address: AccountAddress) -> Result { + let stripped_string = &strip_comments(cmd_str); + let parser = syntax::CmdParser::new(); + match parser.parse(stripped_string) { + Ok(cmd) => Ok(cmd), + Err(e) => handle_error(e, cmd_str), + } +} + +fn handle_error<'input, T, Token>( + e: lalrpop_util::ParseError, + code_str: &'input str, +) -> Result +where + Token: std::fmt::Display, +{ + let mut s = DefaultHasher::new(); + code_str.hash(&mut s); + let mut code = CodeMap::new(); + code.add_filemap(s.finish().to_string().into(), code_str.to_string()); + let msg = match &e { + ParseError::InvalidToken { location } => { + let error = + Diagnostic::new(Severity::Error, "Invalid Token").with_label(Label::new_primary( + Span::new(ByteIndex(*location as u32), ByteIndex(*location as u32)), + )); + let mut buffer = Buffer::no_color(); + emit(&mut buffer, &code, &error).unwrap(); + std::str::from_utf8(buffer.as_slice()).unwrap().to_string() + } + ParseError::UnrecognizedToken { + token: Some((l, tok, r)), + expected, + } => { + let error = Diagnostic::new(Severity::Error, format!("Unrecognized Token: {}", tok)) + .with_label( + Label::new_primary(Span::new(ByteIndex(*l as u32), ByteIndex(*r as u32))) + .with_message(format!( + "Expected: {}", + expected + .iter() + .fold(String::new(), |acc, token| format!("{} {},", acc, token)) + )), + ); + let mut buffer = Buffer::no_color(); + emit(&mut buffer, &code, &error).unwrap(); + std::str::from_utf8(buffer.as_slice()).unwrap().to_string() + } + _ => format!("{}", e), + }; + println!("{}", msg); + bail!("ParserError: {}", e) +} diff --git a/language/compiler/src/parser/syntax.lalrpop b/language/compiler/src/parser/syntax.lalrpop new file mode 100644 index 0000000000000..ec43ea427da4f --- /dev/null +++ b/language/compiler/src/parser/syntax.lalrpop @@ -0,0 +1,500 @@ +use std::str::FromStr; +use std::collections::BTreeMap; +use std::convert::TryFrom; +use codespan::{ByteIndex, Span}; + +use crate::parser::ast::{ModuleDefinition, StructDefinition, Script, Program}; +use crate::parser::ast::{ + FunctionAnnotation, FunctionBody, FunctionVisibility, ImportDefinition, ModuleName, + Kind, Block, Cmd, CopyableVal, Spanned, + Cmd_, Exp_, Exp, Var, Var_, FunctionCall, + FunctionName, Builtin, Statement, IfElse, While, Loop, Type, Tag, Field, Fields, + StructName, StructType, Function, BinOp, ModuleIdent, QualifiedModuleIdent, UnaryOp +}; +use types::{account_address::AccountAddress, byte_array::ByteArray}; +use hex; + +grammar(); + +U64: u64 = => u64::from_str(s).unwrap(); +Name: String = => s.to_string(); +ByteArray: ByteArray = { + => { + ByteArray::new(hex::decode(&s[2..s.len()-1]).unwrap_or_else(|_| panic!("The string {:?} is not a valid hex-encoded byte array", s))) + } +}; +AccountAddress: AccountAddress = { + < s: r"0[xX][0-9a-fA-F]+" > => { + let mut hex_string = String::from(&s[2..]); + if hex_string.len() % 2 != 0 { + hex_string.insert(0, '0'); + } + + let mut result = hex::decode(hex_string.as_str()).unwrap(); + let len = result.len(); + if len < 32 { + result.reverse(); + for i in len..32 { + result.push(0); + } + result.reverse(); + } + + assert!(result.len() >= 32); + AccountAddress::try_from(&result[..]) + .unwrap_or_else(|_| panic!("The address {:?} is of invalid length. Addresses are at most 32-bytes long", result)) + } +}; + +Comma: Vec = { + ",")*> => match e { + None => v, + Some(e) => { + let mut v = v; + v.push(e); + v + } + } +}; + +Sp: Spanned = + => + Spanned{span: Span::new(ByteIndex(l as u32), ByteIndex(r as u32)), value: rule}; + +Var: Var = { + => Var::new(n.as_str()), +}; + +Field: Field = { + => Field::new(n.as_str()), +}; + +CopyableVal: CopyableVal = { + AccountAddress => CopyableVal::Address(<>), + "true" => CopyableVal::Bool(true), + "false" => CopyableVal::Bool(false), + => CopyableVal::U64(i), + => CopyableVal::ByteArray(buf), +} + +Exp = BinopExp; +Exp_ = Sp; + +Tier: Exp = { + >>> >> => { + Exp::BinopExp(Box::new(e1), o, Box::new(e2)) + }, + NextTier +}; + +BinopExp = Tier; + +BinopExp_ = Sp; + +CmpOp: BinOp = { + "==" => BinOp::Eq, + "!=" => BinOp::Neq, + "<" => BinOp::Lt, + ">" => BinOp::Gt, + "<=" => BinOp::Le, + ">=" => BinOp::Ge, +} + +OrExp = Tier; + +OrOp: BinOp = { + "||" => BinOp::Or, +} + +AndExp = Tier; + +AndOp: BinOp = { + "&&" => BinOp::And, +} + +XorExp = Tier; + +XorOp: BinOp = { + "^" => BinOp::Xor, +} + +BinOrExp = Tier; + +BinOrOp: BinOp = { + "|" => BinOp::BitOr, +} + +BinAndExp = Tier; + +BinAndOp: BinOp = { + "&" => BinOp::BitAnd, +} + +AddSubExp = Tier; + +AddSubOp: BinOp = { + "+" => BinOp::Add, + "-" => BinOp::Sub, +} + +FactorExp = Tier; + +FactorOp: BinOp = { + "*" => BinOp::Mul, + "/" => BinOp::Div, + "%" => BinOp::Mod, +} + +UnaryExp : Exp = { + "!" >> => Exp::UnaryExp(UnaryOp::Not, Box::new(e)), + "*" >> => Exp::Dereference(Box::new(e)), + "&mut " >> "." => { + Exp::Borrow{ is_mutable: true, exp: Box::new(e), field: f } + }, + "&" >> "." => { + Exp::Borrow{ is_mutable: false, exp: Box::new(e), field: f } + }, + Term +} + +FieldExp: (Field, Exp_) = { + ":" > => (f, e) +} + +Term : Exp = { + "move(" > ")" => Exp::Move(v), + "copy(" > ")" => Exp::Copy(v), + "&mut " > => Exp::BorrowLocal(true, v), + "&" > => Exp::BorrowLocal(false, v), + Sp => Exp::Value(<>), + "{" > "}" => + Exp::Pack( + StructName::new(n), + fs.into_iter().collect::>() + ), + "(" ")" => <>, +} + +StructName: StructName = { + => StructName::new(n), +} + +StructType : StructType = { + "." => StructType::new(m, n), +} + + +ModuleName: ModuleName = { + => ModuleName::new(n), +} + +Builtin: Builtin = { + "create_account" => Builtin::CreateAccount, + "release" => Builtin::Release, + "exists<" ">" => Builtin::Exists(t), + "borrow_global<" ">" => Builtin::BorrowGlobal(t), + "get_height" => Builtin::GetHeight, + "get_txn_gas_unit_price" => Builtin::GetTxnGasUnitPrice, + "get_txn_max_gas_units" => Builtin::GetTxnMaxGasUnits, + "get_txn_public_key" => Builtin::GetTxnPublicKey, + "get_txn_sender" => Builtin::GetTxnSender, + "get_txn_sequence_number" => Builtin::GetTxnSequenceNumber, + "emit_event" => Builtin::EmitEvent, + "move_from<" ">" => Builtin::MoveFrom(t), + "move_to_sender<" ">" => Builtin::MoveToSender(t), + "get_gas_remaining" => Builtin::GetGasRemaining, + "freeze" => Builtin::Freeze, +} + +FunctionCallBody : FunctionCall = { + => FunctionCall::Builtin(f), + "." => + FunctionCall::ModuleFunctionCall{ + module, + name: FunctionName::new(n), + }, +} + +ReturnBindings: Vec = { + > >)+> => { + let mut v = v; + v.reverse(); + v.push(l); + v.reverse(); + v + } +} + +FieldBindings: (Field, Var_) = { + ":" > => (f, v), + > => (f.value.clone(), Spanned { span: f.span, value: Var::new(f.value.name()) }), +} + +pub Cmd : Cmd = { + > "=" > => Cmd::Assign(v, e), + "*" > "=" > => Cmd::Mutate(e, op), + > "=" > "(" >> ")" => + Cmd::Call { + return_bindings: vec![binding], + call: f, + actuals: s, + }, + "=" > "(" >> ")" => + Cmd::Call { + return_bindings: bindings, + call: f, + actuals: s, + }, + > "(" >> ")" => Cmd::Call { + return_bindings: vec![], + call: f, + actuals: s, + }, + "{" > "}" "=" > => + Cmd::Unpack( + n, + bindings.into_iter().collect(), + e, + ), + "assert(" > "," > ")" => Cmd::Assert(e, err), + "return" >> => Cmd::Return(v), + "continue" => Cmd::Continue, + "break" => Cmd::Break, +} + +Cmd_ : Cmd_ = { + Sp, +} + +Statement : Statement = { + ";" => Statement::CommandStatement(cmd), + , + , + , + , + , + ";" => Statement::EmptyStatement, +} + +IfStatement : Statement = { + "if" "(" > ")" => { + Statement::IfElseStatement(IfElse::if_block(cond, block)) + }, + "if" "(" > ")" "else" => { + Statement::IfElseStatement(IfElse::if_else(cond, if_block, else_block)) + }, +} + +WhileStatement : Statement = { + "while" "(" > ")" => { + Statement::WhileStatement(While {cond, block}) + } +} + +LoopStatement : Statement = { + "loop" => { + Statement::LoopStatement(Loop {block}) + } +} + +VerifierCondition: String = { + "> => { + let mut res = expr.to_string(); + res.remove(0); + res.remove(res.len() - 1); + res + } +} + +VerifyStatement: Statement = { + "verify" => { + Statement::VerifyStatement(cond) + } +} + +AssumeStatement: Statement = { + "assume" => { + Statement::AssumeStatement(cond) + } +} + +Statements : Vec = { + +} + +Block : Block = { + "{" "}" => Block::new(stmts) +} + +Declaration: (Var_, Type) = { + "let" > ":" ";" => (v, t), +} + +Declarations: Vec<(Var_, Type)> = { + +} + +FunctionBlock: (Vec<(Var_, Type)>, Block) = { + "{" "}" => (locals, Block::new(stmts)) +} + +Kind : Kind = { + "R" => Kind::Resource, + "V" => Kind::Value, +} + +Annotation : Type = { + "address" => Type::address(), + "u64" => Type::u64(), + "bool" => Type::bool(), + "bytearray" => Type::bytearray(), + "#" => { + Type::Normal( + kind, + Tag::Struct(c), + ) + }, +} + +RefAnnotation: Type = { + => annot, + "&" => Type::reference(false, annot), + "&mut " => Type::reference(true, annot), +} + +ArgDecl : (Var, Type) = { + ":" ","? => (v, t) +} + +NativeTag: () = { + "native" => () +} + +Public: () = { + "public" => () +} + +FunctionAnnotation: FunctionAnnotation = { + "requires" => FunctionAnnotation::Requires(cond.to_string()), + "ensures" => FunctionAnnotation::Ensures(cond.to_string()), +} + +ReturnType: Vec = { + ":" )*> => { + let mut v = v; + v.insert(0, t); + v + } +} + +FunctionDecl : (FunctionName, Function) = { + => (f.0, f.1), + => (f.0, f.1), +} + +MoveFunctionDecl : (FunctionName, Function) = { + "(" ")" + + => { + let (locals, body) = locals_body; + (FunctionName::new(n), Function::new( + if p.is_some() { FunctionVisibility::Public } else { FunctionVisibility::Internal }, + args, + ret.unwrap_or(vec![]), + annotations, + FunctionBody::Move{locals: locals, code: body}, + )) + } +} + +NativeFunctionDecl: (FunctionName, Function) = { + "(" ")" ";" => { + (FunctionName::new(n), Function::new( + if p.is_some() { FunctionVisibility::Public } else { FunctionVisibility::Internal }, + args, + ret.unwrap_or(vec![]), + vec![], + FunctionBody::Native, + )) + } +} + +FieldDecl : (Field, Type) = { + ":" ","? => (f, t) +} + +StructKind: bool = { + "struct" => false, + "resource" => true +} + +Modules: Vec = { + "modules:" "script:" => c, +} + +pub Program : Program = { + => { + let modules = match m { + Some(modules) => modules, + None => vec![], + }; + Program::new(modules, s) + } +} + +pub Script : Script = { + + "main" "(" ")" => { + let (locals, body) = locals_body; + let main = + Function::new( + FunctionVisibility::Public, + args, + vec![], + vec![], + FunctionBody::Move{ locals: locals, code: body }, + ); + Script::new(imports, main) + } +} + +StructDecl: StructDefinition = { + "{" "}" => { + let mut fields = Fields::new(); + for (field, type_) in data.into_iter() { + fields.insert(field, type_); + } + StructDefinition::new(kind, n, fields) + } +} + +QualifiedModuleIdent: QualifiedModuleIdent = { + "." => QualifiedModuleIdent::new(m, a), +} + +ModuleIdent: ModuleIdent = { + "Transaction" "." => ModuleIdent::Transaction(m), + => ModuleIdent::Qualified(q), +} + +ImportAlias: ModuleName = { + "as" => { + if alias.name_ref() == ModuleName::SELF { + panic!("Invalid use of resesrved module alias '{}'", ModuleName::SELF); + } + alias + } +} + +ImportDecl: ImportDefinition = { + "import" ";" => + ImportDefinition::new(ident, alias) +} + +pub Module : ModuleDefinition = { + "module" "{" + + + + "}" => ModuleDefinition::new(n.to_string(), imports, structs, functions), +} diff --git a/language/compiler/src/unit_tests/branch_tests.rs b/language/compiler/src/unit_tests/branch_tests.rs new file mode 100644 index 0000000000000..8f6b5c876aedd --- /dev/null +++ b/language/compiler/src/unit_tests/branch_tests.rs @@ -0,0 +1,402 @@ +// Copyright (c) The Libra Core Contributors +// SPDX-License-Identifier: Apache-2.0 + +#[macro_use] +mod testutils; +use super::*; +use testutils::compile_script_string; +use vm::file_format::Bytecode::*; + +#[test] +fn compile_if() { + let code = String::from( + " + main() { + let x: u64; + if (42 > 0) { + x = 1; + } + return; + } + ", + ); + + let compiled_script_res = compile_script_string(&code); + let compiled_script = compiled_script_res.unwrap(); + assert!(instr_count!(compiled_script, BrFalse(_)) == 1); + assert!(instr_count!(compiled_script, Branch(_)) == 0); +} + +#[test] +fn compile_if_else() { + let code = String::from( + " + main() { + let x: u64; + let y: u64; + if (42 > 0) { + x = 1; + } else { + y = 1; + } + return; + } + ", + ); + + let compiled_script_res = compile_script_string(&code); + let compiled_script = compiled_script_res.unwrap(); + assert!(instr_count!(compiled_script, BrFalse(_)) == 1); + assert!(instr_count!(compiled_script, Branch(_)) == 1); +} + +#[test] +fn compile_nested_if_else() { + let code = String::from( + " + main() { + let x: u64; + if (42 > 0) { + x = 1; + } else { + if (5 > 10) { + x = 2; + } else { + x = 3; + } + } + return; + } + ", + ); + + let compiled_script_res = compile_script_string(&code); + let compiled_script = compiled_script_res.unwrap(); + assert!(instr_count!(compiled_script, BrFalse(_)) == 2); + assert!(instr_count!(compiled_script, Branch(_)) == 2); +} + +#[test] +fn compile_if_else_with_if_return() { + let code = String::from( + " + main() { + let x: u64; + if (42 > 0) { + return; + } else { + x = 1; + } + return; + } + ", + ); + + let compiled_script_res = compile_script_string(&code); + let compiled_script = compiled_script_res.unwrap(); + assert!(instr_count!(compiled_script, BrFalse(_)) == 1); + assert!(instr_count!(compiled_script, Branch(_)) == 0); + assert!(instr_count!(compiled_script, Ret) == 2); +} + +#[test] +fn compile_if_else_with_else_return() { + let code = String::from( + " + main() { + let x: u64; + if (42 > 0) { + x = 1; + } else { + return; + } + return; + } + ", + ); + + let compiled_script_res = compile_script_string(&code); + let compiled_script = compiled_script_res.unwrap(); + assert!(instr_count!(compiled_script, BrFalse(_)) == 1); + assert!(instr_count!(compiled_script, Branch(_)) == 1); + assert!(instr_count!(compiled_script, Ret) == 2); +} + +#[test] +fn compile_if_else_with_two_returns() { + let code = String::from( + " + main() { + if (42 > 0) { + return; + } else { + return; + } + return; + } + ", + ); + + let compiled_script_res = compile_script_string(&code); + let compiled_script = compiled_script_res.unwrap(); + assert!(instr_count!(compiled_script, BrFalse(_)) == 1); + assert!(instr_count!(compiled_script, Branch(_)) == 0); + assert!(instr_count!(compiled_script, Ret) == 3); +} + +#[test] +fn compile_while() { + let code = String::from( + " + main() { + let x: u64; + x = 0; + while (copy(x) < 5) { + x = copy(x) + 1; + } + return; + } + ", + ); + let compiled_script_res = compile_script_string(&code); + let compiled_script = compiled_script_res.unwrap(); + assert!(instr_count!(compiled_script, BrFalse(_)) == 1); + assert!(instr_count!(compiled_script, Branch(_)) == 1); +} + +#[test] +fn compile_while_return() { + let code = String::from( + " + main() { + while (42 > 0) { + return; + } + return; + } + ", + ); + let compiled_script_res = compile_script_string(&code); + let compiled_script = compiled_script_res.unwrap(); + assert!(instr_count!(compiled_script, BrFalse(_)) == 1); + assert!(instr_count!(compiled_script, Branch(_)) == 1); + assert!(instr_count!(compiled_script, Ret) == 2); +} + +#[test] +fn compile_nested_while() { + let code = String::from( + " + main() { + let x: u64; + let y: u64; + x = 0; + while (copy(x) < 5) { + x = move(x) + 1; + y = 0; + while (copy(y) < 5) { + y = move(y) + 1; + } + } + return; + } + ", + ); + let compiled_script_res = compile_script_string(&code); + let compiled_script = compiled_script_res.unwrap(); + assert!(instr_count!(compiled_script, BrFalse(_)) == 2); + assert!(instr_count!(compiled_script, Branch(_)) == 2); +} + +#[test] +fn compile_break_outside_loop() { + let code = String::from( + " + main() { + break; + return; + } + ", + ); + let compiled_script_res = compile_script_string(&code); + assert!(compiled_script_res.is_err()); +} + +#[test] +fn compile_continue_outside_loop() { + let code = String::from( + " + main() { + continue; + return; + } + ", + ); + let compiled_script_res = compile_script_string(&code); + assert!(compiled_script_res.is_err()); +} + +#[test] +fn compile_while_break() { + let code = String::from( + " + main() { + while (true) { + break; + } + return; + } + ", + ); + let compiled_script_res = compile_script_string(&code); + let compiled_script = compiled_script_res.unwrap(); + assert!(instr_count!(compiled_script, BrFalse(_)) == 1); + assert!(instr_count!(compiled_script, Branch(_)) == 2); +} + +#[test] +fn compile_while_continue() { + let code = String::from( + " + main() { + while (false) { + continue; + } + return; + } + ", + ); + let compiled_script_res = compile_script_string(&code); + let compiled_script = compiled_script_res.unwrap(); + assert!(instr_count!(compiled_script, BrFalse(_)) == 1); + assert!(instr_count!(compiled_script, Branch(_)) == 2); +} + +#[test] +fn compile_while_break_continue() { + let code = String::from( + " + main() { + let x: u64; + x = 42; + while (false) { + x = move(x) / 3; + if (copy(x) == 0) { + break; + } + continue; + } + return; + } + ", + ); + let compiled_script_res = compile_script_string(&code); + let compiled_script = compiled_script_res.unwrap(); + assert!(instr_count!(compiled_script, BrFalse(_)) == 2); + assert!(instr_count!(compiled_script, Branch(_)) == 3); +} + +#[test] +fn compile_loop_empty() { + let code = String::from( + " + main() { + loop { + } + return; + } + ", + ); + let compiled_script_res = compile_script_string(&code); + let compiled_script = compiled_script_res.unwrap(); + assert!(instr_count!(compiled_script, Branch(_)) == 1); +} + +#[test] +fn compile_loop_nested_break() { + let code = String::from( + " + main() { + loop { + loop { + break; + } + break; + } + return; + } + ", + ); + let compiled_script_res = compile_script_string(&code); + let compiled_script = compiled_script_res.unwrap(); + assert!(instr_count!(compiled_script, Branch(_)) == 4); +} + +#[test] +fn compile_loop_continue() { + let code = String::from( + " + main() { + loop { + continue; + } + return; + } + ", + ); + let compiled_script_res = compile_script_string(&code); + let compiled_script = compiled_script_res.unwrap(); + assert!(instr_count!(compiled_script, Branch(_)) == 2); +} + +#[test] +fn compile_loop_break_continue() { + let code = String::from( + " + main() { + let x: u64; + let y: u64; + x = 0; + y = 0; + + loop { + x = move(x) + 1; + if (copy(x) >= 10) { + break; + } + if (copy(x) % 2 == 0) { + continue; + } + y = move(y) + copy(x); + } + + return; + } + ", + ); + let compiled_script_res = compile_script_string(&code); + let compiled_script = compiled_script_res.unwrap(); + assert!(instr_count!(compiled_script, Branch(_)) == 3); + assert!(instr_count!(compiled_script, BrFalse(_)) == 2); +} + +#[test] +fn compile_loop_return() { + let code = String::from( + " + main() { + loop { + loop { + return; + } + return; + } + return; + } + ", + ); + let compiled_script_res = compile_script_string(&code); + let compiled_script = compiled_script_res.unwrap(); + assert!(instr_count!(compiled_script, Branch(_)) == 2); + assert!(instr_count!(compiled_script, Ret) == 3); +} diff --git a/language/compiler/src/unit_tests/cfg_tests.rs b/language/compiler/src/unit_tests/cfg_tests.rs new file mode 100644 index 0000000000000..ad1189935b60c --- /dev/null +++ b/language/compiler/src/unit_tests/cfg_tests.rs @@ -0,0 +1,212 @@ +// Copyright (c) The Libra Core Contributors +// SPDX-License-Identifier: Apache-2.0 + +mod testutils; +use super::*; +use bytecode_verifier::control_flow_graph::{ControlFlowGraph, VMControlFlowGraph}; +use testutils::compile_script_string; + +#[test] +fn cfg_compile_script_ret() { + let code = String::from( + " + main() { + return; + } + ", + ); + let compiled_script_res = compile_script_string(&code); + let compiled_script = compiled_script_res.unwrap(); + let cfg: VMControlFlowGraph = ControlFlowGraph::new(&compiled_script.main.code.code).unwrap(); + cfg.display(); + assert!(cfg.blocks.len() == 1); + assert!(cfg.num_blocks() == 1); + assert!(cfg.reachable_from(0).len() == 1); +} + +#[test] +fn cfg_compile_script_let() { + let code = String::from( + " + main() { + let x: u64; + let y: u64; + let z: u64; + x = 3; + y = 5; + z = move(x) + copy(y) * 5 - copy(y); + return; + } + ", + ); + let compiled_script_res = compile_script_string(&code); + let compiled_script = compiled_script_res.unwrap(); + let cfg: VMControlFlowGraph = ControlFlowGraph::new(&compiled_script.main.code.code).unwrap(); + println!("SCRIPT:\n {:?}", compiled_script); + cfg.display(); + assert!(cfg.blocks.len() == 1); + assert!(cfg.num_blocks() == 1); + assert!(cfg.reachable_from(0).len() == 1); +} + +#[test] +fn cfg_compile_if() { + let code = String::from( + " + main() { + let x: u64; + x = 0; + if (42 > 0) { + x = 1; + } + return; + } + ", + ); + + let compiled_script_res = compile_script_string(&code); + let compiled_script = compiled_script_res.unwrap(); + let cfg: VMControlFlowGraph = ControlFlowGraph::new(&compiled_script.main.code.code).unwrap(); + println!("SCRIPT:\n {:?}", compiled_script); + cfg.display(); + assert!(cfg.blocks.len() == 3); + assert!(cfg.num_blocks() == 3); + assert!(cfg.reachable_from(0).len() == 3); +} + +#[test] +fn cfg_compile_if_else() { + let code = String::from( + " + main() { + let x: u64; + let y: u64; + if (42 > 0) { + x = 1; + y = 2; + } else { + y = 2; + x = 1; + } + return; + } + ", + ); + let compiled_script_res = compile_script_string(&code); + let compiled_script = compiled_script_res.unwrap(); + let cfg: VMControlFlowGraph = ControlFlowGraph::new(&compiled_script.main.code.code).unwrap(); + println!("SCRIPT:\n {:?}", compiled_script); + cfg.display(); + assert!(cfg.blocks.len() == 4); + assert!(cfg.num_blocks() == 4); + assert!(cfg.reachable_from(0).len() == 4); +} + +#[test] +fn cfg_compile_if_else_with_else_return() { + let code = String::from( + " + main() { + let x: u64; + if (42 > 0) { + x = 1; + } else { + return; + } + return; + } + ", + ); + + let compiled_script_res = compile_script_string(&code); + let compiled_script = compiled_script_res.unwrap(); + let cfg: VMControlFlowGraph = ControlFlowGraph::new(&compiled_script.main.code.code).unwrap(); + println!("SCRIPT:\n {:?}", compiled_script); + cfg.display(); + assert!(cfg.blocks.len() == 4); + assert!(cfg.num_blocks() == 4); + assert!(cfg.reachable_from(0).len() == 4); +} + +#[test] +fn cfg_compile_nested_if() { + let code = String::from( + " + main() { + let x: u64; + if (42 > 0) { + x = 1; + } else { + if (5 > 10) { + x = 2; + } else { + x = 3; + } + } + return; + } + ", + ); + let compiled_script_res = compile_script_string(&code); + let compiled_script = compiled_script_res.unwrap(); + let cfg: VMControlFlowGraph = ControlFlowGraph::new(&compiled_script.main.code.code).unwrap(); + println!("SCRIPT:\n {:?}", compiled_script); + cfg.display(); + assert!(cfg.blocks.len() == 6); + assert!(cfg.num_blocks() == 6); + assert!(cfg.reachable_from(7).len() == 4); +} + +#[test] +fn cfg_compile_if_else_with_if_return() { + let code = String::from( + " + main() { + let x: u64; + if (42 > 0) { + return; + } else { + x = 1; + } + return; + } + ", + ); + let compiled_script_res = compile_script_string(&code); + let compiled_script = compiled_script_res.unwrap(); + let cfg: VMControlFlowGraph = ControlFlowGraph::new(&compiled_script.main.code.code).unwrap(); + println!("SCRIPT:\n {}", compiled_script); + cfg.display(); + assert!(cfg.blocks.len() == 3); + assert!(cfg.num_blocks() == 3); + assert!(cfg.reachable_from(0).len() == 3); + assert!(cfg.reachable_from(4).len() == 1); + assert!(cfg.reachable_from(5).len() == 1); +} + +#[test] +fn cfg_compile_if_else_with_two_returns() { + let code = String::from( + " + main() { + if (42 > 0) { + return; + } else { + return; + } + return; + } + ", + ); + let compiled_script_res = compile_script_string(&code); + let compiled_script = compiled_script_res.unwrap(); + let cfg: VMControlFlowGraph = ControlFlowGraph::new(&compiled_script.main.code.code).unwrap(); + println!("SCRIPT:\n {}", compiled_script); + cfg.display(); + assert!(cfg.blocks.len() == 4); + assert!(cfg.num_blocks() == 4); + assert!(cfg.reachable_from(0).len() == 3); + assert!(cfg.reachable_from(4).len() == 1); + assert!(cfg.reachable_from(5).len() == 1); + assert!(cfg.reachable_from(6).len() == 1); +} diff --git a/language/compiler/src/unit_tests/expression_tests.rs b/language/compiler/src/unit_tests/expression_tests.rs new file mode 100644 index 0000000000000..3d240aedc8316 --- /dev/null +++ b/language/compiler/src/unit_tests/expression_tests.rs @@ -0,0 +1,241 @@ +// Copyright (c) The Libra Core Contributors +// SPDX-License-Identifier: Apache-2.0 + +#[macro_use] +mod testutils; +use super::*; +use testutils::{ + compile_module_string, compile_script_string, compile_script_string_and_assert_error, + count_locals, +}; +use vm::file_format::Bytecode::*; + +#[test] +fn compile_script_expr_addition() { + let code = String::from( + " + main() { + let x: u64; + let y: u64; + let z: u64; + x = 3; + y = 5; + z = move(x) + move(y); + return; + } + ", + ); + let compiled_script_res = compile_script_string(&code); + let compiled_script = compiled_script_res.unwrap(); + assert!(compiled_script.main.code.max_stack_size == 2); + assert!(count_locals(&compiled_script) == 3); + assert!(compiled_script.main.code.code.len() == 9); + assert!(compiled_script.struct_handles.is_empty()); + assert!(compiled_script.function_handles.len() == 1); + assert!(compiled_script.type_signatures.is_empty()); + assert!(compiled_script.function_signatures.len() == 1); // method sig + assert!(compiled_script.locals_signatures.len() == 1); // local variables sig + assert!(compiled_script.module_handles.len() == 1); // the module + assert!(compiled_script.string_pool.len() == 2); // the name of `main()` + the name of the "" module + assert!(compiled_script.address_pool.len() == 1); // the empty address of module +} + +#[test] +fn compile_script_expr_combined() { + let code = String::from( + " + main() { + let x: u64; + let y: u64; + let z: u64; + x = 3; + y = 5; + z = move(x) + copy(y) * 5 - copy(y); + return; + } + ", + ); + let compiled_script_res = compile_script_string(&code); + let compiled_script = compiled_script_res.unwrap(); + assert!(compiled_script.main.code.max_stack_size == 3); + assert!(count_locals(&compiled_script) == 3); + assert!(compiled_script.main.code.code.len() == 13); + assert!(compiled_script.struct_handles.is_empty()); + assert!(compiled_script.function_handles.len() == 1); + assert!(compiled_script.type_signatures.is_empty()); + assert!(compiled_script.function_signatures.len() == 1); // method sig + assert!(compiled_script.locals_signatures.len() == 1); // local variables sig + assert!(compiled_script.module_handles.len() == 1); // the module + assert!(compiled_script.string_pool.len() == 2); // the name of `main()` + the name of the "" module + assert!(compiled_script.address_pool.len() == 1); // the empty address of module +} + +#[test] +fn compile_script_borrow_local() { + let code = String::from( + " + main() { + let x: u64; + let ref_x: &u64; + x = 3; + ref_x = &x; + release(move(ref_x)); + return; + } + ", + ); + let compiled_script_res = compile_script_string(&code); + let compiled_script = compiled_script_res.unwrap(); + assert!(count_locals(&compiled_script) == 2); + assert!(compiled_script.struct_handles.is_empty()); + assert!(compiled_script.function_handles.len() == 1); + assert!(compiled_script.type_signatures.is_empty()); + assert!(compiled_script.function_signatures.len() == 1); // method sig + assert!(compiled_script.locals_signatures.len() == 1); // local variables sig + assert!(compiled_script.module_handles.len() == 1); // the module + assert!(compiled_script.string_pool.len() == 2); // the name of `main()` + the name of the "" module + assert!(compiled_script.address_pool.len() == 1); // the empty address of module +} + +#[test] +fn compile_script_borrow_local_mutable() { + let code = String::from( + " + main() { + let x: u64; + let ref_x: &mut u64; + x = 3; + ref_x = &mut x; + *move(ref_x) = 42; + return; + } + ", + ); + let compiled_script_res = compile_script_string(&code); + let compiled_script = compiled_script_res.unwrap(); + assert!(count_locals(&&compiled_script) == 2); + assert!(compiled_script.struct_handles.is_empty()); + assert!(compiled_script.function_handles.len() == 1); + assert!(compiled_script.type_signatures.is_empty()); + assert!(compiled_script.function_signatures.len() == 1); // method sig + assert!(compiled_script.locals_signatures.len() == 1); // local variables sig + assert!(compiled_script.module_handles.len() == 1); // the module + assert!(compiled_script.string_pool.len() == 2); // the name of `main()` + the name of the "" module + assert!(compiled_script.address_pool.len() == 1); // the empty address of module +} + +#[test] +fn compile_script_borrow_reference() { + let code = String::from( + " + main() { + let x: u64; + let ref_x: &u64; + let ref_ref_x: &u64; + x = 3; + ref_x = &x; + ref_ref_x = &ref_x; + return; + } + ", + ); + let compiled_script_res = compile_script_string_and_assert_error(&code, None); + let compiled_script = compiled_script_res.unwrap(); + assert!(count_locals(&&compiled_script) == 3); + assert!(compiled_script.struct_handles.is_empty()); + assert!(compiled_script.function_handles.len() == 1); + assert!(compiled_script.type_signatures.is_empty()); + assert!(compiled_script.function_signatures.len() == 1); // method sig + assert!(compiled_script.locals_signatures.len() == 1); // local variables sig + assert!(compiled_script.module_handles.len() == 1); // the module + assert!(compiled_script.string_pool.len() == 2); // the name of `main()` + the name of the "" module + assert!(compiled_script.address_pool.len() == 1); // the empty address of module +} + +#[test] +fn compile_assert() { + let code = String::from( + " + main() { + let x: u64; + x = 3; + assert(copy(x) > 2, 42); + return; + } + ", + ); + let compiled_script_res = compile_script_string(&code); + let _compiled_script = compiled_script_res.unwrap(); +} + +#[test] +fn single_resource() { + let code = String::from( + " +module Test { + resource T { i: u64 } + + public new_t(): R#Self.T { + return T { i: 0 }; + } +}", + ); + let compiled_script = compile_module_string(&code).unwrap(); + assert!(compiled_script.struct_handles.len() == 1); +} + +#[test] +fn compile_immutable_borrow_local() { + let code = String::from( + " + main() { + let x: u64; + let ref_x: &u64; + + x = 5; + ref_x = &x; + + release(move(ref_x)); + + return; + } + ", + ); + let compiled_script_res = compile_script_string(&code); + let compiled_script = compiled_script_res.unwrap(); + assert!(instr_count!(compiled_script, FreezeRef) == 1); +} + +#[test] +fn compile_borrow_field() { + let code = String::from( + " + module Foobar { + resource FooCoin { value: u64 } + + public borrow_immut_field(arg: &R#Self.FooCoin) { + let field_ref: &u64; + field_ref = &move(arg).value; + release(move(field_ref)); + return; + } + + public borrow_immut_field_from_mut_ref(arg: &mut R#Self.FooCoin) { + let field_ref: &u64; + field_ref = &move(arg).value; + release(move(field_ref)); + return; + } + + public borrow_mut_field(arg: &mut R#Self.FooCoin) { + let field_ref: &mut u64; + field_ref = &mut move(arg).value; + release(move(field_ref)); + return; + } + } + ", + ); + let compiled_module_res = compile_module_string(&code); + let _compiled_module = compiled_module_res.unwrap(); +} diff --git a/language/compiler/src/unit_tests/function_tests.rs b/language/compiler/src/unit_tests/function_tests.rs new file mode 100644 index 0000000000000..612bf0621d405 --- /dev/null +++ b/language/compiler/src/unit_tests/function_tests.rs @@ -0,0 +1,106 @@ +// Copyright (c) The Libra Core Contributors +// SPDX-License-Identifier: Apache-2.0 + +mod testutils; +use super::*; +use testutils::compile_module_string; + +#[test] +fn compile_script_with_functions() { + let code = String::from( + " + module Foobar { + resource FooCoin { value: u64 } + + public value(this: &R#Self.FooCoin): u64 { + let value_ref: &u64; + value_ref = &move(this).value; + return *move(value_ref); + } + + public deposit(this: &mut R#Self.FooCoin, check: R#Self.FooCoin) { + let value_ref: &mut u64; + let value: u64; + let check_ref: &R#Self.FooCoin; + let check_value: u64; + let new_value: u64; + let i: u64; + value_ref = &mut move(this).value; + value = *copy(value_ref); + check_ref = ✓ + check_value = Self.value(move(check_ref)); + new_value = copy(value) + copy(check_value); + *move(value_ref) = move(new_value); + FooCoin { value: i } = move(check); + return; + } + } + ", + ); + let compiled_module_res = compile_module_string(&code); + assert!(compiled_module_res.is_ok()); +} + +fn generate_function(name: &str, num_formals: usize, num_locals: usize) -> String { + let mut code = format!("public {}(", name); + + code.reserve(30 * (num_formals + num_locals)); + + for i in 0..num_formals { + code.push_str(&format!("formal_{}: u64", i)); + if i < num_formals - 1 { + code.push_str(", "); + } + } + + code.push_str(") {\n"); + + for i in 0..num_locals { + code.push_str(&format!("let x_{}: u64;\n", i)); + } + for i in 0..num_locals { + code.push_str(&format!("x_{} = {};\n", i, i)); + } + + code.push_str("return;"); + + code.push_str("}"); + + code +} + +#[test] +fn compile_script_with_large_frame() { + let mut code = String::from( + " + module Foobar { + resource FooCoin { value: u64 } + ", + ); + + // Max number of locals (formals + local variables) is u8::max_value(). + code.push_str(&generate_function("foo_func", 128, 127)); + + code.push_str("}"); + + let compiled_module_res = compile_module_string(&code); + assert!(compiled_module_res.is_ok()); +} + +#[test] +fn compile_script_with_invalid_large_frame() { + let mut code = String::from( + " + module Foobar { + resource FooCoin { value: u64 } + ", + ); + + // Max number of locals (formals + local variables) is u8::max_value(). + code.push_str(&generate_function("foo_func", 128, 128)); + + code.push_str("}"); + + let compiled_module_res = compile_module_string(&code); + assert!(compiled_module_res.is_err()); +} diff --git a/language/compiler/src/unit_tests/import_tests.rs b/language/compiler/src/unit_tests/import_tests.rs new file mode 100644 index 0000000000000..0808de5772d19 --- /dev/null +++ b/language/compiler/src/unit_tests/import_tests.rs @@ -0,0 +1,63 @@ +// Copyright (c) The Libra Core Contributors +// SPDX-License-Identifier: Apache-2.0 + +mod testutils; +use super::*; +use testutils::{compile_module_string_with_stdlib, compile_script_string_with_stdlib}; + +#[test] +fn compile_script_with_imports() { + let code = String::from( + " + import 0x0000000000000000000000000000000000000000000000000000000000000001.LibraCoin; + + main() { + let x: u64; + let y: u64; + x = 2; + y = copy(x) + copy(x); + return; + } + ", + ); + let compiled_script_res = compile_script_string_with_stdlib(&code); + let _compiled_script = compiled_script_res.unwrap(); +} + +#[test] +fn compile_module_with_imports() { + let code = String::from( + " + module Foobar { + import 0x1.LibraCoin; + + resource FooCoin { value: u64 } + + public value(this: &R#Self.FooCoin): u64 { + let value_ref: &u64; + value_ref = &move(this).value; + return *move(value_ref); + } + + public deposit(this: &mut R#Self.FooCoin, check: R#Self.FooCoin) { + let value_ref: &mut u64; + let value: u64; + let check_ref: &R#Self.FooCoin; + let check_value: u64; + let new_value: u64; + let i: u64; + value_ref = &mut move(this).value; + value = *copy(value_ref); + check_ref = ✓ + check_value = Self.value(move(check_ref)); + new_value = copy(value) + copy(check_value); + *move(value_ref) = move(new_value); + FooCoin { value: i } = move(check); + return; + } + } + ", + ); + let compiled_module_res = compile_module_string_with_stdlib(&code); + let _compiled_module = compiled_module_res.unwrap(); +} diff --git a/language/compiler/src/unit_tests/serializer_tests.rs b/language/compiler/src/unit_tests/serializer_tests.rs new file mode 100644 index 0000000000000..af4f40b775246 --- /dev/null +++ b/language/compiler/src/unit_tests/serializer_tests.rs @@ -0,0 +1,39 @@ +// Copyright (c) The Libra Core Contributors +// SPDX-License-Identifier: Apache-2.0 + +mod testutils; +use super::*; +use testutils::compile_script_string; + +#[test] +fn serialize_script_ret() { + let code = String::from( + " + main() { + return; + } + ", + ); + let compiled_script_res = compile_script_string(&code); + let compiled_script = compiled_script_res.unwrap(); + let mut binary: Vec = Vec::new(); + let res = compiled_script.serialize(&mut binary); + assert!(res.is_ok()); + println!("SCRIPT:\n{:?}", compiled_script); + println!("Serialized Script:\n{:?}", binary); + println!("binary[74]: {:?}", binary.get(74)); + println!("binary[76]: {:?}", binary.get(76)); + println!("binary[79]: {:?}", binary.get(79)); + println!("binary[82]: {:?}", binary.get(82)); + println!("binary[84]: {:?}", binary.get(84)); + println!("binary[96]: {:?}", binary.get(96)); + println!("binary[128]: {:?}", binary.get(128)); + println!("binary[75]: {:?}", binary.get(75)); + println!("binary[77]: {:?}", binary.get(77)); + println!("binary[80]: {:?}", binary.get(80)); + println!("binary[83]: {:?}", binary.get(83)); + println!("binary[85]: {:?}", binary.get(85)); + println!("binary[97]: {:?}", binary.get(97)); + println!("binary[129]: {:?}", binary.get(129)); + // println!("SCRIPT:\n{}", compiled_script); +} diff --git a/language/compiler/src/unit_tests/stdlib_scripts.rs b/language/compiler/src/unit_tests/stdlib_scripts.rs new file mode 100644 index 0000000000000..997e64b6b50df --- /dev/null +++ b/language/compiler/src/unit_tests/stdlib_scripts.rs @@ -0,0 +1,57 @@ +// Copyright (c) The Libra Core Contributors +// SPDX-License-Identifier: Apache-2.0 + +mod testutils; +use super::*; +use testutils::{ + compile_module_string, compile_module_string_with_deps, compile_script_string_with_stdlib, +}; + +#[test] +fn compile_native_hash() { + let code = include_str!("../../../stdlib/modules/hash.mvir"); + let _compiled_module = compile_module_string(&code).unwrap(); +} + +#[test] +fn compile_libra_coin() { + let code = include_str!("../../../stdlib/modules/libra_coin.mvir"); + let _compiled_module = compile_module_string(&code).unwrap(); +} + +#[test] +fn compile_account_module() { + let hash_code = include_str!("../../../stdlib/modules/hash.mvir"); + let coin_code = include_str!("../../../stdlib/modules/libra_coin.mvir"); + let account_code = include_str!("../../../stdlib/modules/libra_account.mvir"); + + let hash_module = compile_module_string(hash_code).unwrap(); + let coin_module = compile_module_string(coin_code).unwrap(); + + let _compiled_module = + compile_module_string_with_deps(account_code, vec![hash_module, coin_module]).unwrap(); +} + +#[test] +fn compile_create_account_script() { + let code = include_str!("../../../stdlib/transaction_scripts/create_account.mvir"); + let _compiled_script = compile_script_string_with_stdlib(code).unwrap(); +} + +#[test] +fn compile_mint_script() { + let code = include_str!("../../../stdlib/transaction_scripts/mint.mvir"); + let _compiled_script = compile_script_string_with_stdlib(code).unwrap(); +} + +#[test] +fn compile_rotate_authentication_key_script() { + let code = include_str!("../../../stdlib/transaction_scripts/rotate_authentication_key.mvir"); + let _compiled_script = compile_script_string_with_stdlib(code).unwrap(); +} + +#[test] +fn compile_peer_to_peer_transfer_script() { + let code = include_str!("../../../stdlib/transaction_scripts/peer_to_peer_transfer.mvir"); + let _compiled_script = compile_script_string_with_stdlib(code).unwrap(); +} diff --git a/language/compiler/src/unit_tests/testutils.rs b/language/compiler/src/unit_tests/testutils.rs new file mode 100644 index 0000000000000..5b966b6eef31b --- /dev/null +++ b/language/compiler/src/unit_tests/testutils.rs @@ -0,0 +1,155 @@ +// Copyright (c) The Libra Core Contributors +// SPDX-License-Identifier: Apache-2.0 + +use super::*; +use crate::parser::{parse_module, parse_program}; +use bytecode_verifier::verifier::{verify_module, verify_script}; +use vm::{ + errors::VerificationError, + file_format::{CompiledModule, CompiledScript}, +}; + +#[allow(unused_macros)] +macro_rules! instr_count { + ($compiled: expr, $instr: pat) => { + $compiled + .main + .code + .code + .iter() + .filter(|ins| match ins { + $instr => true, + _ => false, + }) + .count(); + }; +} + +fn stdlib_deps() -> Result> { + Ok(crate::util::build_stdlib()) +} + +fn compile_script_string_impl( + code: &str, + deps_opt: Option>, +) -> Result<(CompiledScript, Vec)> { + let deps = if let Some(x) = deps_opt { + x + } else { + stdlib_deps().unwrap() + }; + let parsed_program = parse_program(code).unwrap(); + let compiled_program = compile_program(&AccountAddress::default(), &parsed_program, &deps)?; + + let mut serialized_script = Vec::::new(); + compiled_program.script.serialize(&mut serialized_script)?; + let deserialized_script = CompiledScript::deserialize(&serialized_script)?; + assert_eq!(compiled_program.script, deserialized_script); + + Ok(verify_script(compiled_program.script)) +} + +pub fn compile_script_string_and_assert_no_error( + code: &str, + deps_opt: Option>, +) -> Result { + let (verified_script, verification_errors) = compile_script_string_impl(code, deps_opt)?; + assert!(verification_errors.is_empty()); + Ok(verified_script) +} + +#[allow(dead_code)] +pub fn compile_script_string(code: &str) -> Result { + compile_script_string_and_assert_no_error(code, Some(vec![])) +} + +#[allow(dead_code)] +pub fn compile_script_string_with_deps( + code: &str, + deps: Vec, +) -> Result { + compile_script_string_and_assert_no_error(code, Some(deps)) +} + +#[allow(dead_code)] +pub fn compile_script_string_with_stdlib(code: &str) -> Result { + compile_script_string_and_assert_no_error(code, None) +} + +#[allow(dead_code)] +pub fn compile_script_string_and_assert_error( + code: &str, + deps_opt: Option>, +) -> Result { + let (verified_script, verification_errors) = compile_script_string_impl(code, deps_opt)?; + assert!(!verification_errors.is_empty()); + Ok(verified_script) +} + +fn compile_module_string_impl( + code: &str, + deps_opt: Option>, +) -> Result<(CompiledModule, Vec)> { + let address = &AccountAddress::default(); + let deps = if let Some(x) = deps_opt { + x + } else { + stdlib_deps().unwrap() + }; + let module = parse_module(code).unwrap(); + let compiled_module = compile_module(&address, &module, &deps)?; + + let mut serialized_module = Vec::::new(); + compiled_module.serialize(&mut serialized_module)?; + let deserialized_module = CompiledModule::deserialize(&serialized_module)?; + assert_eq!(compiled_module, deserialized_module); + + Ok(verify_module(compiled_module)) +} + +pub fn compile_module_string_and_assert_no_error( + code: &str, + deps_opt: Option>, +) -> Result { + let (verified_module, verification_errors) = compile_module_string_impl(code, deps_opt)?; + assert!(verification_errors.is_empty()); + Ok(verified_module) +} + +#[allow(dead_code)] +pub fn compile_module_string(code: &str) -> Result { + compile_module_string_and_assert_no_error(code, Some(vec![])) +} + +#[allow(dead_code)] +pub fn compile_module_string_with_deps( + code: &str, + deps: Vec, +) -> Result { + compile_module_string_and_assert_no_error(code, Some(deps)) +} + +#[allow(dead_code)] +pub fn compile_module_string_with_stdlib(code: &str) -> Result { + compile_module_string_and_assert_no_error(code, None) +} + +#[allow(dead_code)] +pub fn compile_module_string_and_assert_error( + code: &str, + deps_opt: Option>, +) -> Result { + let (verified_module, verification_errors) = compile_module_string_impl(code, deps_opt)?; + assert!(!verification_errors.is_empty()); + Ok(verified_module) +} + +#[allow(dead_code)] +pub fn count_locals(script: &CompiledScript) -> usize { + script + .locals_signatures + .get(script.main.code.locals.0 as usize) + .unwrap() + .0 + .len() +} diff --git a/language/compiler/src/util.rs b/language/compiler/src/util.rs new file mode 100644 index 0000000000000..3b03e13b23708 --- /dev/null +++ b/language/compiler/src/util.rs @@ -0,0 +1,44 @@ +// Copyright (c) The Libra Core Contributors +// SPDX-License-Identifier: Apache-2.0 + +use crate::{compiler::compile_module, parser::parse_module}; +use std::{ + fs, + path::{Path, PathBuf}, +}; +use types::account_address::AccountAddress; +use vm::file_format::CompiledModule; + +pub fn do_compile_module( + source_path: &Path, + address: &AccountAddress, + dependencies: &[CompiledModule], +) -> CompiledModule { + let source = fs::read_to_string(source_path) + .unwrap_or_else(|_| panic!("Unable to read file: {:?}", source_path)); + let parsed_module = parse_module(&source).unwrap(); + compile_module(address, &parsed_module, dependencies).unwrap() +} + +pub fn build_stdlib() -> Vec { + // TODO: Change source paths for stdlib when we have proper SDK packaging. + let mut stdlib_root = PathBuf::from(env!("CARGO_MANIFEST_DIR")); + stdlib_root.pop(); + stdlib_root.push("stdlib"); + + let address = AccountAddress::default(); + let mut stdlib_modules = Vec::::new(); + for e in [ + "modules/hash.mvir", + "modules/signature.mvir", + "modules/libra_coin.mvir", + "modules/libra_account.mvir", + "modules/validator_set.mvir", + ] + .iter() + { + let res = do_compile_module(&Path::join(&stdlib_root, e), &address, &stdlib_modules); + stdlib_modules.push(res); + } + stdlib_modules +} diff --git a/language/functional_tests/Cargo.toml b/language/functional_tests/Cargo.toml new file mode 100644 index 0000000000000..44b172769e653 --- /dev/null +++ b/language/functional_tests/Cargo.toml @@ -0,0 +1,20 @@ +[package] +name = "functional_tests" +version = "0.1.0" +authors = ["Libra Association "] +license = "Apache-2.0" +edition = "2018" + +[dependencies] +compiler = { path = "../compiler" } +failure = { path = "../../common/failure_ext", package = "failure_ext" } +types = { path = "../../types" } +vm = { path = "../vm" } +bytecode_verifier = { path = "../bytecode_verifier" } +vm_runtime_tests = { path = "../vm/vm_runtime/vm_runtime_tests" } +config = { path = "../../config" } +termcolor = "1.0.4" +filecheck = "0.4.0" +datatest = "0.3.1" +lazy_static = "1.3.0" +regex = "1.1.6" diff --git a/language/functional_tests/README.md b/language/functional_tests/README.md new file mode 100644 index 0000000000000..ce263040b5d90 --- /dev/null +++ b/language/functional_tests/README.md @@ -0,0 +1,35 @@ +## Overview + +This crate implements a unified testing infrastructure that allows developers +to write tests as individual Move IR programs, send them through the entire +pipeline, and check the output of each stage using inline directives. + +## How to run functional tests + +Run `cargo test` inside this crate, or `cargo test -p fuctional_tests` anywhere +in the repo. `cargo test` also accepts a filter: `cargo test foo` runs only +the tests with `foo` in the name. + +## Adding a new test + +To add a new test, simply create a new .mvir file in `tests/testsuite`. +The test harness will recursively search for all move ir sources in +the directory and register each of them as a test case. + +## Checking the test output using directives + +Directives are essentially comments with special meanings to the testing infra. +They can be used to define patterns that should appear in the test output. +If the test output does not match the pattern specified, the test is +considered a failure. + +The test output is a log-like structure that consists of the debug print +of the data structure each pipeline stage outputs. In case there is an +error, it is the debug print of the error. Good tests should match only +crucial details such as the name of the error and omit unimportant details +such as formatting, spaces, brackets etc. + +When no directives are specified, the testing infra requires a test program +to pass all stages of the pipeline. Any error will result in a test failure. + +See `tests/testsuite/examples` for more examples. diff --git a/language/functional_tests/_old_move_ir_tests/src/tests/epilogue.rs b/language/functional_tests/_old_move_ir_tests/src/tests/epilogue.rs new file mode 100644 index 0000000000000..74bca4ff5025f --- /dev/null +++ b/language/functional_tests/_old_move_ir_tests/src/tests/epilogue.rs @@ -0,0 +1,562 @@ +// Copyright (c) The Libra Core Contributors +// SPDX-License-Identifier: Apache-2.0 + +use crate::*; +use language_common::error_codes::*; +use move_ir::{assert_error_type, assert_no_error}; +use types::transaction::TransactionArgument; + +#[test] +fn increment_sequence_number_on_tx_script_success() { + let mut test_env = TestEnvironment::default(); + + let sequence_number = 0; + assert_no_error!(test_env + .run_with_sequence_number(sequence_number, to_script(b"main() { return; }", vec![]),)); + + // now we need to run with sequence number 1 + let sequence_number = 1; + assert_no_error!(test_env.run_with_sequence_number( + sequence_number, + to_script( + b" +import 0x0.LibraAccount; +main() { + let transaction_sequence_number; + let sender; + let sequence_number; + + transaction_sequence_number = get_txn_sequence_number(); + assert(copy(transaction_sequence_number) == 1, 42); + + sender = get_txn_sender(); + sequence_number = LibraAccount.sequence_number(move(sender)); + assert(move(sequence_number) == 1, 43); + + return; +}", + vec![] + ), + )) +} + +#[test] +fn increment_sequence_number_on_tx_script_failure() { + let mut test_env = TestEnvironment::default(); + + let sequence_number = 0; + assert_error_type!( + test_env.run_with_sequence_number( + sequence_number, + to_script( + b" +main() { + assert(false, 77); + return; +} +", + vec![] + ) + ), + ErrorKind::AssertError(77, _) + ); + + // there was a failure during the transaction script, but + // sequence number should still be bumped + let sequence_number = 1; + assert_no_error!(test_env + .run_with_sequence_number(sequence_number, to_script(b"main() { return; }", vec![]),)) +} + +#[test] +fn charge_more_gas_on_tx_script_failure1() { + // Make sure that we charge more for a transaction that was aborted + let mut test_env = TestEnvironment::default(); + + let program = b" + main(abrt: u64) { + assert(move(abrt) == 0, 78); + return; + }"; + let result = test_env.run_with_arguments( + vec![TransactionArgument::U64(0)], + to_script(program, vec![]), + ); + let good_gas_cost = match &result { + Ok(res) => res.gas_used(), + _ => 0, + }; + assert_no_error!(result); + + // Fail this one + assert_error_type!( + test_env.run_with_arguments( + vec![TransactionArgument::U64(1)], + to_script(program, vec![]) + ), + ErrorKind::AssertError(78, _) + ); + + let verify_script = b" + import 0x0.LibraAccount; + main(initial_balance: u64, good_gas_fees: u64) { + let sender; + let sender_balance; + sender = get_txn_sender(); + sender_balance = LibraAccount.balance(copy(sender)); + assert(move(initial_balance) - move(good_gas_fees) < move(sender_balance), 79); + return; + }"; + assert_no_error!(test_env.run_with_arguments( + vec![ + TransactionArgument::U64(TestEnvironment::INITIAL_BALANCE), + TransactionArgument::U64(2 * good_gas_cost * TestEnvironment::DEFAULT_GAS_COST) + ], + to_script(verify_script, vec![]) + )); +} + +#[test] +fn charge_more_gas_on_tx_script_failure2() { + // This then verifies that the verification would have failed if both transactions + // had succeeded in `charge_more_gas_on_tx_script_failure1`. + let mut test_env = TestEnvironment::default(); + + let program = b" + main(abrt: u64) { + assert(move(abrt) == 0, 80); + return; + }"; + let result = test_env.run_with_arguments( + vec![TransactionArgument::U64(0)], + to_script(program, vec![]), + ); + let good_gas_cost = match &result { + Ok(res) => res.gas_used(), + _ => 0, + }; + assert_no_error!(result); + + assert_no_error!(test_env.run_with_arguments( + vec![TransactionArgument::U64(0)], + to_script(program, vec![]) + )); + + let verify_script = b" + import 0x0.LibraAccount; + main(initial_balance: u64, good_gas_fees: u64) { + let sender; + let sender_balance; + sender = get_txn_sender(); + sender_balance = LibraAccount.balance(copy(sender)); + assert(move(initial_balance) - move(good_gas_fees) < move(sender_balance), 81); + return; + }"; + assert_error_type!( + test_env.run_with_arguments( + vec![ + TransactionArgument::U64(TestEnvironment::INITIAL_BALANCE), + TransactionArgument::U64(2 * good_gas_cost * TestEnvironment::DEFAULT_GAS_COST) + ], + to_script(verify_script, vec![]) + ), + ErrorKind::AssertError(81, _) + ); +} + +#[test] +fn dont_increment_sequence_number_on_sequence_number_too_new() { + let mut test_env = TestEnvironment::default(); + + let sequence_number = 7; + assert_error_type!( + test_env.run_with_sequence_number( + sequence_number, + to_script( + b" +main() { + assert(false, 77); + return; +} +", + vec![] + ) + ), + ErrorKind::AssertError(ESEQUENCE_NUMBER_TOO_NEW, _) + ); + + // running with 0 should succeed because sequence number wasn't bumped + let sequence_number = 0; + assert_no_error!(test_env + .run_with_sequence_number(sequence_number, to_script(b"main() { return; }", vec![]),)) +} + +#[test] +fn dont_increment_sequence_number_on_sequence_number_too_old() { + let mut test_env = TestEnvironment::default(); + + let sequence_number = 0; + assert_no_error!(test_env + .run_with_sequence_number(sequence_number, to_script(b"main() { return; }", vec![]),)); + + // running with 0 should fail + let sequence_number = 0; + assert_error_type!( + test_env + .run_with_sequence_number(sequence_number, to_script(b"main() { return; }", vec![])), + ErrorKind::AssertError(ESEQUENCE_NUMBER_TOO_OLD, _) + ); + + // but running with 1 should succeed + let sequence_number = 1; + assert_no_error!(test_env + .run_with_sequence_number(sequence_number, to_script(b"main() { return; }", vec![]),)) +} + +// You can call your own epilogue if you want to, but your sequence number will only be bumped once +#[test] +fn calling_own_epilogue_bumps_sequence_number_once() { + let mut test_env = TestEnvironment::default(); + + let sequence_number = 0; + assert_no_error!(test_env.run_with_sequence_number( + sequence_number, + to_script( + b" +import 0x0.LibraAccount; +main() { + LibraAccount.epilogue(); + LibraAccount.epilogue(); + LibraAccount.epilogue(); + + return; +}", + vec![] + ), + )); + + let sequence_number = 1; + assert_no_error!(test_env + .run_with_sequence_number(sequence_number, to_script(b"main() { return; }", vec![]),)) +} + +#[test] +fn gas_charge_different() { + let mut test_env1 = TestEnvironment::default(); + let mut test_env2 = TestEnvironment::default(); + let max_gas_deposit_fee = TestEnvironment::DEFAULT_MAX_GAS * TestEnvironment::DEFAULT_GAS_COST; + let program1 = b"main() { return; }"; + let program2 = b"main() { let x; x = 32 + 10; return; }"; + let verifier_program = b" + import 0x0.LibraAccount; + main(original_balance: u64, max_gas_fee: u64, actual_gas_fee: u64) { + let sender; + let sender_balance; + sender = get_txn_sender(); + sender_balance = LibraAccount.balance(copy(sender)); + assert(copy(original_balance) - move(max_gas_fee) < copy(sender_balance), 66); + assert(move(original_balance) - move(actual_gas_fee) == move(sender_balance), 66); + + return; + }"; + + // Run the first program + let result1 = test_env1.run(to_script(program1, vec![])).unwrap(); + let gas_used1 = result1.gas_used(); + let gas_used_fee1 = gas_used1 * TestEnvironment::DEFAULT_GAS_COST; + // Verify that this gas amount is correct + assert_no_error!(test_env1.run_with_arguments( + vec![ + TransactionArgument::U64(TestEnvironment::INITIAL_BALANCE), + TransactionArgument::U64(max_gas_deposit_fee), + TransactionArgument::U64(gas_used_fee1), + ], + to_script(verifier_program, vec![]), + )); + + // Now run the second program. Use a different test env so we don't have to do accounting for + // the cost of the verifier program. + let result2 = test_env2.run(to_script(program2, vec![])).unwrap(); + let gas_used2 = result2.gas_used(); + let gas_used_fee2 = gas_used2 * TestEnvironment::DEFAULT_GAS_COST; + // Verify that this gas amount for the second when is correct as well + assert_no_error!(test_env2.run_with_arguments( + vec![ + TransactionArgument::U64(TestEnvironment::INITIAL_BALANCE), + TransactionArgument::U64(max_gas_deposit_fee), + TransactionArgument::U64(gas_used_fee2), + ], + to_script(verifier_program, vec![]), + )); + // Make sure that the first one is less than the second. These numbers have been verified. + assert!(gas_used1 < gas_used2); +} + +#[test] +fn gas_charge_accurate() { + let mut test_env = TestEnvironment::default(); + let program1 = b"main() { return; }"; + // Ensures that we are not just charging max_gas for the transaction. + // Ensures that the account was deducted for the gas fee + let verifier_program = b" + import 0x0.LibraAccount; + main(original_balance: u64, max_gas_fee: u64, actual_gas_fee: u64) { + let sender; + let sender_balance; + sender = get_txn_sender(); + sender_balance = LibraAccount.balance(copy(sender)); + assert(copy(original_balance) - move(max_gas_fee) < copy(sender_balance), 66); + assert(move(original_balance) - move(actual_gas_fee) == move(sender_balance), 66); + + return; + }"; + let result1 = test_env.run(to_script(program1, vec![])).unwrap(); + let gas_used1 = result1.gas_used(); + let max_gas_deposit_fee = TestEnvironment::DEFAULT_MAX_GAS * TestEnvironment::DEFAULT_GAS_COST; + let gas_used_fee = gas_used1 * TestEnvironment::DEFAULT_GAS_COST; + assert_no_error!(test_env.run_with_arguments( + vec![ + TransactionArgument::U64(TestEnvironment::INITIAL_BALANCE), + TransactionArgument::U64(max_gas_deposit_fee), + TransactionArgument::U64(gas_used_fee), + ], + to_script(verifier_program, vec![]), + )); +} + +#[test] +fn gas_deposit_withdraws() { + let mut test_env = TestEnvironment::default(); + + let result = test_env + .run(to_script(b"main() { return; }", vec![])) + .unwrap(); + let gas_used = result.gas_used(); + + // sender balance should be less the gas deposit after transaction execution + let gas_deposit_fee = TestEnvironment::DEFAULT_GAS_COST * gas_used; + assert_no_error!(test_env.run_with_arguments( + vec![ + TransactionArgument::U64(TestEnvironment::INITIAL_BALANCE), + TransactionArgument::U64(gas_deposit_fee), + ], + to_script( + b" +import 0x0.LibraAccount; +main(original_balance: u64, gas_deposit_amount: u64) { + let sender; + let sender_balance; + + sender = get_txn_sender(); + sender_balance = LibraAccount.balance(copy(sender)); + assert(move(sender_balance) == move(original_balance) - move(gas_deposit_amount), 66); + + return; +}", + vec![] + ), + )) +} + +#[test] +fn revert_tx_script_state_changes_after_failure() { + let mut test_env = TestEnvironment::default(); + let recipient = test_env.accounts.get_address(1); + + let amount = 10; + assert_error_type!( + test_env.run_with_arguments( + vec![ + TransactionArgument::Address(recipient), + TransactionArgument::U64(amount) + ], + to_script( + b" +import 0x0.LibraAccount; +main(payee: address, amount: u64) { + LibraAccount.pay_from_sender(move(payee), move(amount)); + assert(false, 66); + return; +} +", + vec![] + ) + ), + ErrorKind::AssertError(66, _) + ); + + // now we need to run with sequence number 1 + assert_no_error!(test_env.run_with_arguments( + vec![ + TransactionArgument::Address(recipient), + TransactionArgument::U64(TestEnvironment::INITIAL_BALANCE) + ], + to_script( + b" +import 0x0.LibraAccount; +main(recipient: address, recipient_original_balance: u64) { + let recipient_balance; + recipient_balance = LibraAccount.balance(move(recipient)); + assert(move(recipient_balance) == move(recipient_original_balance), 55); + + return; +}", + vec![] + ), + )) +} + +#[test] +fn revert_tx_script_state_changes_after_failed_epilogue() { + let mut test_env = TestEnvironment::default(); + let recipient = test_env.accounts.get_address(1); + + // transfer all of the sender's funds to recipients. this script will execute successfully, but + // the epilogue will fail because the user spent his gas deposit. the VM should revert the state + // changes and re-execute the epilogue + let amount = TestEnvironment::INITIAL_BALANCE; + assert_error_type!( + test_env.run_with_arguments( + vec![ + TransactionArgument::Address(recipient), + TransactionArgument::U64(amount) + ], + to_script( + b" + import 0x0.LibraAccount; + main(payee: address, amount: u64) { + LibraAccount.pay_from_sender(move(payee), move(amount)); + return; + }", + vec![] + ) + ), + ErrorKind::OutOfGas + ); + + // We need to calculate the gas used in order to verify that the correct amount has been + // debited below. Sadly the only way of doing this is re-running the transaction since the + // previous run ran out of gas. + let gas_used = { + let mut test_env2 = TestEnvironment::default(); + let recipient = test_env2.accounts.get_address(1); + let amount = 1; + let result = test_env2 + .run_with_arguments( + vec![ + TransactionArgument::Address(recipient), + TransactionArgument::U64(amount), + ], + to_script( + b" + import 0x0.LibraAccount; + main(payee: address, amount: u64) { + LibraAccount.pay_from_sender(move(payee), move(amount)); + return; + }", + vec![], + ), + ) + .unwrap(); + result.gas_used() + }; + + let gas_deposit_fee = TestEnvironment::DEFAULT_GAS_COST * gas_used; + // this script checks that the state changes from tx1's script were actually reverted; the + // recipient's balance should be the same as the initial state, and the sender's balance should + // be the same as the initial state less the gas deposit. + assert_no_error!(test_env.run_with_arguments( + vec![ + TransactionArgument::Address(recipient), + TransactionArgument::U64(TestEnvironment::INITIAL_BALANCE), + TransactionArgument::U64(gas_deposit_fee), + ], + to_script( + b" +import 0x0.LibraAccount; +main(recipient: address, original_balance: u64, gas_deposit_amount: u64) { + let recipient_balance; + let sender; + let sender_balance; + let sender_sequence_number; + + recipient_balance = LibraAccount.balance(move(recipient)); + + sender = get_txn_sender(); + sender_balance = LibraAccount.balance(copy(sender)); + assert(move(sender_balance) == move(original_balance) - move(gas_deposit_amount), 66); + + sender_sequence_number = LibraAccount.sequence_number(move(sender)); + assert(move(sender_sequence_number) == 1, 77); + + return; +}", + vec![] + ), + )) +} + +#[test] +fn recursion_out_of_gas_charges_max_gas() { + let mut test_env = TestEnvironment::default(); + let sender = test_env.accounts.get_account(0).addr; + + let program = format!( + " +modules: +module Looper {{ + public run_loop(n: u64) {{ + while (true) {{ + + }} + return; + }} + +}} + +script: +import 0x{0}.Looper; +main() {{ + Looper.run_loop(5); + return; +}} +", + hex::encode(sender) + ); + + let gas_amount = 423; + assert_error_type!( + test_env.run_with_max_gas_amount(gas_amount, to_standalone_script(program.as_bytes())), + ErrorKind::OutOfGas + ); + + let verify_script = b" + import 0x0.LibraAccount; + main(initial_balance: u64, gas_fees: u64) { + let sender; + let sender_balance; + sender = get_txn_sender(); + sender_balance = LibraAccount.balance(copy(sender)); + assert(move(initial_balance) - move(gas_fees) == move(sender_balance), 101); + return; + }"; + let gas_fee = gas_amount * TestEnvironment::DEFAULT_GAS_COST; + assert_no_error!(test_env.run_with_arguments( + vec![ + TransactionArgument::U64(TestEnvironment::INITIAL_BALANCE), + TransactionArgument::U64(gas_fee) + ], + to_script(verify_script, vec![]) + )); +} + +// TODO don't increment sequence number after: +// bad signature +// bad auth key +// cant pay gas deposit + +// TODO: do increment sequence number after: +// parse error +// module publish error +// bytecode verification error +// non-assert runtime error diff --git a/language/functional_tests/_old_move_ir_tests/src/tests/global_ref_count.rs b/language/functional_tests/_old_move_ir_tests/src/tests/global_ref_count.rs new file mode 100644 index 0000000000000..336a2f0e9d42f --- /dev/null +++ b/language/functional_tests/_old_move_ir_tests/src/tests/global_ref_count.rs @@ -0,0 +1,85 @@ +// Copyright (c) The Libra Core Contributors +// SPDX-License-Identifier: Apache-2.0 + +use crate::*; +use move_ir::{assert_no_error, assert_other_error}; + +#[test] +fn increment_borrow_field() { + let mut test_env = TestEnvironment::default(); + let sender = hex::encode(test_env.accounts.get_address(0)); + let program = format!( + " +modules: +module Test {{ + resource T {{ i: u64 }} + + public test() {{ + let t; + let t_ref; + let i_ref; + let sender; + + t = T {{ i: 0 }}; + move_to_sender(move(t)); + + sender = get_txn_sender(); + t_ref = borrow_global(copy(sender)); + i_ref = ©(t_ref).i; + release(move(t_ref)); + + t_ref = borrow_global(copy(sender)); + release(move(t_ref)); + release(move(i_ref)); + }} +}} +script: +import 0x{0}.Test; +main() {{ + Test.test(); + return; +}}", + sender + ); + assert_other_error!(test_env.run(to_script(program.as_bytes(), vec![])), format!("Invalid borrow of global resource 0x{0}.0x{0}.Test.T. There already exists a reference to this resource. You must free all references to this resource before calling \'borrow_global\' again.", sender)) +} + +#[test] +fn increment_copy() { + let mut test_env = TestEnvironment::default(); + let sender = hex::encode(test_env.accounts.get_address(0)); + let program = format!( + " +modules: +module Test {{ + resource T {{ i: u64 }} + + public test() {{ + let t; + let t_ref; + let i_ref; + let sender; + + t = T {{ i: 0 }}; + move_to_sender(move(t)); + + sender = get_txn_sender(); + t_ref = borrow_global(copy(sender)); + i_ref = copy(t_ref); + release(move(t_ref)); + + t_ref = borrow_global(copy(sender)); + release(move(t_ref)); + release(move(i_ref)); + }} +}} +script: +import 0x{0}.Test; +main() {{ + Test.test(); + return; +}}", + sender + ); + assert_other_error!(test_env.run(to_script(program.as_bytes(), vec![])), format!("Invalid borrow of global resource 0x{0}.0x{0}.Test.T. There already exists a reference to this resource. You must free all references to this resource before calling \'borrow_global\' again.", sender)) +} diff --git a/language/functional_tests/_old_move_ir_tests/src/tests/key_rotation.rs b/language/functional_tests/_old_move_ir_tests/src/tests/key_rotation.rs new file mode 100644 index 0000000000000..e3efd3534a923 --- /dev/null +++ b/language/functional_tests/_old_move_ir_tests/src/tests/key_rotation.rs @@ -0,0 +1,80 @@ +// Copyright (c) The Libra Core Contributors +// SPDX-License-Identifier: Apache-2.0 + +use crate::*; +use crypto; +use language_common::{error_codes::*, tooling::fake_executor::Account}; +use move_ir::{assert_error_type, assert_no_error}; +use types::account_address::AccountAddress; + +#[test] +fn cant_send_transaction_with_old_key_after_rotation() { + let mut test_env = TestEnvironment::default(); + // Not a public key anyone can sign for + let new_key = "deadbeefdeadbeefdeadbeefdeadbeefdeadbeefdeadbeefdeadbeefdeadbeef"; + + let program = format!( + " +import 0x0.LibraAccount; +main() {{ + let new_key; + new_key = b\"{}\"; + LibraAccount.rotate_authentication_key(move(new_key)); + + return; +}}", + new_key + ); + + // rotate key + assert_no_error!(test_env.run(to_standalone_script(program.as_bytes()))); + + // prologue will fail when signing with old key after rotation + assert_error_type!( + test_env.run(to_standalone_script(b"main() { return; }")), + ErrorKind::AssertError(EBAD_ACCOUNT_AUTHENTICATION_KEY, _) + ) +} + +#[test] +fn can_send_transaction_with_new_key_after_rotation() { + let mut test_env = TestEnvironment::default(); + + let (privkey, pubkey) = + crypto::signing::generate_keypair_for_testing(&mut test_env.accounts.randomness_source); + let program = format!( + " +import 0x0.LibraAccount; +main() {{ + let new_key; + new_key = b\"{}\"; + LibraAccount.rotate_authentication_key(move(new_key)); + + return; +}}", + hex::encode(AccountAddress::from(pubkey)) + ); + + // rotate key + assert_no_error!(test_env.run(to_standalone_script(program.as_bytes()))); + + // we need to use the new key in order to send a transaction + let old_account = test_env.accounts.get_account(0); + let new_account = Account { + addr: old_account.addr, + privkey, + pubkey, + }; + + let sequence_number = test_env.get_txn_sequence_number(0); + let txn = test_env.create_signed_txn( + to_standalone_script(b"main() { return; }"), + old_account.addr, + new_account, + sequence_number, + TestEnvironment::DEFAULT_MAX_GAS, + TestEnvironment::DEFAULT_GAS_COST, + ); + + assert_no_error!(test_env.run_txn(txn)) +} diff --git a/language/functional_tests/_old_move_ir_tests/src/tests/payments.rs b/language/functional_tests/_old_move_ir_tests/src/tests/payments.rs new file mode 100644 index 0000000000000..29a5f07e0908c --- /dev/null +++ b/language/functional_tests/_old_move_ir_tests/src/tests/payments.rs @@ -0,0 +1,103 @@ +// Copyright (c) The Libra Core Contributors +// SPDX-License-Identifier: Apache-2.0 + +use crate::*; +use language_common::error_codes::EINSUFFICIENT_BALANCE; +use move_ir::{assert_error_type, assert_no_error}; + +#[test] +fn cant_copy_resource() { + let mut test_env = TestEnvironment::default(); + + let program = b" +import 0x0.LibraAccount; +main() { + let addr; + let ten_coins; + let i_created_money; + addr = get_txn_sender(); + ten_coins = LibraAccount.withdraw_from_sender(10); + i_created_money = copy(ten_coins); + LibraAccount.deposit(copy(addr), move(ten_coins)); + LibraAccount.deposit(copy(addr), move(i_created_money)); + + return; +}"; + + assert_error_type!( + test_env.run(to_script(program, vec![])), + ErrorKind::InvalidCopy(_, _) + ) +} + +#[test] +fn cant_double_deposit() { + let program = b" +import 0x0.LibraAccount; +main() { + let addr; + let ten_coins; + addr = get_txn_sender(); + ten_coins = LibraAccount.withdraw_from_sender(10); + LibraAccount.deposit(copy(addr), move(ten_coins)); + LibraAccount.deposit(copy(addr), move(ten_coins)); + + return; +}"; + assert_error_type!(run(program), ErrorKind::UseAfterMove(_, _)) +} + +#[test] +fn cant_overdraft() { + let program = b" +import 0x0.LibraAccount; +main() { + let addr; + let sender_balance; + let all_coins; + let sender_new_balance; + let one_coin; + + addr = get_txn_sender(); + + sender_balance = LibraAccount.balance(copy(addr)); + + all_coins = LibraAccount.withdraw_from_sender(move(sender_balance)); + + sender_new_balance = LibraAccount.balance(copy(addr)); + assert(move(sender_new_balance) == 0, 41); + + one_coin = LibraAccount.withdraw_from_sender(1); + + return; +}"; + assert_error_type!( + run(program), + ErrorKind::AssertError(EINSUFFICIENT_BALANCE, _) + ) +} + +#[test] +fn zero_payment() { + let program = b" +import 0x0.LibraAccount; +import 0x0.LibraCoin; +main() { + let addr; + let sender_old_balance; + let zero_resource; + let sender_new_balance; + + addr = get_txn_sender(); + + sender_old_balance = LibraAccount.balance(copy(addr)); + zero_resource = LibraCoin.zero(); + LibraAccount.deposit(copy(addr), move(zero_resource)); + + sender_new_balance = LibraAccount.balance(move(addr)); + assert(move(sender_new_balance) == move(sender_old_balance), 42); + + return; +}"; + assert_error_type!(run(program), ErrorKind::AssertError(7, _)) +} diff --git a/language/functional_tests/_old_move_ir_tests/src/tests/prologue.rs b/language/functional_tests/_old_move_ir_tests/src/tests/prologue.rs new file mode 100644 index 0000000000000..f53274d382ec6 --- /dev/null +++ b/language/functional_tests/_old_move_ir_tests/src/tests/prologue.rs @@ -0,0 +1,99 @@ +// Copyright (c) The Libra Core Contributors +// SPDX-License-Identifier: Apache-2.0 + +use crate::*; +use language_common::error_codes::*; +use move_ir::{assert_error_type, assert_no_error}; + +#[test] +fn cant_run_with_future_sequence_number() { + let program = b" +main() { + return; +}"; + + // expecting sequence number 0 + assert_error_type!( + run_with_sequence_number(17, program), + ErrorKind::AssertError(ESEQUENCE_NUMBER_TOO_NEW, _) + ) +} + +#[test] +fn cant_run_with_stale_sequence_number() { + let mut test_env = TestEnvironment::default(); + + let program = " +import 0x0.LibraAccount; +main() { + let sender; + let sequence_number; + + sender = get_txn_sender(); + sequence_number = LibraAccount.sequence_number(move(sender)); + assert(move(sequence_number) == 0, 42); + + return; +}"; + assert_no_error!(test_env.run_with_sequence_number(0, to_script(program.as_bytes(), vec![]))); + + let program = " +main() { + return; +}"; + + // expecting sequence number 1 + assert_error_type!( + test_env.run_with_sequence_number(0, to_script(program.as_bytes(), vec![])), + ErrorKind::AssertError(ESEQUENCE_NUMBER_TOO_OLD, _) + ) +} + +#[test] +fn fail_if_cant_pay_deposit() { + let program = b" +main() { + let oh_no_ill_never_get_run; + oh_no_ill_never_get_run = 2; + return; +}"; + + assert_error_type!( + run_with_max_gas_amount(10_000_000_000_000_000, program), + ErrorKind::AssertError(ECANT_PAY_GAS_DEPOSIT, _) + ) +} + +#[test] +fn fail_if_signature_valid_but_pubkey_doesnt_match_auth_key() { + let mut test_env = TestEnvironment::default(); + + let sender_account = test_env.accounts.get_account(0); + let other_account = test_env.accounts.get_account(1); + + let program = "main() { return; }"; + let sender = 0; + let sequence_number = test_env.get_txn_sequence_number(sender); + let max_gas = TestEnvironment::DEFAULT_MAX_GAS; + let gas_cost = TestEnvironment::DEFAULT_GAS_COST; + + assert!(sender_account.addr != other_account.addr); + let signed_transaction = test_env.create_signed_txn_with_args( + to_script(program.as_bytes(), vec![]), + vec![], + sender_account.addr, + other_account, // sender's address, but someone else's account + sequence_number, + max_gas, + gas_cost, + ); + + // creates a transaction whose signature is valid... + assert!(signed_transaction.verify_signature().is_ok()); + + // ...but whose authentication key doesn't match the public key used in the signature + assert_error_type!( + test_env.run_txn(signed_transaction), + ErrorKind::AssertError(EBAD_ACCOUNT_AUTHENTICATION_KEY, _) + ); +} diff --git a/language/functional_tests/_old_move_ir_tests/src/tests/publish.rs b/language/functional_tests/_old_move_ir_tests/src/tests/publish.rs new file mode 100644 index 0000000000000..9932e48fb5fb2 --- /dev/null +++ b/language/functional_tests/_old_move_ir_tests/src/tests/publish.rs @@ -0,0 +1,50 @@ +// Copyright (c) The Libra Core Contributors +// SPDX-License-Identifier: Apache-2.0 + +use crate::*; +use move_ir::{assert_no_error, assert_other_error}; + +#[test] +fn publish_existing_module() { + let mut test_env = TestEnvironment::default(); + let sender = hex::encode(test_env.accounts.get_address(0)); + let program = b" +modules: +module Currency { + resource Coin{money: R#Self.Coin} + public new(m: R#Self.Coin): R#Self.Coin { + return Coin{money: move(m)}; + } + public value(this :&R#Self.Coin): u64 { + let ref; + let val; + ref = ©(this).money; + val = Self.value(move(ref)); + release(move(this)); + return move(val); + } +} +module Currency { + resource Coin{money: R#Self.Coin} + public new(m: R#Self.Coin): R#Self.Coin { + return Coin{money: move(m)}; + } + public value(this: &R#Self.Coin): u64 { + let ref; + let val; + ref = ©(this).money; + val = Self.value(move(ref)); + release(move(this)); + return move(val); + } +} +script: +main() { + return; +}"; + // TODO: check error type once we add more macros. + assert_other_error!( + test_env.run(to_script(program, vec![])), + format!("Publish to an existing module 0x{}.Currency", sender) + ) +} diff --git a/language/functional_tests/_old_move_ir_tests/src/tests/references.rs b/language/functional_tests/_old_move_ir_tests/src/tests/references.rs new file mode 100644 index 0000000000000..c081fd247f05b --- /dev/null +++ b/language/functional_tests/_old_move_ir_tests/src/tests/references.rs @@ -0,0 +1,398 @@ +// Copyright (c) The Libra Core Contributors +// SPDX-License-Identifier: Apache-2.0 + +use crate::*; +use move_ir::{assert_no_error, assert_other_error}; + +#[test] +fn mutate_simple_reference() { + let mut test_env = TestEnvironment::default(); + let sender = hex::encode(test_env.accounts.get_address(0)); + + // This test creates two references that refer to the same local value. + // Case 1: Mutating one of them will invalidate the other reference. + // Case 2: You can keep mutating on the same reference for multiple times. + + let program = format!( + " +modules: +module M {{ + struct A{{f: u64}} + + public new(x: u64): V#Self.A {{ + return A{{f: move(x)}}; + }} + + public t(a: &mut V#Self.A) {{ + let f1; + let f2; + + f1 = &mut copy(a).f; + f2 = &mut copy(a).f; + *move(f1) = 4; + *move(f2) = 5; + + release(move(a)); + return; + }} +}} +script: +import 0x{0}.M; +main() {{ + let x; + let x_ref; + let f1; + let f2; + x = M.new(0); + x_ref = &mut x; + f1 = M.t(move(x_ref)); +}}", + sender + ); + assert_other_error!( + test_env.run(to_script(program.as_bytes(), vec![])), + "In function \'t\': Invalid usage of \'a\'. Field \'f\' is being borrowed by: \'f1\'" + ); + + let program2 = format!( + " +modules: +module M2 {{ + struct A{{f: u64}} + + public new(x: u64): V#Self.A {{ + return A{{f: move(x)}}; + }} + + public f(a: &mut V#Self.A): &mut u64 {{ + let f; + f = &mut copy(a).f; + release(move(a)); + return move(f); + }} +}} +script: +import 0x{0}.M2; +main() {{ + let x; + let x_ref; + let f1; + x = M2.new(0); + x_ref = &mut x; + f1 = M2.f(copy(x_ref)); + *copy(f1) = 4; + *move(f1) = 5; + release(move(x_ref)); + return; +}}", + sender + ); + assert_no_error!(test_env.run(to_script(program2.as_bytes(), vec![]))) +} + +#[test] +fn mutate_nested_reference() { + let mut test_env = TestEnvironment::default(); + let sender = hex::encode(test_env.accounts.get_address(0)); + + // This test creates a module that looks like following: + // A: f + // | + // {g: u64} + // + // The two test cases are following: + // 1. let f = &mut f; + // let g = &mut copy(f).g; + // *copy(f).g = 1; + // *copy(f) = ...; + // This should be success because the mutation of kid shouldn't invalidate the read of + // parent. + // + // 2. let f = &mut f; + // let g = &mut copy(f).g; + // *copy(f) = ...; + // *copy(f).g = 1; <-- This line should fail because mutation of parent will invalidate kid. + let program = format!( + " +modules: +module M {{ + struct A{{f: V#Self.B}} + struct B{{g: u64}} + + public A(f: V#Self.B): V#Self.A {{ + return A{{f: move(f)}}; + }} + + public B(g: u64): V#Self.B {{ + return B{{g: move(g)}}; + }} + + public t(a: &mut V#Self.A) {{ + let f_ref; + let g_ref; + let b; + + f_ref = &mut copy(a).f; + g_ref = &mut copy(f_ref).g; + *move(g_ref) = 5; + b = Self.B(2); + *move(f_ref) = move(b); + release(move(a)); + return; + }} +}} +script: +import 0x{0}.M; +main() {{ + let b; + let x; + let x_ref; + let f_ref; + let g_ref; + + b = M.B(0); + x = M.A(move(b)); + x_ref = &mut x; + M.t(move(x_ref)); + return; +}}", + sender, + ); + assert_no_error!(test_env.run(to_script(program.as_bytes(), vec![]))); + + // Mutate kid after parent + let program2 = format!( + " +modules: +module M2 {{ + struct A{{f: V#Self.B}} + struct B{{g: u64}} + + public A(f: V#Self.B): V#Self.A {{ + return A{{f: move(f)}}; + }} + + public B(g: u64): V#Self.B {{ + return B{{g: move(g)}}; + }} + + public t(a: &mut V#Self.A) {{ + let f_ref; + let g_ref; + let b; + + f_ref = &mut copy(a).f; + g_ref = &mut copy(f_ref).g; + + b = Self.B(2); + *move(f_ref) = move(b); + *move(g_ref) = 5; + + release(move(a)); + + return; + }} +}} +script: +import 0x{0}.M2; +main() {{ + let b; + let x; + let x_ref; + let f_ref; + let g_ref; + + b = M2.B(0); + x = M2.A(move(b)); + x_ref = &mut x; + M2.t(move(x_ref)); + + return; +}}", + sender, + ); + assert_other_error!( + test_env.run(to_script(program2.as_bytes(), vec![])), + "In function \'t\': Invalid usage of \'f_ref\'. Field \'g\' is being borrowed by: \'g_ref\'" + ) +} + +#[test] +fn mutate_sibling_reference() { + let mut test_env = TestEnvironment::default(); + let sender = hex::encode(test_env.accounts.get_address(0)); + // This test create a following struct: + // A: f + // / \ + // / \ + // B: g h + // The tests are as following: + // let f = &mut f; + // let g = &mut copy(f).g; + // let h = &mut copy(f).h; + // 1. *copy(g) = 5; + // let h2 = *copy(h); + // let f2 = *copy(f); + // This should success because mutating a kid will not invalidate its parent nor its siblings + // 2. *copy(f) = _; + // let g2 = *copy(g); <-- This should fail because mutating a parent will invalidate all of + // its kids. + // let h2 = *copy(h); <-- This will fail because of the exact same reasoning. + let program = format!( + " +modules: +module M {{ + struct A{{f: V#Self.B}} + struct B{{g: u64, h: u64}} + + public A(f: V#Self.B): V#Self.A {{ + return A{{f: move(f)}}; + }} + + public B(g: u64, h: u64): V#Self.B {{ + return B{{g: move(g), h: move(h)}}; + }} + + public t(a: &mut V#Self.A) {{ + let f; + let g; + let h; + let f_ref; + let g_ref; + let h_ref; + + f_ref = &mut copy(a).f; + g_ref = &mut copy(f_ref).g; + h_ref = &mut copy(f_ref).h; + *move(g_ref) = 5; + h = *move(h_ref); + assert(move(h) == 1, 42); + f = *move(f_ref); + release(move(a)); + return; + }} +}} +script: +import 0x{0}.M; +main() {{ + let b; + let x; + let x_ref; + let f_ref; + let g_ref; + let h_ref; + let h; + let f; + + b = M.B(0, 1); + x = M.A(move(b)); + x_ref = &mut x; + M.t(move(x_ref)); + + return; +}}", + sender, + ); + assert_no_error!(test_env.run(to_script(program.as_bytes(), vec![]))); + + // Mutating parent will invalidate both kids + let program2 = format!( + " +modules: +module M2 {{ + struct A{{f: V#Self.B}} + struct B{{g: u64, h: u64}} + + public A(f: V#Self.B): V#Self.A {{ + return A{{f: move(f)}}; + }} + + public B(g: u64, h: u64): V#Self.B {{ + return B{{g: move(g), h: move(h)}}; + }} + + public t(a: &mut V#Self.A) {{ + let f_ref; + let h_ref; + let b; + let h; + + f_ref = &mut copy(a).f; + h_ref = &mut copy(f_ref).h; + + b = Self.B(2, 3); + *move(f_ref) = move(b); + h = *move(h_ref); + release(move(a)); + }} +}} +script: +import 0x{0}.M2; +main() {{ + let b; + let x; + let x_ref; + + b = M2.B(0, 1); + x = M2.A(move(b)); + x_ref = &mut x; + M2.t(move(x_ref)); + return; +}}", + sender, + ); + assert_other_error!( + test_env.run(to_script(program2.as_bytes(), vec![])), + "In function \'t\': Invalid usage of \'f_ref\'. Field \'h\' is being borrowed by: \'h_ref\'" + ); + + let program3 = format!( + " +modules: +module M3 {{ + struct A{{f: V#Self.B}} + struct B{{g: u64, h: u64}} + + public A(f: V#Self.B): V#Self.A {{ + return A{{f: move(f)}}; + }} + + public B(g: u64, h: u64): V#Self.B {{ + return B{{g: move(g), h: move(h)}}; + }} + + public t(a: &mut V#Self.A) {{ + let f_ref; + let g_ref; + let b; + let g; + + f_ref = &mut copy(a).f; + g_ref = &mut copy(f_ref).g; + + b = Self.B(2, 3); + *move(f_ref) = move(b); + g = *move(g_ref); + release(move(a)); + }} +}} +script: +import 0x{0}.M3; +main() {{ + let b; + let x; + let x_ref; + + b = M3.B(0, 1); + x = M3.A(move(b)); + x_ref = &mut x; + M3.t(move(x_ref)); + return; +}}", + sender, + ); + assert_other_error!( + test_env.run(to_script(program3.as_bytes(), vec![])), + "In function \'t\': Invalid usage of \'f_ref\'. Field \'g\' is being borrowed by: \'g_ref\'" + ); +} diff --git a/language/functional_tests/_old_move_ir_tests/src/tests/transaction_scripts.rs b/language/functional_tests/_old_move_ir_tests/src/tests/transaction_scripts.rs new file mode 100644 index 0000000000000..031d6018991f7 --- /dev/null +++ b/language/functional_tests/_old_move_ir_tests/src/tests/transaction_scripts.rs @@ -0,0 +1,171 @@ +// Copyright (c) The Libra Core Contributors +// SPDX-License-Identifier: Apache-2.0 + +use crate::*; +use hex; +use move_ir::assert_no_error; +use std::borrow::Borrow; + +#[test] +fn peer_to_peer_payment_script() { + let mut test_env = TestEnvironment::default(); + let recipient = test_env.accounts.get_address(1); + let program = move_ir::stdlib::transaction_scripts::peer_to_peer_transfer_transaction_bincode( + &recipient, 10, + ); + let result = test_env.run_program(program); + let gas_cost = match &result { + Ok(res) => res.gas_used(), + _ => 0, + }; + + assert_no_error!(result); + + let assert_balance = format!( + " +import 0x0.LibraAccount; +main() {{ + let sender_address; + let recipient_address; + let sender_initial_balance; + let transaction_cost; + let sender_balance; + let recipient_balance; + + sender_address = get_txn_sender(); + recipient_address = 0x{}; + sender_initial_balance = {}; + transaction_cost = {}; + sender_balance = LibraAccount.balance(move(sender_address)); + recipient_balance = LibraAccount.balance(move(recipient_address)); + + assert(move(sender_balance) == copy(sender_initial_balance) - copy(transaction_cost) - 10, 77); + assert(move(recipient_balance) == move(sender_initial_balance) + 10, 88); + + return; +}}", + hex::encode(recipient), + TestEnvironment::INITIAL_BALANCE, + TestEnvironment::DEFAULT_GAS_COST * gas_cost, + ); + + assert_no_error!(test_env.run(to_script(assert_balance.as_bytes(), vec![]))) +} + +#[test] +fn peer_to_peer_payment_script_create_account() { + let mut test_env = TestEnvironment::default(); + // some recipient whose account does not exist + let recipient = test_env.accounts.fresh_account().addr; + let program = move_ir::stdlib::transaction_scripts::peer_to_peer_transfer_transaction_bincode( + &recipient, 10, + ); + + assert_no_error!(test_env.run_program(program)); +} + +#[test] +fn create_account_script() { + let mut test_env = TestEnvironment::default(); + // some recipient whose account does not exist + let fresh_account = test_env.accounts.fresh_account(); + let fresh_address = &fresh_account.addr; + let program = move_ir::stdlib::transaction_scripts::create_account_transaction_bincode( + fresh_address, + TestEnvironment::DEFAULT_MAX_GAS + 100, + ); + + assert_no_error!(test_env.run_program(program)); + + // make sure the account has been created by sending a transaction from it + let sequence_number = 0; + let txn = test_env.create_signed_txn( + to_script(b"main() { return; }", vec![]), + fresh_address.clone(), + fresh_account, + sequence_number, + TestEnvironment::DEFAULT_MAX_GAS, + TestEnvironment::DEFAULT_GAS_COST, + ); + + assert_no_error!(test_env.run_txn(txn)) +} + +#[test] +fn rotate_authentication_key_script() { + let mut test_env = TestEnvironment::default(); + let fresh_account = test_env.accounts.fresh_account(); + let new_authentication_key = fresh_account.pubkey.borrow().into(); + let program = + move_ir::stdlib::transaction_scripts::rotate_authentication_key_transaction_bincode( + new_authentication_key, + ); + + assert_no_error!(test_env.run_program(program)); + + // we need to use the new key in order to send a transaction + let old_account = test_env.accounts.get_account(0); + let new_account = Account { + addr: old_account.addr, + privkey: fresh_account.privkey, + pubkey: fresh_account.pubkey, + }; + + // make sure rotation worked by sending with the new key + let sequence_number = 1; + let txn = test_env.create_signed_txn( + to_standalone_script(b"main() { return; }"), + old_account.addr, + new_account, + sequence_number, + TestEnvironment::DEFAULT_MAX_GAS, + TestEnvironment::DEFAULT_GAS_COST, + ); + + assert_no_error!(test_env.run_txn(txn)) +} + +// TODO: eliminate mint/this script +#[test] +fn mint_script() { + let mut test_env = TestEnvironment::default(); + let recipient = test_env.accounts.get_address(1); + let program = move_ir::stdlib::transaction_scripts::mint_transaction_bincode(&recipient, 10); + let result = test_env.run_program(program); + let gas_cost = match &result { + Ok(res) => res.gas_used(), + _ => 0, + }; + + assert_no_error!(result); + + let assert_balance = format!( + " +import 0x0.LibraAccount; +main() {{ + let sender_address; + let recipient_address; + let sender_initial_balance; + let transaction_cost; + let sender_balance; + let recipient_balance; + + sender_address = get_txn_sender(); + recipient_address = 0x{}; + sender_initial_balance = {}; + transaction_cost = {}; + sender_balance = LibraAccount.balance(move(sender_address)); + recipient_balance = LibraAccount.balance(move(recipient_address)); + + assert(move(sender_balance) == copy(sender_initial_balance) - copy(transaction_cost), 77); + assert(move(recipient_balance) == move(sender_initial_balance) + 10, 88); + + return; +}}", + hex::encode(recipient), + TestEnvironment::INITIAL_BALANCE, + TestEnvironment::DEFAULT_GAS_COST * gas_cost, + ); + + assert_no_error!(test_env.run(to_script(assert_balance.as_bytes(), vec![]))) +} diff --git a/language/functional_tests/_old_move_ir_tests/src/tests/transactions.rs b/language/functional_tests/_old_move_ir_tests/src/tests/transactions.rs new file mode 100644 index 0000000000000..105393c3a8939 --- /dev/null +++ b/language/functional_tests/_old_move_ir_tests/src/tests/transactions.rs @@ -0,0 +1,249 @@ +// Copyright (c) The Libra Core Contributors +// SPDX-License-Identifier: Apache-2.0 + +use crate::*; +use move_ir::{assert_error_type, assert_no_error}; +use proptest::prelude::*; +use types::{ + account_address::AccountAddress, + transaction::{TransactionArgument, TransactionPayload}, +}; + +// no arguments on a void main() is good +#[test] +fn tx_no_args_good() { + let mut test_env = TestEnvironment::default(); + + assert_no_error!(test_env.run_with_arguments(vec![], to_script(b"main() { return; }", vec![]))); +} + +// one arg to a void main() is an error (integer arg) +#[test] +fn tx_no_args_bad1() { + let mut test_env = TestEnvironment::default(); + + assert_error_type!( + test_env.run_with_arguments( + vec![TransactionArgument::U64(10)], + to_script(b"main() { return; }", vec![]) + ), + ErrorKind::BadTransactionArgs + ); +} + +// no arg to a void main(u64) is an error +#[test] +fn tx_one_arg_bad1() { + let mut test_env = TestEnvironment::default(); + + assert_error_type!( + test_env.run_with_arguments(vec![], to_script(b"main(value: u64) { return; }", vec![])), + ErrorKind::BadTransactionArgs + ) +} + +// two arg to a void main(u64) is an error +#[test] +fn tx_one_arg_bad2() { + let mut test_env = TestEnvironment::default(); + + assert_error_type!( + test_env.run_with_arguments( + vec![ + TransactionArgument::U64(10), + TransactionArgument::U64(1), + ], + to_script(b"main(value: u64) { return; }", vec![]) + ), + ErrorKind::BadTransactionArgs + ); +} + +// address arg to a void main(u64) is an error +#[test] +fn tx_one_arg_bad3() { + let mut test_env = TestEnvironment::default(); + + assert_error_type!( + test_env.run_with_arguments( + vec![TransactionArgument::Address(AccountAddress::default())], + to_script(b"main(value: u64) { return; }", vec![]) + ), + ErrorKind::BadTransactionArgs + ); +} + +// integer arg to main(u64) is good with assert true (value passed is expected) +#[test] +fn tx_one_arg_good1() { + let mut test_env = TestEnvironment::default(); + + assert_no_error!(test_env.run_with_arguments( + vec![TransactionArgument::U64(10),], + to_script( + b"main(value: u64) { assert(copy(value) == 10, 42); return; }", + vec![] + ) + )) +} + +// (u64, u64) args to main(u64, address) is an error +#[test] +fn tx_two_args_bad2() { + let mut test_env = TestEnvironment::default(); + + assert_error_type!( + test_env.run_with_arguments( + vec![ + TransactionArgument::U64(10), + TransactionArgument::U64(1), + ], + to_script(b"main(value: u64, addr: address) { return; }", vec![]) + ), + ErrorKind::BadTransactionArgs + ) +} + +// address arg to main(u64, address) is an error +#[test] +fn tx_two_args_bad3() { + let mut test_env = TestEnvironment::default(); + + assert_error_type!( + test_env.run_with_arguments( + vec![TransactionArgument::Address(AccountAddress::default())], + to_script(b"main(value: u64, addr: address) { return; }", vec![]) + ), + ErrorKind::BadTransactionArgs + ) +} + +// (address, u64) args to main(u64, address) is an error +#[test] +fn tx_two_args_bad4() { + let mut test_env = TestEnvironment::default(); + + assert_error_type!( + test_env.run_with_arguments( + vec![ + TransactionArgument::Address(AccountAddress::default()), + TransactionArgument::U64(10), + ], + to_script(b"main(value: u64, addr: address) { return; }", vec![]) + ), + ErrorKind::BadTransactionArgs + ) +} + +// (u64, address, u64) args to main(u64, address) is an error +#[test] +fn tx_two_args_bad5() { + let mut test_env = TestEnvironment::default(); + + assert_error_type!( + test_env.run_with_arguments( + vec![ + TransactionArgument::U64(10), + TransactionArgument::Address(AccountAddress::default()), + TransactionArgument::U64(10), + ], + to_script(b"main(value: u64, addr: address) { return; }", vec![]) + ), + ErrorKind::BadTransactionArgs + ) +} + +// (u64, address) args to main(u64, address) is good - with assert not firing +#[test] +fn tx_two_args_good1() { + let mut test_env = TestEnvironment::default(); + let account = test_env.accounts.get_account(1); + let address = account.addr; + let address_str = hex::encode(address); + + let program = format!( + " +main(value: u64, addr: address) {{ + assert(copy(value) == 10, 42); + assert(copy(addr) == 0x{}, 42); + return; +}} + ", + address_str + ); + + assert_no_error!(test_env.run_with_arguments( + vec![ + TransactionArgument::U64(10), + TransactionArgument::Address(address), + ], + to_script(program.as_bytes(), vec![]) + )); +} + +// (u64, address) args to main(u64, address) is good - with assert firing on u64 +#[test] +fn tx_two_args_good2() { + let mut test_env = TestEnvironment::default(); + + assert_error_type!( + test_env.run_with_arguments( + vec![ + TransactionArgument::U64(1), + TransactionArgument::Address(AccountAddress::default()), + ], + to_script( + b"main(value: u64, addr: address) { assert(copy(value) == 10, 42); return; }", + vec![], + ) + ), + ErrorKind::AssertError(_, _) + ); +} + +// (u64, u64, address) args to main(u64, u64, address) is good +#[test] +fn tx_three_args_good1() { + let mut test_env = TestEnvironment::default(); + + let address = test_env.accounts.get_address(1); + let address_str = hex::encode(address); + + let program = format!( + " +main(value1: u64, value2: u64, addr: address) {{ + assert(copy(value1) + copy(value2) == 8, 42); + assert(copy(addr) == 0x{}, 42); + return; +}} + ", + address_str + ); + + assert_no_error!(test_env.run_with_arguments( + vec![ + TransactionArgument::U64(3), + TransactionArgument::U64(5), + TransactionArgument::Address(address), + ], + to_script(program.as_bytes(), vec![]) + )) +} + +#[test] +fn write_set_txn_roundtrip() { + // Creating a new test environment is expensive so do it outside the proptest environment. + let test_env = TestEnvironment::default(); + + proptest!(|(signed_txn in SignedTransaction::genesis_strategy())| { + let write_set = match signed_txn.payload() { + TransactionPayload::WriteSet(write_set) => write_set.clone(), + TransactionPayload::Program(_) => unreachable!( + "write set strategy should only generate write set transactions", + ), + }; + let output = test_env.eval_txn(signed_txn) + .expect("write set transactions should succeed"); + prop_assert_eq!(output.write_set(), &write_set); + }); +} diff --git a/language/functional_tests/_old_move_ir_tests/src/tests/verify_transaction.rs b/language/functional_tests/_old_move_ir_tests/src/tests/verify_transaction.rs new file mode 100644 index 0000000000000..40292574615f9 --- /dev/null +++ b/language/functional_tests/_old_move_ir_tests/src/tests/verify_transaction.rs @@ -0,0 +1,148 @@ +// Copyright (c) The Libra Core Contributors +// SPDX-License-Identifier: Apache-2.0 + +use crate::*; +use move_ir::assert_no_error; +use proptest::prelude::*; +use std::time::Duration; +use types::transaction::{RawTransaction, SignedTransaction, TransactionPayload}; + +#[test] +fn verify_txn_accepts_good_sequence_number() { + let test_env = TestEnvironment::default(); + + let sequence_number = 0; + assert_no_error!(test_env.verify_txn_with_context( + to_script( + b" +main() { + let transaction_sequence_number; + let sender; + let sequence_number; + + transaction_sequence_number = get_txn_sequence_number(); + assert(copy(transaction_sequence_number) == 0, 42); + + sender = get_txn_sender(); + sequence_number = LibraAccount.sequence_number(move(sender)); + assert(move(sequence_number) == 0, 43); + + return; +}", + vec![] + ), + 0, + sequence_number, + TestEnvironment::DEFAULT_MAX_GAS, + TestEnvironment::DEFAULT_GAS_COST, + )); +} + +// TODO: ensure that verify_txn rejects stale sequence numbers, but accepts any sequence number >= +// the current one. this is needed to support the "parking lot" feature of the mempool. +#[test] +fn verify_txn_rejects_bad_sequence_number() { + let test_env = TestEnvironment::default(); + + let sequence_number = 1; + assert!(test_env + .verify_txn_with_context( + to_script( + b" +main() { + return; +}", + vec![] + ), + 0, + sequence_number, + TestEnvironment::DEFAULT_MAX_GAS, + TestEnvironment::DEFAULT_GAS_COST, + ) + .is_err()); +} + +#[test] +fn verify_txn_rejects_bad_signature() { + let test_env = TestEnvironment::default(); + + // Create a transaction signed by account 0 but has the pubkey of account 1. + let sender_account = test_env.accounts.get_account(0); + let public_key = test_env.accounts.get_account(1).pubkey; + + let raw_txn = RawTransaction::new( + sender_account.addr, + 0, + Program::new(to_script(b"main() { return; }", vec![]), vec![], vec![]), + "".to_string(), + TestEnvironment::DEFAULT_MAX_GAS, + TestEnvironment::DEFAULT_GAS_COST, + Duration::from_secs(u64::max_value()), + ); + + let signed_txn = raw_txn + .clone() + .sign(&sender_account.privkey, &public_key) + .unwrap(); + + let signed_txn_with_bad_pubkey = + SignedTransaction::new_for_test(raw_txn, public_key, signed_txn.signature()); + assert!(test_env.verify_txn(signed_txn_with_bad_pubkey).is_err()); +} + +#[test] +fn verify_txn_accepts_good_signature() { + let test_env = TestEnvironment::default(); + + let sender_account = test_env.accounts.get_account(0); + let signed_txn = test_env.create_signed_txn( + to_script(b"main() { return; }", vec![]), + sender_account.addr, + sender_account, + 0, + TestEnvironment::DEFAULT_MAX_GAS, + TestEnvironment::DEFAULT_GAS_COST, + ); + assert!(test_env.verify_txn(signed_txn).is_ok()); +} + +#[test] +fn verify_txn_rejects_write_set() { + let test_env = TestEnvironment::default(); + assert_ne!(test_env.get_version(), 0); + + proptest!(|(txn in SignedTransaction::write_set_strategy())| { + test_env.verify_txn(txn).expect_err("non-genesis write set txns should fail verification"); + }); +} + +#[test] +fn verify_txn_rejects_genesis_deletion() { + let test_env = TestEnvironment::empty(); + assert_eq!(test_env.get_version(), 0); + + proptest!(|(txn in SignedTransaction::write_set_strategy())| { + let write_set = match txn.payload() { + TransactionPayload::WriteSet(write_set) => write_set, + TransactionPayload::Program(_) => panic!( + "write_set_strategy shouldn't generate programs", + ), + }; + let any_deletions = write_set.iter().any(|(_, write_op)| write_op.is_deletion()); + if any_deletions { + test_env.verify_txn(txn).expect_err("genesis write set with deletes should be rejected"); + } else { + test_env.verify_txn(txn).expect("genesis write set txns should verify correctly"); + } + }); +} + +#[test] +fn verify_txn_accepts_genesis_write_set() { + let test_env = TestEnvironment::empty(); + assert_eq!(test_env.get_version(), 0); + + proptest!(|(txn in SignedTransaction::genesis_strategy())| { + test_env.verify_txn(txn).expect("genesis write set txns should verify correctly"); + }); +} diff --git a/language/functional_tests/src/checker.rs b/language/functional_tests/src/checker.rs new file mode 100644 index 0000000000000..d1a42cbb3e042 --- /dev/null +++ b/language/functional_tests/src/checker.rs @@ -0,0 +1,109 @@ +// Copyright (c) The Libra Core Contributors +// SPDX-License-Identifier: Apache-2.0 + +use crate::{ + errors::*, + evaluator::{EvaluationResult, Stage, Status}, +}; +use filecheck; +use std::slice::SliceConcatExt; + +/// A directive specifies a pattern in the output. +/// Directives are extracted from comments starting with "//". +#[derive(Debug, Clone)] +pub enum Directive { + /// Matches the specified stage in the output. Acts as a barrier. + Stage(Stage), + /// Used to build the filecheck checker. Right now all comments except the ones that are + /// recognized as other directives goes here. + Check(String), +} + +impl Directive { + /// Tries to parse the given string into a directive. Returns an option indicating whether + /// the given input is a directive or not. Errors when the input looks like a directive but + /// is ill-formed. + pub fn try_parse(s: &str) -> Result> { + let s1 = s.trim_start(); + if !s1.starts_with("//") { + return Ok(None); + } + let s2 = s1[2..].trim_start(); + if s2.starts_with("stage: ") { + let s3 = s2[7..].trim_start().trim_end(); + return Ok(Some(Directive::Stage(Stage::parse(s3)?))); + } + Ok(Some(Directive::Check(s.to_string()))) + } +} + +/// Check the output using filecheck checker. +pub fn run_filecheck(output: &str, checks: &str) -> Result { + let mut builder = filecheck::CheckerBuilder::new(); + builder.text(checks)?; + let checker = builder.finish(); + // filecheck allows one to pass in a variable map, however we're not using it + if !checker.check(output, filecheck::NO_VARIABLES)? { + return Err(ErrorKind::CheckerFailure.into()); + } + Ok(!checker.is_empty()) +} + +/// Verifies the directives against the given evaluation result. +pub fn check(res: &EvaluationResult, directives: &[Directive]) -> Result<()> { + let mut checks: Vec = vec![]; + let mut outputs: Vec = vec![]; + let mut did_run_checks = false; + + let mut i = 0; + + for directive in directives { + match directive { + Directive::Check(check) => { + checks.push(check.clone()); + } + Directive::Stage(barrier) => loop { + if i >= res.stages.len() { + return Err(ErrorKind::Other(format!( + "no stage '{:?}' in the output", + barrier + )) + .into()); + } + let (stage, output) = &res.stages[i]; + if stage < barrier { + outputs.push(output.to_string()); + i += 1; + } else if stage == barrier { + did_run_checks |= run_filecheck(&outputs.join("\n"), &checks.join("\n"))?; + checks.clear(); + outputs.clear(); + outputs.push(output.to_string()); + i += 1; + break; + } else { + return Err(ErrorKind::Other(format!( + "no stage '{:?}' in the output", + barrier + )) + .into()); + } + }, + } + } + + for (_, output) in res.stages[i..].iter() { + outputs.push(output.clone()); + } + did_run_checks |= run_filecheck(&outputs.join("\n"), &checks.join("\n"))?; + + if res.status == Status::Failure && !did_run_checks { + return Err(ErrorKind::Other(format!( + "program failed at stage '{:?}', no directives found, assuming failure", + res.stages.last().unwrap().0 + )) + .into()); + } + + Ok(()) +} diff --git a/language/functional_tests/src/config.rs b/language/functional_tests/src/config.rs new file mode 100644 index 0000000000000..d9e8913ea2189 --- /dev/null +++ b/language/functional_tests/src/config.rs @@ -0,0 +1,82 @@ +// Copyright (c) The Libra Core Contributors +// SPDX-License-Identifier: Apache-2.0 + +// The config holds the options that define the testing environment. +// A config entry starts with "//!", differentiating it from a directive. + +use crate::errors::*; + +/// A raw config entry extracted from the input. Used to build the config. +#[derive(Debug, Clone)] +pub enum ConfigEntry { + NoVerify, + NoExecute, +} + +impl ConfigEntry { + /// Tries to parse the input as an entry. Errors when the input looks + /// like a config but is ill-formed. + pub fn try_parse(s: &str) -> Result> { + let s1 = s.trim_start().trim_end(); + if !s1.starts_with("//!") { + return Ok(None); + } + let s2 = s1[3..].trim_start(); + match s2 { + "no-verify" => { + return Ok(Some(ConfigEntry::NoVerify)); + } + "no-execute" => { + return Ok(Some(ConfigEntry::NoExecute)); + } + _ => {} + } + Err(ErrorKind::Other(format!("invalid config option '{:?}'", s2)).into()) + } +} + +/// A table of options that customizes/defines the testing environment. +#[derive(Debug)] +pub struct Config { + /// If set to true, the compiled program is sent through execution without being verified + pub no_verify: bool, + /// If set to true, the compiled program will not get executed + pub no_execute: bool, +} + +impl Config { + /// Builds a config from a collection of entries. Also sets the default values for entries that + /// are missing. + pub fn build(entries: &[ConfigEntry]) -> Result { + let mut no_verify: Option = None; + let mut no_execute: Option = None; + for entry in entries { + match entry { + ConfigEntry::NoVerify => match no_verify { + None => { + no_verify = Some(true); + } + _ => { + return Err( + ErrorKind::Other("flag 'no-verify' already set".to_string()).into() + ); + } + }, + ConfigEntry::NoExecute => match no_execute { + None => { + no_execute = Some(true); + } + _ => { + return Err( + ErrorKind::Other("flag 'no-execute' already set".to_string()).into(), + ); + } + }, + } + } + Ok(Config { + no_verify: no_verify.unwrap_or(false), + no_execute: no_execute.unwrap_or(false), + }) + } +} diff --git a/language/functional_tests/src/errors.rs b/language/functional_tests/src/errors.rs new file mode 100644 index 0000000000000..7fdb8ebf4786f --- /dev/null +++ b/language/functional_tests/src/errors.rs @@ -0,0 +1,25 @@ +// Copyright (c) The Libra Core Contributors +// SPDX-License-Identifier: Apache-2.0 + +use failure::{Error, Fail}; +use types::transaction::TransactionOutput; +use vm::errors::VerificationError as VMVerificationError; + +/// Defines all errors in this crate. +#[derive(Clone, Debug, Fail)] +pub enum ErrorKind { + #[fail(display = "an error occurred when executing the program")] + VMExecutionFailure(TransactionOutput), + #[fail(display = "the transaction was discarded")] + DiscardedTransaction(TransactionOutput), + #[fail(display = "the checker has failed to match the directives against the output")] + CheckerFailure, + #[fail(display = "verification error {:?}", _0)] + VerificationFailure(Vec), + #[fail(display = "other error: {}", _0)] + #[allow(dead_code)] + Other(String), +} + +/// The common result type used in this crate. +pub type Result = std::result::Result; diff --git a/language/functional_tests/src/evaluator.rs b/language/functional_tests/src/evaluator.rs new file mode 100644 index 0000000000000..ae2fc2870ea21 --- /dev/null +++ b/language/functional_tests/src/evaluator.rs @@ -0,0 +1,200 @@ +// Copyright (c) The Libra Core Contributors +// SPDX-License-Identifier: Apache-2.0 + +use crate::{config::Config, errors::*}; +use bytecode_verifier::{ + verify_module, verify_module_dependencies, verify_script, verify_script_dependencies, +}; +use compiler::{compiler::compile_program, parser::parse_program, util::build_stdlib}; +use config::config::VMPublishingOption; +use std::time::Duration; +use types::{ + transaction::{Program, RawTransaction, TransactionOutput, TransactionStatus}, + vm_error::{ExecutionStatus, VMStatus}, +}; +use vm::{ + errors::VerificationError, + file_format::{CompiledModule, CompiledProgram, CompiledScript}, +}; +use vm_runtime_tests::{account::AccountData, executor::FakeExecutor}; + +/// Indicates one step in the pipeline the given move module/program goes through. +// Ord is derived as we need to be able to determine if one stage is before another. +#[derive(Debug, Clone, Ord, PartialOrd, Eq, PartialEq)] +pub enum Stage { + Parser, + // Right now parser is a separate stage. + // However it could be merged into the compiler. + Compiler, + Verifier, + Runtime, +} + +impl Stage { + /// Parses the input string as Stage. + pub fn parse(s: &str) -> Result { + match s { + "parser" => Ok(Stage::Parser), + "compiler" => Ok(Stage::Compiler), + "verifier" => Ok(Stage::Verifier), + "runtime" => Ok(Stage::Runtime), + _ => Err(ErrorKind::Other(format!("unrecognized stage '{:?}'", s)).into()), + } + } +} + +/// Evaluation status: success or failure. +#[derive(Debug, Clone, Eq, PartialEq)] +pub enum Status { + Success, + Failure, +} + +/// A log consisting of outputs from all stages and the final status. +/// This is checked against the directives. +#[derive(Debug)] +pub struct EvaluationResult { + pub stages: Vec<(Stage, String)>, + pub status: Status, +} + +impl EvaluationResult { + /// Appends another entry to the evaluation result. + pub fn append(&mut self, stage: Stage, output: String) { + self.stages.push((stage, output)); + } +} + +fn check_verification_errors(errors: Vec) -> Result<()> { + if !errors.is_empty() { + return Err(ErrorKind::VerificationFailure(errors).into()); + } + Ok(()) +} + +fn do_verify_module(module: &CompiledModule, deps: &[CompiledModule]) -> Result<()> { + check_verification_errors(verify_module(module.clone()).1)?; + check_verification_errors(verify_module_dependencies(module.clone(), deps).1) +} + +fn do_verify_script(script: &CompiledScript, deps: &[CompiledModule]) -> Result<()> { + check_verification_errors(verify_script(script.clone()).1)?; + check_verification_errors(verify_script_dependencies(script.clone(), deps).1) +} + +// TODO: Add a helper function to the verifier +fn do_verify_program(program: &CompiledProgram, deps: &[CompiledModule]) -> Result<()> { + let mut deps = deps.to_vec(); + for m in &program.modules { + do_verify_module(m, &deps)?; + deps.push(m.clone()); + } + do_verify_script(&program.script, &deps) +} + +fn create_transaction_program(program: &CompiledProgram) -> Result { + let mut script_blob = vec![]; + program.script.serialize(&mut script_blob)?; + + let module_blobs = program + .modules + .iter() + .map(|m| { + let mut module_blob = vec![]; + m.serialize(&mut module_blob)?; + Ok(module_blob) + }) + .collect::>>()?; + + // Currently we do not support transaction arguments in functional tests. + Ok(Program::new(script_blob, module_blobs, vec![])) +} + +/// Runs a single transaction using the fake executor. +fn run_transaction(data: &AccountData, program: &CompiledProgram) -> Result { + let mut exec = FakeExecutor::from_genesis_with_options(VMPublishingOption::Open); + exec.add_account_data(data); + let account = data.account(); + + let program = create_transaction_program(program)?; + + let transaction = RawTransaction::new( + *data.address(), + data.sequence_number(), + program, + 1_000_000, + // Right now, max gas cost is an arbitratry large number. + // TODO: allow the user to specify this in the config. + 1, + Duration::from_secs(u64::max_value()), + ) + .sign(&account.privkey, account.pubkey)?; + + let mut outputs = exec.execute_block(vec![transaction]); + if outputs.len() == 1 { + let output = outputs.pop().unwrap(); + match output.status() { + TransactionStatus::Keep(VMStatus::Execution(ExecutionStatus::Executed)) => Ok(output), + TransactionStatus::Keep(_) => Err(ErrorKind::VMExecutionFailure(output).into()), + TransactionStatus::Discard(_) => Err(ErrorKind::DiscardedTransaction(output).into()), + } + } else { + panic!("transaction outputs size mismatch"); + } +} + +/// Tries to unwrap the given result. Upon failure, log the error and aborts. +macro_rules! unwrap_or_log { + ($res: expr, $log: expr, $stage: expr) => {{ + match $res { + Ok(r) => r, + Err(e) => { + $log.append($stage, format!("{:?}", e)); + return Ok($log); + } + } + }}; +} + +/// Feeds the input through the pipeline and produces an EvaluationResult. +pub fn eval(config: &Config, text: &str) -> Result { + let mut res = EvaluationResult { + stages: vec![], + status: Status::Failure, + }; + + let deps = build_stdlib(); + let account_data = AccountData::new(1_000_000, 0); + let addr = account_data.address(); + + let parsed_program = unwrap_or_log!(parse_program(&text), res, Stage::Parser); + res.append(Stage::Parser, format!("{:?}", parsed_program)); + + let compiled_program = unwrap_or_log!( + compile_program(addr, &parsed_program, &deps), + res, + Stage::Compiler + ); + res.append(Stage::Compiler, format!("{:?}", compiled_program)); + + if !config.no_verify { + unwrap_or_log!( + do_verify_program(&compiled_program, &deps), + res, + Stage::Verifier + ); + res.append(Stage::Verifier, "".to_string()); + } + + if !config.no_execute { + let txn_output = unwrap_or_log!( + run_transaction(&account_data, &compiled_program), + res, + Stage::Runtime + ); + res.append(Stage::Runtime, format!("{:?}", txn_output)); + } + + res.status = Status::Success; + Ok(res) +} diff --git a/language/functional_tests/src/lib.rs b/language/functional_tests/src/lib.rs new file mode 100644 index 0000000000000..151087502ff3e --- /dev/null +++ b/language/functional_tests/src/lib.rs @@ -0,0 +1,12 @@ +// Copyright (c) The Libra Core Contributors +// SPDX-License-Identifier: Apache-2.0 + +#![feature(slice_concat_ext)] + +pub mod checker; +pub mod config; +pub mod errors; +pub mod evaluator; + +#[cfg(test)] +pub mod tests; diff --git a/language/functional_tests/src/tests/checker_tests.rs b/language/functional_tests/src/tests/checker_tests.rs new file mode 100644 index 0000000000000..1af5c0f4cb417 --- /dev/null +++ b/language/functional_tests/src/tests/checker_tests.rs @@ -0,0 +1,145 @@ +// Copyright (c) The Libra Core Contributors +// SPDX-License-Identifier: Apache-2.0 + +use crate::{ + checker::{check, run_filecheck, Directive}, + evaluator::{EvaluationResult, Stage, Status}, +}; + +#[test] +fn parse_directives() { + assert!(Directive::try_parse("abc").unwrap().is_none()); + Directive::try_parse("// check: abc").unwrap().unwrap(); + Directive::try_parse(" // check: abc").unwrap().unwrap(); + Directive::try_parse("//not: foo").unwrap().unwrap(); + Directive::try_parse("// stage: parser").unwrap().unwrap(); + + Directive::try_parse("// stage: compiler").unwrap().unwrap(); + Directive::try_parse("// stage: verifier").unwrap().unwrap(); + Directive::try_parse("// stage: runtime").unwrap().unwrap(); + Directive::try_parse("// stage: runtime ") + .unwrap() + .unwrap(); + + Directive::try_parse("// stage: runtime bad ").unwrap_err(); + Directive::try_parse("// stage: bad stage").unwrap_err(); + Directive::try_parse("// stage: ").unwrap_err(); +} + +#[rustfmt::skip] +#[test] +fn filecheck() { + run_filecheck("AAA BBB CCC", r" + // check: AAA + // check: CCC + ").unwrap(); + + run_filecheck("AAA BBB CCC", r" + // check: AAA + // not: BBB + // check: CCC + ").unwrap_err(); +} + +macro_rules! eval_result { + ($status: expr, $($stage: expr, $output: expr),* $(,)*) => { + { + EvaluationResult { + stages: vec![$(($stage, $output.to_string())),*], + status: $status, + } + } + }; +} + +fn make_directives(s: &str) -> Vec { + s.lines() + .filter_map(|s| { + if let Ok(directive) = Directive::try_parse(s) { + return directive; + } + None + }) + .collect() +} + +#[rustfmt::skip] +#[test] +fn check_basic() { + let res = eval_result!( + Status::Success, + Stage::Compiler, "foo", + Stage::Verifier, "baz", + Stage::Runtime, "bar" + ); + + check(&res, &make_directives(r" + // check: foo + // stage: runtime + // check: bar + ")).unwrap(); + + check(&res, &make_directives(r" + // stage: compiler + // stage: verifier + // check: bar + ")).unwrap(); + + check(&res, &make_directives(r" + // stage: verifier + // check: foo + ")).unwrap_err(); + + check(&res, &make_directives(r" + // check: foo + // check: bar + ")).unwrap(); + + check(&res, &make_directives(r" + // check: baz + // check: foo + ")).unwrap_err(); +} + +#[rustfmt::skip] +#[test] +fn check_match_twice() { + let res = eval_result!( + Status::Success, + Stage::Compiler, "foo", + Stage::Verifier, "bar", + ); + + check(&res, &make_directives(r" + // check: foo + // check: foo + ")).unwrap_err(); + + check(&res, &make_directives(r" + // stage: compiler + // check: foo + // check: foo + // stage: verifier + ")).unwrap_err(); +} + +#[rustfmt::skip] +#[test] +fn check_no_stage() { + let res = eval_result!( + Status::Success, + Stage::Verifier, "", + ); + + check(&res, &make_directives(r" + // stage: verifier + ")).unwrap(); + + check(&res, &make_directives(r" + // stage: compiler + ")).unwrap_err(); + + check(&res, &make_directives(r" + // stage: runtime + ")).unwrap_err(); +} diff --git a/language/functional_tests/src/tests/config_tests.rs b/language/functional_tests/src/tests/config_tests.rs new file mode 100644 index 0000000000000..d442d1947cbfb --- /dev/null +++ b/language/functional_tests/src/tests/config_tests.rs @@ -0,0 +1,55 @@ +// Copyright (c) The Libra Core Contributors +// SPDX-License-Identifier: Apache-2.0 + +use crate::{ + config::{Config, ConfigEntry}, + errors::*, +}; + +#[test] +fn parse_config_entries() { + assert!(ConfigEntry::try_parse("abc").unwrap().is_none()); + ConfigEntry::try_parse("//! no-verify").unwrap().unwrap(); + ConfigEntry::try_parse("//!no-verify ").unwrap().unwrap(); + ConfigEntry::try_parse("//! no-execute").unwrap().unwrap(); + ConfigEntry::try_parse("//! no-execute ") + .unwrap() + .unwrap(); + ConfigEntry::try_parse("//!").unwrap_err(); + ConfigEntry::try_parse("//! ").unwrap_err(); + ConfigEntry::try_parse("//! garbage").unwrap_err(); +} + +fn parse_and_build_config(s: &str) -> Result { + let mut entries = vec![]; + for line in s.lines() { + if let Some(entry) = ConfigEntry::try_parse(line)? { + entries.push(entry); + } + } + Config::build(&entries) +} + +#[rustfmt::skip] +#[test] +fn build_config() { + parse_and_build_config(r"").unwrap(); + + parse_and_build_config(r" + //! no-verify + //! no-execute + ").unwrap(); + + parse_and_build_config(r" + //! no-execute + ").unwrap(); + + parse_and_build_config(r" + //! no-verify + ").unwrap(); + + parse_and_build_config(r" + //! no-verify + //! no-verify + ").unwrap_err(); +} diff --git a/language/functional_tests/src/tests/mod.rs b/language/functional_tests/src/tests/mod.rs new file mode 100644 index 0000000000000..cb6465d4878a9 --- /dev/null +++ b/language/functional_tests/src/tests/mod.rs @@ -0,0 +1,5 @@ +// Copyright (c) The Libra Core Contributors +// SPDX-License-Identifier: Apache-2.0 + +mod checker_tests; +mod config_tests; diff --git a/language/functional_tests/tests/testsuite.rs b/language/functional_tests/tests/testsuite.rs new file mode 100644 index 0000000000000..f6bc633296d4d --- /dev/null +++ b/language/functional_tests/tests/testsuite.rs @@ -0,0 +1,46 @@ +// Copyright (c) The Libra Core Contributors +// SPDX-License-Identifier: Apache-2.0 + +#![feature(custom_test_frameworks)] +#![test_runner(datatest::runner)] + +use functional_tests::{ + checker::{check, Directive}, + config::{Config, ConfigEntry}, + errors::*, + evaluator::eval, +}; + +fn parse_input(input: &str) -> Result<(Config, Vec, String)> { + let mut config_entries = vec![]; + let mut directives = vec![]; + let mut text = vec![]; + + for line in input.lines() { + if let Some(entry) = ConfigEntry::try_parse(line)? { + config_entries.push(entry); + continue; + } + if let Some(directive) = Directive::try_parse(line)? { + directives.push(directive); + continue; + } + text.push(line.to_string()); + } + + let text = text.join("\n"); + let config = Config::build(&config_entries)?; + Ok((config, directives, text)) +} + +// Runs all tests under the test/testsuite directory. +#[datatest::files("tests/testsuite", { input in r".*\.mvir" })] +fn functional_tests(input: &str) -> Result<()> { + let (config, directives, text) = parse_input(input)?; + let res = eval(&config, &text)?; + if let Err(e) = check(&res, &directives) { + println!("{:#?}", res); + return Err(e); + } + Ok(()) +} diff --git a/language/functional_tests/tests/testsuite/borrow_tests/borrow_copy_ok.mvir b/language/functional_tests/tests/testsuite/borrow_tests/borrow_copy_ok.mvir new file mode 100644 index 0000000000000..47e4668191c51 --- /dev/null +++ b/language/functional_tests/tests/testsuite/borrow_tests/borrow_copy_ok.mvir @@ -0,0 +1,31 @@ +modules: + +module B { + struct T {g: u64} + + public new(g: u64): V#Self.T { + return T{g: move(g)}; + } + + public t(this: &V#Self.T) { + let g: &u64; + let y: u64; + g = ©(this).g; + y = *move(g); + release(move(this)); + return; + } +} + +script: + +import Transaction.B; + +main() { + let x: V#B.T; + let y: &V#B.T; + x = B.new(5); + y = &x; + B.t(move(y)); + return; +} diff --git a/language/functional_tests/tests/testsuite/borrow_tests/borrow_field_ok.mvir b/language/functional_tests/tests/testsuite/borrow_tests/borrow_field_ok.mvir new file mode 100644 index 0000000000000..c95d18bfd8aff --- /dev/null +++ b/language/functional_tests/tests/testsuite/borrow_tests/borrow_field_ok.mvir @@ -0,0 +1,37 @@ +modules: + +module A { + + struct T{v: u64} + + struct K{f: V#Self.T} + + public new_T(v: u64) : V#Self.T { + return T{v: move(v)}; + } + + public new_K(f: V#Self.T) : V#Self.K { + return K{f: move(f)}; + } + + public value(this: &V#Self.K) : u64 { + let k: &u64; + k = &(&move(this).f).v; + return *move(k); + } +} + +script: + +import Transaction.A; + +main() { + let x: V#A.T; + let y: V#A.K; + let z: u64; + x = A.new_T(2); + y = A.new_K(move(x)); + z = A.value(&y); + assert(move(z) == 2, 42); + return; +} \ No newline at end of file diff --git a/language/functional_tests/tests/testsuite/borrow_tests/borrow_if.mvir b/language/functional_tests/tests/testsuite/borrow_tests/borrow_if.mvir new file mode 100644 index 0000000000000..eac87d4895276 --- /dev/null +++ b/language/functional_tests/tests/testsuite/borrow_tests/borrow_if.mvir @@ -0,0 +1,14 @@ +main() { + let x: u64; + let ref: &u64; + + x = 5; + if (true) { + ref = &x; + } + + assert(*move(ref) == 5, 42); + return; +} + +// check: VerificationError { kind: FunctionDefinition, idx: 0, err: MoveLocUnavailableError(8) } diff --git a/language/functional_tests/tests/testsuite/borrow_tests/borrow_local_bad.mvir b/language/functional_tests/tests/testsuite/borrow_tests/borrow_local_bad.mvir new file mode 100644 index 0000000000000..8ae56e7bd0166 --- /dev/null +++ b/language/functional_tests/tests/testsuite/borrow_tests/borrow_local_bad.mvir @@ -0,0 +1,26 @@ +modules: + +module A { + struct T{v: u64} + + public new(u: u64) : V#Self.T { + return T{u: move(u)}; + } + + public value(this: V#Self.T): u64 { + let f: &u64; + f = &(&mut this).v; + return *move(f); + } +} + +script: +import Transaction.A; +main() { + let x: V#A.T ; + let x_ref: u64; + x = A.new(5); + x_ref = A.value(move(x)); + assert(move(x_ref) == 5, 42); + return; +} \ No newline at end of file diff --git a/language/functional_tests/tests/testsuite/borrow_tests/borrow_move_ok.mvir b/language/functional_tests/tests/testsuite/borrow_tests/borrow_move_ok.mvir new file mode 100644 index 0000000000000..8a7f903f98dca --- /dev/null +++ b/language/functional_tests/tests/testsuite/borrow_tests/borrow_move_ok.mvir @@ -0,0 +1,31 @@ +modules: + +module M { + struct T {v: u64} + + public new(v: u64): V#Self.T { + return T{v: move(v)}; + } + + public value(this: &V#Self.T) : u64 { + let f: &u64; + //borrow of move + f = &move(this).v; + return *move(f); + } +} + +script: + +import Transaction.M; + +main() { + let x: V#M.T; + let x_ref: &V#M.T; + let y: u64; + x = M.new(5); + x_ref = &x; + y = M.value(move(x_ref)); + assert(move(y) == 5, 42); + return; +} \ No newline at end of file diff --git a/language/functional_tests/tests/testsuite/borrow_tests/borrow_parens_ok.mvir b/language/functional_tests/tests/testsuite/borrow_tests/borrow_parens_ok.mvir new file mode 100644 index 0000000000000..5e012e097ce83 --- /dev/null +++ b/language/functional_tests/tests/testsuite/borrow_tests/borrow_parens_ok.mvir @@ -0,0 +1,27 @@ +modules: + +module M { + struct T{v: u64} + + public new(v: u64): V#Self.T { + return T{v: move(v)}; + } + public value(this: &V#Self.T) : u64 { + let f: &u64; + f = &(move(this)).v; + return *move(f); + } +} + +script: + +import Transaction.M; + +main(){ + let x: V#M.T; + let y: u64; + x = M.new(5); + y = M.value(&x); + assert(move(y) == 5, 42); + return; +} \ No newline at end of file diff --git a/language/functional_tests/tests/testsuite/borrow_tests/borrow_x_in_if_y_in_else.mvir b/language/functional_tests/tests/testsuite/borrow_tests/borrow_x_in_if_y_in_else.mvir new file mode 100644 index 0000000000000..2f8ad293ab097 --- /dev/null +++ b/language/functional_tests/tests/testsuite/borrow_tests/borrow_x_in_if_y_in_else.mvir @@ -0,0 +1,18 @@ +main() { + let x: u64; + let y: u64; + let ref: &u64; + + x = 1; + y = 2; + + if (true) { + ref = &x; + } + else { + ref = &y; + } + + assert(*move(ref) == 1, 42); + return; +} diff --git a/language/functional_tests/tests/testsuite/builtins/create_account.mvir b/language/functional_tests/tests/testsuite/builtins/create_account.mvir new file mode 100644 index 0000000000000..3dbe3e10a5de7 --- /dev/null +++ b/language/functional_tests/tests/testsuite/builtins/create_account.mvir @@ -0,0 +1,22 @@ +import 0x0.LibraAccount; +import 0x0.LibraCoin; + +main() { + let addr: address; + let account_exists: bool; + let ten_coins: R#LibraCoin.T; + let account_exists_now: bool; + + addr = 0x0111111111111111111111111111111111111011111111111111111111111110; + account_exists = LibraAccount.exists(copy(addr)); + assert(!move(account_exists), 83); + + ten_coins = LibraAccount.withdraw_from_sender(10); + create_account(copy(addr)); + LibraAccount.deposit(copy(addr), move(ten_coins)); + + account_exists_now = LibraAccount.exists(copy(addr)); + assert(move(account_exists_now), 84); + + return; +} diff --git a/language/functional_tests/tests/testsuite/builtins/create_account_and_get_balance.mvir b/language/functional_tests/tests/testsuite/builtins/create_account_and_get_balance.mvir new file mode 100644 index 0000000000000..b1ae59fc5efa5 --- /dev/null +++ b/language/functional_tests/tests/testsuite/builtins/create_account_and_get_balance.mvir @@ -0,0 +1,22 @@ +import 0x0.LibraAccount; +import 0x0.LibraCoin; + +main() { + let addr2: address; + let account_exists: bool; + let ten_coins: R#LibraCoin.T; + let account2_balance: u64; + + addr2 = 0x0111111111111111111111101111111111111011111111111111111111101110; + account_exists = LibraAccount.exists(copy(addr2)); + assert(!move(account_exists), 83); + + ten_coins = LibraAccount.withdraw_from_sender(10); + create_account(copy(addr2)); + LibraAccount.deposit(copy(addr2), move(ten_coins)); + + account2_balance = LibraAccount.balance(move(addr2)); + assert(copy(account2_balance) == 10, 84); + + return; +} diff --git a/language/functional_tests/tests/testsuite/builtins/create_account_and_get_sequence_number.mvir b/language/functional_tests/tests/testsuite/builtins/create_account_and_get_sequence_number.mvir new file mode 100644 index 0000000000000..069bc91f4cde5 --- /dev/null +++ b/language/functional_tests/tests/testsuite/builtins/create_account_and_get_sequence_number.mvir @@ -0,0 +1,22 @@ +import 0x0.LibraAccount; +import 0x0.LibraCoin; + +main() { + let addr: address; + let account_exists: bool; + let ten_coins: R#LibraCoin.T; + let account2_sequence_number: u64; + + addr = 0x0111111111111111111111111111111111111011111111111111111111111110; + account_exists = LibraAccount.exists(copy(addr)); + assert(!move(account_exists), 83); + + ten_coins = LibraAccount.withdraw_from_sender(10); + create_account(copy(addr)); + LibraAccount.deposit(copy(addr), move(ten_coins)); + + account2_sequence_number = LibraAccount.sequence_number(move(addr)); + assert(copy(account2_sequence_number) == 0, 84); + + return; +} diff --git a/language/functional_tests/tests/testsuite/builtins/freeze_makes_imm.mvir b/language/functional_tests/tests/testsuite/builtins/freeze_makes_imm.mvir new file mode 100644 index 0000000000000..357c660129c14 --- /dev/null +++ b/language/functional_tests/tests/testsuite/builtins/freeze_makes_imm.mvir @@ -0,0 +1,13 @@ +main() { + let x: u64; + let x_ref: &mut u64; + let imm_ref: &u64; + + x = 5; + x_ref = &mut x; + imm_ref = freeze(move(x_ref)); + *move(imm_ref) = 0; + return; +} + +// check: VerificationError { kind: FunctionDefinition, idx: 0, err: WriteRefNoMutableReferenceError(9) } diff --git a/language/functional_tests/tests/testsuite/builtins/freeze_on_imm.mvir b/language/functional_tests/tests/testsuite/builtins/freeze_on_imm.mvir new file mode 100644 index 0000000000000..a19e573b33fdc --- /dev/null +++ b/language/functional_tests/tests/testsuite/builtins/freeze_on_imm.mvir @@ -0,0 +1,12 @@ +main() { + let x: u64; + let imm_ref: &u64; + + x = 5; + imm_ref = &x; + imm_ref = freeze(move(imm_ref)); + release(move(imm_ref)); + return; +} + +// check: VerificationError { kind: FunctionDefinition, idx: 0, err: FreezeRefTypeMismatchError(6) } diff --git a/language/functional_tests/tests/testsuite/builtins/freeze_valid.mvir b/language/functional_tests/tests/testsuite/builtins/freeze_valid.mvir new file mode 100644 index 0000000000000..c1240206b6802 --- /dev/null +++ b/language/functional_tests/tests/testsuite/builtins/freeze_valid.mvir @@ -0,0 +1,12 @@ +main() { + let x: u64; + let x_ref: &mut u64; + let imm_ref: &u64; + + x = 5; + x_ref = &mut x; + imm_ref = freeze(move(x_ref)); + release(move(imm_ref)); + + return; +} diff --git a/language/functional_tests/tests/testsuite/builtins/freeze_wrong_type.mvir b/language/functional_tests/tests/testsuite/builtins/freeze_wrong_type.mvir new file mode 100644 index 0000000000000..70ee1d58f5376 --- /dev/null +++ b/language/functional_tests/tests/testsuite/builtins/freeze_wrong_type.mvir @@ -0,0 +1,13 @@ +main() { + let x: u64; + let x_ref: &mut u64; + let imm_ref: &bool; + + x = 5; + x_ref = &mut x; + imm_ref = freeze(move(x_ref)); + release(move(imm_ref)); + return; +} + +// check: VerificationError { kind: FunctionDefinition, idx: 0, err: StLocTypeMismatchError(6) } diff --git a/language/functional_tests/tests/testsuite/builtins/get_missing_account.mvir b/language/functional_tests/tests/testsuite/builtins/get_missing_account.mvir new file mode 100644 index 0000000000000..339c4960c885a --- /dev/null +++ b/language/functional_tests/tests/testsuite/builtins/get_missing_account.mvir @@ -0,0 +1,14 @@ +import 0x0.LibraAccount; + +main() { + let addr: address; + let account_exists: bool; + let account_balance: u64; + + addr = 0x0111111111111111111111111111111111111111111111111111111111111110; + account_exists = LibraAccount.exists(copy(addr)); + account_balance = LibraAccount.balance(move(addr)); + return; +} + +// check: Execution(MissingData) diff --git a/language/functional_tests/tests/testsuite/builtins/get_missing_struct.mvir b/language/functional_tests/tests/testsuite/builtins/get_missing_struct.mvir new file mode 100644 index 0000000000000..a75fccd5a9a38 --- /dev/null +++ b/language/functional_tests/tests/testsuite/builtins/get_missing_struct.mvir @@ -0,0 +1,41 @@ +modules: +module Token { + resource T { } + + public new(): R#Self.T { + return T{ }; + } + + public has(addr: address): bool { + let yes: bool; + yes = exists(move(addr)); + return move(yes); + } + + public get(addr: address): &mut R#Self.T { + let t_ref: &mut R#Self.T; + t_ref = borrow_global(move(addr)); + return move(t_ref); + } + + public publish(t: R#Self.T) { + move_to_sender(move(t)); + return; + } + +} +script: +import Transaction.Token; + +main() { + let addr1: address; + let struct1: &mut R#Token.T; + + addr1 = 0x0111111111111111111111111111111111111111111111111111111111111110; + struct1 = Token.get(move(addr1)); + release(move(struct1)); + + return; +} + +// check: Execution(MissingData) diff --git a/language/functional_tests/tests/testsuite/builtins/get_published_resource.mvir b/language/functional_tests/tests/testsuite/builtins/get_published_resource.mvir new file mode 100644 index 0000000000000..2969c43f26d5d --- /dev/null +++ b/language/functional_tests/tests/testsuite/builtins/get_published_resource.mvir @@ -0,0 +1,54 @@ +modules: +module Token { + resource T {v: u64} + + public new(v: u64): R#Self.T { + return T{v: move(v)}; + } + + public value(this: &R#Self.T): u64 { + let vref: &u64; + vref = ©(this).v; + release(move(this)); + return *move(vref); + } + + public has(addr: address): bool { + let yes: bool; + yes = exists(move(addr)); + return move(yes); + } + + public get(addr: address): &mut R#Self.T { + let t_ref: &mut R#Self.T; + t_ref = borrow_global(move(addr)); + return move(t_ref); + } + + public publish(t: R#Self.T) { + move_to_sender(move(t)); + return; + } +} + +script: +import Transaction.Token; +main() { + let z: R#Token.T; + let addr1: address; + let struct1: &mut R#Token.T; + let imms1: &R#Token.T; + let struct1_original_balance: u64; + + z = Token.new(0); + Token.publish(move(z)); + + addr1 = get_txn_sender(); + struct1 = Token.get(move(addr1)); + imms1 = freeze(move(struct1)); + + struct1_original_balance = Token.value(move(imms1)); + assert(copy(struct1_original_balance) == 0, 42); + + return; +} diff --git a/language/functional_tests/tests/testsuite/builtins/get_txn_gas_unit_price.mvir b/language/functional_tests/tests/testsuite/builtins/get_txn_gas_unit_price.mvir new file mode 100644 index 0000000000000..8a49384fb3a08 --- /dev/null +++ b/language/functional_tests/tests/testsuite/builtins/get_txn_gas_unit_price.mvir @@ -0,0 +1,7 @@ +main() { + let gas_price: u64; + gas_price = get_txn_gas_unit_price(); + assert(move(gas_price) > 0, 42); + + return; +} diff --git a/language/functional_tests/tests/testsuite/builtins/get_txn_max_gas_units.mvir b/language/functional_tests/tests/testsuite/builtins/get_txn_max_gas_units.mvir new file mode 100644 index 0000000000000..3cf3869c60479 --- /dev/null +++ b/language/functional_tests/tests/testsuite/builtins/get_txn_max_gas_units.mvir @@ -0,0 +1,7 @@ +main() { + let gas_units: u64; + gas_units = get_txn_max_gas_units(); + assert(copy(gas_units) > 0, 42); + + return; +} diff --git a/language/functional_tests/tests/testsuite/builtins/get_txn_public_key.mvir b/language/functional_tests/tests/testsuite/builtins/get_txn_public_key.mvir new file mode 100644 index 0000000000000..6e2c42b5019f8 --- /dev/null +++ b/language/functional_tests/tests/testsuite/builtins/get_txn_public_key.mvir @@ -0,0 +1,13 @@ +main(input_public_key: bytearray) { + let public_key: bytearray; + + public_key = get_txn_public_key(); + assert(move(public_key) == move(input_public_key), 42); + + return; +} + +// we do not support transaction arguments yet +// TODO: fix the test once support is added + +// check: Verification([Script(TypeMismatch("Actual Type Mismatch"))]) diff --git a/language/functional_tests/tests/testsuite/builtins/get_txn_sender.mvir b/language/functional_tests/tests/testsuite/builtins/get_txn_sender.mvir new file mode 100644 index 0000000000000..40c3435895f1a --- /dev/null +++ b/language/functional_tests/tests/testsuite/builtins/get_txn_sender.mvir @@ -0,0 +1,11 @@ +main() { + let sender: address; + let addr: address; + + sender = get_txn_sender(); + addr = 0x0; + assert(copy(sender) != copy(addr), 42); + // TODO: do we really need to test this? + + return; +} diff --git a/language/functional_tests/tests/testsuite/builtins/has_published_struct.mvir b/language/functional_tests/tests/testsuite/builtins/has_published_struct.mvir new file mode 100644 index 0000000000000..ab26daddc290e --- /dev/null +++ b/language/functional_tests/tests/testsuite/builtins/has_published_struct.mvir @@ -0,0 +1,46 @@ +modules: +module Token { + resource T { } + + public new(): R#Self.T { + return T{ }; + } + + public has(addr: address): bool { + let yes: bool; + yes = exists(move(addr)); + return move(yes); + } + + public get(addr: address): &mut R#Self.T { + let t_ref: &mut R#Self.T; + t_ref = borrow_global(move(addr)); + return move(t_ref); + } + + public publish(t: R#Self.T) { + move_to_sender(move(t)); + return; + } +} +script: +import Transaction.Token; +main() { + let z: R#Token.T; + let sender: address; + let exists1: bool; + let addr1: address; + let exists2: bool; + + z = Token.new(); + Token.publish(move(z)); + sender = get_txn_sender(); + exists1 = Token.has(copy(sender)); + assert(copy(exists1), 42); + + addr1 = 0x0111111111111111111111111111111111111111111111111111111111111110; + exists2 = Token.has(copy(addr1)); + assert(!copy(exists2), 42); + + return; +} diff --git a/language/functional_tests/tests/testsuite/builtins/move_published_resource.mvir b/language/functional_tests/tests/testsuite/builtins/move_published_resource.mvir new file mode 100644 index 0000000000000..55e5e55c3f06f --- /dev/null +++ b/language/functional_tests/tests/testsuite/builtins/move_published_resource.mvir @@ -0,0 +1,70 @@ +modules: +module TestMoveFrom { + resource Counter { i: u64 } + + public has(): bool { + let sender_address: address; + let yes: bool; + + sender_address = get_txn_sender(); + yes = exists(move(sender_address)); + return move(yes); + } + + public increment() { + let sender_address: address; + let t_ref: &mut R#Self.Counter; + let counter_ref: &mut u64; + + sender_address = get_txn_sender(); + t_ref = borrow_global(move(sender_address)); + counter_ref = &mut copy(t_ref).i; + release(move(t_ref)); + *move(counter_ref) = *copy(counter_ref) + 1; + + return; + } + + public publish() { + let t: R#Self.Counter; + + t = Counter { i: 0 }; + move_to_sender(move(t)); + + return; + } + + public unpublish() { + let sender_address: address; + let counter: R#Self.Counter; + let i: u64; + + sender_address = get_txn_sender(); + counter = move_from(move(sender_address)); + Counter { i } = move(counter); + + return; + } + +} + +script: +import Transaction.TestMoveFrom; + +main() { + let has1: bool; + let has2: bool; + + TestMoveFrom.publish(); + TestMoveFrom.increment(); + + has1 = TestMoveFrom.has(); + assert(move(has1), 77); + + TestMoveFrom.unpublish(); + + has2 = TestMoveFrom.has(); + assert(!move(has2), 88); + + return; +} diff --git a/language/functional_tests/tests/testsuite/builtins/release.mvir b/language/functional_tests/tests/testsuite/builtins/release.mvir new file mode 100644 index 0000000000000..5272b54589b7f --- /dev/null +++ b/language/functional_tests/tests/testsuite/builtins/release.mvir @@ -0,0 +1,13 @@ +import 0x0.LibraAccount; +import 0x0.LibraCoin; + +main() { + let ten_coins: R#LibraCoin.T; + ten_coins = LibraAccount.withdraw_from_sender(10); + release(move(ten_coins)); + + return; +} + +// check: VerificationError +// check: ReleaseRefTypeMismatchError(4) diff --git a/language/functional_tests/tests/testsuite/builtins/verify_valid_bytearray.mvir b/language/functional_tests/tests/testsuite/builtins/verify_valid_bytearray.mvir new file mode 100644 index 0000000000000..2075b9042afaa --- /dev/null +++ b/language/functional_tests/tests/testsuite/builtins/verify_valid_bytearray.mvir @@ -0,0 +1,6 @@ +main() { + let byte_array: bytearray; + byte_array = b"deadbeefdeadbeef"; + assert(move(byte_array) == b"deadbeefdeadbeef", 42); + return; +} diff --git a/language/functional_tests/tests/testsuite/commands/assign_copy.mvir b/language/functional_tests/tests/testsuite/commands/assign_copy.mvir new file mode 100644 index 0000000000000..739eda2be198c --- /dev/null +++ b/language/functional_tests/tests/testsuite/commands/assign_copy.mvir @@ -0,0 +1,9 @@ +main() { + let x: u64; + let y: u64; + + x = 5; + y = copy(x); + assert(copy(y) == 5, 42); + return; +} diff --git a/language/functional_tests/tests/testsuite/commands/assign_in_one_if_branch.mvir b/language/functional_tests/tests/testsuite/commands/assign_in_one_if_branch.mvir new file mode 100644 index 0000000000000..21165f11b2667 --- /dev/null +++ b/language/functional_tests/tests/testsuite/commands/assign_in_one_if_branch.mvir @@ -0,0 +1,19 @@ +main() { + let x: u64; + let y: u64; + + if (true) { + x = 5; + } else { + assert(false, 42); + } + + if (true) { + y = 5; + } + + assert(copy(x) == copy(y), 42); + return; +} + +// check: VerificationError diff --git a/language/functional_tests/tests/testsuite/commands/assign_move.mvir b/language/functional_tests/tests/testsuite/commands/assign_move.mvir new file mode 100644 index 0000000000000..c3109c637063d --- /dev/null +++ b/language/functional_tests/tests/testsuite/commands/assign_move.mvir @@ -0,0 +1,9 @@ +main() { + let x: u64; + let y: u64; + + x = 5; + y = move(x); + assert(copy(y) == 5, 42); + return; +} diff --git a/language/functional_tests/tests/testsuite/commands/assign_resource.mvir b/language/functional_tests/tests/testsuite/commands/assign_resource.mvir new file mode 100644 index 0000000000000..f193a1bcf2025 --- /dev/null +++ b/language/functional_tests/tests/testsuite/commands/assign_resource.mvir @@ -0,0 +1,11 @@ +import 0x0.LibraCoin; + +main() { + let z: R#LibraCoin.T; + z = LibraCoin.zero(); + z = LibraCoin.zero(); + return; +} + +// check: VerificationError +// check: StLocUnsafeToDestroyError diff --git a/language/functional_tests/tests/testsuite/commands/assign_wrong_if_branch.mvir b/language/functional_tests/tests/testsuite/commands/assign_wrong_if_branch.mvir new file mode 100644 index 0000000000000..61212a4b09650 --- /dev/null +++ b/language/functional_tests/tests/testsuite/commands/assign_wrong_if_branch.mvir @@ -0,0 +1,13 @@ +main() { + let x: u64; + if (true) { + assert(true, 42); + } else { + x = 100; + } + + assert(copy(x) == 100, 42); + return; +} + +// check: VerificationError diff --git a/language/functional_tests/tests/testsuite/commands/assign_wrong_if_branch_no_else.mvir b/language/functional_tests/tests/testsuite/commands/assign_wrong_if_branch_no_else.mvir new file mode 100644 index 0000000000000..d332accca0d50 --- /dev/null +++ b/language/functional_tests/tests/testsuite/commands/assign_wrong_if_branch_no_else.mvir @@ -0,0 +1,11 @@ +main() { + let x: u64; + if (false) { + x = 100; + } + + assert(copy(x) == 100, 42); + return; +} + +//check: VerificationError diff --git a/language/functional_tests/tests/testsuite/commands/assign_wrong_type.mvir b/language/functional_tests/tests/testsuite/commands/assign_wrong_type.mvir new file mode 100644 index 0000000000000..c563a8db584aa --- /dev/null +++ b/language/functional_tests/tests/testsuite/commands/assign_wrong_type.mvir @@ -0,0 +1,6 @@ +// check: VerificationError { kind: FunctionDefinition, idx: 0, err: StLocTypeMismatchError(1) } +main() { + let x: u64; + x = false; + return; +} diff --git a/language/functional_tests/tests/testsuite/commands/branch_assigns_then_moves.mvir b/language/functional_tests/tests/testsuite/commands/branch_assigns_then_moves.mvir new file mode 100644 index 0000000000000..f48c65310c795 --- /dev/null +++ b/language/functional_tests/tests/testsuite/commands/branch_assigns_then_moves.mvir @@ -0,0 +1,17 @@ +main() { + let x: u64; + let y: u64; + + if (true) { + x = 1; + y = move(x); + } else { + x = 0; + } + + assert(copy(x) == 5, 42); + return; +} + +// check: VerificationError +// check: CopyLocUnavailableError diff --git a/language/functional_tests/tests/testsuite/commands/branch_assigns_then_moves_then_assigns.mvir b/language/functional_tests/tests/testsuite/commands/branch_assigns_then_moves_then_assigns.mvir new file mode 100644 index 0000000000000..07168b1d8de92 --- /dev/null +++ b/language/functional_tests/tests/testsuite/commands/branch_assigns_then_moves_then_assigns.mvir @@ -0,0 +1,15 @@ +main() { + let x: u64; + let y: u64; + + if (true) { + x = 1; + y = move(x); + x = 5; + } else { + x = 0; + } + + assert(copy(x) == 5, 42); + return; +} diff --git a/language/functional_tests/tests/testsuite/commands/break_accumulator.mvir b/language/functional_tests/tests/testsuite/commands/break_accumulator.mvir new file mode 100644 index 0000000000000..528ed34743c37 --- /dev/null +++ b/language/functional_tests/tests/testsuite/commands/break_accumulator.mvir @@ -0,0 +1,12 @@ +main() { + let x: u64; + x = 0; + while (true) { + if (copy(x) >= 5) { + break; + } + x = move(x) + 1; + } + assert(move(x) == 5, 42); + return; +} diff --git a/language/functional_tests/tests/testsuite/commands/break_continue_simple.mvir b/language/functional_tests/tests/testsuite/commands/break_continue_simple.mvir new file mode 100644 index 0000000000000..61506c27650bf --- /dev/null +++ b/language/functional_tests/tests/testsuite/commands/break_continue_simple.mvir @@ -0,0 +1,14 @@ +main() { + let x: u64; + x = 0; + while (true) { + if (copy(x) >= 5) { + break; + } + x = move(x) + 1; + continue; + assert(false, 42); + } + assert(move(x) == 5, 42); + return; +} diff --git a/language/functional_tests/tests/testsuite/commands/break_continue_sum_of_odds.mvir b/language/functional_tests/tests/testsuite/commands/break_continue_sum_of_odds.mvir new file mode 100644 index 0000000000000..37cfcbb3fccd0 --- /dev/null +++ b/language/functional_tests/tests/testsuite/commands/break_continue_sum_of_odds.mvir @@ -0,0 +1,20 @@ +main() { + let x: u64; + let y: u64; + x = 0; + y = 0; + loop { + if (copy(x) < 10) { + x = move(x) + 1; + if (copy(x) % 2 == 0) { + continue; + } + y = move(y) + copy(x); + } + else { + break; + } + } + assert(move(y) == 25, 42); + return; +} diff --git a/language/functional_tests/tests/testsuite/commands/break_nested.mvir b/language/functional_tests/tests/testsuite/commands/break_nested.mvir new file mode 100644 index 0000000000000..369eb7ad9c232 --- /dev/null +++ b/language/functional_tests/tests/testsuite/commands/break_nested.mvir @@ -0,0 +1,17 @@ +main() { + let x: u64; + let y: u64; + x = 0; + y = 0; + while (true) { + loop { + y = 5; + break; + } + x = 3; + break; + } + assert(move(x) == 3, 42); + assert(move(y) == 5, 42); + return; +} diff --git a/language/functional_tests/tests/testsuite/commands/break_outside_loop.mvir b/language/functional_tests/tests/testsuite/commands/break_outside_loop.mvir new file mode 100644 index 0000000000000..701b5518dd9fc --- /dev/null +++ b/language/functional_tests/tests/testsuite/commands/break_outside_loop.mvir @@ -0,0 +1,6 @@ +main() { + break; + return; +} + +// check: break outside loop diff --git a/language/functional_tests/tests/testsuite/commands/break_outside_loop_in_else.mvir b/language/functional_tests/tests/testsuite/commands/break_outside_loop_in_else.mvir new file mode 100644 index 0000000000000..e9910a5115e5e --- /dev/null +++ b/language/functional_tests/tests/testsuite/commands/break_outside_loop_in_else.mvir @@ -0,0 +1,10 @@ +main() { + if (false) { + } + else { + break; + } + return; +} + +// check: break outside loop diff --git a/language/functional_tests/tests/testsuite/commands/break_outside_loop_in_if.mvir b/language/functional_tests/tests/testsuite/commands/break_outside_loop_in_if.mvir new file mode 100644 index 0000000000000..eeebc96687243 --- /dev/null +++ b/language/functional_tests/tests/testsuite/commands/break_outside_loop_in_if.mvir @@ -0,0 +1,8 @@ +main() { + if (true) { + break; + } + return; +} + +// check: break outside loop diff --git a/language/functional_tests/tests/testsuite/commands/break_simple.mvir b/language/functional_tests/tests/testsuite/commands/break_simple.mvir new file mode 100644 index 0000000000000..2433841e6551d --- /dev/null +++ b/language/functional_tests/tests/testsuite/commands/break_simple.mvir @@ -0,0 +1,10 @@ +main() { + let x: u64; + x = 0; + while (true) { + x = move(x) + 1; + break; + } + assert(move(x) == 1, 42); + return; +} diff --git a/language/functional_tests/tests/testsuite/commands/break_unreachable.mvir b/language/functional_tests/tests/testsuite/commands/break_unreachable.mvir new file mode 100644 index 0000000000000..f87aa544f217d --- /dev/null +++ b/language/functional_tests/tests/testsuite/commands/break_unreachable.mvir @@ -0,0 +1,12 @@ +main() { + let x: u64; + x = 1; + while (true) { + x = 3; + break; + x = 5; + break; + } + assert(move(x) == 3, 42); + return; +} diff --git a/language/functional_tests/tests/testsuite/commands/continue_outside_loop.mvir b/language/functional_tests/tests/testsuite/commands/continue_outside_loop.mvir new file mode 100644 index 0000000000000..3c75b9a15339a --- /dev/null +++ b/language/functional_tests/tests/testsuite/commands/continue_outside_loop.mvir @@ -0,0 +1,6 @@ +main() { + continue; + return; +} + +// check: continue outside loop diff --git a/language/functional_tests/tests/testsuite/commands/continue_outside_loop_in_if.mvir b/language/functional_tests/tests/testsuite/commands/continue_outside_loop_in_if.mvir new file mode 100644 index 0000000000000..cd7b93d023161 --- /dev/null +++ b/language/functional_tests/tests/testsuite/commands/continue_outside_loop_in_if.mvir @@ -0,0 +1,8 @@ +main() { + if (true) { + continue; + } + return; +} + +// check: continue outside loop diff --git a/language/functional_tests/tests/testsuite/commands/dead_return.mvir b/language/functional_tests/tests/testsuite/commands/dead_return.mvir new file mode 100644 index 0000000000000..00b7305387202 --- /dev/null +++ b/language/functional_tests/tests/testsuite/commands/dead_return.mvir @@ -0,0 +1,19 @@ +// TODO: we might need to modify this test if we enable unreachable code +// checking in the verifier. + +modules: +module Test { + public t(): u64 { + return 100; + return 0; + } +} + +script: +import Transaction.Test; +main() { + let x: u64; + x = Test.t(); + assert(copy(x) == 100, 42); + return; +} diff --git a/language/functional_tests/tests/testsuite/commands/dead_return_local.mvir b/language/functional_tests/tests/testsuite/commands/dead_return_local.mvir new file mode 100644 index 0000000000000..ebe084bc83cb6 --- /dev/null +++ b/language/functional_tests/tests/testsuite/commands/dead_return_local.mvir @@ -0,0 +1,7 @@ +// TODO: should the verifier allow this? + +main() { + return; + assert(false, 42); + return; +} diff --git a/language/functional_tests/tests/testsuite/commands/deep_return_branch_doesnt_assign.mvir b/language/functional_tests/tests/testsuite/commands/deep_return_branch_doesnt_assign.mvir new file mode 100644 index 0000000000000..099ef4434a905 --- /dev/null +++ b/language/functional_tests/tests/testsuite/commands/deep_return_branch_doesnt_assign.mvir @@ -0,0 +1,16 @@ +main() { + let x: u64; + + if (true) { + if (false) { + return; + } else { + return; + } + } else { + x = 0; + } + + assert(copy(x) == 5, 42); + return; +} diff --git a/language/functional_tests/tests/testsuite/commands/else_assigns_if_doesnt.mvir b/language/functional_tests/tests/testsuite/commands/else_assigns_if_doesnt.mvir new file mode 100644 index 0000000000000..69a63f2e066ad --- /dev/null +++ b/language/functional_tests/tests/testsuite/commands/else_assigns_if_doesnt.mvir @@ -0,0 +1,13 @@ +main() { + let x: u64; + let y: u64; + if (true) { + y = 0; + } else { + x = 42; + } + assert(copy(y) == 0, 42); + return; +} + +// check: VerificationError { kind: FunctionDefinition, idx: 0, err: CopyLocUnavailableError(8) } diff --git a/language/functional_tests/tests/testsuite/commands/else_moves_if_doesnt.mvir b/language/functional_tests/tests/testsuite/commands/else_moves_if_doesnt.mvir new file mode 100644 index 0000000000000..2d1593a5fa54f --- /dev/null +++ b/language/functional_tests/tests/testsuite/commands/else_moves_if_doesnt.mvir @@ -0,0 +1,15 @@ +main() { + let x: u64; + let y: u64; + + x = 0; + if (true) { + y = 0; + } else { + y = move(x); + } + assert(copy(x) == 0, 42); + return; +} + +// check: VerificationError diff --git a/language/functional_tests/tests/testsuite/commands/if_assigns_else_doesnt.mvir b/language/functional_tests/tests/testsuite/commands/if_assigns_else_doesnt.mvir new file mode 100644 index 0000000000000..70777dde048a0 --- /dev/null +++ b/language/functional_tests/tests/testsuite/commands/if_assigns_else_doesnt.mvir @@ -0,0 +1,14 @@ +main() { + let x: u64; + let y: u64; + if (true) { + x = 42; + } else { + y = 0; + } + assert(copy(x) == 42, 42); + return; +} + +// check: VerificationError +// check: CopyLocUnavailableError diff --git a/language/functional_tests/tests/testsuite/commands/if_assigns_no_else.mvir b/language/functional_tests/tests/testsuite/commands/if_assigns_no_else.mvir new file mode 100644 index 0000000000000..ad7ae59bf44bc --- /dev/null +++ b/language/functional_tests/tests/testsuite/commands/if_assigns_no_else.mvir @@ -0,0 +1,11 @@ +main() { + let x: u64; + if (true) { + x = 42; + } + assert(copy(x) == 42, 42); + return; +} + +// check: VerificationError +// check: CopyLocUnavailableError diff --git a/language/functional_tests/tests/testsuite/commands/if_branch_diverges_1.mvir b/language/functional_tests/tests/testsuite/commands/if_branch_diverges_1.mvir new file mode 100644 index 0000000000000..07c3dbc3fc3c5 --- /dev/null +++ b/language/functional_tests/tests/testsuite/commands/if_branch_diverges_1.mvir @@ -0,0 +1,9 @@ +main() { + let ret_if_val: bool; + ret_if_val = true; + if (move(ret_if_val)) { + return; + } else { + return; + } +} \ No newline at end of file diff --git a/language/functional_tests/tests/testsuite/commands/if_branch_diverges_10.mvir b/language/functional_tests/tests/testsuite/commands/if_branch_diverges_10.mvir new file mode 100644 index 0000000000000..00e9faaf72704 --- /dev/null +++ b/language/functional_tests/tests/testsuite/commands/if_branch_diverges_10.mvir @@ -0,0 +1,10 @@ +main() { + let ret_if_val: bool; + ret_if_val = true; + if (move(ret_if_val)) { + loop {return; break; } + } else { + assert(false, 42); + return; + } +} \ No newline at end of file diff --git a/language/functional_tests/tests/testsuite/commands/if_branch_diverges_2.mvir b/language/functional_tests/tests/testsuite/commands/if_branch_diverges_2.mvir new file mode 100644 index 0000000000000..ff669d65a2258 --- /dev/null +++ b/language/functional_tests/tests/testsuite/commands/if_branch_diverges_2.mvir @@ -0,0 +1,9 @@ +main() { + let ret_if_val: bool; + ret_if_val = true; + if (move(ret_if_val)) { + return; + } + assert(false, 42); + return; +} \ No newline at end of file diff --git a/language/functional_tests/tests/testsuite/commands/if_branch_diverges_3.mvir b/language/functional_tests/tests/testsuite/commands/if_branch_diverges_3.mvir new file mode 100644 index 0000000000000..c9c0333326ff7 --- /dev/null +++ b/language/functional_tests/tests/testsuite/commands/if_branch_diverges_3.mvir @@ -0,0 +1,11 @@ +main() { + let ret_if_val: bool; + ret_if_val = true; + if (move(ret_if_val)) { + return; + } else { + assert(false, 42); + } + assert(false, 43); + return; +} \ No newline at end of file diff --git a/language/functional_tests/tests/testsuite/commands/if_branch_diverges_4.mvir b/language/functional_tests/tests/testsuite/commands/if_branch_diverges_4.mvir new file mode 100644 index 0000000000000..affcc31e7e3e8 --- /dev/null +++ b/language/functional_tests/tests/testsuite/commands/if_branch_diverges_4.mvir @@ -0,0 +1,10 @@ +main() { + let ret_if_val: bool; + ret_if_val = true; + if (move(ret_if_val)) { + if (true) { return; } else { return; } + } else { + assert(false, 42); + return; + } +} \ No newline at end of file diff --git a/language/functional_tests/tests/testsuite/commands/if_branch_diverges_5.mvir b/language/functional_tests/tests/testsuite/commands/if_branch_diverges_5.mvir new file mode 100644 index 0000000000000..7c1862c263985 --- /dev/null +++ b/language/functional_tests/tests/testsuite/commands/if_branch_diverges_5.mvir @@ -0,0 +1,17 @@ +// check: VerificationError +// check: IndexOutOfBounds + +main() { + let ret_if_val: bool; + ret_if_val = true; + if (move(ret_if_val)) { + // This must index out-of-bounds + // - We will not insert a return, it is not the job of the IR to do so + // - It must jump somewhere, or it will fall through to the else branch + // - As such, the jump here is to the 'end of the else' which is an invalid jump + loop { break; } + } else { + assert(false, 42); + return; + } +} \ No newline at end of file diff --git a/language/functional_tests/tests/testsuite/commands/if_branch_diverges_6.mvir b/language/functional_tests/tests/testsuite/commands/if_branch_diverges_6.mvir new file mode 100644 index 0000000000000..9c3d4e7c724d0 --- /dev/null +++ b/language/functional_tests/tests/testsuite/commands/if_branch_diverges_6.mvir @@ -0,0 +1,13 @@ +// check: VerificationError +// check: IndexOutOfBounds + +main() { + let ret_if_val: bool; + ret_if_val = true; + if (move(ret_if_val)) { + loop { if (true) { return; } else { break; } } + } else { + assert(false, 42); + return; + } +} \ No newline at end of file diff --git a/language/functional_tests/tests/testsuite/commands/if_branch_diverges_7.mvir b/language/functional_tests/tests/testsuite/commands/if_branch_diverges_7.mvir new file mode 100644 index 0000000000000..53cb42d896c62 --- /dev/null +++ b/language/functional_tests/tests/testsuite/commands/if_branch_diverges_7.mvir @@ -0,0 +1,10 @@ +main() { + let ret_if_val: bool; + ret_if_val = true; + if (move(ret_if_val)) { + loop { if (true) { return; } else { continue; }; break; } + } else { + assert(false, 42); + return; + } +} \ No newline at end of file diff --git a/language/functional_tests/tests/testsuite/commands/if_branch_diverges_8.mvir b/language/functional_tests/tests/testsuite/commands/if_branch_diverges_8.mvir new file mode 100644 index 0000000000000..b0336bda909fc --- /dev/null +++ b/language/functional_tests/tests/testsuite/commands/if_branch_diverges_8.mvir @@ -0,0 +1,13 @@ +// check: VerificationError +// check: IndexOutOfBounds + +main() { + let ret_if_val: bool; + ret_if_val = true; + if (move(ret_if_val)) { + loop { break; if (true) { return; } else { continue; }; } + } else { + assert(false, 42); + return; + } +} \ No newline at end of file diff --git a/language/functional_tests/tests/testsuite/commands/if_branch_diverges_9.mvir b/language/functional_tests/tests/testsuite/commands/if_branch_diverges_9.mvir new file mode 100644 index 0000000000000..502704ed5b03d --- /dev/null +++ b/language/functional_tests/tests/testsuite/commands/if_branch_diverges_9.mvir @@ -0,0 +1,15 @@ +main() { + let b: bool; + b = false; + loop { + if (copy(b)) { + if (copy(b)) { + continue; + } + } + else { + break; + } + } + return; +} diff --git a/language/functional_tests/tests/testsuite/commands/if_moves_else_doesnt.mvir b/language/functional_tests/tests/testsuite/commands/if_moves_else_doesnt.mvir new file mode 100644 index 0000000000000..bb71af9b57f3a --- /dev/null +++ b/language/functional_tests/tests/testsuite/commands/if_moves_else_doesnt.mvir @@ -0,0 +1,14 @@ +main() { + let x: u64; + let y: u64; + x = 0; + if (true) { + y = move(x); + } else { + y = 0; + } + assert(copy(x) == 0, 42); + return; +} + +// check: VerificationError diff --git a/language/functional_tests/tests/testsuite/commands/if_moves_no_else.mvir b/language/functional_tests/tests/testsuite/commands/if_moves_no_else.mvir new file mode 100644 index 0000000000000..0902f0bba2760 --- /dev/null +++ b/language/functional_tests/tests/testsuite/commands/if_moves_no_else.mvir @@ -0,0 +1,13 @@ +main() { + let x: u64; + let y: u64; + x = 0; + if (true) { + y = move(x); + } + assert(copy(x) == 0, 42); + return; +} + +// check: VerificationError +// check: CopyLocUnavailableError diff --git a/language/functional_tests/tests/testsuite/commands/if_without_braces_1.mvir b/language/functional_tests/tests/testsuite/commands/if_without_braces_1.mvir new file mode 100644 index 0000000000000..24443a4138780 --- /dev/null +++ b/language/functional_tests/tests/testsuite/commands/if_without_braces_1.mvir @@ -0,0 +1,10 @@ +main() { + let x: u64; + if (true) + x = 3; + else + x = 5; + return; +} + +// check: ParserError diff --git a/language/functional_tests/tests/testsuite/commands/if_without_braces_2.mvir b/language/functional_tests/tests/testsuite/commands/if_without_braces_2.mvir new file mode 100644 index 0000000000000..d78fce7925346 --- /dev/null +++ b/language/functional_tests/tests/testsuite/commands/if_without_braces_2.mvir @@ -0,0 +1,11 @@ +main() { + let x: u64; + if (true) { + x = 3; + } + else + x = 5; + return; +} + +// check: ParserError diff --git a/language/functional_tests/tests/testsuite/commands/if_without_braces_3.mvir b/language/functional_tests/tests/testsuite/commands/if_without_braces_3.mvir new file mode 100644 index 0000000000000..020dc7ad536b0 --- /dev/null +++ b/language/functional_tests/tests/testsuite/commands/if_without_braces_3.mvir @@ -0,0 +1,11 @@ +main() { + let x: u64; + if (true) + x = 3; + else { + x = 5 + }; + return; +} + +// check: ParserError diff --git a/language/functional_tests/tests/testsuite/commands/if_without_braces_4.mvir b/language/functional_tests/tests/testsuite/commands/if_without_braces_4.mvir new file mode 100644 index 0000000000000..82b812678c421 --- /dev/null +++ b/language/functional_tests/tests/testsuite/commands/if_without_braces_4.mvir @@ -0,0 +1,8 @@ +main() { + let x: u64; + if (true) + x = 3; + return; +} + +// check: ParserError diff --git a/language/functional_tests/tests/testsuite/commands/local_assigned_many_times.mvir b/language/functional_tests/tests/testsuite/commands/local_assigned_many_times.mvir new file mode 100644 index 0000000000000..0bae0d39948d3 --- /dev/null +++ b/language/functional_tests/tests/testsuite/commands/local_assigned_many_times.mvir @@ -0,0 +1,31 @@ +main() { + let x: u64; + let y: u64; + let x_ref: &u64; + let y_ref: &u64; + + x = 5; + y = 2; + + x_ref = &x; + y_ref = &y; + release(move(y_ref)); + release(move(x_ref)); + + x_ref = &x; + y_ref = move(x_ref); + + if (true) { + release(move(y_ref)); + x_ref = &y; + y_ref = &x; + } else { + release(move(y_ref)); + x_ref = &x; + y_ref = &y; + } + + assert(*move(x_ref) == 2, 42); + release(move(y_ref)); + return; +} diff --git a/language/functional_tests/tests/testsuite/commands/loop_nested_breaks.mvir b/language/functional_tests/tests/testsuite/commands/loop_nested_breaks.mvir new file mode 100644 index 0000000000000..4bc80a4ab9690 --- /dev/null +++ b/language/functional_tests/tests/testsuite/commands/loop_nested_breaks.mvir @@ -0,0 +1,11 @@ +main() { + loop { + loop { + break; + assert(false, 42); + } + break; + assert(false, 42); + } + return; +} diff --git a/language/functional_tests/tests/testsuite/commands/loop_return.mvir b/language/functional_tests/tests/testsuite/commands/loop_return.mvir new file mode 100644 index 0000000000000..a287723bf7758 --- /dev/null +++ b/language/functional_tests/tests/testsuite/commands/loop_return.mvir @@ -0,0 +1,7 @@ +main() { + loop { + return; + assert(false, 42); + } + return; +} diff --git a/language/functional_tests/tests/testsuite/commands/loop_simple.mvir b/language/functional_tests/tests/testsuite/commands/loop_simple.mvir new file mode 100644 index 0000000000000..08b07f9f1bc77 --- /dev/null +++ b/language/functional_tests/tests/testsuite/commands/loop_simple.mvir @@ -0,0 +1,10 @@ +main() { + let x: u64; + x = 0; + loop { + x = move(x) + 1; + break; + } + assert(move(x) == 1, 42); + return; +} diff --git a/language/functional_tests/tests/testsuite/commands/move_before_assign.mvir b/language/functional_tests/tests/testsuite/commands/move_before_assign.mvir new file mode 100644 index 0000000000000..cf573aa8f4389 --- /dev/null +++ b/language/functional_tests/tests/testsuite/commands/move_before_assign.mvir @@ -0,0 +1,9 @@ +main() { + let x: u64; + let y: u64; + y = move(x); + return; +} + +// check: VerificationError +// check: MoveLocUnavailableError diff --git a/language/functional_tests/tests/testsuite/commands/no_let_outside_if.mvir b/language/functional_tests/tests/testsuite/commands/no_let_outside_if.mvir new file mode 100644 index 0000000000000..6e5d62178ed4b --- /dev/null +++ b/language/functional_tests/tests/testsuite/commands/no_let_outside_if.mvir @@ -0,0 +1,11 @@ +main() { + if (true) { + y = 5; + } else { + y = 0; + } + assert(copy(y) == 5, 42); + return; +} + +// check: variable y undefined diff --git a/language/functional_tests/tests/testsuite/commands/no_rebind.mvir b/language/functional_tests/tests/testsuite/commands/no_rebind.mvir new file mode 100644 index 0000000000000..87f3816f68607 --- /dev/null +++ b/language/functional_tests/tests/testsuite/commands/no_rebind.mvir @@ -0,0 +1,7 @@ +main() { + let x: u64; + let x: u64; + return; +} + +// check: variable redefinition x diff --git a/language/functional_tests/tests/testsuite/commands/return_branch_doesnt_assign.mvir b/language/functional_tests/tests/testsuite/commands/return_branch_doesnt_assign.mvir new file mode 100644 index 0000000000000..eb9b0c800fcf4 --- /dev/null +++ b/language/functional_tests/tests/testsuite/commands/return_branch_doesnt_assign.mvir @@ -0,0 +1,12 @@ +main() { + let x: u64; + + if (true) { + return; + } else { + x = 0; + } + + assert(copy(x) == 5, 42); + return; +} diff --git a/language/functional_tests/tests/testsuite/commands/return_branch_moves.mvir b/language/functional_tests/tests/testsuite/commands/return_branch_moves.mvir new file mode 100644 index 0000000000000..4e94a58dc6ac9 --- /dev/null +++ b/language/functional_tests/tests/testsuite/commands/return_branch_moves.mvir @@ -0,0 +1,13 @@ +main() { + let x: u64; + let y: u64; + + x = 0; + if (false) { + y = move(x); + return; + } + + assert(copy(x) == 0, 42); + return; +} diff --git a/language/functional_tests/tests/testsuite/commands/return_in_if_branch_taken.mvir b/language/functional_tests/tests/testsuite/commands/return_in_if_branch_taken.mvir new file mode 100644 index 0000000000000..0b991100907a7 --- /dev/null +++ b/language/functional_tests/tests/testsuite/commands/return_in_if_branch_taken.mvir @@ -0,0 +1,22 @@ +modules: +module Test { + public t(): u64 { + let x: u64; + if (true) { + return 100; + } else { + x = 0; + } + return copy(x); + } +} + +script: +import Transaction.Test; +main() { + let x: u64; + x = Test.t(); + assert(copy(x) == 100, 42); + return; +} + diff --git a/language/functional_tests/tests/testsuite/commands/return_in_if_branch_taken_local.mvir b/language/functional_tests/tests/testsuite/commands/return_in_if_branch_taken_local.mvir new file mode 100644 index 0000000000000..ec67fb5657712 --- /dev/null +++ b/language/functional_tests/tests/testsuite/commands/return_in_if_branch_taken_local.mvir @@ -0,0 +1,10 @@ +main() { + let x: u64; + if (true) { + return; + } else { + x = 0; + } + assert(false, 42); + return; +} diff --git a/language/functional_tests/tests/testsuite/commands/return_in_if_branch_taken_no_else.mvir b/language/functional_tests/tests/testsuite/commands/return_in_if_branch_taken_no_else.mvir new file mode 100644 index 0000000000000..e59cc15521b5d --- /dev/null +++ b/language/functional_tests/tests/testsuite/commands/return_in_if_branch_taken_no_else.mvir @@ -0,0 +1,18 @@ +modules: +module Test { + public t(): u64 { + if (true) { + return 100; + } + return 0; + } +} + +script: +import Transaction.Test; +main() { + let x: u64; + x = Test.t(); + assert(copy(x) == 100, 42); + return; +} diff --git a/language/functional_tests/tests/testsuite/commands/unpack_extra_binding.mvir b/language/functional_tests/tests/testsuite/commands/unpack_extra_binding.mvir new file mode 100644 index 0000000000000..8a9ea80d2c750 --- /dev/null +++ b/language/functional_tests/tests/testsuite/commands/unpack_extra_binding.mvir @@ -0,0 +1,41 @@ +modules: +module Test { + struct X { } + struct T { i: u64, x: V#Self.X } + + public new_t(): V#Self.T { + let x: V#Self.T; + x = X { }; + return T { i: 0, x: move(x), b: false }; + } + + public destroy_t(t: V#Self.T): u64 * V#Self.X * bool { + let i: u64; + let x: V#Self.X; + let flag: bool; + T { i, x, b: flag } = move(t); + return move(i), move(x), move(flag); + } + +} +script: +import Transaction.Test; + +main() { + let x: V#Test.X; + let i: u64; + let t: V#Test.X; + let b: bool; + + t = Test.new_t(); + i, x, b = Test.destroy_t(move(t)); + + return; +} + +// check: VerificationError +// check: PositiveStackSizeAtBlockEnd + +// check: VerificationError +// check: NegativeStackSizeInsideBlock + diff --git a/language/functional_tests/tests/testsuite/commands/unpack_missing_binding.mvir b/language/functional_tests/tests/testsuite/commands/unpack_missing_binding.mvir new file mode 100644 index 0000000000000..c68828b5c3889 --- /dev/null +++ b/language/functional_tests/tests/testsuite/commands/unpack_missing_binding.mvir @@ -0,0 +1,37 @@ +modules: +module Test { + struct X { } + struct T { i: u64, x: V#Self.X, b: bool, y: u64 } + + public new_t(): V#Self.T { + let x: V#Self.X; + x = X { }; + return T { i: 0, x: move(x), b: false, y: 0 }; + } + + public destroy_t(t: V#Self.T): u64 * V#Self.X * bool { + let i: u64; + let x: V#Self.X; + let flag: bool; + T { i, x, b: flag } = move(t); + return move(i), move(x), move(flag); + } + +} +script: +import Transaction.Test; + +main() { + let x: V#Test.X; + let i: u64; + let t: V#Test.T; + let b: bool; + + t = Test.new_t(); + i, x, b = Test.destroy_t(move(t)); + + return; +} + +// check: VerificationError +// check: PositiveStackSizeAtBlockEnd diff --git a/language/functional_tests/tests/testsuite/commands/unpack_resource.mvir b/language/functional_tests/tests/testsuite/commands/unpack_resource.mvir new file mode 100644 index 0000000000000..fb281cc080a0e --- /dev/null +++ b/language/functional_tests/tests/testsuite/commands/unpack_resource.mvir @@ -0,0 +1,44 @@ +modules: +module Test { + resource X { } + resource T { i: u64, x: R#Self.X, b: bool } + + public new_x(): R#Self.X { + return X { }; + } + + public new_t(x: R#Self.X): R#Self.T { + return T { i: 0, x: move(x), b: false }; + } + + public destroy_x(x: R#Self.X) { + X { } = move(x); + return; + } + + public destroy_t(t: R#Self.T): u64 * R#Self.X * bool { + let i: u64; + let x: R#Self.X; + let flag: bool; + T { i, x, b: flag } = move(t); + return move(i), move(x), move(flag); + } + +} + +script: +import Transaction.Test; + +main() { + let x: R#Test.X; + let i: u64; + let t: R#Test.T; + let b: bool; + + x = Test.new_x(); + t = Test.new_t(move(x)); + i, x, b = Test.destroy_t(move(t)); + Test.destroy_x(move(x)); + + return; +} diff --git a/language/functional_tests/tests/testsuite/commands/unpack_top_level.mvir b/language/functional_tests/tests/testsuite/commands/unpack_top_level.mvir new file mode 100644 index 0000000000000..e5e289b2df0b6 --- /dev/null +++ b/language/functional_tests/tests/testsuite/commands/unpack_top_level.mvir @@ -0,0 +1,22 @@ +modules: +module Test { + struct T { } + + public new_t(): V#Self.T { + return T { }; + } + +} +script: +import Transaction.Test; + +main() { + let t: V#Test.T; + + t = Test.new_t(); + T { } = move(t); + + return; +} + +// check: no struct definition referencing in scripts diff --git a/language/functional_tests/tests/testsuite/commands/unpack_wrong_type.mvir b/language/functional_tests/tests/testsuite/commands/unpack_wrong_type.mvir new file mode 100644 index 0000000000000..45dc044565fe0 --- /dev/null +++ b/language/functional_tests/tests/testsuite/commands/unpack_wrong_type.mvir @@ -0,0 +1,32 @@ +modules: +module Test { + struct X { } + struct T { } + + public new_t(): V#Self.T { + return T { }; + } + + public destroy_t(t: V#Self.T) { + X { } = move(t); + return; + } + +} +script: +import Transaction.Test; + +main() { + let x: V#Test.X; + let i: u64; + let t: V#Test.T; + let b: bool; + + t = Test.new_t(); + Test.destroy_t(move(t)); + + return; +} + +// check: VerificationError +// check: UnpackTypeMismatchError diff --git a/language/functional_tests/tests/testsuite/commands/use_before_assign.mvir b/language/functional_tests/tests/testsuite/commands/use_before_assign.mvir new file mode 100644 index 0000000000000..b056659cc9427 --- /dev/null +++ b/language/functional_tests/tests/testsuite/commands/use_before_assign.mvir @@ -0,0 +1,9 @@ +main() { + let x: u64; + let y: u64; + y = copy(x); + return; +} + +// check: VerificationError +// check: CopyLocUnavailableError diff --git a/language/functional_tests/tests/testsuite/commands/while_false.mvir b/language/functional_tests/tests/testsuite/commands/while_false.mvir new file mode 100644 index 0000000000000..d610a1fa95409 --- /dev/null +++ b/language/functional_tests/tests/testsuite/commands/while_false.mvir @@ -0,0 +1,9 @@ +main() { + let x: u64; + x = 0; + while (false) { + x = 1; + } + assert(move(x) == 0, 42); + return; +} diff --git a/language/functional_tests/tests/testsuite/commands/while_move_local.mvir b/language/functional_tests/tests/testsuite/commands/while_move_local.mvir new file mode 100644 index 0000000000000..637fd101eb975 --- /dev/null +++ b/language/functional_tests/tests/testsuite/commands/while_move_local.mvir @@ -0,0 +1,16 @@ +// check: VerificationError +// check: MoveLocUnavailableError + +main() { + let x: u64; + let y: u64; + let b: bool; + x = 0; + b = true; + while (copy(b)) { + y = move(x); + b = false; + } + assert(move(y) == 0, 42); + return; +} diff --git a/language/functional_tests/tests/testsuite/commands/while_move_local_2.mvir b/language/functional_tests/tests/testsuite/commands/while_move_local_2.mvir new file mode 100644 index 0000000000000..b6ac987ee43ae --- /dev/null +++ b/language/functional_tests/tests/testsuite/commands/while_move_local_2.mvir @@ -0,0 +1,19 @@ +// check: VerificationError +// check: MoveLocUnavailableError + +main() { + let x: u64; + let y: u64; + let b: bool; + x = 0; + b = true; + while (true) { + if (copy(b)) { + y = move(x); + } else { + x = move(y); + } + b = false; + } + return; +} diff --git a/language/functional_tests/tests/testsuite/commands/while_nested.mvir b/language/functional_tests/tests/testsuite/commands/while_nested.mvir new file mode 100644 index 0000000000000..4d9d2bdc6a870 --- /dev/null +++ b/language/functional_tests/tests/testsuite/commands/while_nested.mvir @@ -0,0 +1,17 @@ +main() { + let x: u64; + let y: u64; + let z: u64; + x = 0; + z = 0; + while (copy(x) < 3) { + x = move(x) + 1; + y = 0; + while (copy(y) < 7) { + y = move(y) + 1; + z = move(z) + 1; + } + } + assert(move(z) == 21, 42); + return; +} diff --git a/language/functional_tests/tests/testsuite/commands/while_nested_return.mvir b/language/functional_tests/tests/testsuite/commands/while_nested_return.mvir new file mode 100644 index 0000000000000..7f6bab7f2813f --- /dev/null +++ b/language/functional_tests/tests/testsuite/commands/while_nested_return.mvir @@ -0,0 +1,11 @@ +main() { + while (true) { + while (true) { + return; + assert(false, 42); + } + assert(false, 42); + } + assert(false, 42); + return; +} diff --git a/language/functional_tests/tests/testsuite/commands/while_return.mvir b/language/functional_tests/tests/testsuite/commands/while_return.mvir new file mode 100644 index 0000000000000..6ecc86e25b68e --- /dev/null +++ b/language/functional_tests/tests/testsuite/commands/while_return.mvir @@ -0,0 +1,7 @@ +main() { + while (true) { + return; + assert(false, 42); + } + return; +} diff --git a/language/functional_tests/tests/testsuite/commands/while_simple.mvir b/language/functional_tests/tests/testsuite/commands/while_simple.mvir new file mode 100644 index 0000000000000..ffc0a460da15d --- /dev/null +++ b/language/functional_tests/tests/testsuite/commands/while_simple.mvir @@ -0,0 +1,9 @@ +main() { + let x: u64; + x = 0; + while (copy(x) < 5) { + x = copy(x) + 1; + } + assert(move(x) == 5, 42); + return; +} diff --git a/language/functional_tests/tests/testsuite/comments/multi_line_comment_commented_by_single_line_comment.mvir b/language/functional_tests/tests/testsuite/comments/multi_line_comment_commented_by_single_line_comment.mvir new file mode 100644 index 0000000000000..0a2a6f16f0371 --- /dev/null +++ b/language/functional_tests/tests/testsuite/comments/multi_line_comment_commented_by_single_line_comment.mvir @@ -0,0 +1,5 @@ +main() { + // /* + return; + // */ +} diff --git a/language/functional_tests/tests/testsuite/comments/multiple_single_line_comments_mixed.mvir b/language/functional_tests/tests/testsuite/comments/multiple_single_line_comments_mixed.mvir new file mode 100644 index 0000000000000..23c417849bffc --- /dev/null +++ b/language/functional_tests/tests/testsuite/comments/multiple_single_line_comments_mixed.mvir @@ -0,0 +1,4 @@ +main() { + // This is a comment + return; // This is another comment +} diff --git a/language/functional_tests/tests/testsuite/comments/multiple_single_line_comments_own_line.mvir b/language/functional_tests/tests/testsuite/comments/multiple_single_line_comments_own_line.mvir new file mode 100644 index 0000000000000..68d6d4eb5b382 --- /dev/null +++ b/language/functional_tests/tests/testsuite/comments/multiple_single_line_comments_own_line.mvir @@ -0,0 +1,5 @@ +main() { + // This is a comment + // This is another comment + return; +} diff --git a/language/functional_tests/tests/testsuite/comments/single_line_comment_line.mvir b/language/functional_tests/tests/testsuite/comments/single_line_comment_line.mvir new file mode 100644 index 0000000000000..c5aae1abadac0 --- /dev/null +++ b/language/functional_tests/tests/testsuite/comments/single_line_comment_line.mvir @@ -0,0 +1,3 @@ +main() { + return; // This is a comment +} diff --git a/language/functional_tests/tests/testsuite/comments/single_line_comment_own_line.mvir b/language/functional_tests/tests/testsuite/comments/single_line_comment_own_line.mvir new file mode 100644 index 0000000000000..6fc91aa5f3167 --- /dev/null +++ b/language/functional_tests/tests/testsuite/comments/single_line_comment_own_line.mvir @@ -0,0 +1,4 @@ +main() { + // This is a comment + return; +} diff --git a/language/functional_tests/tests/testsuite/data_types/empty_structs.mvir b/language/functional_tests/tests/testsuite/data_types/empty_structs.mvir new file mode 100644 index 0000000000000..82e94155702c5 --- /dev/null +++ b/language/functional_tests/tests/testsuite/data_types/empty_structs.mvir @@ -0,0 +1,14 @@ +modules: +module Test { + resource NonEmpty1 { f1: u64, f2: bool } + + struct Empty1 { } + resource Empty2 { } + + resource NonEmpty2 { f1: bool, f2: u64 } +} + +script: +main() { + return; +} diff --git a/language/functional_tests/tests/testsuite/dereference_tests/deref_borrow_field_ok.mvir b/language/functional_tests/tests/testsuite/dereference_tests/deref_borrow_field_ok.mvir new file mode 100644 index 0000000000000..d97079a0e68b7 --- /dev/null +++ b/language/functional_tests/tests/testsuite/dereference_tests/deref_borrow_field_ok.mvir @@ -0,0 +1,29 @@ +modules: + +module M { + struct T{f: u64} + + public new(g: u64): V#Self.T { + return T{g: move(g)}; + } + + public t(this: &V#Self.T) { + let y: u64; + y = *&move(this).f; + assert(copy(y) == 2, 42); + return; + } +} + +script: + +import Transaction.M; + +main(){ + let x: V#M.T; + let x_ref: &V#M.T; + x = M.new(2); + x_ref = &x; + M.t(move(x_ref)); + return; +} \ No newline at end of file diff --git a/language/functional_tests/tests/testsuite/dereference_tests/deref_borrow_local_ok.mvir b/language/functional_tests/tests/testsuite/dereference_tests/deref_borrow_local_ok.mvir new file mode 100644 index 0000000000000..b7d9e6bcaf3d1 --- /dev/null +++ b/language/functional_tests/tests/testsuite/dereference_tests/deref_borrow_local_ok.mvir @@ -0,0 +1,6 @@ +main() { + let x: u64; + x = 5; + assert(*&mut x == 5, 42); + return; +} diff --git a/language/functional_tests/tests/testsuite/dereference_tests/deref_copy_bad.mvir b/language/functional_tests/tests/testsuite/dereference_tests/deref_copy_bad.mvir new file mode 100644 index 0000000000000..654c7e982f646 --- /dev/null +++ b/language/functional_tests/tests/testsuite/dereference_tests/deref_copy_bad.mvir @@ -0,0 +1,15 @@ +// check: VerificationError + +// Assigning local after move fails. + +main() { + let x: u64; + let x_ref: &mut u64; + let dead: u64; + x = 5; + x_ref = &mut x; + assert(*copy(x_ref) == 5, 42); + dead = move(x); + *move(x_ref) = 42; + return; +} \ No newline at end of file diff --git a/language/functional_tests/tests/testsuite/dereference_tests/deref_move_module_ok.mvir b/language/functional_tests/tests/testsuite/dereference_tests/deref_move_module_ok.mvir new file mode 100644 index 0000000000000..ecef59a786a76 --- /dev/null +++ b/language/functional_tests/tests/testsuite/dereference_tests/deref_move_module_ok.mvir @@ -0,0 +1,30 @@ +modules: + +module M { + struct T {v : u64} + + public new(v: u64): V#Self.T { + return T{v: move(v)}; + } + + public value (this: &mut V#Self.T) : u64 { + let b: &u64; + b = &move(this).v; + return *move(b); + } +} + +script: + +import Transaction.M; + +main() { + let x: V#M.T; + let y: &mut V#M.T; + let z: u64; + x = M.new(5); + y = &mut x; + z = M.value(move(y)); + assert(move(z) == 5, 42); + return; +} diff --git a/language/functional_tests/tests/testsuite/dereference_tests/deref_move_ok.mvir b/language/functional_tests/tests/testsuite/dereference_tests/deref_move_ok.mvir new file mode 100644 index 0000000000000..28bec4d16581e --- /dev/null +++ b/language/functional_tests/tests/testsuite/dereference_tests/deref_move_ok.mvir @@ -0,0 +1,8 @@ +main() { + let x: u64; + let x_ref: &u64; + x = 5; + x_ref = &x; + assert(*move(x_ref) == 5, 42); + return; +} diff --git a/language/functional_tests/tests/testsuite/dereference_tests/deref_non_reference.mvir b/language/functional_tests/tests/testsuite/dereference_tests/deref_non_reference.mvir new file mode 100644 index 0000000000000..4c10ed11a0200 --- /dev/null +++ b/language/functional_tests/tests/testsuite/dereference_tests/deref_non_reference.mvir @@ -0,0 +1,10 @@ +// check: VerificationError +// check: ReadRefTypeMismatchError + +main() { + let x: u64; + let y: u64; + x = 0; + y = *move(x); + return; +} \ No newline at end of file diff --git a/language/functional_tests/tests/testsuite/dereference_tests/deref_not_reference_bad.mvir b/language/functional_tests/tests/testsuite/dereference_tests/deref_not_reference_bad.mvir new file mode 100644 index 0000000000000..df81ea25f4be5 --- /dev/null +++ b/language/functional_tests/tests/testsuite/dereference_tests/deref_not_reference_bad.mvir @@ -0,0 +1,10 @@ +// check: VerificationError +// check: BooleanOpTypeMismatchError + +main() { + let x: u64; + let y: u64; + x = 0; + y = *!move(x); + return; +} \ No newline at end of file diff --git a/language/functional_tests/tests/testsuite/dereference_tests/deref_parens_ok.mvir b/language/functional_tests/tests/testsuite/dereference_tests/deref_parens_ok.mvir new file mode 100644 index 0000000000000..7059a192daa6a --- /dev/null +++ b/language/functional_tests/tests/testsuite/dereference_tests/deref_parens_ok.mvir @@ -0,0 +1,10 @@ +main() { + let x: u64; + let x_ref: &u64; + x = 5; + x_ref = &x; + assert(*(copy(x_ref)) == 5, 42); + assert(*(move(x_ref)) == 5, 42); + assert(*(&mut x) == 5, 42); + return; +} \ No newline at end of file diff --git a/language/functional_tests/tests/testsuite/examples/multiple_stages.mvir b/language/functional_tests/tests/testsuite/examples/multiple_stages.mvir new file mode 100644 index 0000000000000..d89418251c078 --- /dev/null +++ b/language/functional_tests/tests/testsuite/examples/multiple_stages.mvir @@ -0,0 +1,42 @@ +modules: +module M { + public foo(x: u64): u64 { + return copy(x)*2; + } +} + +module N { + public bar(x: u64): u64 { + return copy(x)*3; + } +} + +script: +import Transaction.M; +import Transaction.N; + +main() { + let x: u64; + let y: u64; + x = N.bar(7); + y = M.foo(move(x)); + assert(move(y) == 41, 42); + return; +} + +// The following directives check that +// 1) There are two function definitions in the compiled module. +// 2) There are no verification errors. +// 3) There is an AssertionFailure in the transaction output. + +// stage: compiler +// check: CompiledModule +// check: FunctionDefinition +// check: CompiledModule +// check: FunctionDefinition + +// stage: verifier +// not: VerificationError + +// stage: runtime +// check: AssertionFailure diff --git a/language/functional_tests/tests/testsuite/examples/no_execute.mvir b/language/functional_tests/tests/testsuite/examples/no_execute.mvir new file mode 100644 index 0000000000000..91ef93f724fca --- /dev/null +++ b/language/functional_tests/tests/testsuite/examples/no_execute.mvir @@ -0,0 +1,6 @@ +//! no-execute + +main() { + assert(false, 42); + return; +} diff --git a/language/functional_tests/tests/testsuite/examples/script.mvir b/language/functional_tests/tests/testsuite/examples/script.mvir new file mode 100644 index 0000000000000..c678d618a5a0b --- /dev/null +++ b/language/functional_tests/tests/testsuite/examples/script.mvir @@ -0,0 +1,3 @@ +main() { + return; +} diff --git a/language/functional_tests/tests/testsuite/examples/script_with_module.mvir b/language/functional_tests/tests/testsuite/examples/script_with_module.mvir new file mode 100644 index 0000000000000..c3689fd2fe95b --- /dev/null +++ b/language/functional_tests/tests/testsuite/examples/script_with_module.mvir @@ -0,0 +1,11 @@ +modules: +module M { + foo() { + return; + } +} + +script: +main() { + return; +} diff --git a/language/functional_tests/tests/testsuite/examples/simple.mvir b/language/functional_tests/tests/testsuite/examples/simple.mvir new file mode 100644 index 0000000000000..0c3b60698043b --- /dev/null +++ b/language/functional_tests/tests/testsuite/examples/simple.mvir @@ -0,0 +1,9 @@ +// check: VerificationError + +main() { + let x: u64; + let y: u64; + x = 3; + y = move(x); + return move(x); +} diff --git a/language/functional_tests/tests/testsuite/expressions/address_equality.mvir b/language/functional_tests/tests/testsuite/expressions/address_equality.mvir new file mode 100644 index 0000000000000..d2de1128c2e7c --- /dev/null +++ b/language/functional_tests/tests/testsuite/expressions/address_equality.mvir @@ -0,0 +1,31 @@ +main() { + let a1: address; + let a2: address; + let a3: address; + let a4: address; + let a5: address; + let a6: address; + let a7: address; + let a8: address; + let a9: address; + + a1 = 0x1; + a2 = 0x01; + a3 = 0x0001; + a4 = 0x00000001; + a5 = 0x0000000000000001; + a6 = 0x00000000000000000000000000000001; + a7 = 0x000000000000000000000000000000001; + a8 = 0x000000000000000000000000000000000000000000000000000000000000001; + a9 = 0x0000000000000000000000000000000000000000000000000000000000000001; + + assert(copy(a1) == copy(a2), 42); + assert(copy(a2) == copy(a3), 43); + assert(copy(a3) == copy(a4), 44); + assert(copy(a4) == copy(a5), 45); + assert(copy(a5) == copy(a6), 46); + assert(copy(a6) == copy(a7), 47); + assert(copy(a7) == copy(a8), 48); + assert(copy(a8) == copy(a9), 49); + return; +} \ No newline at end of file diff --git a/language/functional_tests/tests/testsuite/expressions/cant_borrow_field_of_resource.mvir b/language/functional_tests/tests/testsuite/expressions/cant_borrow_field_of_resource.mvir new file mode 100644 index 0000000000000..f1d8347ca2cc7 --- /dev/null +++ b/language/functional_tests/tests/testsuite/expressions/cant_borrow_field_of_resource.mvir @@ -0,0 +1,52 @@ +// check: no field type referencing in scripts +// TODO is it possible to get this to compile to let the bytecode verifier complain? + +modules: +module Token { + resource T {v: u64} + + public new(v: u64): R#Self.T { + return T{v: move(v)}; + } + + public value(this: &R#Self.T): u64 { + let vref: &u64; + vref = ©(this).v; + release(move(this)); + return *move(vref); + } + + public exists(addr: address): bool { + let yes: bool; + yes = exists(move(addr)); + return move(yes); + } + + public get(addr: address): &mut R#Self.T { + let t_ref: &mut R#Self.T; + t_ref = borrow_global(move(addr)); + return move(t_ref); + } + + public publish(t: R#Self.T) { + move_to_sender(move(t)); + return; + } + +} +script: +import Transaction.Token; +main() { + let addr: address; + let t: R#Token.T; + let tref: &mut R#Token.T; + let balance_ref: &u64; + + addr = get_txn_sender(); + t = Token.new(1); + Token.publish(move(t)); + tref = Token.get(move(addr)); + balance_ref = &move(tref).v; + release(move(balance_ref)); + return; +} \ No newline at end of file diff --git a/language/functional_tests/tests/testsuite/expressions/cant_deref_resource.mvir b/language/functional_tests/tests/testsuite/expressions/cant_deref_resource.mvir new file mode 100644 index 0000000000000..8d2d0ca5d50a9 --- /dev/null +++ b/language/functional_tests/tests/testsuite/expressions/cant_deref_resource.mvir @@ -0,0 +1,51 @@ +// check: VerificationError { kind: FunctionDefinition, idx: 0, err: ReadRefResourceError(11) } + +modules: +module Token { + resource T {v: u64} + + public new(v: u64): R#Self.T { + return T{v: move(v)}; + } + + public value(this: &R#Self.T): u64 { + let vref: &u64; + let res: u64; + vref = &move(this).v; + res = *move(vref); + return move(res); + } + + public exists(addr: address): bool { + let yes: bool; + yes = exists(move(addr)); + return move(yes); + } + + public get(addr: address): &mut R#Self.T { + let t_ref: &mut R#Self.T; + t_ref = borrow_global(move(addr)); + return move(t_ref); + } + + public publish(t: R#Self.T) { + move_to_sender(move(t)); + return; + } + +} +script: +import Transaction.Token; +main() { + let addr: address; + let t: R#Token.T; + let tref: &mut R#Token.T; + let y: R#Token.T; + + addr = get_txn_sender(); + t = Token.new(0); + Token.publish(move(t)); + tref = Token.get(move(addr)); + y = *move(tref); + return; +} diff --git a/language/functional_tests/tests/testsuite/expressions/deref_value.mvir b/language/functional_tests/tests/testsuite/expressions/deref_value.mvir new file mode 100644 index 0000000000000..bcf868e15e02b --- /dev/null +++ b/language/functional_tests/tests/testsuite/expressions/deref_value.mvir @@ -0,0 +1,30 @@ +modules: +module A { + struct T{f: u64} + + public new(f: u64): V#Self.T { + return T{f: move(f)}; + } + + public t(this: &V#Self.T) { + let f: &u64; + let y: u64; + f = ©(this).f; + y = *move(f); + assert(copy(y) == 2, 42); + release(move(this)); + return; + } +} + +script: +import Transaction.A; +main() { + let x: V#A.T; + let x_ref: &V#A.T; + + x = A.new(2); + x_ref = &x; + A.t(move(x_ref)); + return; +} diff --git a/language/functional_tests/tests/testsuite/expressions/deref_value_nested.mvir b/language/functional_tests/tests/testsuite/expressions/deref_value_nested.mvir new file mode 100644 index 0000000000000..b2ea666b67eb4 --- /dev/null +++ b/language/functional_tests/tests/testsuite/expressions/deref_value_nested.mvir @@ -0,0 +1,49 @@ +modules: +module B { + struct T{g: u64} + + public new(g: u64): V#Self.T { + return T{g: move(g)}; + } + + public t(this: &V#Self.T) { + let g: &u64; + let y: u64; + g = &move(this).g; + y = *move(g); + assert(copy(y) == 2, 42); + return; + } +} + +module A { + import Transaction.B; + + struct T{f: V#B.T} + + public new(f: V#B.T): V#Self.T { + return T{f: move(f)}; + } + + public t(this: &V#Self.T) { + let f: &V#B.T; + f = &move(this).f; + B.t(move(f)); + return; + } +} + +script: +import Transaction.A; +import Transaction.B; +main() { + let b: V#B.T; + let x: V#A.T; + let x_ref: &V#A.T; + + b = B.new(2); + x = A.new(move(b)); + x_ref = &x; + A.t(move(x_ref)); + return; +} \ No newline at end of file diff --git a/language/functional_tests/tests/testsuite/expressions/multiple_return_values.mvir b/language/functional_tests/tests/testsuite/expressions/multiple_return_values.mvir new file mode 100644 index 0000000000000..6683d06b20b7e --- /dev/null +++ b/language/functional_tests/tests/testsuite/expressions/multiple_return_values.mvir @@ -0,0 +1,35 @@ +modules: +module Test { + resource T { } + + public new(): R#Self.T { + return T { }; + } + + public test(i: &u64, x: R#Self.T): u64 * R#Self.T * bool { + return *move(i), move(x), false; + } + + public destroy(x: R#Self.T) { + T { } = move(x); + return; + } +} +script: +import Transaction.Test; +main() { + let i: u64; + let t: R#Test.T; + let a: u64; + let x: R#Test.T; + let b: bool; + + i = 0; + t = Test.new(); + a, x, b = Test.test(&i, move(t)); + assert(move(a) == 0, 42); + Test.destroy(move(x)); + assert(!move(b), 43); + + return; +} \ No newline at end of file diff --git a/language/functional_tests/tests/testsuite/expressions/multiple_return_values_extra_binding.mvir b/language/functional_tests/tests/testsuite/expressions/multiple_return_values_extra_binding.mvir new file mode 100644 index 0000000000000..e6793fdf2fe26 --- /dev/null +++ b/language/functional_tests/tests/testsuite/expressions/multiple_return_values_extra_binding.mvir @@ -0,0 +1,38 @@ +// check: VerificationError { kind: FunctionDefinition, idx: 0, err: NegativeStackSizeInsideBlock(0, 11) } + +modules: +module Test { + resource T { } + + public new(): R#Self.T { + return T { }; + } + + public test(i: &u64, x: R#Self.T): u64 * R#Self.T * bool { + return *move(i), move(x), false; + } + + public destroy(x: R#Self.T) { + T { } = move(x); + return; + } +} +script: +import Transaction.Test; +main() { + let i: u64; + let t: R#Test.T; + let a: u64; + let b: bool; + let x: R#Test.T; + let z: address; + + i = 0; + t = Test.new(); + a, x, b, z = Test.test(&i, move(t)); + assert(move(a) == 0, 42); + Test.destroy(move(x)); + assert(!move(b), 43); + + return; +} \ No newline at end of file diff --git a/language/functional_tests/tests/testsuite/expressions/multiple_return_values_extra_value.mvir b/language/functional_tests/tests/testsuite/expressions/multiple_return_values_extra_value.mvir new file mode 100644 index 0000000000000..c3cc670c9d820 --- /dev/null +++ b/language/functional_tests/tests/testsuite/expressions/multiple_return_values_extra_value.mvir @@ -0,0 +1,37 @@ +// check: VerificationError { kind: FunctionDefinition, idx: 1, err: PositiveStackSizeAtBlockEnd(0) } + +modules: +module Test { + resource T { } + + public new(): R#Self.T { + return T { }; + } + + public test(i: &u64, x: R#Self.T): u64 * R#Self.T * bool { + return *move(i), move(x), false, 0x0; + } + + public destroy(x: R#Self.T) { + T { } = move(x); + return; + } +} +script: +import Transaction.Test; +main() { + let i: u64; + let t: R#Test.T; + let a: u64; + let b: bool; + let x: R#Test.T; + + i = 0; + t = Test.new(); + a, x, b = Test.test(&i, move(t)); + assert(move(a) == 0, 42); + Test.destroy(move(x)); + assert(!move(b), 43); + + return; +} \ No newline at end of file diff --git a/language/functional_tests/tests/testsuite/expressions/multiple_return_values_missing_binding.mvir b/language/functional_tests/tests/testsuite/expressions/multiple_return_values_missing_binding.mvir new file mode 100644 index 0000000000000..0a7317eeaaf1a --- /dev/null +++ b/language/functional_tests/tests/testsuite/expressions/multiple_return_values_missing_binding.mvir @@ -0,0 +1,35 @@ +// check: VerificationError { kind: FunctionDefinition, idx: 0, err: PositiveStackSizeAtBlockEnd(0) } + +modules: +module Test { + resource T { } + + public new(): R#Self.T { + return T { }; + } + + public test(i: &u64, x: R#Self.T): u64 * R#Self.T * bool { + return *move(i), move(x), false; + } + + public destroy(x: R#Self.T) { + T { } = move(x); + return; + } +} +script: +import Transaction.Test; +main() { + let i: u64; + let t: R#Test.T; + let a: u64; + let x: R#Test.T; + + i = 0; + t = Test.new(); + a, x = Test.test(&i, move(t)); + assert(move(a) == 0, 42); + Test.destroy(move(x)); + + return; +} \ No newline at end of file diff --git a/language/functional_tests/tests/testsuite/expressions/multiple_return_values_missing_value.mvir b/language/functional_tests/tests/testsuite/expressions/multiple_return_values_missing_value.mvir new file mode 100644 index 0000000000000..b26c49d8e10b6 --- /dev/null +++ b/language/functional_tests/tests/testsuite/expressions/multiple_return_values_missing_value.mvir @@ -0,0 +1,37 @@ +// check: VerificationError { kind: FunctionDefinition, idx: 1, err: NegativeStackSizeInsideBlock(0, 3) } + +modules: +module Test { + resource T { } + + public new(): R#Self.T { + return T { }; + } + + public test(i: &u64, x: R#Self.T): u64 * R#Self.T * bool { + return *move(i), move(x); + } + + public destroy(x: R#Self.T) { + T { } = move(x); + return; + } +} +script: +import Transaction.Test; +main() { + let i: u64; + let t: R#Test.T; + let a: u64; + let b: bool; + let x: R#Test.T; + + i = 0; + t = Test.new(); + a, x, b = Test.test(&i, move(t)); + assert(move(a) == 0, 42); + Test.destroy(move(x)); + assert(!move(b), 43); + + return; +} \ No newline at end of file diff --git a/language/functional_tests/tests/testsuite/failure/failure_rollover.mvir b/language/functional_tests/tests/testsuite/failure/failure_rollover.mvir new file mode 100644 index 0000000000000..cd9033d7c80f4 --- /dev/null +++ b/language/functional_tests/tests/testsuite/failure/failure_rollover.mvir @@ -0,0 +1,17 @@ +import 0x0.LibraAccount; +import 0x0.LibraCoin; + +main() { + let addr: address; + let ten_coins: R#LibraCoin.T; + + addr = get_txn_sender(); + ten_coins = LibraAccount.withdraw_from_sender(10); + LibraAccount.deposit(copy(addr), move(ten_coins)); + LibraAccount.deposit(move(addr), move(ten_coins)); + + return; +} + +// check: VerificationError +// check: MoveLocUnavailableError(9) diff --git a/language/functional_tests/tests/testsuite/global_ref_count/decrement_dereference.mvir b/language/functional_tests/tests/testsuite/global_ref_count/decrement_dereference.mvir new file mode 100644 index 0000000000000..9675eb3ba853b --- /dev/null +++ b/language/functional_tests/tests/testsuite/global_ref_count/decrement_dereference.mvir @@ -0,0 +1,30 @@ +modules: +module Test { + resource T { i: u64 } + + public test() { + let t: R#Self.T; + let t_ref: &mut R#Self.T; + let i_ref: &u64; + let sender: address; + + t = T { i: 0 }; + move_to_sender(move(t)); + + sender = get_txn_sender(); + t_ref = borrow_global(copy(sender)); + i_ref = ©(t_ref).i; + release(move(t_ref)); + assert(*move(i_ref) == 0, 42); + + t_ref = borrow_global(copy(sender)); + release(move(t_ref)); + return; + } +} +script: +import Transaction.Test; +main() { + Test.test(); + return; +} diff --git a/language/functional_tests/tests/testsuite/global_ref_count/decrement_emit_event.mvir b/language/functional_tests/tests/testsuite/global_ref_count/decrement_emit_event.mvir new file mode 100644 index 0000000000000..c3c77e9f0011d --- /dev/null +++ b/language/functional_tests/tests/testsuite/global_ref_count/decrement_emit_event.mvir @@ -0,0 +1,34 @@ +modules: +module Test { + resource T { i: u64 } + struct Event { } + public test() { + let t: R#Self.T; + let t_ref: &mut R#Self.T; + let i_ref: &mut u64; + let sender: address; + let event: V#Self.Event; + let event_id: bytearray; + + t = T { i: 0 }; + move_to_sender(move(t)); + + sender = get_txn_sender(); + t_ref = borrow_global(copy(sender)); + i_ref = &mut copy(t_ref).i; + event = Event { }; + event_id = b"69"; + emit_event(move(i_ref), move(event_id), move(event)); + release(move(t_ref)); + + t_ref = borrow_global(copy(sender)); + release(move(t_ref)); + return; + } +} +script: +import Transaction.Test; +main() { + Test.test(); + return; +} diff --git a/language/functional_tests/tests/testsuite/global_ref_count/decrement_write.mvir b/language/functional_tests/tests/testsuite/global_ref_count/decrement_write.mvir new file mode 100644 index 0000000000000..d8e370e4665db --- /dev/null +++ b/language/functional_tests/tests/testsuite/global_ref_count/decrement_write.mvir @@ -0,0 +1,30 @@ +modules: +module Test { + resource T { i: u64 } + + public test() { + let t: R#Self.T; + let t_ref: &mut R#Self.T; + let i_ref: &mut u64; + let sender: address; + + t = T { i: 0 }; + move_to_sender(move(t)); + + sender = get_txn_sender(); + t_ref = borrow_global(copy(sender)); + i_ref = &mut copy(t_ref).i; + release(move(t_ref)); + *move(i_ref) = 1; + + t_ref = borrow_global(copy(sender)); + release(move(t_ref)); + return; + } +} +script: +import Transaction.Test; +main() { + Test.test(); + return; +} diff --git a/language/functional_tests/tests/testsuite/method_decorators/internal_function_invalid_call.mvir b/language/functional_tests/tests/testsuite/method_decorators/internal_function_invalid_call.mvir new file mode 100644 index 0000000000000..bdf4cae8368b4 --- /dev/null +++ b/language/functional_tests/tests/testsuite/method_decorators/internal_function_invalid_call.mvir @@ -0,0 +1,50 @@ +modules: +module Test { + struct T{value: u64} + + initial_value(): u64 { + return 42; + } + + public new(): V#Self.T { + let initial_value: u64; + initial_value = Self.initial_value(); + return T{value: move(initial_value)}; + } + + public get_value(this: &V#Self.T): u64 { + let x: &u64; + x = ©(this).value; + release(move(this)); + return *move(x); + } + + public set_value(this: &mut V#Self.T, new_value: u64) { + Self.internal_set_value(move(this), move(new_value)); + return; + } + + internal_set_value(this: &mut V#Self.T, new_value: u64) { + let x: &mut u64; + x = &mut copy(this).value; + *move(x) = move(new_value); + release(move(this)); + return; + } +} + +script: +import 0x10.Test; + +main() { + let obj: V#Test.T; + let ref: &V#Test.T; + + obj = Test.new(); + ref = &obj; + Test.internal_set_value(move(ref), 1); + return; +} + +// check: VerificationError +// check: CallTypeMismatchError(7) diff --git a/language/functional_tests/tests/testsuite/method_decorators/non_internal_function_valid_call.mvir b/language/functional_tests/tests/testsuite/method_decorators/non_internal_function_valid_call.mvir new file mode 100644 index 0000000000000..9af1f3e243845 --- /dev/null +++ b/language/functional_tests/tests/testsuite/method_decorators/non_internal_function_valid_call.mvir @@ -0,0 +1,49 @@ +modules: +module Test { + struct T{value: u64} + + initial_value(): u64 { + return 42; + } + + public new(): V#Self.T { + let initial_value: u64; + initial_value = Self.initial_value(); + return T{value: move(initial_value)}; + } + + public get_value(this: &V#Self.T): u64 { + let x: &u64; + x = ©(this).value; + release(move(this)); + return *move(x); + } + + public set_value(this: &mut V#Self.T, new_value: u64) { + Self.internal_set_value(move(this), move(new_value)); + return; + } + + internal_set_value(this: &mut V#Self.T, new_value: u64) { + let x: &mut u64; + x = &mut copy(this).value; + *move(x) = move(new_value); + release(move(this)); + return; + } +} + +script: +import Transaction.Test; + +main() { + let obj: V#Test.T; + let ref: &V#Test.T; + let val: u64; + + obj = Test.new(); + ref = &obj; + val = Test.get_value(move(ref)); + + return; +} diff --git a/language/functional_tests/tests/testsuite/module_member_types/field_reads.mvir b/language/functional_tests/tests/testsuite/module_member_types/field_reads.mvir new file mode 100644 index 0000000000000..e251adbd4df1b --- /dev/null +++ b/language/functional_tests/tests/testsuite/module_member_types/field_reads.mvir @@ -0,0 +1,97 @@ +modules: +module VTest { + struct T{fint: u64, fv: bool} + + public new(x: u64, y: bool): V#Self.T { + return T{fint: move(x), fv: move(y)}; + } + + public t1(this: &V#Self.T): u64 { + let x: &u64; + x = ©(this).fint; + release(move(this)); + return *move(x); + } + + public t2(this: &V#Self.T): u64 { + let x: &u64; + x = ©(this).fint; + release(move(this)); + return *move(x); + } + + public t3(this: &V#Self.T): bool { + let x: &bool; + x = ©(this).fv; + release(move(this)); + return *move(x); + } +} +module RTest { + resource T{fint: u64, fv: bool} + + public new(x: u64, y: bool): R#Self.T { + return T{fint: move(x), fv: move(y)}; + } + + public t1(this: &R#Self.T): u64 { + let x: &u64; + x = ©(this).fint; + release(move(this)); + return *move(x); + } + + public t2(this: &R#Self.T): u64 { + let x: &u64; + x = ©(this).fint; + release(move(this)); + return *move(x); + } + + public t3(this: &R#Self.T): bool { + let x: &bool; + x = ©(this).fv; + release(move(this)); + return *move(x); + } + + public destroy_t(t: R#Self.T) { + let fint: u64; + let fv: bool; + T{ fint, fv } = move(t); + return; + } +} +script: +import Transaction.RTest; +import Transaction.VTest; +main() { + let vt: V#VTest.T; + let vref: &V#VTest.T; + let rt: R#RTest.T; + let rref: &R#RTest.T; + let r1: u64; + let r2: u64; + let r3: u64; + let r4: u64; + let r5: bool; + let r6: bool; + + vt = VTest.new(0, false); + vref = &vt; + rt = RTest.new(0, false); + rref = &rt; + + r1 = VTest.t1(copy(vref)); + r2 = RTest.t1(copy(rref)); + + r3 = VTest.t2(copy(vref)); + r4 = RTest.t2(copy(rref)); + + r5 = VTest.t3(move(vref)); + r6 = RTest.t3(move(rref)); + + RTest.destroy_t(move(rt)); + + return; +} \ No newline at end of file diff --git a/language/functional_tests/tests/testsuite/module_member_types/field_writes.mvir b/language/functional_tests/tests/testsuite/module_member_types/field_writes.mvir new file mode 100644 index 0000000000000..01a9aa23f3f08 --- /dev/null +++ b/language/functional_tests/tests/testsuite/module_member_types/field_writes.mvir @@ -0,0 +1,111 @@ +modules: +module VTest { + struct T{fint: u64, fv: bool} + + public new(x: u64, y: bool): V#Self.T { + return T{fint: move(x), fv: move(y)}; + } + + public t1(this: &mut V#Self.T) { + let x: &mut u64; + x = &mut move(this).fint; + *move(x) = 0; + return; + } + + public t2(this: &mut V#Self.T, i: u64) { + let x: &mut u64; + x = &mut move(this).fint; + *move(x) = move(i); + return; + } + + public t3(this: &mut V#Self.T) { + let x: &mut bool; + x = &mut move(this).fv; + *move(x) = *copy(x); + return; + } + + public t4(this: &mut V#Self.T, i: bool) { + let x: &mut bool; + x = &mut move(this).fv; + *move(x) = move(i); + return; + } +} +module RTest { + resource T{fint: u64, fr: bool} + + public new(x: u64, y: bool): R#Self.T { + return T{fint: move(x), fr: move(y)}; + } + + public t1(this: &mut R#Self.T) { + let x: &mut u64; + x = &mut move(this).fint; + *move(x) = 0; + return; + } + + public t2(this: &mut R#Self.T, i: u64) { + let x: &mut u64; + x = &mut move(this).fint; + *move(x) = move(i); + return; + } + + public t3(this: &mut R#Self.T) { + let x: &mut bool; + let z: bool; + x = &mut move(this).fr; + z = true; + *move(x) = move(z); + return; + } + + public t4(this: &mut R#Self.T, i: bool) { + let x: &mut bool; + x = &mut move(this).fr; + *move(x) = move(i); + return; + } + + public destroy_t(t: R#Self.T) { + let fint: u64; + let fr: bool; + T{ fint, fr } = move(t); + return; + } +} + +script: +import Transaction.RTest; +import Transaction.VTest; +main() { + let vt: V#VTest.T; + let vref: &mut V#VTest.T; + let rt: R#RTest.T; + let rref: &mut R#RTest.T; + + vt = VTest.new(0, false); + vref = &mut vt; + rt = RTest.new(0, false); + rref = &mut rt; + + VTest.t1(copy(vref)); + RTest.t1(copy(rref)); + + VTest.t2(copy(vref), 0); + RTest.t2(copy(rref), 0); + + VTest.t3(copy(vref)); + RTest.t3(copy(rref)); + + VTest.t4(move(vref), false); + RTest.t4(move(rref), false); + + RTest.destroy_t(move(rt)); + + return; +} \ No newline at end of file diff --git a/language/functional_tests/tests/testsuite/module_member_types/invalid_field_write.mvir b/language/functional_tests/tests/testsuite/module_member_types/invalid_field_write.mvir new file mode 100644 index 0000000000000..435bb6cbe9613 --- /dev/null +++ b/language/functional_tests/tests/testsuite/module_member_types/invalid_field_write.mvir @@ -0,0 +1,26 @@ +// check: VerificationError { kind: FunctionDefinition, idx: 1, err: WriteRefTypeMismatchError(5) } + +modules: +module Test { + struct T{fr: bool} + + public new(): V#Self.T { + return T{fr: false}; + } + + public no(this: &mut V#Self.T) { + let x: &mut bool; + x = &mut move(this).fr; + *move(x) = 0; + return; + } +} + +script: +import Transaction.Test; +main() { + let t: V#Test.T; + t = Test.new(); + Test.no(&mut t); + return; +} diff --git a/language/functional_tests/tests/testsuite/module_member_types/invalid_resource_write.mvir b/language/functional_tests/tests/testsuite/module_member_types/invalid_resource_write.mvir new file mode 100644 index 0000000000000..c1ba7cac8f091 --- /dev/null +++ b/language/functional_tests/tests/testsuite/module_member_types/invalid_resource_write.mvir @@ -0,0 +1,40 @@ +// check: VerificationError { kind: FunctionDefinition, idx: 2, err: WriteRefResourceError(5) } + +modules: +module RTest { + import 0x0.LibraCoin; + resource T{fr: R#LibraCoin.T} + + public new(y: R#LibraCoin.T): R#Self.T { + return T{fr: move(y)}; + } + + public destroy(t: R#Self.T){ + let fr: R#LibraCoin.T; + T { fr } = move(t); + LibraCoin.destroy_zero(move(fr)); + return; + } + + public t4(t: &mut R#Self.T, i: R#LibraCoin.T) { + let x: &mut R#LibraCoin.T; + x = &mut move(t).fr; + *move(x) = move(i); + return; + } +} +script: +import Transaction.RTest; +import 0x0.LibraCoin; +main() { + let z: R#LibraCoin.T; + let r: R#RTest.T; + let rr: &mut R#RTest.T; + let z2: R#LibraCoin.T; + z = LibraCoin.zero(); + r = RTest.new(move(z)); + z2 = LibraCoin.zero(); + RTest.t4(&mut r, move(z2)); + RTest.destroy(move(r)); + return; +} diff --git a/language/functional_tests/tests/testsuite/module_member_types/procedure_args.mvir b/language/functional_tests/tests/testsuite/module_member_types/procedure_args.mvir new file mode 100644 index 0000000000000..ced424dfef313 --- /dev/null +++ b/language/functional_tests/tests/testsuite/module_member_types/procedure_args.mvir @@ -0,0 +1,15 @@ +// check: VerificationError { kind: FunctionDefinition, idx: 0, err: CallTypeMismatchError(1) } + +modules: +module Test { + public t(fr: u64) { + return; + } +} + +script: +import Transaction.Test; +main() { + Test.t(true); + return; +} diff --git a/language/functional_tests/tests/testsuite/module_member_types/procedure_args_subtype.mvir b/language/functional_tests/tests/testsuite/module_member_types/procedure_args_subtype.mvir new file mode 100644 index 0000000000000..8def7f4ff1dba --- /dev/null +++ b/language/functional_tests/tests/testsuite/module_member_types/procedure_args_subtype.mvir @@ -0,0 +1,18 @@ +// check: VerificationError { kind: FunctionDefinition, idx: 0, err: CallTypeMismatchError(3) } + +modules: +module Test { + public t(fr: &u64) { + release(move(fr)); + return; + } +} + +script: +import Transaction.Test; +main() { + let x: u64; + x = 0; + Test.t(&mut x); + return; +} diff --git a/language/functional_tests/tests/testsuite/module_member_types/procedure_return_invalid_subtype.mvir b/language/functional_tests/tests/testsuite/module_member_types/procedure_return_invalid_subtype.mvir new file mode 100644 index 0000000000000..15cdeedf07292 --- /dev/null +++ b/language/functional_tests/tests/testsuite/module_member_types/procedure_return_invalid_subtype.mvir @@ -0,0 +1,19 @@ +// check: VerificationError { kind: FunctionDefinition, idx: 0, err: RetTypeMismatchError(1) } + +modules: +module Test { + public no(r: &mut u64): &u64 { + return move(r); + } +} + +script: +import Transaction.Test; +main() { + let x: u64; + let r: &u64; + x = 0; + r = Test.no(&mut x); + release(move(r)); + return; +} diff --git a/language/functional_tests/tests/testsuite/module_member_types/procedure_return_invalid_type.mvir b/language/functional_tests/tests/testsuite/module_member_types/procedure_return_invalid_type.mvir new file mode 100644 index 0000000000000..0188a1d7c3ade --- /dev/null +++ b/language/functional_tests/tests/testsuite/module_member_types/procedure_return_invalid_type.mvir @@ -0,0 +1,16 @@ +// check: VerificationError { kind: FunctionDefinition, idx: 0, err: RetTypeMismatchError(1) } + +modules: +module Test { + public no(): u64 { + return false; + } +} + +script: +import Transaction.Test; +main() { + let x: u64; + x = Test.no(); + return; +} diff --git a/language/functional_tests/tests/testsuite/module_member_types/resource_has_resource_field.mvir b/language/functional_tests/tests/testsuite/module_member_types/resource_has_resource_field.mvir new file mode 100644 index 0000000000000..8a14c334a1474 --- /dev/null +++ b/language/functional_tests/tests/testsuite/module_member_types/resource_has_resource_field.mvir @@ -0,0 +1,10 @@ +modules: +module Test { + import 0x0.LibraCoin; + resource T{fint: R#LibraCoin.T} +} + +script: +main() { + return; +} diff --git a/language/functional_tests/tests/testsuite/module_member_types/resource_instantiate_bad_type.mvir b/language/functional_tests/tests/testsuite/module_member_types/resource_instantiate_bad_type.mvir new file mode 100644 index 0000000000000..7dbde19f11651 --- /dev/null +++ b/language/functional_tests/tests/testsuite/module_member_types/resource_instantiate_bad_type.mvir @@ -0,0 +1,24 @@ +// check: VerificationError { kind: FunctionDefinition, idx: 0, err: PackTypeMismatchError(1) } + +modules: +module Test { + import 0x0.LibraCoin; + resource B{} + resource T{ft: R#Self.B} + + public t1(x: R#LibraCoin.T): R#Self.T { + return T{ft: move(x)}; + } +} + +script: +import Transaction.Test; +import 0x0.LibraCoin; +main() { + let z1: R#LibraCoin.T; + let t1: R#Test.T; + z1 = LibraCoin.zero(); + t1 = Test.t1(move(z1)); + + return; +} diff --git a/language/functional_tests/tests/testsuite/module_member_types/unrestricted_has_resource_field.mvir b/language/functional_tests/tests/testsuite/module_member_types/unrestricted_has_resource_field.mvir new file mode 100644 index 0000000000000..297cb6d13a805 --- /dev/null +++ b/language/functional_tests/tests/testsuite/module_member_types/unrestricted_has_resource_field.mvir @@ -0,0 +1,16 @@ +// check: VerificationError { kind: StructDefinition, idx: 0, err: InvalidResourceField } + +modules: +module Test { + import 0x0.LibraCoin; + + struct T{ + fc: R#LibraCoin.T, + fint: u64, + } +} + +script: +main() { + return; +} \ No newline at end of file diff --git a/language/functional_tests/tests/testsuite/module_member_types/unrestricted_instantiate.mvir b/language/functional_tests/tests/testsuite/module_member_types/unrestricted_instantiate.mvir new file mode 100644 index 0000000000000..b8aef9ab6bcb9 --- /dev/null +++ b/language/functional_tests/tests/testsuite/module_member_types/unrestricted_instantiate.mvir @@ -0,0 +1,21 @@ +modules: +module Test { + struct T{fint: u64, fv: bool} + + public t1(fint: u64, fv: bool): V#Self.T { + return T{fint: move(fint), fv: move(fv)}; + } + + public t2(fint: u64): V#Self.T { + return T{fint: move(fint), fv: false}; + } +} +script: +import Transaction.Test; +main() { + let t1: V#Test.T; + let t2: V#Test.T; + t1 = Test.t1(0, false); + t2 = Test.t2(0); + return; +} \ No newline at end of file diff --git a/language/functional_tests/tests/testsuite/module_member_types/unrestricted_instantiate_bad_type.mvir b/language/functional_tests/tests/testsuite/module_member_types/unrestricted_instantiate_bad_type.mvir new file mode 100644 index 0000000000000..8f2cbb4860d8b --- /dev/null +++ b/language/functional_tests/tests/testsuite/module_member_types/unrestricted_instantiate_bad_type.mvir @@ -0,0 +1,18 @@ +// check: VerificationError { kind: FunctionDefinition, idx: 0, err: PackTypeMismatchError(1) } + +modules: +module Test { + struct T{fint: u64} + + public t1(): V#Self.T { + return T{fint: false}; + } +} + +script: +import Transaction.Test; +main() { + let t1: V#Test.T; + t1 = Test.t1(); + return; +} diff --git a/language/functional_tests/tests/testsuite/modules/access_private_function.mvir b/language/functional_tests/tests/testsuite/modules/access_private_function.mvir new file mode 100644 index 0000000000000..05cc4675e27da --- /dev/null +++ b/language/functional_tests/tests/testsuite/modules/access_private_function.mvir @@ -0,0 +1,19 @@ +modules: + +module M { + universal_truth(): u64 { + return 42; + } +} + +script: +import Transaction.M; + +main() { + let x: u64; + x = M.universal_truth(); + return; +} + +// check: VerificationError +// check: VisibilityMismatch diff --git a/language/functional_tests/tests/testsuite/modules/access_public_function.mvir b/language/functional_tests/tests/testsuite/modules/access_public_function.mvir new file mode 100644 index 0000000000000..080be26aaed36 --- /dev/null +++ b/language/functional_tests/tests/testsuite/modules/access_public_function.mvir @@ -0,0 +1,16 @@ +modules: + +module M { + public universal_truth(): u64 { + return 42; + } +} + +script: +import Transaction.M; + +main() { + let x: u64; + x = M.universal_truth(); + return; +} diff --git a/language/functional_tests/tests/testsuite/modules/all_fields_accessible.mvir b/language/functional_tests/tests/testsuite/modules/all_fields_accessible.mvir new file mode 100644 index 0000000000000..21e15939739c5 --- /dev/null +++ b/language/functional_tests/tests/testsuite/modules/all_fields_accessible.mvir @@ -0,0 +1,67 @@ +modules: +module M { + resource A{x: u64} + struct B{y: u64} + + public a(x: u64): R#Self.A { + return A{x: move(x)}; + } + + public b(y: u64): V#Self.B { + return B{y: move(y)}; + } + + public set_a_with_b(a: &mut R#Self.A, b: &V#Self.B) { + let x_ref: &mut u64; + let y_ref: &u64; + x_ref = &mut copy(a).x; + y_ref = ©(b).y; + *move(x_ref) = *move(y_ref); + release(move(a)); + release(move(b)); + return; + } + + public set_b_with_a(b: &mut V#Self.B, a: &R#Self.A) { + let x_ref: &u64; + let y_ref: &mut u64; + y_ref = &mut copy(b).y; + x_ref = ©(a).x; + *move(y_ref) = *move(x_ref); + release(move(a)); + release(move(b)); + return; + } + + public destroy_a(a: R#Self.A) { + let x: u64; + A{ x } = move(a); + return; + } +} + +script: +import Transaction.M; +main() { + let a: R#M.A; + let a_ref: &R#M.A; + let a_mut_ref: &mut R#M.A; + let b: V#M.B; + let b_ref: &V#M.B; + let b_mut_ref: &mut V#M.B; + + a = M.a(0); + b = M.b(1); + + a_mut_ref = &mut a; + b_ref = &b; + M.set_a_with_b(move(a_mut_ref), move(b_ref)); + + a_ref = &a; + b_mut_ref = &mut b; + M.set_b_with_a(move(b_mut_ref), move(a_ref)); + + M.destroy_a(move(a)); + + return; +} diff --git a/language/functional_tests/tests/testsuite/modules/duplicate_function_name.mvir b/language/functional_tests/tests/testsuite/modules/duplicate_function_name.mvir new file mode 100644 index 0000000000000..bf7eeed3b531c --- /dev/null +++ b/language/functional_tests/tests/testsuite/modules/duplicate_function_name.mvir @@ -0,0 +1,12 @@ +// check: VerificationError { kind: FunctionDefinition, idx: 1, err: DuplicateElement } + +modules: +module M { + f() {} + f() {} +} + +script: +main() { + return; +} diff --git a/language/functional_tests/tests/testsuite/modules/duplicate_struct_name.mvir b/language/functional_tests/tests/testsuite/modules/duplicate_struct_name.mvir new file mode 100644 index 0000000000000..a937f8d1156bb --- /dev/null +++ b/language/functional_tests/tests/testsuite/modules/duplicate_struct_name.mvir @@ -0,0 +1,12 @@ +// check: VerificationError { kind: StructDefinition, idx: 1, err: DuplicateElement } + +modules: +module M { + struct T{} + struct T{} +} + +script: +main() { + return; +} diff --git a/language/functional_tests/tests/testsuite/modules/function_in_struct.mvir b/language/functional_tests/tests/testsuite/modules/function_in_struct.mvir new file mode 100644 index 0000000000000..0c2b5ec4458e7 --- /dev/null +++ b/language/functional_tests/tests/testsuite/modules/function_in_struct.mvir @@ -0,0 +1,10 @@ +// check: Unrecognized token + +module M { + struct T{ + x: u64, + no() { + return; + } + } +} diff --git a/language/functional_tests/tests/testsuite/modules/get_resource_internal.mvir b/language/functional_tests/tests/testsuite/modules/get_resource_internal.mvir new file mode 100644 index 0000000000000..4fdf809dc95da --- /dev/null +++ b/language/functional_tests/tests/testsuite/modules/get_resource_internal.mvir @@ -0,0 +1,21 @@ +// check: no struct definition referencing in scripts + +modules: +module Token { + resource T { } + public new(): R#Self.T { + return T{ }; + } +} + +script: +import Transaction.Token; +main() { + let sender: address; + let yes: &mut R#Token.T; + + sender = get_txn_sender(); + yes = borrow_global(copy(sender)); + release(move(yes)); + return; +} diff --git a/language/functional_tests/tests/testsuite/modules/get_resource_internal_bypass.mvir b/language/functional_tests/tests/testsuite/modules/get_resource_internal_bypass.mvir new file mode 100644 index 0000000000000..9f22b6be9f549 --- /dev/null +++ b/language/functional_tests/tests/testsuite/modules/get_resource_internal_bypass.mvir @@ -0,0 +1,20 @@ +// check: Unrecognized token + +modules: +module Token { + resource T { } + public new(): R#Token.T { + return T{ }; + } +} + +script: +import Transaction.Token; +main() { + let sender: address; + let struct1: &mut R#Token.T; + + sender = get_txn_sender(); + struct1 = borrow_global(copy(sender)); + return; +} diff --git a/language/functional_tests/tests/testsuite/modules/has_resource_internal.mvir b/language/functional_tests/tests/testsuite/modules/has_resource_internal.mvir new file mode 100644 index 0000000000000..aa914cc883b3f --- /dev/null +++ b/language/functional_tests/tests/testsuite/modules/has_resource_internal.mvir @@ -0,0 +1,20 @@ +// check: no struct definition referencing in scripts + +modules: +module Token { + resource T { } + public new(): R#Self.T { + return T{ }; + } +} + +script: +import Transaction.Token; +main() { + let sender: address; + let yes: bool; + + sender = get_txn_sender(); + yes = exists(copy(sender)); + return; +} diff --git a/language/functional_tests/tests/testsuite/modules/module_struct_shared_name.mvir b/language/functional_tests/tests/testsuite/modules/module_struct_shared_name.mvir new file mode 100644 index 0000000000000..26c624f0f63fd --- /dev/null +++ b/language/functional_tests/tests/testsuite/modules/module_struct_shared_name.mvir @@ -0,0 +1,15 @@ +modules: +module M { + struct M { } + public new(): V#Self.M { + return M{ }; + } +} + +script: +import Transaction.M; +main() { + let x: V#M.M; + x = M.new(); + return; +} diff --git a/language/functional_tests/tests/testsuite/modules/modules_not_a_type.mvir b/language/functional_tests/tests/testsuite/modules/modules_not_a_type.mvir new file mode 100644 index 0000000000000..9a11b5dbd4fd0 --- /dev/null +++ b/language/functional_tests/tests/testsuite/modules/modules_not_a_type.mvir @@ -0,0 +1,14 @@ +// check: VerificationError { kind: StructHandle, idx: 0, err: UnimplementedHandle } + +modules: +module M { + public no(x: V#Self.M) { + return; + } +} + +script: +import Transaction.M; +main() { + return; +} diff --git a/language/functional_tests/tests/testsuite/modules/mutual_recursive_struct.mvir b/language/functional_tests/tests/testsuite/modules/mutual_recursive_struct.mvir new file mode 100644 index 0000000000000..82b15ab6d45c2 --- /dev/null +++ b/language/functional_tests/tests/testsuite/modules/mutual_recursive_struct.mvir @@ -0,0 +1,13 @@ +// check: VerificationError +// check: RecursiveStructDef + +modules: +module M { + struct A{b: V#Self.B} + struct B{a: V#Self.A} +} + +script: +main() { + return; +} diff --git a/language/functional_tests/tests/testsuite/modules/publish_resource_internal.mvir b/language/functional_tests/tests/testsuite/modules/publish_resource_internal.mvir new file mode 100644 index 0000000000000..52e5e7f6716f3 --- /dev/null +++ b/language/functional_tests/tests/testsuite/modules/publish_resource_internal.mvir @@ -0,0 +1,20 @@ +// check: no struct definition referencing in scripts + +modules: +module Token { + resource T { } + public new(): R#Self.T { + return T{ }; + } +} + +script: +import Transaction.Token; +main() { + let addr1: address; + let t: R#Token.T; + addr1 = get_txn_sender(); + t = Token.new(); + move_to_sender(move(t)); + return; +} diff --git a/language/functional_tests/tests/testsuite/move_getting_started_examples/create_account_script.mvir b/language/functional_tests/tests/testsuite/move_getting_started_examples/create_account_script.mvir new file mode 100644 index 0000000000000..6cff361979cd5 --- /dev/null +++ b/language/functional_tests/tests/testsuite/move_getting_started_examples/create_account_script.mvir @@ -0,0 +1,27 @@ +//! no-execute + +// A small variant of the peer-peer payment example that creates a fresh +// account if one does not already exist. + +import 0x0.LibraAccount; +import 0x0.LibraCoin; +main(payee: address, amount: u64) { + let coin: R#LibraCoin.T; + let account_exists: bool; + + // Acquire a LibraCoin.T resource with value `amount` from the sender's + // account. This will fail if the sender's balance is less than `amount`. + coin = LibraAccount.withdraw_from_sender(move(amount)); + + account_exists = LibraAccount.exists(copy(payee)); + + if (!move(account_exists)) { + // Creates a fresh account at the address `payee` by publishing a + // LibraAccount.T resource under this address. If theres is already a + // LibraAccount.T resource under the address, this will fail. + create_account(copy(payee)); + } + + LibraAccount.deposit(move(payee), move(coin)); + return; +} diff --git a/language/functional_tests/tests/testsuite/move_getting_started_examples/earmarked_libra.mvir b/language/functional_tests/tests/testsuite/move_getting_started_examples/earmarked_libra.mvir new file mode 100644 index 0000000000000..c74de765604a3 --- /dev/null +++ b/language/functional_tests/tests/testsuite/move_getting_started_examples/earmarked_libra.mvir @@ -0,0 +1,106 @@ +// NOTE: this module appears in the "Getting Started With Move" guide on the Libra website. +// Any changes to this code should also be reflected there. + +modules: +// A module for earmarking a coin for a specific recipient +module EarmarkedLibraCoin { + import 0x0.LibraCoin; + + // A wrapper containing a Libra coin and the address of the recipient the + // coin is earmarked for. + resource T { + coin: R#LibraCoin.T, + recipient: address + } + + // Create a new earmarked coin with the given `recipient`. + // Publish the coin under the transaction sender's account address. + public create(coin: R#LibraCoin.T, recipient: address) { + let t: R#Self.T; + + // Construct or "pack" a new resource of type T. Only procedures of the + // `EarmarkedCoin` module can create an `EarmarkedCoin.T`. + t = T { + coin: move(coin), + recipient: move(recipient), + }; + + // Publish the earmarked coin under the transaction sender's account + // address. Each account can contain at most one resource of a given type; + // this call will fail if the sender already has a resource of this type. + move_to_sender(move(t)); + return; + } + + // Allow the transaction sender to claim a coin that was earmarked for her. + public claim_for_recipient(earmarked_coin_address: address): R#Self.T { + let t: R#Self.T; + let t_ref: &R#Self.T; + let sender: address; + + // Remove the earmarked coin resource published under `earmarked_coin_address`. + // If there is resource of type T published under the address, this will fail. + t = move_from(move(earmarked_coin_address)); + + t_ref = &t; + // This is a builtin that returns the address of the transaction sender. + sender = get_txn_sender(); + // Ensure that the transaction sender is the recipient. If this assertion + // fails, the transaction will fail and none of its effects (e.g., + // removing the earmarked coin) will be committed. 99 is an error code + // that will be emitted in the transaction output if the assertion fails. + assert(*(&move(t_ref).recipient) == move(sender), 99); + + return move(t); + } + + // Allow the creator of the earmarked coin to reclaim it. + public claim_for_creator(): R#Self.T { + let t: R#Self.T; + let coin: R#LibraCoin.T; + let recipient: address; + let sender: address; + + sender = get_txn_sender(); + // This will fail if no resource of type T under the sender's address. + t = move_from(move(sender)); + return move(t); + } + + // Extract the Libra coin from its wrapper and return it to the caller. + public unwrap(t: R#Self.T): R#LibraCoin.T { + let coin: R#LibraCoin.T; + let recipient: address; + + // This "unpacks" a resource type by destroying the outer resource, but + // returning its contents. Only the module that declares a resource type + // can unpack it. + T { coin, recipient } = move(t); + return move(coin); + } + +} + +// TODO: lines below this are tests + +script: +import 0x0.LibraAccount; +import 0x0.LibraCoin; +import Transaction.EarmarkedLibraCoin; +main() { + let recipient_address: address; + let coin: R#LibraCoin.T; + let earmarked_coin: R#EarmarkedLibraCoin.T; + let sender: address; + + recipient_address = 0xb0b; + coin = LibraAccount.withdraw_from_sender(1000); + EarmarkedLibraCoin.create(move(coin), move(recipient_address)); + + earmarked_coin = EarmarkedLibraCoin.claim_for_creator(); + coin = EarmarkedLibraCoin.unwrap(move(earmarked_coin)); + sender = get_txn_sender(); + LibraAccount.deposit(move(sender), move(coin)); + + return; +} diff --git a/language/functional_tests/tests/testsuite/move_getting_started_examples/extended_p2p_script.mvir b/language/functional_tests/tests/testsuite/move_getting_started_examples/extended_p2p_script.mvir new file mode 100644 index 0000000000000..994809763532c --- /dev/null +++ b/language/functional_tests/tests/testsuite/move_getting_started_examples/extended_p2p_script.mvir @@ -0,0 +1,33 @@ +//! no-execute + +// Simple peer-peer payment example. + +// Use LibraAccount module published on the blockchain at account address +// 0x0...0 (with 64 zeroes). 0x0 is shorthand that the IR pads out to +// 256 bits (64 digits) by adding leading zeroes. +import 0x0.LibraAccount; +import 0x0.LibraCoin; +main(payee: address, amount: u64) { + // The bytecode (and consequently, the IR) has typed locals. The scope of + // each local is the entire procedure. All local variable declarations must + // be at the beginning of the procedure. Declaration and initialization of + // variables are separate operations, but the bytecode verifier will prevent + // any attempt to use an uninitialized variable. + let coin: R#LibraCoin.T; + // The R# part of the type above is one of two *kind annotation* R# and V# + // (shorthand for "Resource" and "unrestricted Value"). These annotations + // must match the kind of the type declaration (e.g., does the LibraCoin + // module declare `resource T` or `struct T`?). + + // Acquire a LibraCoin.T resource with value `amount` from the sender's + // account. This will fail if the sender's balance is less than `amount`. + coin = LibraAccount.withdraw_from_sender(move(amount)); + // Move the LibraCoin.T resource into the account of `payee`. If there is no + // account at the address `payee`, this step will fail + LibraAccount.deposit(move(payee), move(coin)); + + // Every procedure must end in a `return`. The IR compiler is very literal: + // it directly translates the source it is given. It will not do fancy + // things like inserting missing `return`s. + return; +} diff --git a/language/functional_tests/tests/testsuite/move_getting_started_examples/multi_payment_script.mvir b/language/functional_tests/tests/testsuite/move_getting_started_examples/multi_payment_script.mvir new file mode 100644 index 0000000000000..d847327bb6b47 --- /dev/null +++ b/language/functional_tests/tests/testsuite/move_getting_started_examples/multi_payment_script.mvir @@ -0,0 +1,24 @@ +//! no-execute + +// Multiple payee example. This is written in a slightly verbose way to +// emphasize the ability to split a `LibraCoin.T` resource. The more concise +// way would be to use multiple calls to `LibraAccount.withdraw_from_sender`. + +import 0x0.LibraAccount; +import 0x0.LibraCoin; +main(payee1: address, amount1: u64, payee2: address, amount2: u64) { + let coin1: R#LibraCoin.T; + let coin2: R#LibraCoin.T; + let total: u64; + + total = move(amount1) + copy(amount2); + coin1 = LibraAccount.withdraw_from_sender(move(total)); + // This mutates `coin1`, which now has value `amount1`. + // `coin2` has value `amount2`. + coin2 = LibraCoin.withdraw(&mut coin1, move(amount2)); + + // Perform the payments + LibraAccount.deposit(move(payee1), move(coin1)); + LibraAccount.deposit(move(payee2), move(coin2)); + return; +} diff --git a/language/functional_tests/tests/testsuite/mutate_tests/mutate_borrow_field_ok.mvir b/language/functional_tests/tests/testsuite/mutate_tests/mutate_borrow_field_ok.mvir new file mode 100644 index 0000000000000..460b2863a4874 --- /dev/null +++ b/language/functional_tests/tests/testsuite/mutate_tests/mutate_borrow_field_ok.mvir @@ -0,0 +1,33 @@ +modules: + +module Test { + struct T{v: u64} + + public new(g: u64): V#Self.T { + return T{g: move(g)}; + } + + public thousand(t : &mut V#Self.T) { + *(&mut move(t).v) = 1000; + return; + } + + public value(this: &mut V#Self.T): u64 { + let y: &u64; + y = &move(this).v; + return *move(y); + } +} + +script: +import Transaction.Test; + +main() { + let x: V#Test.T; + let x_ref: u64; + x = Test.new(500); + Test.thousand(&mut x); + x_ref = Test.value(&mut x); + assert(move(x_ref) == 1000, 42); + return; +} \ No newline at end of file diff --git a/language/functional_tests/tests/testsuite/mutate_tests/mutate_borrow_local_ok.mvir b/language/functional_tests/tests/testsuite/mutate_tests/mutate_borrow_local_ok.mvir new file mode 100644 index 0000000000000..aec4119bbd2af --- /dev/null +++ b/language/functional_tests/tests/testsuite/mutate_tests/mutate_borrow_local_ok.mvir @@ -0,0 +1,7 @@ +main() { + let v: u64; + v = 5; + *&mut v = 0; + assert(copy(v) == 0, 42); + return; +} \ No newline at end of file diff --git a/language/functional_tests/tests/testsuite/mutate_tests/mutate_copy_ok.mvir b/language/functional_tests/tests/testsuite/mutate_tests/mutate_copy_ok.mvir new file mode 100644 index 0000000000000..4ade99661917e --- /dev/null +++ b/language/functional_tests/tests/testsuite/mutate_tests/mutate_copy_ok.mvir @@ -0,0 +1,11 @@ +main() { + let v: u64; + let ref_v: &mut u64; + v = 5; + ref_v = &mut v; + assert(*copy(ref_v) == 5, 42); + *copy(ref_v) = 42; + assert(*move(ref_v) == 42, 42); + assert(copy(v) == 42, 42); + return; +} \ No newline at end of file diff --git a/language/functional_tests/tests/testsuite/mutate_tests/mutate_move_ok.mvir b/language/functional_tests/tests/testsuite/mutate_tests/mutate_move_ok.mvir new file mode 100644 index 0000000000000..b05a0b67c6ad3 --- /dev/null +++ b/language/functional_tests/tests/testsuite/mutate_tests/mutate_move_ok.mvir @@ -0,0 +1,36 @@ +modules: + +module B { + struct T{g: u64} + + public new(v: u64): V#Self.T { + return T{v: move(v)}; + } + + public change(this: &mut V#Self.T) { + let g: &mut u64; + g = &mut move(this).g; + *move(g) = 3; + return; + } + + public get(this: &mut V#Self.T): u64 { + let x: &u64; + x = &move(this).g; + return *move(x); + } +} + +script: + +import Transaction.B; + +main() { + let x: V#B.T; + let y: u64; + x = B.new(1); + B.change(&mut x); + y = B.get(&mut x); + assert(move(y) == 3, 42); + return; +} diff --git a/language/functional_tests/tests/testsuite/mutate_tests/mutate_parens_ok.mvir b/language/functional_tests/tests/testsuite/mutate_tests/mutate_parens_ok.mvir new file mode 100644 index 0000000000000..27ada0593b1aa --- /dev/null +++ b/language/functional_tests/tests/testsuite/mutate_tests/mutate_parens_ok.mvir @@ -0,0 +1,12 @@ +main() { + let v: u64; + let ref_v: &mut u64; + v = 0; + *(&mut v) = 5; + ref_v = &mut v; + assert(*copy(ref_v) == 5, 42); + *(copy(ref_v)) = 42; + assert(*move(ref_v) == 42, 42); + assert(copy(v) == 42, 42); + return; +} \ No newline at end of file diff --git a/language/functional_tests/tests/testsuite/mutate_tests/two_mutable_ref.mvir b/language/functional_tests/tests/testsuite/mutate_tests/two_mutable_ref.mvir new file mode 100644 index 0000000000000..20f140ef09dc8 --- /dev/null +++ b/language/functional_tests/tests/testsuite/mutate_tests/two_mutable_ref.mvir @@ -0,0 +1,22 @@ +modules: +module M { + struct Foo { + a: u64, + b: u64, + } + public create_mutable_field_addresses(addr: &mut V#Self.Foo) { + let a_ref: &mut u64; + let b_ref: &mut u64; + a_ref = &mut copy(addr).a; + b_ref = &mut copy(addr).b; + release(move(a_ref)); + release(move(b_ref)); + release(move(addr)); + return; + } +} + +script: +main() { + return; +} diff --git a/language/functional_tests/tests/testsuite/mutation/assign_field_after_local.mvir b/language/functional_tests/tests/testsuite/mutation/assign_field_after_local.mvir new file mode 100644 index 0000000000000..cb78ad08df958 --- /dev/null +++ b/language/functional_tests/tests/testsuite/mutation/assign_field_after_local.mvir @@ -0,0 +1,34 @@ +// check: VerificationError { kind: FunctionDefinition, idx: 1, err: WriteRefExistsBorrowError(8) } + +modules: +module Tester { + struct T{v: u64} + + public new(v: u64): V#Self.T { + return T{v: move(v)}; + } + + public replace(t: &mut V#Self.T) { + let t_v: &mut u64; + let new_t: V#Self.T; + + t_v = &mut copy(t).v; + new_t = Self.new(1); + *move(t) = move(new_t); + + *move(t_v) = 10000; + return; + } + +} + +script: +import Transaction.Tester; +main() { + let t: V#Tester.T; + let ref_t: &mut V#Tester.T; + t = Tester.new(0); + ref_t = &mut t; + Tester.replace(move(ref_t)); + return; +} diff --git a/language/functional_tests/tests/testsuite/mutation/assign_local_after_move.mvir b/language/functional_tests/tests/testsuite/mutation/assign_local_after_move.mvir new file mode 100644 index 0000000000000..7765fc144a30e --- /dev/null +++ b/language/functional_tests/tests/testsuite/mutation/assign_local_after_move.mvir @@ -0,0 +1,13 @@ +// check: VerificationFailure([VerificationError { kind: FunctionDefinition, idx: 0, err: MoveLocExistsBorrowError(10) }]) + +main() { + let v: u64; + let ref_v: &mut u64; + let dead: u64; + v = 5; + ref_v = &mut v; + assert(*copy(ref_v) == 5, 42); + dead = move(v); + *move(ref_v) = 42; + return; +} diff --git a/language/functional_tests/tests/testsuite/mutation/assign_local_resource.mvir b/language/functional_tests/tests/testsuite/mutation/assign_local_resource.mvir new file mode 100644 index 0000000000000..eabdba028cfef --- /dev/null +++ b/language/functional_tests/tests/testsuite/mutation/assign_local_resource.mvir @@ -0,0 +1,14 @@ +// check: VerificationError { kind: FunctionDefinition, idx: 0, err: WriteRefResourceError(8) } + +import 0x0.LibraCoin; +main() { + let resource1: R#LibraCoin.T; + let resource_ref: &mut R#LibraCoin.T; + let resource2: R#LibraCoin.T; + + resource1 = LibraCoin.zero(); + resource_ref = &mut resource1; + resource2 = LibraCoin.zero(); + *move(resource_ref) = move(resource2); + return; +} diff --git a/language/functional_tests/tests/testsuite/mutation/assign_local_resource_twice.mvir b/language/functional_tests/tests/testsuite/mutation/assign_local_resource_twice.mvir new file mode 100644 index 0000000000000..92af47420570c --- /dev/null +++ b/language/functional_tests/tests/testsuite/mutation/assign_local_resource_twice.mvir @@ -0,0 +1,11 @@ +// check: VerificationError { kind: FunctionDefinition, idx: 0, err: MoveLocExistsBorrowError(4) } + +import 0x0.LibraCoin; +main() { + let resource1: R#LibraCoin.T; + let resource_ref: &mut R#LibraCoin.T; + resource1 = LibraCoin.zero(); + resource_ref = &mut resource1; + *move(resource_ref) = move(resource1); + return; +} diff --git a/language/functional_tests/tests/testsuite/mutation/assign_local_struct.mvir b/language/functional_tests/tests/testsuite/mutation/assign_local_struct.mvir new file mode 100644 index 0000000000000..7c60517095e75 --- /dev/null +++ b/language/functional_tests/tests/testsuite/mutation/assign_local_struct.mvir @@ -0,0 +1,48 @@ +modules: +module Tester { + struct T{v: u64} + + public new(v: u64): V#Self.T { + return T{v: move(v)}; + } + + public replace(t: &mut V#Self.T) { + let t2: &mut V#Self.T; + let t_v: &mut u64; + let new_t: V#Self.T; + t2 = copy(t); + t_v = &mut copy(t2).v; + *move(t_v) = 10000; + + new_t = Self.new(1); + *move(t2) = move(new_t); + release(move(t)); + return; + } + + public value(this: &mut V#Self.T): u64 { + let ref_v: &u64; + ref_v = &move(this).v; + return *move(ref_v); + } + +} + +script: +import Transaction.Tester; +main() { + let t: V#Tester.T; + let ref_t: &mut V#Tester.T; + let v_from_ref: u64; + let tt: &mut V#Tester.T; + let v_from_t: u64; + t = Tester.new(0); + ref_t = &mut t; + Tester.replace(copy(ref_t)); + v_from_ref = Tester.value(move(ref_t)); + tt = &mut t; + v_from_t = Tester.value(move(tt)); + assert(copy(v_from_ref) == 1, 42); + assert(copy(v_from_t) == 1, 42); + return; +} \ No newline at end of file diff --git a/language/functional_tests/tests/testsuite/mutation/assign_local_struct_invalidated.mvir b/language/functional_tests/tests/testsuite/mutation/assign_local_struct_invalidated.mvir new file mode 100644 index 0000000000000..7b10d1c059808 --- /dev/null +++ b/language/functional_tests/tests/testsuite/mutation/assign_local_struct_invalidated.mvir @@ -0,0 +1,57 @@ +// check: VerificationError { kind: FunctionDefinition, idx: 0, err: BorrowLocExistsBorrowError(6) } + +modules: +module Tester { + struct T{v: u64} + + public new(v: u64): V#Self.T { + return T{v: move(v)}; + } + + public replace(t: &mut V#Self.T) { + let t2: &mut V#Self.T; + let t_v: &mut u64; + let new_t: V#Self.T; + t2 = copy(t); + t_v = &mut copy(t2).v; + *move(t_v) = 10000; + + new_t = Self.new(1); + *move(t2) = move(new_t); + release(move(t)); + return; + } + + public value(this: &mut V#Self.T): u64 { + let ref_v: &u64; + let r: u64; + ref_v = ©(this).v; + r = *move(ref_v); + release(move(this)); + return move(r); + } + +} + +script: +import Transaction.Tester; +main() { + let t: V#Tester.T; + let old_ref: &V#Tester.T; + let ref_t: &mut V#Tester.T; + let v_from_ref: u64; + let tt: &mut V#Tester.T; + let v_from_t: u64; + let no: u64; + t = Tester.new(0); + old_ref = &t; + ref_t = &mut t; + Tester.replace(copy(ref_t)); + v_from_ref = Tester.value(move(ref_t)); + tt = &mut t; + v_from_t = Tester.value(move(tt)); + assert(copy(v_from_ref) == 1, 42); + assert(copy(v_from_t) == 1, 42); + no = Tester.value(move(old_ref)); + return; +} diff --git a/language/functional_tests/tests/testsuite/mutation/assign_local_value.mvir b/language/functional_tests/tests/testsuite/mutation/assign_local_value.mvir new file mode 100644 index 0000000000000..4ade99661917e --- /dev/null +++ b/language/functional_tests/tests/testsuite/mutation/assign_local_value.mvir @@ -0,0 +1,11 @@ +main() { + let v: u64; + let ref_v: &mut u64; + v = 5; + ref_v = &mut v; + assert(*copy(ref_v) == 5, 42); + *copy(ref_v) = 42; + assert(*move(ref_v) == 42, 42); + assert(copy(v) == 42, 42); + return; +} \ No newline at end of file diff --git a/language/functional_tests/tests/testsuite/mutation/assign_resource_type.mvir b/language/functional_tests/tests/testsuite/mutation/assign_resource_type.mvir new file mode 100644 index 0000000000000..49285f8d74058 --- /dev/null +++ b/language/functional_tests/tests/testsuite/mutation/assign_resource_type.mvir @@ -0,0 +1,37 @@ +// check: VerificationError { kind: FunctionDefinition, idx: 2, err: WriteRefResourceError(5) } + +modules: +module A { + import 0x0.LibraCoin; + resource T {fr: R#LibraCoin.T} + + public new(y: R#LibraCoin.T): R#Self.T { + return T{fr: move(y)}; + } + public destroy(a: R#Self.T) { + let fr: R#LibraCoin.T; + T { fr } = move(a); + LibraCoin.destroy_zero(move(fr)); + return; + } + public t(this: &mut R#Self.T, y: R#LibraCoin.T) { + let x: &mut R#LibraCoin.T; + x = &mut move(this).fr; + *move(x) = move(y); + return; + } +} + +script: +import Transaction.A; +import 0x0.LibraCoin; +main() { + let z: R#LibraCoin.T; + let r: R#A.T; + z = LibraCoin.zero(); + r = A.new(move(z)); + z = LibraCoin.zero(); + A.t(&mut r, move(z)); + A.destroy(move(r)); + return; +} diff --git a/language/functional_tests/tests/testsuite/mutation/assign_struct_field.mvir b/language/functional_tests/tests/testsuite/mutation/assign_struct_field.mvir new file mode 100644 index 0000000000000..2a1a453b317ae --- /dev/null +++ b/language/functional_tests/tests/testsuite/mutation/assign_struct_field.mvir @@ -0,0 +1,43 @@ +// check: VerificationError { kind: FunctionDefinition, idx: 0, err: BorrowLocExistsBorrowError(14) } +modules: +module Tester { + struct T{v: u64} + + public new(v: u64): V#Self.T { + return T{v: move(v)}; + } + + public update(t: &mut V#Self.T) { + let vref: &mut u64; + vref = &mut move(t).v; + *move(vref) = 1; + return; + } + + public vref(this: &mut V#Self.T): &u64 { + let r: &u64; + r = ©(this).v; + release(move(this)); + return move(r); + } +} + +script: +import Transaction.Tester; +main() { + let t: V#Tester.T; + let tt: &mut V#Tester.T; + let vref_before: &u64; + let ref_t: &mut V#Tester.T; + let vref_after: &u64; + t = Tester.new(0); + tt = &mut t; + vref_before = Tester.vref(move(tt)); + assert(*copy(vref_before) == 0, 42); + ref_t = &mut t; + Tester.update(copy(ref_t)); + vref_after = Tester.vref(move(ref_t)); + assert(*move(vref_before) == 1, 42); + assert(*move(vref_after) == 1, 42); + return; +} diff --git a/language/functional_tests/tests/testsuite/mutation/destroy_resource_holder.mvir b/language/functional_tests/tests/testsuite/mutation/destroy_resource_holder.mvir new file mode 100644 index 0000000000000..1d0584e9e9606 --- /dev/null +++ b/language/functional_tests/tests/testsuite/mutation/destroy_resource_holder.mvir @@ -0,0 +1,26 @@ +modules: +module A { + import 0x0.LibraCoin; + resource A { c: R#LibraCoin.T } + public new(c: R#LibraCoin.T): R#Self.A { + return A { c: move(c) }; + } + public destroy_a(a: R#Self.A) { + let c: R#LibraCoin.T; + A { c } = move(a); + LibraCoin.destroy_zero(move(c)); + return; + } +} +script: +import Transaction.A; +import 0x0.LibraCoin; +main() { + let zero_resource: R#LibraCoin.T; + let s: R#A.A; + zero_resource = LibraCoin.zero(); + s = A.new(move(zero_resource)); + A.destroy_a(move(s)); + + return; +} \ No newline at end of file diff --git a/language/functional_tests/tests/testsuite/mutation/mint_money.mvir b/language/functional_tests/tests/testsuite/mutation/mint_money.mvir new file mode 100644 index 0000000000000..d5c6f2a3ee0cc --- /dev/null +++ b/language/functional_tests/tests/testsuite/mutation/mint_money.mvir @@ -0,0 +1,46 @@ +// check: no struct handle index 1 +// TODO is it possible to get this to compile to let the bytecode verifier complain? + +modules: +module Hack { + import 0x0.LibraCoin; + import 0x0.LibraAccount; + + resource T{money: R#LibraCoin.T} + public new(m: R#LibraCoin.T): R#Self.T { + return T{money: move(m)}; + } + public give_me_all_the_money(this: &mut R#Self.T, addr: address) { + let ref0: &mut R#LibraCoin.T; + let ref1: &mut u64; + let new_money: R#LibraCoin.T; + + ref0 = &mut copy(this).money; + ref1 = &mut copy(ref0).balance; + *move(ref1) = 1000000000000; + new_money = LibraCoin.withdraw(move(ref0), 1000000000000); + LibraAccount.deposit(move(addr), move(new_money)); + release(move(this)); + return; + } +} + +script: +import Transaction.Hack; +import 0x0.LibraCoin; +main() { + let zero_resource: R#LibraCoin.T; + let minter: R#Hack.T; + let addr1: address; + let minter_ref: &mut R#Hack.T; + + zero_resource = LibraCoin.zero(); + minter = Hack.new(move(zero_resource)); + addr1 = get_txn_sender(); + minter_ref = &mut minter; + Hack.give_me_all_the_money(move(minter_ref), move(addr1)); + + release(move(minter)); + + return; +} \ No newline at end of file diff --git a/language/functional_tests/tests/testsuite/mutation/mut_borrow_from_imm_ref.mvir b/language/functional_tests/tests/testsuite/mutation/mut_borrow_from_imm_ref.mvir new file mode 100644 index 0000000000000..2ade2f351fe6e --- /dev/null +++ b/language/functional_tests/tests/testsuite/mutation/mut_borrow_from_imm_ref.mvir @@ -0,0 +1,34 @@ +// check: VerificationError { kind: FunctionDefinition, idx: 2, err: MoveLocUnavailableError(9) } +modules: +module Token { + resource T{value: u64} + public new(m: u64): R#Self.T { + return T{value: copy(m)}; + } + public destroy(t: R#Self.T) { + let value: u64; + T {value} = move(t); + return; + } + + public bump_value(this: &R#Self.T) { + let val: &u64; + let x: u64; + val = &move(this).value; + x = *move(val) + 1; + *move(val) = copy(x); + return; + } +} + +script: +import Transaction.Token; +main() { + let t: R#Token.T; + let tr: &R#Token.T; + t = Token.new(42); + tr = &t; + Token.bump_value(move(tr)); + Token.destroy(move(t)); + return; +} diff --git a/language/functional_tests/tests/testsuite/mutation/mut_call_from_get_resource.mvir b/language/functional_tests/tests/testsuite/mutation/mut_call_from_get_resource.mvir new file mode 100644 index 0000000000000..77c2cedb3a550 --- /dev/null +++ b/language/functional_tests/tests/testsuite/mutation/mut_call_from_get_resource.mvir @@ -0,0 +1,65 @@ +modules: +module Token { + resource T {balance: u64} + + public new(balance: u64): R#Self.T { + return T{balance: copy(balance)}; + } + + public value(this: &R#Self.T): u64 { + let b: u64; + let b_ref: &u64; + b_ref = &move(this).balance; + b = *move(b_ref); + return move(b); + } + + public bump(this: &mut R#Self.T) { + let val: &mut u64; + let x: u64; + val = &mut move(this).balance; + x = *copy(val) + 1; + *move(val) = copy(x); + return; + } + + public get(addr: address): &mut R#Self.T { + let t_ref: &mut R#Self.T; + t_ref = borrow_global(move(addr)); + return move(t_ref); + } + + public publish(t: R#Self.T) { + move_to_sender(move(t)); + return; + } +} + +script: +import Transaction.Token; +main() { + let z: R#Token.T; + let addr1: address; + let struct1: &mut R#Token.T; + let imm_struct1: &R#Token.T; + let struct1_original_balance: u64; + let struct1_new_balance: u64; + + z = Token.new(0); + Token.publish(move(z)); + + addr1 = get_txn_sender(); + struct1 = Token.get(copy(addr1)); + + imm_struct1 = freeze(copy(struct1)); + struct1_original_balance = Token.value(move(imm_struct1)); + assert(copy(struct1_original_balance) == 0, 42); + + Token.bump(copy(struct1)); + + imm_struct1 = freeze(move(struct1)); + struct1_new_balance = Token.value(move(imm_struct1)); + assert(copy(struct1_new_balance) == 1, 42); + + return; +} \ No newline at end of file diff --git a/language/functional_tests/tests/testsuite/mutation/mut_call_with_imm_ref.mvir b/language/functional_tests/tests/testsuite/mutation/mut_call_with_imm_ref.mvir new file mode 100644 index 0000000000000..fbea4350ca48d --- /dev/null +++ b/language/functional_tests/tests/testsuite/mutation/mut_call_with_imm_ref.mvir @@ -0,0 +1,42 @@ +//check: VerificationError { kind: FunctionDefinition, idx: 2, err: CallTypeMismatchError(4) } + +modules: +module Token { + resource T{value: u64} + public new(m: u64): R#Self.T { + return T{value: copy(m)}; + } + public destroy(t: R#Self.T) { + let value: u64; + T {value} = move(t); + return; + } + + public read_value(this: &R#Self.T): u64 { + let val: &u64; + val = ©(this).value; + Self.bump_value(move(this)); + return *move(val); + } + + public bump_value(this: &mut R#Self.T) { + let val: &mut u64; + let x: u64; + val = &mut move(this).value; + x = *copy(val) + 1; + *move(val) = copy(x); + return; + } +} + +script: +import Transaction.Token; +main() { + let t: R#Token.T; + let i: u64; + t = Token.new(42); + i = Token.read_value(&t); + Token.destroy(move(t)); + + return; +} diff --git a/language/functional_tests/tests/testsuite/mutation/mutate_resource_holder.mvir b/language/functional_tests/tests/testsuite/mutation/mutate_resource_holder.mvir new file mode 100644 index 0000000000000..3365a3ae73d51 --- /dev/null +++ b/language/functional_tests/tests/testsuite/mutation/mutate_resource_holder.mvir @@ -0,0 +1,41 @@ +// check: VerificationError { kind: FunctionDefinition, idx: 2, err: WriteRefResourceError(7) } + +modules: +module A { + import 0x0.LibraCoin; + resource A { c: R#LibraCoin.T } + public new(c: R#LibraCoin.T): R#Self.A { + return A { c: move(c) }; + } + public destroy_a(a: R#Self.A) { + let c: R#LibraCoin.T; + A { c } = move(a); + LibraCoin.destroy_zero(move(c)); + return; + } + public mutate(a_ref: &mut R#Self.A) { + let ref: &mut R#LibraCoin.T; + let zero: R#LibraCoin.T; + + ref = &mut move(a_ref).c; + zero = LibraCoin.zero(); + *move(ref) = move(zero); + + return; + } +} + +script: +import Transaction.A; +import 0x0.LibraCoin; +main() { + let zero_resource: R#LibraCoin.T; + let s: R#A.A; + + zero_resource = LibraCoin.zero(); + s = A.new(move(zero_resource)); + A.mutate(&mut s); + A.destroy_a(move(s)); + + return; +} diff --git a/language/functional_tests/tests/testsuite/mutation/mutate_resource_holder_2.mvir b/language/functional_tests/tests/testsuite/mutation/mutate_resource_holder_2.mvir new file mode 100644 index 0000000000000..2d5a049101794 --- /dev/null +++ b/language/functional_tests/tests/testsuite/mutation/mutate_resource_holder_2.mvir @@ -0,0 +1,42 @@ +// check: no struct handle index 1 +// TODO is it possible to get this to compile to let the bytecode verifier complain? + +modules: +module A { + import 0x0.LibraCoin; + resource A { c: R#LibraCoin.T } + public new(c: R#LibraCoin.T): R#Self.A { + return A { c: move(c) }; + } + public destroy_a(a: R#Self.A) { + let c: R#LibraCoin.T; + A { c } = move(a); + LibraCoin.destroy_zero(move(c)); + return; + } + public mutate(a_ref: &mut R#Self.A) { + let ref: &mut R#LibraCoin.T; + let ref_balance: &mut u64; + + ref = &mut move(a_ref).c; + ref_balance = &mut move(ref).balance; + *move(ref_balance) = 100; + + return; + } +} + +script: +import Transaction.A; +import 0x0.LibraCoin; +main() { + let zero_resource: R#LibraCoin.T; + let s: R#A.A; + + zero_resource = LibraCoin.zero(); + s = A.new(move(zero_resource)); + A.mutate(&mut s); + A.destroy_a(move(s)); + + return; +} \ No newline at end of file diff --git a/language/functional_tests/tests/testsuite/mutation/nested_mutate.mvir b/language/functional_tests/tests/testsuite/mutation/nested_mutate.mvir new file mode 100644 index 0000000000000..2efb3fb9d85a1 --- /dev/null +++ b/language/functional_tests/tests/testsuite/mutation/nested_mutate.mvir @@ -0,0 +1,48 @@ +modules: +module B { + struct T{g: u64} + + public new(g: u64): V#Self.T { + return T{g: move(g)}; + } + + public t(this: &mut V#Self.T) { + let g: &mut u64; + g = &mut copy(this).g; + *copy(g) = 2; + assert(*move(g) == 2, 42); + release(move(this)); + return; + } +} + +module A { + import Transaction.B; + struct T{f: V#B.T} + + public new(f: V#B.T): V#Self.T { + return T{f: move(f)}; + } + + public t(this: &mut V#Self.T) { + let f: &mut V#B.T; + f = &mut copy(this).f; + B.t(move(f)); + release(move(this)); + return; + } +} + +script: +import Transaction.A; +import Transaction.B; +main() { + let b: V#B.T; + let x: V#A.T; + let x_ref: &mut V#A.T; + b = B.new(0); + x = A.new(move(b)); + x_ref = &mut x; + A.t(move(x_ref)); + return; +} \ No newline at end of file diff --git a/language/functional_tests/tests/testsuite/mutation/no_borrow_ref.mvir b/language/functional_tests/tests/testsuite/mutation/no_borrow_ref.mvir new file mode 100644 index 0000000000000..703376596818a --- /dev/null +++ b/language/functional_tests/tests/testsuite/mutation/no_borrow_ref.mvir @@ -0,0 +1,10 @@ +// check: VerificationError { kind: FunctionDefinition, idx: 0, err: BorrowLocReferenceError(5) } + +main() { + let v: u64; + let ref_v: &u64; + v = 5; + ref_v = &v; + release(&ref_v); + return; +} \ No newline at end of file diff --git a/language/functional_tests/tests/testsuite/mutation/read_field_after_assign_local.mvir b/language/functional_tests/tests/testsuite/mutation/read_field_after_assign_local.mvir new file mode 100644 index 0000000000000..09537c7043d4c --- /dev/null +++ b/language/functional_tests/tests/testsuite/mutation/read_field_after_assign_local.mvir @@ -0,0 +1,34 @@ +// check: VerificationError { kind: FunctionDefinition, idx: 1, err: WriteRefExistsBorrowError(8) } + +modules: +module Tester { + struct T{v: u64} + + public new(v: u64): V#Self.T { + return T{v: move(v)}; + } + + public replace(t: &mut V#Self.T) { + let t_v: &mut u64; + let new_t: V#Self.T; + + t_v = &mut copy(t).v; + new_t = Self.new(1); + *move(t) = move(new_t); + + assert(*move(t_v) == 1, 42); + return; + } + +} + +script: +import Transaction.Tester; +main() { + let t: V#Tester.T; + let ref_t: &mut V#Tester.T; + t = Tester.new(0); + ref_t = &mut t; + Tester.replace(move(ref_t)); + return; +} diff --git a/language/functional_tests/tests/testsuite/mutation/read_local_ref_after_assign.mvir b/language/functional_tests/tests/testsuite/mutation/read_local_ref_after_assign.mvir new file mode 100644 index 0000000000000..ac6231d7aaf1a --- /dev/null +++ b/language/functional_tests/tests/testsuite/mutation/read_local_ref_after_assign.mvir @@ -0,0 +1,16 @@ +// check: VerificationError { kind: FunctionDefinition, idx: 0, err: ReadRefExistsMutableBorrowError(16) } + +main() { + let v: u64; + let read_ref: &mut u64; + let assign_ref: &mut u64; + let no: u64; + v = 5; + read_ref = &mut v; + assign_ref = copy(read_ref); + *copy(assign_ref) = 0; + assert(*copy(assign_ref) == 0, 42); + no = *move(read_ref); + release(move(assign_ref)); + return; +} diff --git a/language/functional_tests/tests/testsuite/mutation/read_local_ref_after_move.mvir b/language/functional_tests/tests/testsuite/mutation/read_local_ref_after_move.mvir new file mode 100644 index 0000000000000..4998601b0e638 --- /dev/null +++ b/language/functional_tests/tests/testsuite/mutation/read_local_ref_after_move.mvir @@ -0,0 +1,12 @@ +// check: VerificationError { kind: FunctionDefinition, idx: 0, err: MoveLocExistsBorrowError(4) } + +main() { + let v: u64; + let ref_v: &mut u64; + let dead: u64; + v = 5; + ref_v = &mut v; + dead = move(v); + assert(*move(ref_v) == 5, 42); + return; +} diff --git a/language/functional_tests/tests/testsuite/mutation/return_local_ref.mvir b/language/functional_tests/tests/testsuite/mutation/return_local_ref.mvir new file mode 100644 index 0000000000000..47ab501660c3f --- /dev/null +++ b/language/functional_tests/tests/testsuite/mutation/return_local_ref.mvir @@ -0,0 +1,24 @@ +// check: VerificationError { kind: FunctionDefinition, idx: 0, err: RetUnsafeToDestroyError(6) } + +modules: +module Tester { + public no(): &u64 { + let x: u64; + let x_ref: &u64; + x = 5; + x_ref = &x; + return move(x_ref); + } + +} + +script: +import Transaction.Tester; +main() { + let x: u64; + let r: &u64; + x = 100; + r = Tester.no(); + assert(*move(r) == 5, 42); + return; +} diff --git a/language/functional_tests/tests/testsuite/mutation/simple_mutate.mvir b/language/functional_tests/tests/testsuite/mutation/simple_mutate.mvir new file mode 100644 index 0000000000000..a885b367684c2 --- /dev/null +++ b/language/functional_tests/tests/testsuite/mutation/simple_mutate.mvir @@ -0,0 +1,28 @@ +modules: +module A { + struct T{f: u64} + + public new(f: u64): V#Self.T { + return T{f: move(f)}; + } + + public t(this: &mut V#Self.T) { + let f: &mut u64; + f = &mut copy(this).f; + *copy(f) = 2; + assert(*move(f) == 2, 42); + release(move(this)); + return; + } +} + +script: +import Transaction.A; +main() { + let x: V#A.T; + let x_ref: &mut V#A.T; + x = A.new(0); + x_ref = &mut x; + A.t(move(x_ref)); + return; +} \ No newline at end of file diff --git a/language/functional_tests/tests/testsuite/mutation/unused_resource_holder.mvir b/language/functional_tests/tests/testsuite/mutation/unused_resource_holder.mvir new file mode 100644 index 0000000000000..7ecdc5f66ab38 --- /dev/null +++ b/language/functional_tests/tests/testsuite/mutation/unused_resource_holder.mvir @@ -0,0 +1,30 @@ +// check: VerificationError { kind: FunctionDefinition, idx: 0, err: RetUnsafeToDestroyError(5) } +modules: +module A { + import 0x0.LibraCoin; + + resource T{g: R#LibraCoin.T} + + public new(g: R#LibraCoin.T): R#Self.T { + return T{g: move(g)}; + } + + public destroy(a: R#Self.T) { + let c: R#LibraCoin.T; + T{g: c} = move(a); + LibraCoin.destroy_zero(move(c)); + return; + } +} + +script: +import Transaction.A; +import 0x0.LibraCoin; +main() { + let zero_resource: R#LibraCoin.T; + let s: R#A.T; + zero_resource = LibraCoin.zero(); + s = A.new(move(zero_resource)); + + return; +} diff --git a/language/functional_tests/tests/testsuite/mutation/use_after_move.mvir b/language/functional_tests/tests/testsuite/mutation/use_after_move.mvir new file mode 100644 index 0000000000000..5129f76e5a7c8 --- /dev/null +++ b/language/functional_tests/tests/testsuite/mutation/use_after_move.mvir @@ -0,0 +1,48 @@ +// check: VerificationError { kind: FunctionDefinition, idx: 1, err: WriteRefExistsBorrowError(10) } + +modules: +module B { + struct T{g: u64} + + public new(g: u64): V#Self.T { + return T { g: move(g) }; + } +} + +module A { + import Transaction.B; + struct T{value: V#B.T} + public new(m: V#B.T): V#Self.T { + return T{value: move(m)}; + } + + public t(this: &mut V#Self.T) { + let ref1: &mut V#B.T; + let ref2: &mut V#B.T; + let b2: V#B.T; + let x: V#B.T; + ref1 = &mut move(this).value; + ref2 = copy(ref1); + b2 = B.new(3); + *move(ref1) = move(b2); + + x = *move(ref2); + + return; + } +} + +script: +import Transaction.A; +import Transaction.B; +main() { + let b: V#B.T; + let a: V#A.T; + let a_ref: &mut V#A.T; + b = B.new(1); + a = A.new(move(b)); + a_ref = &mut a; + A.t(move(a_ref)); + + return; +} diff --git a/language/functional_tests/tests/testsuite/mutation/use_prefix_after_move.mvir b/language/functional_tests/tests/testsuite/mutation/use_prefix_after_move.mvir new file mode 100644 index 0000000000000..c6641fa6f85bb --- /dev/null +++ b/language/functional_tests/tests/testsuite/mutation/use_prefix_after_move.mvir @@ -0,0 +1,46 @@ +modules: +module B { + struct T{g: u64} + + public new(g: u64): V#Self.T { + return T{g: move(g)}; + } + + public t(this: &mut V#Self.T) { + let g: &mut u64; + g = &mut move(this).g; + *move(g) = 3; + return; + } +} +module A { + import Transaction.B; + struct T{f: V#B.T} + + public new(f: V#B.T): V#Self.T { + return T{f: move(f)}; + } + + public t(this: &mut V#Self.T) { + let ref1: &mut V#B.T; + let ok: V#B.T; + ref1 = &mut move(this).f; + B.t(copy(ref1)); + ok = *move(ref1); + return; + } +} + +script: +import Transaction.A; +import Transaction.B; +main() { + let b: V#B.T; + let x: V#A.T; + let x_ref: &mut V#A.T; + b = B.new(0); + x = A.new(move(b)); + x_ref = &mut x; + A.t(move(x_ref)); + return; +} \ No newline at end of file diff --git a/language/functional_tests/tests/testsuite/mutation/use_suffix_after_move.mvir b/language/functional_tests/tests/testsuite/mutation/use_suffix_after_move.mvir new file mode 100644 index 0000000000000..e421d96e623f5 --- /dev/null +++ b/language/functional_tests/tests/testsuite/mutation/use_suffix_after_move.mvir @@ -0,0 +1,50 @@ +// check: VerificationError { kind: FunctionDefinition, idx: 1, err: CallTypeMismatchError(4) } + +modules: +module B { + struct T{g: u64} + + public new(g: u64): V#Self.T { + return T{g: move(g)}; + } + + public t(this: &V#Self.T): &u64 { + let g: &u64; + g = &move(this).g; + return move(g); + } +} + +module A { + import Transaction.B; + struct T{f: V#B.T} + + public new(f: V#B.T): V#Self.T { + return T{f: move(f)}; + } + + public t(this: &mut V#Self.T) { + let ref1: &mut V#B.T; + let ref2: &u64; + let b2: V#B.T; + ref1 = &mut move(this).f; + ref2 = B.t(copy(ref1)); + b2 = B.new(3); + *move(ref1) = move(b2); + return; + } +} + +script: +import Transaction.A; +import Transaction.B; +main() { + let b: V#B.T; + let x: V#A.T; + let x_ref: &mut V#A.T; + b = B.new(0); + x = A.new(move(b)); + x_ref = &mut x; + A.t(move(x_ref)); + return; +} diff --git a/language/functional_tests/tests/testsuite/natives/check_native_keccak256.mvir b/language/functional_tests/tests/testsuite/natives/check_native_keccak256.mvir new file mode 100644 index 0000000000000..862dcf6443620 --- /dev/null +++ b/language/functional_tests/tests/testsuite/natives/check_native_keccak256.mvir @@ -0,0 +1,15 @@ +import 0x0.Hash; + +main() { + let input: bytearray; + let output: bytearray; + let expected_output: bytearray; + + input = b"616263"; + output = Hash.keccak256(copy(input)); + expected_output = b"4e03657aea45a94fc7d47ba826c8d667c0d1e6e33a64a036ec44f58fa12d6c45"; + + assert(move(output) == move(expected_output), 42); + + return; +} diff --git a/language/functional_tests/tests/testsuite/natives/check_native_ripemd160.mvir b/language/functional_tests/tests/testsuite/natives/check_native_ripemd160.mvir new file mode 100644 index 0000000000000..346c6f92c92ce --- /dev/null +++ b/language/functional_tests/tests/testsuite/natives/check_native_ripemd160.mvir @@ -0,0 +1,15 @@ +import 0x0.Hash; + +main() { + let input: bytearray; + let output: bytearray; + let expected_output: bytearray; + + input = b"616263"; + output = Hash.ripemd160(copy(input)); + expected_output = b"bb1be98c142444d7a56aa3981c3942a978e4dc33"; + + assert(move(output) == move(expected_output), 42); + + return; +} diff --git a/language/functional_tests/tests/testsuite/natives/check_native_sha2_256.mvir b/language/functional_tests/tests/testsuite/natives/check_native_sha2_256.mvir new file mode 100644 index 0000000000000..4482122f780da --- /dev/null +++ b/language/functional_tests/tests/testsuite/natives/check_native_sha2_256.mvir @@ -0,0 +1,15 @@ +import 0x0.Hash; + +main() { + let input: bytearray; + let output: bytearray; + let expected_output: bytearray; + + input = b"616263"; + output = Hash.sha2_256(copy(input)); + expected_output = b"ba7816bf8f01cfea414140de5dae2223b00361a396177a9cb410ff61f20015ad"; + + assert(move(output) == move(expected_output), 42); + + return; +} diff --git a/language/functional_tests/tests/testsuite/natives/check_native_sha3_256.mvir b/language/functional_tests/tests/testsuite/natives/check_native_sha3_256.mvir new file mode 100644 index 0000000000000..5897b319bab5e --- /dev/null +++ b/language/functional_tests/tests/testsuite/natives/check_native_sha3_256.mvir @@ -0,0 +1,15 @@ +import 0x0.Hash; + +main() { + let input: bytearray; + let output: bytearray; + let expected_output: bytearray; + + input = b"616263"; + output = Hash.sha3_256(copy(input)); + expected_output = b"3a985da74fe225b2045c172d6bd390bd855f086e3e9d525b46bfe24511431532"; + + assert(move(output) == move(expected_output), 42); + + return; +} diff --git a/language/functional_tests/tests/testsuite/natives/signature_ed25519.mvir b/language/functional_tests/tests/testsuite/natives/signature_ed25519.mvir new file mode 100644 index 0000000000000..511e8f9f51d90 --- /dev/null +++ b/language/functional_tests/tests/testsuite/natives/signature_ed25519.mvir @@ -0,0 +1,24 @@ +import 0x0.Signature; + +main() { + let signature: bytearray; + let public_key: bytearray; + let message: bytearray; + + let output: bool; + let expected_output: bool; + + signature = b"62d6be393b8ec77fb2c12ff44ca8b5bd8bba83b805171bc99f0af3bdc619b20b8bd529452fe62dac022c80752af2af02fb610c20f01fb67a4d72789db2b8b703"; + public_key = b"7013b6ed7dde3cfb1251db1b04ae9cd7853470284085693590a75def645a926d"; + message = b"0000000000000000000000000000000000000000000000000000000000000000"; + + output = Signature.ed25519_verify(copy(signature), copy(public_key), copy(message)); + expected_output = true; + + assert(move(output) == move(expected_output), 42); + + return; +} + +// check: LinkerError +// TODO: fix it diff --git a/language/functional_tests/tests/testsuite/natives/signature_ed25519_bad_msg.mvir b/language/functional_tests/tests/testsuite/natives/signature_ed25519_bad_msg.mvir new file mode 100644 index 0000000000000..ce474e0d19f52 --- /dev/null +++ b/language/functional_tests/tests/testsuite/natives/signature_ed25519_bad_msg.mvir @@ -0,0 +1,24 @@ +import 0x0.Signature; + +main() { + let signature: bytearray; + let public_key: bytearray; + let message: bytearray; + + let output: bool; + let expected_output: bool; + + signature = b"62d6be393b8ec77fb2c12ff44ca8b5bd8bba83b805171bc99f0af3bdc619b20b8bd529452fe62dac022c80752af2af02fb610c20f01fb67a4d72789db2b8b703"; + public_key = b"7013b6ed7dde3cfb1251db1b04ae9cd7853470284085693590a75def645a926d"; + message = b"0000000000000000000000000000000000000000000000000000000000000001"; + + output = Signature.ed25519_verify(copy(signature), copy(public_key), copy(message)); + expected_output = false; + + assert(move(output) == move(expected_output), 42); + + return; +} + +// check: LinkerError +// TODO: fix it diff --git a/language/functional_tests/tests/testsuite/natives/signature_ed25519_bad_pk.mvir b/language/functional_tests/tests/testsuite/natives/signature_ed25519_bad_pk.mvir new file mode 100644 index 0000000000000..93b7149014816 --- /dev/null +++ b/language/functional_tests/tests/testsuite/natives/signature_ed25519_bad_pk.mvir @@ -0,0 +1,24 @@ +import 0x0.Signature; + +main() { + let signature: bytearray; + let public_key: bytearray; + let message: bytearray; + + let output: bool; + let expected_output: bool; + + signature = b"62d6be393b8ec77fb2c12ff44ca8b5bd8bba83b805171bc99f0af3bdc619b20b8bd529452fe62dac022c80752af2af02fb610c20f01fb67a4d72789db2b8b703"; + public_key = b"7013b6ed7dde3cfb1251db1b04ae9cd7853480284085693590a75def645a926d"; + message = b"0000000000000000000000000000000000000000000000000000000000000000"; + + output = Signature.ed25519_verify(copy(signature), copy(public_key), copy(message)); + expected_output = false; + + assert(move(output) == move(expected_output), 42); + + return; +} + +// check: LinkerError +// TODO: fix it diff --git a/language/functional_tests/tests/testsuite/natives/signature_ed25519_bad_sig.mvir b/language/functional_tests/tests/testsuite/natives/signature_ed25519_bad_sig.mvir new file mode 100644 index 0000000000000..f4d76546bf95a --- /dev/null +++ b/language/functional_tests/tests/testsuite/natives/signature_ed25519_bad_sig.mvir @@ -0,0 +1,24 @@ +import 0x0.Signature; + +main() { + let signature: bytearray; + let public_key: bytearray; + let message: bytearray; + + let output: bool; + let expected_output: bool; + + signature = b"62d6be393b8ec77fb2c12ff44ca8b5bd8bba83b895171bc99f0af3bdc619b20b8bd529452fe62dac022c80752af2af02fb610c20f01fb67a4d72789db2b8b703"; + public_key = b"7013b6ed7dde3cfb1251db1b04ae9cd7853470284085693590a75def645a926d"; + message = b"0000000000000000000000000000000000000000000000000000000000000000"; + + output = Signature.ed25519_verify(copy(signature), copy(public_key), copy(message)); + expected_output = false; + + assert(move(output) == move(expected_output), 42); + + return; +} + +// check: LinkerError +// TODO: fix it diff --git a/language/functional_tests/tests/testsuite/operators/arithmetic_operators.mvir b/language/functional_tests/tests/testsuite/operators/arithmetic_operators.mvir new file mode 100644 index 0000000000000..96c25d2e0ddaa --- /dev/null +++ b/language/functional_tests/tests/testsuite/operators/arithmetic_operators.mvir @@ -0,0 +1,9 @@ +main() { + assert(1 + 2 == 3, 99); + assert(3 - 2 == 1, 100); + assert(2 * 3 == 6, 101); + assert(5 % 2 == 1, 102); + assert(6 / 2 == 3, 103); + + return; +} \ No newline at end of file diff --git a/language/functional_tests/tests/testsuite/operators/bitwise_operators.mvir b/language/functional_tests/tests/testsuite/operators/bitwise_operators.mvir new file mode 100644 index 0000000000000..018deaccd52e1 --- /dev/null +++ b/language/functional_tests/tests/testsuite/operators/bitwise_operators.mvir @@ -0,0 +1,7 @@ +main() { + assert(2 | 2 == 2, 99); + assert(2 & 1 == 0, 100); + assert(2 ^ 1 == 3, 101); + + return; +} diff --git a/language/functional_tests/tests/testsuite/operators/boolean_not_non_boolean.mvir b/language/functional_tests/tests/testsuite/operators/boolean_not_non_boolean.mvir new file mode 100644 index 0000000000000..edeff072121d7 --- /dev/null +++ b/language/functional_tests/tests/testsuite/operators/boolean_not_non_boolean.mvir @@ -0,0 +1,8 @@ +main() { + let x: bool; + x = !0; + return; +} + +// check: VerificationError +// check: BooleanOpTypeMismatchError diff --git a/language/functional_tests/tests/testsuite/operators/boolean_operators.mvir b/language/functional_tests/tests/testsuite/operators/boolean_operators.mvir new file mode 100644 index 0000000000000..bd36fc40c1ff5 --- /dev/null +++ b/language/functional_tests/tests/testsuite/operators/boolean_operators.mvir @@ -0,0 +1,10 @@ +main() { + assert(true && false == false, 99); + assert(true || false == true, 100); + assert(!true == false, 101); + assert(!false == true, 102); + assert(!!true == true, 103); + assert(!!false == false, 104); + + return; +} \ No newline at end of file diff --git a/language/functional_tests/tests/testsuite/operators/comparison_operators.mvir b/language/functional_tests/tests/testsuite/operators/comparison_operators.mvir new file mode 100644 index 0000000000000..151edeff068f8 --- /dev/null +++ b/language/functional_tests/tests/testsuite/operators/comparison_operators.mvir @@ -0,0 +1,21 @@ +main() { + assert((1 == 1) == true, 99); + assert((1 == 2) == false, 99); + + assert((1 != 2) == true, 99); + assert((1 != 1) == false, 99); + + assert((2 > 1) == true, 99); + assert((1 > 1) == false, 99); + + assert((1 >= 1) == true, 99); + assert((0 >= 1) == false, 99); + + assert((1 < 2) == true, 99); + assert((2 < 1) == false, 99); + + assert((1 <= 1) == true, 99); + assert((2 <= 1) == false, 99); + + return; +} diff --git a/language/functional_tests/tests/testsuite/operators/division_by_zero.mvir b/language/functional_tests/tests/testsuite/operators/division_by_zero.mvir new file mode 100644 index 0000000000000..173404c90bd87 --- /dev/null +++ b/language/functional_tests/tests/testsuite/operators/division_by_zero.mvir @@ -0,0 +1,8 @@ +main() { + let x: u64; + x = 514 / 0; + return; +} + +// check: ArithmeticError +// TODO: check for division by zero once vm errors are improved diff --git a/language/functional_tests/tests/testsuite/operators/mod_by_zero.mvir b/language/functional_tests/tests/testsuite/operators/mod_by_zero.mvir new file mode 100644 index 0000000000000..cb3d9d6008887 --- /dev/null +++ b/language/functional_tests/tests/testsuite/operators/mod_by_zero.mvir @@ -0,0 +1,8 @@ +main() { + let x: u64; + x = 514 % 0; + return; +} + +// check: ArithmeticError +// TODO: check for division by zero once vm errors are improved diff --git a/language/functional_tests/tests/testsuite/operators/overflow_via_addition.mvir b/language/functional_tests/tests/testsuite/operators/overflow_via_addition.mvir new file mode 100644 index 0000000000000..620bb76893bfc --- /dev/null +++ b/language/functional_tests/tests/testsuite/operators/overflow_via_addition.mvir @@ -0,0 +1,8 @@ +main() { + let x: u64; + x = 18446744073709551615 + 1; + return; +} + +// check: ArithmeticError +// TODO: check for overflow once vm errors are improved diff --git a/language/functional_tests/tests/testsuite/operators/overflow_via_multiplication.mvir b/language/functional_tests/tests/testsuite/operators/overflow_via_multiplication.mvir new file mode 100644 index 0000000000000..4695473b9cda2 --- /dev/null +++ b/language/functional_tests/tests/testsuite/operators/overflow_via_multiplication.mvir @@ -0,0 +1,8 @@ +main() { + let x: u64; + x = 18446744073709551615 * 2; + return; +} + +// check: ArithmeticError +// TODO: check for overflow once vm errors are improved diff --git a/language/functional_tests/tests/testsuite/operators/underflow_via_subtraction.mvir b/language/functional_tests/tests/testsuite/operators/underflow_via_subtraction.mvir new file mode 100644 index 0000000000000..c87e0444a5851 --- /dev/null +++ b/language/functional_tests/tests/testsuite/operators/underflow_via_subtraction.mvir @@ -0,0 +1,8 @@ +main() { + let x: u64; + x = 0 - 1; + return; +} + +// check: ArithmeticError +// TODO: check for underflow once vm errors are improved diff --git a/language/functional_tests/tests/testsuite/payments/check_balance.mvir b/language/functional_tests/tests/testsuite/payments/check_balance.mvir new file mode 100644 index 0000000000000..1ef04fd94181e --- /dev/null +++ b/language/functional_tests/tests/testsuite/payments/check_balance.mvir @@ -0,0 +1,11 @@ +import 0x0.LibraAccount; + +main() { + let addr: address; + let struct1_original_balance: u64; + addr = get_txn_sender(); + struct1_original_balance = LibraAccount.balance(copy(addr)); + assert(copy(struct1_original_balance) > 10, 77); + + return; +} diff --git a/language/functional_tests/tests/testsuite/payments/check_balance_after_withdraw.mvir b/language/functional_tests/tests/testsuite/payments/check_balance_after_withdraw.mvir new file mode 100644 index 0000000000000..69fb70ee6d1b2 --- /dev/null +++ b/language/functional_tests/tests/testsuite/payments/check_balance_after_withdraw.mvir @@ -0,0 +1,25 @@ +import 0x0.LibraAccount; +import 0x0.LibraCoin; + +main() { + let addr: address; + let sender_original_balance: u64; + let five_coins: R#LibraCoin.T; + let five_coins_ref: &R#LibraCoin.T; + let five_coins_value: u64; + let sender_new_balance: u64; + + addr = get_txn_sender(); + sender_original_balance = LibraAccount.balance(copy(addr)); + five_coins = LibraAccount.withdraw_from_sender(5); + + five_coins_ref = &five_coins; + five_coins_value = LibraCoin.value(move(five_coins_ref)); + assert(move(five_coins_value) == 5, 66); + + sender_new_balance = LibraAccount.balance(copy(addr)); + assert(copy(sender_new_balance) == copy(sender_original_balance) - 5, 77); + LibraAccount.deposit(move(addr), move(five_coins)); + + return; +} diff --git a/language/functional_tests/tests/testsuite/payments/local_split_payment.mvir b/language/functional_tests/tests/testsuite/payments/local_split_payment.mvir new file mode 100644 index 0000000000000..5830264bbf66a --- /dev/null +++ b/language/functional_tests/tests/testsuite/payments/local_split_payment.mvir @@ -0,0 +1,44 @@ +import 0x0.LibraAccount; +import 0x0.LibraCoin; + +main() { + let addr1: address; + let addr2: address; + let addr3: address; + let struct1_original_balance: u64; + let struct2_original_balance: u64; + let struct3_original_balance: u64; + let five_coins1: R#LibraCoin.T; + let fc_ref: &mut R#LibraCoin.T; + let five_coins2: R#LibraCoin.T; + let struct1_new_balance: u64; + let struct2_new_balance: u64; + let struct3_new_balance: u64; + + addr1 = get_txn_sender(); + addr2 = 0x42; + addr3 = 0x43; + + struct1_original_balance = LibraAccount.balance(copy(addr1)); + struct2_original_balance = LibraAccount.balance(copy(addr2)); + struct3_original_balance = LibraAccount.balance(copy(addr3)); + + five_coins1 = LibraAccount.withdraw_from_sender(10); + fc_ref = &mut five_coins1; + five_coins2 = LibraCoin.withdraw(move(fc_ref), 5); + + LibraAccount.deposit(copy(addr2), move(five_coins1)); + LibraAccount.deposit(copy(addr3), move(five_coins2)); + + struct1_new_balance = LibraAccount.balance(copy(addr1)); + struct2_new_balance = LibraAccount.balance(copy(addr2)); + struct3_new_balance = LibraAccount.balance(copy(addr3)); + + assert(copy(struct1_new_balance) == copy(struct1_original_balance) - 10, 41); + assert(copy(struct2_new_balance) == copy(struct2_original_balance) + 5, 42); + assert(copy(struct3_new_balance) == copy(struct3_original_balance) + 5, 43); + return; +} + +// check: Execution(MissingData) +// TODO: current testing setup does not support this scenario diff --git a/language/functional_tests/tests/testsuite/payments/multi_payment.mvir b/language/functional_tests/tests/testsuite/payments/multi_payment.mvir new file mode 100644 index 0000000000000..55b02e11af1e0 --- /dev/null +++ b/language/functional_tests/tests/testsuite/payments/multi_payment.mvir @@ -0,0 +1,42 @@ +import 0x0.LibraAccount; +import 0x0.LibraCoin; + +main() { + let addr1: address; + let addr2: address; + let addr3: address; + let struct1_original_balance: u64; + let struct2_original_balance: u64; + let struct3_original_balance: u64; + let five_coins1: R#LibraCoin.T; + let five_coins2: R#LibraCoin.T; + let struct1_new_balance: u64; + let struct2_new_balance: u64; + let struct3_new_balance: u64; + + addr1 = get_txn_sender(); + addr2 = 0x42; + addr3 = 0x43; + + struct1_original_balance = LibraAccount.balance(copy(addr1)); + struct2_original_balance = LibraAccount.balance(copy(addr2)); + struct3_original_balance = LibraAccount.balance(copy(addr3)); + + five_coins1 = LibraAccount.withdraw_from_sender(5); + five_coins2 = LibraAccount.withdraw_from_sender(5); + LibraAccount.deposit(copy(addr2), move(five_coins1)); + LibraAccount.deposit(copy(addr3), move(five_coins2)); + + struct1_new_balance = LibraAccount.balance(copy(addr1)); + struct2_new_balance = LibraAccount.balance(copy(addr2)); + struct3_new_balance = LibraAccount.balance(copy(addr3)); + + assert(copy(struct1_new_balance) == copy(struct1_original_balance) - 10, 42); + assert(copy(struct2_new_balance) == copy(struct2_original_balance) + 5, 42); + assert(copy(struct3_new_balance) == copy(struct3_original_balance) + 5, 42); + + return; +} + +// check: Execution(MissingData) +// TODO: current testing setup does not support this scenario diff --git a/language/functional_tests/tests/testsuite/payments/peer_to_peer_payment.mvir b/language/functional_tests/tests/testsuite/payments/peer_to_peer_payment.mvir new file mode 100644 index 0000000000000..0eac0518c50fe --- /dev/null +++ b/language/functional_tests/tests/testsuite/payments/peer_to_peer_payment.mvir @@ -0,0 +1,26 @@ +import 0x0.LibraAccount; + +main() { + let sender_addr: address; + let recipient_addr: address; + let sender_original_balance: u64; + let recipient_original_balance: u64; + let sender_new_balance: u64; + let recipient_new_balance: u64; + + sender_addr = get_txn_sender(); + recipient_addr = 0x42; + sender_original_balance = LibraAccount.balance(copy(sender_addr)); + recipient_original_balance = LibraAccount.balance(copy(recipient_addr)); + LibraAccount.pay_from_sender(copy(recipient_addr), 5); + + sender_new_balance = LibraAccount.balance(move(sender_addr)); + recipient_new_balance = LibraAccount.balance(move(recipient_addr)); + assert(copy(sender_new_balance) == copy(sender_original_balance) - 5, 77); + assert(copy(recipient_new_balance) == copy(recipient_original_balance) + 5, 77); + + return; +} + +// check: Execution(MissingData) +// TODO: current testing setup does not support this scenario diff --git a/language/functional_tests/tests/testsuite/payments/withdraw_then_deposit_payment.mvir b/language/functional_tests/tests/testsuite/payments/withdraw_then_deposit_payment.mvir new file mode 100644 index 0000000000000..c62c15650719e --- /dev/null +++ b/language/functional_tests/tests/testsuite/payments/withdraw_then_deposit_payment.mvir @@ -0,0 +1,30 @@ +import 0x0.LibraAccount; +import 0x0.LibraCoin; + +main() { + let sender_addr: address; + let recipient_addr: address; + let sender_original_balance: u64; + let recipient_original_balance: u64; + let five_coins: R#LibraCoin.T; + let sender_new_balance: u64; + let recipient_new_balance: u64; + + sender_addr = get_txn_sender(); + recipient_addr = 0x42; + sender_original_balance = LibraAccount.balance(copy(sender_addr)); + recipient_original_balance = LibraAccount.balance(copy(recipient_addr)); + five_coins = LibraAccount.withdraw_from_sender(5); + LibraAccount.deposit(copy(recipient_addr), move(five_coins)); + + sender_new_balance = LibraAccount.balance(move(sender_addr)); + recipient_new_balance = LibraAccount.balance(move(recipient_addr)); + + assert(move(sender_new_balance) == move(sender_original_balance) - 5, 77); + assert(move(recipient_new_balance) == move(recipient_original_balance) + 5, 88); + + return; +} + +// check: Execution(MissingData) +// TODO: current testing setup does not support this scenario diff --git a/language/functional_tests/tests/testsuite/prologue/get_txn_sequence_number.mvir b/language/functional_tests/tests/testsuite/prologue/get_txn_sequence_number.mvir new file mode 100644 index 0000000000000..8df4ed13bea4a --- /dev/null +++ b/language/functional_tests/tests/testsuite/prologue/get_txn_sequence_number.mvir @@ -0,0 +1,17 @@ +import 0x0.LibraAccount; +main() { + let transaction_sequence_number: u64; + let sender: address; + let new_sequence_number: u64; + + transaction_sequence_number = get_txn_sequence_number(); + + assert(copy(transaction_sequence_number) == 0, 42); + + sender = get_txn_sender(); + new_sequence_number = LibraAccount.sequence_number(move(sender)); + + assert(copy(new_sequence_number) == copy(transaction_sequence_number), 42); + + return; +} diff --git a/language/functional_tests/tests/testsuite/publish/publish_duplicate_modules.mvir b/language/functional_tests/tests/testsuite/publish/publish_duplicate_modules.mvir new file mode 100644 index 0000000000000..4e923b9b91eb5 --- /dev/null +++ b/language/functional_tests/tests/testsuite/publish/publish_duplicate_modules.mvir @@ -0,0 +1,18 @@ +// Attempting to publish two modules with the same name should fail +modules: + +module Duplicate { + resource T1 { f: u64 } +} + +module Duplicate { + resource T2 { f: bool } +} + +script: +main() { + return; +} + +// check: VMExecutionFailure +// check: DuplicateModuleName diff --git a/language/functional_tests/tests/testsuite/publish/publish_empty_module.mvir b/language/functional_tests/tests/testsuite/publish/publish_empty_module.mvir new file mode 100644 index 0000000000000..6af738da79abf --- /dev/null +++ b/language/functional_tests/tests/testsuite/publish/publish_empty_module.mvir @@ -0,0 +1,5 @@ +modules: +script: +main() { + return; +} diff --git a/language/functional_tests/tests/testsuite/publish/publish_module_and_use.mvir b/language/functional_tests/tests/testsuite/publish/publish_module_and_use.mvir new file mode 100644 index 0000000000000..00607c28ea9e1 --- /dev/null +++ b/language/functional_tests/tests/testsuite/publish/publish_module_and_use.mvir @@ -0,0 +1,44 @@ +modules: +module MoneyHolder { + import 0x0.LibraCoin; + + resource T { money: R#LibraCoin.T } + + public new(m: R#LibraCoin.T): R#Self.T { + return T{ money: move(m) }; + } + + public value(this :&R#Self.T): u64 { + let ref: &R#LibraCoin.T; + let val: u64; + ref = ©(this).money; + val = LibraCoin.value(move(ref)); + release(move(this)); + return move(val); + } + + public destroy_t(t: R#Self.T) { + let money: R#LibraCoin.T; + T{ money } = move(t); + LibraCoin.destroy_zero(move(money)); + return; + } +} + +script: +import Transaction.MoneyHolder; +import 0x0.LibraCoin; +main() { + let coin: R#LibraCoin.T; + let money_holder: R#MoneyHolder.T; + let money_holder_ref: &R#MoneyHolder.T; + let value: u64; + coin = LibraCoin.zero(); + money_holder = MoneyHolder.new(move(coin)); + money_holder_ref = &money_holder; + value = MoneyHolder.value(move(money_holder_ref)); + assert(copy(value) == 0, 42); + MoneyHolder.destroy_t(move(money_holder)); + + return; +} diff --git a/language/functional_tests/tests/testsuite/publish/publish_two_modules.mvir b/language/functional_tests/tests/testsuite/publish/publish_two_modules.mvir new file mode 100644 index 0000000000000..64eb3b6835041 --- /dev/null +++ b/language/functional_tests/tests/testsuite/publish/publish_two_modules.mvir @@ -0,0 +1,67 @@ +modules: +module MoneyHolder { + import 0x0.LibraCoin; + + resource T { money: R#LibraCoin.T } + + public new(m: R#LibraCoin.T): R#Self.T { + return T{ money: move(m) }; + } + + public value(this :&R#Self.T): u64 { + let ref: &R#LibraCoin.T; + let val: u64; + ref = ©(this).money; + val = LibraCoin.value(move(ref)); + release(move(this)); + return move(val); + } + + public destroy_t(t: R#Self.T) { + let money: R#LibraCoin.T; + T{ money } = move(t); + LibraCoin.destroy_zero(move(money)); + return; + } +} + +module Bar { + struct T{baz: u64} + public new(m: u64): V#Self.T { + return T{baz: move(m)}; + } + public value(this: &V#Self.T): u64 { + let ref: &u64; + ref = &move(this).baz; + return *move(ref); + } +} +script: +import Transaction.MoneyHolder; +import Transaction.Bar; +import 0x0.LibraCoin; +main() { + let coin: R#LibraCoin.T; + let money_holder: R#MoneyHolder.T; + let money_holder_ref: &R#MoneyHolder.T; + let value: u64; + let v: u64; + let bar: V#Bar.T; + let bar_ref: &V#Bar.T; + let v2: u64; + + coin = LibraCoin.zero(); + money_holder = MoneyHolder.new(move(coin)); + money_holder_ref = &money_holder; + + value = MoneyHolder.value(move(money_holder_ref)); + assert(copy(value) == 0, 42); + MoneyHolder.destroy_t(move(money_holder)); + + v = 1; + bar = Bar.new(copy(v)); + bar_ref = &bar; + v2 = Bar.value(move(bar_ref)); + assert(copy(v) == copy(v2), 42); + return; +} diff --git a/language/functional_tests/tests/testsuite/publish/resources_are_distinct_by_published_account.mvir b/language/functional_tests/tests/testsuite/publish/resources_are_distinct_by_published_account.mvir new file mode 100644 index 0000000000000..be36edb83a56c --- /dev/null +++ b/language/functional_tests/tests/testsuite/publish/resources_are_distinct_by_published_account.mvir @@ -0,0 +1,55 @@ +modules: +module LibraAccount { + resource T{sequence_number: u64} + + public new(m: u64): R#Self.T { + return T{sequence_number: move(m)}; + } + + public sequence_number(this: &R#Self.T): u64 { + let ref: &u64; + ref = ©(this).sequence_number; + release(move(this)); + return *move(ref); + } + + public get(account_addr: address): &mut R#Self.T { + let ref: &mut R#Self.T; + ref = borrow_global(move(account_addr)); + return move(ref); + } + + public publish(r: R#Self.T) { + move_to_sender(move(r)); + return; + } +} + +script: +import Transaction.LibraAccount as MyAccount; +import 0x0.LibraAccount; +main() { + let c1: R#MyAccount.T; + let sender: address; + let sequence_number: u64; + let fake_ref: &mut R#MyAccount.T; + let fake_value: u64; + + c1 = MyAccount.new(112); + MyAccount.publish(move(c1)); + + sender = get_txn_sender(); + + sequence_number = LibraAccount.sequence_number(copy(sender)); + assert(copy(sequence_number) == 0, 42); + + fake_ref = MyAccount.get(copy(sender)); + fake_ref = freeze(move(fake_ref)); + fake_value = MyAccount.sequence_number(move(fake_ref)); + assert(copy(fake_value) == 112, 43); + + return; +} + +// TODO: disabled test +// check: Cannot find function diff --git a/language/functional_tests/tests/testsuite/recursion/direct_recursion.mvir b/language/functional_tests/tests/testsuite/recursion/direct_recursion.mvir new file mode 100644 index 0000000000000..6afbb15ecdcbc --- /dev/null +++ b/language/functional_tests/tests/testsuite/recursion/direct_recursion.mvir @@ -0,0 +1,42 @@ +modules: +module Math { + + public sum_(n: u64, acc: u64): u64 { + let new_n: u64; + let new_acc: u64; + let new_sum: u64; + + if (copy(n) == 0) { + return move(acc); + } + + new_n = copy(n) - 1; + new_acc = move(acc) + move(n); + new_sum = Self.sum_(move(new_n), move(new_acc)); + + return move(new_sum); + } + + public sum(n: u64): u64 { + let result: u64; + result = Self.sum_(move(n), 0); + return move(result); + } + +} + +script: +import Transaction.Math; +main() { + let sum1: u64; + let sum2: u64; + + sum1 = Math.sum(5); + assert(move(sum1) == 15, 66); + + sum2 = Math.sum(7); + assert(move(sum2) == 28, 67); + + return; +} + diff --git a/language/functional_tests/tests/testsuite/recursion/mutual_recursion.mvir b/language/functional_tests/tests/testsuite/recursion/mutual_recursion.mvir new file mode 100644 index 0000000000000..6450cc03f63c8 --- /dev/null +++ b/language/functional_tests/tests/testsuite/recursion/mutual_recursion.mvir @@ -0,0 +1,47 @@ +modules: +module Math { + + public even(n: u64): bool { + let is_pred_odd: bool; + if (copy(n) == 0) { + return true; + } + + is_pred_odd = Self.odd(move(n) - 1); + return move(is_pred_odd); + } + + public odd(n: u64): bool { + let is_pred_even: bool; + if (copy(n) == 0) { + return false; + } + + is_pred_even = Self.even(move(n) - 1); + return move(is_pred_even); + } +} + +script: +import Transaction.Math; +main() { + let zero_even: bool; + let zero_odd: bool; + let ten_even: bool; + let ten_odd: bool; + + zero_even = Math.even(0); + assert(move(zero_even), 50); + + zero_odd = Math.odd(0); + assert(!move(zero_odd), 51); + + ten_even = Math.even(10); + assert(move(ten_even), 52); + + ten_odd = Math.odd(10); + assert(!move(ten_odd), 53); + + return; +} + diff --git a/language/stdlib/Cargo.toml b/language/stdlib/Cargo.toml new file mode 100644 index 0000000000000..3fd5814bddb91 --- /dev/null +++ b/language/stdlib/Cargo.toml @@ -0,0 +1,12 @@ +[package] +name = "stdlib" +version = "0.1.0" +edition = "2018" +authors = ["Libra Association "] +license = "Apache-2.0" +publish = false + +[dependencies] +compiler = { path = "../compiler" } +types = { path = "../../types" } +lazy_static = "1.3.0" diff --git a/language/stdlib/modules/hash.mvir b/language/stdlib/modules/hash.mvir new file mode 100644 index 0000000000000..851056bb889be --- /dev/null +++ b/language/stdlib/modules/hash.mvir @@ -0,0 +1,6 @@ +module Hash { + native public keccak256(data: bytearray): bytearray; + native public ripemd160(data: bytearray): bytearray; + native public sha2_256(data: bytearray): bytearray; + native public sha3_256(data: bytearray): bytearray; +} diff --git a/language/stdlib/modules/libra_account.mvir b/language/stdlib/modules/libra_account.mvir new file mode 100644 index 0000000000000..112f1813471c4 --- /dev/null +++ b/language/stdlib/modules/libra_account.mvir @@ -0,0 +1,303 @@ +// The module for the account resource that governs every Libra account +module LibraAccount { + import 0x0.LibraCoin; + import 0x00.Hash; + + // Every Libra account has a LibraLibraAccount.T resource + resource T { + // The coins stored in this account + balance: R#LibraCoin.T, + // The current authentication key. + // This can be different than the key used to create the account + authentication_key: bytearray, + // The current sequence number. + // Incremented by one each time a transaction is submitted + sequence_number: u64, + // TEMPORARY the current count for the number of sent events for this account + // The events system is being overhauled and this will be replaced + sent_events_count: u64, + // TEMPORARY the current count for the number of received events for this account + // The events system is being overhauled and this will be replaced + received_events_count: u64 + } + + // Message for sent events + struct SentPaymentEvent { + // The address that was paid + payee: address, + // The amount of LibraCoin.T sent + amount: u64, + } + + // Message for received events + struct ReceivedPaymentEvent { + // The address that sent the coin + payer: address, + // The amount of LibraCoin.T received + amount: u64, + } + + // Creates a new LibraLibraAccount.T + // Invoked by the `create_account` builtin + make(auth_key: bytearray): R#Self.T { + let zero_balance: R#LibraCoin.T; + zero_balance = LibraCoin.zero(); + return T { + balance: move(zero_balance), + authentication_key: move(auth_key), + sequence_number: 0, + sent_events_count: 0, + received_events_count: 0, + }; + } + + // Deposits the `to_deposit` coin into the `payee`'s account + public deposit(payee: address, to_deposit: R#LibraCoin.T) { + let deposit_value: u64; + let payee_account_ref: &mut R#Self.T; + let sender: address; + let sender_account_ref: &mut R#Self.T; + let sent_event: V#Self.SentPaymentEvent; + let received_event: V#Self.ReceivedPaymentEvent; + + // Check that the `to_deposit` coin is non-zero + deposit_value = LibraCoin.value(&to_deposit); + assert(copy(deposit_value) > 0, 7); + + // Load the sender's account + sender = get_txn_sender(); + sender_account_ref = borrow_global(copy(sender)); + // Log a send event + sent_event = SentPaymentEvent { payee: copy(payee), amount: copy(deposit_value) }; + // TEMPORARY The events system is being overhauled and this will be replaced by something + // more principled in the future + emit_event(&mut move(sender_account_ref).sent_events_count, b"73656E745F6576656E74735F636F756E74", move(sent_event)); + + // Load the payee's account + payee_account_ref = borrow_global(move(payee)); + // Deposit the `to_deposit` coin + LibraCoin.deposit(&mut copy(payee_account_ref).balance, move(to_deposit)); + // Log a received event + received_event = ReceivedPaymentEvent { payer: move(sender), amount: move(deposit_value) }; + // TEMPORARY The events system is being overhauled and this will be replaced by something + // more principled in the future + emit_event(&mut move(payee_account_ref).received_events_count, b"72656365697665645F6576656E74735F636F756E74", move(received_event)); + return; + } + + // mint_to_address can only be called by accounts with MintCapability (see LibraCoin) + // and those account will be charged for gas. If those account don't have enough gas to pay + // for the transaction cost they will fail minting. + // However those account can also mint to themselves so that is a decent workaround + public mint_to_address(payee: address, amount: u64) { + let mint_capability_ref: &R#LibraCoin.MintCapability; + let coin: R#LibraCoin.T; + let payee_account_ref: &mut R#Self.T; + let payee_exists: bool; + let sender: address; + + // Mint the coin + mint_capability_ref = LibraCoin.borrow_sender_mint_capability(); + coin = LibraCoin.mint(copy(amount), move(mint_capability_ref)); + + // Create an account if it does not exist + payee_exists = exists(copy(payee)); + if (!move(payee_exists)) { + sender = get_txn_sender(); + Self.create_new_account(copy(payee), 0); + } + + // Load the `payee`'s account and deposit the minted `coin` + payee_account_ref = borrow_global(move(payee)); + LibraCoin.deposit(&mut move(payee_account_ref).balance, move(coin)); + return; + } + + // Helper to withdraw `amount` from the given `account` and return the resulting LibraCoin.T + withdraw_from_account(account: &mut R#Self.T, amount: u64): R#LibraCoin.T { + let to_withdraw: R#LibraCoin.T; + to_withdraw = LibraCoin.withdraw(&mut move(account).balance, copy(amount)); + return move(to_withdraw); + } + + // Withdraw `amount` LibraCoin.T from the transaction sender's account + public withdraw_from_sender(amount: u64): R#LibraCoin.T { + let sender: address; + let sender_account: &mut R#Self.T; + let to_withdraw: R#LibraCoin.T; + + // Load the sender + sender = get_txn_sender(); + sender_account = borrow_global(move(sender)); + + // Withdraw the coin + to_withdraw = Self.withdraw_from_account(move(sender_account), move(amount)); + return move(to_withdraw); + } + + // Withdraw `amount` LibraCoin.T from the transaction sender's account and send the coin + // to the `payee` address + // Creates the `payee` account if it does not exist + public pay_from_sender(payee: address, amount: u64) { + let to_pay: R#LibraCoin.T; + let payee_exists: bool; + payee_exists = exists(copy(payee)); + if (move(payee_exists)) { + to_pay = Self.withdraw_from_sender(move(amount)); + Self.deposit(move(payee), move(to_pay)); + } else { + Self.create_new_account(move(payee), move(amount)); + } + return; + } + + // Rotate the transaction sender's authentication key + // The new key will be used for signing future transactions + public rotate_authentication_key(new_authentication_key: bytearray) { + let sender: address; + let sender_account: &mut R#Self.T; + sender = get_txn_sender(); + sender_account = borrow_global(move(sender)); + *(&mut move(sender_account).authentication_key) = move(new_authentication_key); + return; + } + + // Creates a new account at `fresh_address` with the `initial_balance` deducted from the + // transaction sender's account + public create_new_account(fresh_address: address, initial_balance: u64) { + create_account(copy(fresh_address)); + if (copy(initial_balance) > 0) { + Self.pay_from_sender(move(fresh_address), move(initial_balance)); + } + return; + } + + // Helper to return u64 value of the `balance` field for given `account` + balance_for_account(account: &R#Self.T): u64 { + let balance_value: u64; + balance_value = LibraCoin.value(&move(account).balance); + return move(balance_value); + } + + // Return the current balance of the LibraCoin.T in LibraLibraAccount.T at `addr` + public balance(addr: address): u64 { + let payee_account: &mut R#Self.T; + let imm_payee_account: &R#Self.T; + let balance_amount: u64; + payee_account = borrow_global(move(addr)); + imm_payee_account = freeze(move(payee_account)); + balance_amount = Self.balance_for_account(move(imm_payee_account)); + return move(balance_amount); + } + + // Helper to return the sequence number field for given `account` + sequence_number_for_account(account: &R#Self.T): u64 { + return *(&move(account).sequence_number); + } + + // Return the current sequence number at `addr` + public sequence_number(addr: address): u64 { + let account_ref: &mut R#Self.T; + let imm_ref: &R#Self.T; + let sequence_number_value: u64; + account_ref = borrow_global(move(addr)); + imm_ref = freeze(move(account_ref)); + sequence_number_value = Self.sequence_number_for_account(move(imm_ref)); + return move(sequence_number_value); + } + + // Checks if an account exists at `check_addr` + public exists(check_addr: address): bool { + let is_present: bool; + is_present = exists(move(check_addr)); + return move(is_present); + } + + // The prologue is invoked at the beginning of every transaction + // It verifies: + // - The account's auth key matches the transaction's public key + // - That the account has enough balance to pay for all of the gas + // - That the sequence number matches the transaction's sequence key + prologue() { + let transaction_sender: address; + let transaction_sender_exists: bool; + let sender_account: &mut R#Self.T; + let imm_sender_account: &R#Self.T; + let sender_public_key: bytearray; + let public_key_hash: bytearray; + let gas_price: u64; + let gas_units: u64; + let gas_fee: u64; + let balance_amount: u64; + let sequence_number_value: u64; + let transaction_sequence_number_value: u64; + + transaction_sender = get_txn_sender(); + + // FUTURE: Make these error codes sequential + // Verify that the transaction sender's account exists + transaction_sender_exists = exists(copy(transaction_sender)); + assert(move(transaction_sender_exists), 5); + + // Load the transaction sender's account + sender_account = borrow_global(copy(transaction_sender)); + + // Check that the transaction's public key matches the account's current auth key + sender_public_key = get_txn_public_key(); + public_key_hash = Hash.sha3_256(move(sender_public_key)); + assert(move(public_key_hash) == *(©(sender_account).authentication_key), 2); + + // Check that the account has enough balance for all of the gas + gas_price = get_txn_gas_unit_price(); + gas_units = get_txn_max_gas_units(); + gas_fee = move(gas_price) * move(gas_units); + imm_sender_account = freeze(copy(sender_account)); + balance_amount = Self.balance_for_account(move(imm_sender_account)); + assert(move(balance_amount) >= move(gas_fee), 6); + + // Check that the transaction sequence number matches the sequence number of the account + sequence_number_value = *(&mut move(sender_account).sequence_number); + transaction_sequence_number_value = get_txn_sequence_number(); + assert(copy(transaction_sequence_number_value) >= copy(sequence_number_value), 3); + assert(move(transaction_sequence_number_value) == move(sequence_number_value), 4); + return; + } + + // The epilogue is invoked at the end of transactions. + // It collects gas and bumps the sequence number + epilogue() { + let transaction_sender: address; + let sender_account: &mut R#Self.T; + let imm_sender_account: &R#Self.T; + let gas_price: u64; + let gas_units_remaining: u64; + let starting_gas_units: u64; + let gas_fee_amount: u64; + let balance_amount: u64; + let gas_fee: R#LibraCoin.T; + let transaction_sequence_number_value: u64; + + transaction_sender = get_txn_sender(); + + // Load the transaction sender's account + sender_account = borrow_global(copy(transaction_sender)); + + // Charge for gas + gas_price = get_txn_gas_unit_price(); + starting_gas_units = get_txn_max_gas_units(); + gas_units_remaining = get_gas_remaining(); + gas_fee_amount = move(gas_price) * (move(starting_gas_units) - move(gas_units_remaining)); + imm_sender_account = freeze(copy(sender_account)); + balance_amount = Self.balance_for_account(move(imm_sender_account)); + assert(move(balance_amount) >= copy(gas_fee_amount), 6); + + gas_fee = Self.withdraw_from_account(copy(sender_account), move(gas_fee_amount)); + LibraCoin.TODO_REMOVE_burn_gas_fee(move(gas_fee)); + + // Bump the sequence number + transaction_sequence_number_value = get_txn_sequence_number(); + *(&mut move(sender_account).sequence_number) = move(transaction_sequence_number_value) + 1; + return; + } + +} diff --git a/language/stdlib/modules/libra_coin.mvir b/language/stdlib/modules/libra_coin.mvir new file mode 100644 index 0000000000000..d7fbbb784a9a4 --- /dev/null +++ b/language/stdlib/modules/libra_coin.mvir @@ -0,0 +1,112 @@ +module LibraCoin { + // A resource representing the Libra coin + resource T { + // The value of the coin. May be zero + value: u64, + } + + // A resource that grants access to `LibraCoin.mint`. Only the Association account has one. + resource MintCapability {} + + // Return a reference to the MintCapability published under the sender's account. Fails if the + // sender does not have a MintCapability. + // Since only the Association account has a mint capability, this will only succeed if it is + // invoked by a transaction sent by that account. + public borrow_sender_mint_capability(): &R#Self.MintCapability { + let sender: address; + let capability_ref: &mut R#Self.MintCapability; + let capability_immut_ref: &R#Self.MintCapability; + + sender = get_txn_sender(); + capability_ref = borrow_global(move(sender)); + capability_immut_ref = freeze(move(capability_ref)); + return move(capability_immut_ref); + } + + // Mint a new LibraCoin.T worth `value`. The caller must have a reference to a MintCapability. + // Only the Association account can acquire such a reference, and it can do so only via + // `borrow_sender_mint_capability` + public mint(value: u64, capability: &R#Self.MintCapability): R#Self.T { + release(move(capability)); + return T{value: move(value)}; + } + + // This procedure is private and thus can only be called by the VM internally. It is used only + // during genesis writeset creation to give a single MintCapability to the Association account. + grant_mint_capability() { + move_to_sender(MintCapability{}); + return; + } + + // Create a new LibraCoin.T with a value of 0 + public zero(): R#Self.T { + return T{value: 0}; + } + + // Public accessor for the value of a coin + public value(coin_ref: &R#Self.T): u64 { + return *&move(coin_ref).value; + } + + // Splits the given coin into two and returns them both + // It leverages `Self.withdraw` for any verifications of the values + public split(coin: R#Self.T, amount: u64): R#Self.T * R#Self.T { + let other: R#Self.T; + other = Self.withdraw(&mut coin, move(amount)); + return move(coin), move(other); + } + + // "Divides" the given coin into two, where original coin is modified in place + // The original coin will have value = original value - `amount` + // The new coin will have a value = `amount` + // Fails if the coins value is less than `amount` + public withdraw(coin_ref: &mut R#Self.T, amount: u64): R#Self.T { + let value: u64; + + // Check that `amount` is less than the coin's value + value = *(&mut copy(coin_ref).value); + assert(copy(value) >= copy(amount), 10); + + // Split the coin + *(&mut move(coin_ref).value) = move(value) - copy(amount); + return T{value: move(amount)}; + } + + // Merges two coins and returns a new coin whose value is equal to the sum of the two inputs + public join(coin1: R#Self.T, coin2: R#Self.T): R#Self.T { + Self.deposit(&mut coin1, move(coin2)); + return move(coin1); + } + + // "Merges" the two coins + // The coin passed in by reference will have a value equal to the sum of the two coins + // The `check` coin is consumed in the process + public deposit(coin_ref: &mut R#Self.T, check: R#Self.T) { + let value: u64; + let check_value: u64; + + value = *(&mut copy(coin_ref).value); + T { value: check_value } = move(check); + *(&mut move(coin_ref).value)= move(value) + move(check_value); + return; + } + + // Destroy a coin + // Fails if the value is non-zero + // The amount of LibraCoin.T in the system is a tightly controlled property, + // so you cannot "burn" any non-zero amount of LibraCoin.T + public destroy_zero(coin: R#Self.T) { + let value: u64; + T { value } = move(coin); + assert(move(value) == 0, 11); + return; + } + + // Temporary procedure that is called to burn off the collected gas fee + // In the future this will be replaced by the actual mechanism for collecting gas + public TODO_REMOVE_burn_gas_fee(coin: R#Self.T) { + let value: u64; + T { value } = move(coin); + return; + } +} diff --git a/language/stdlib/modules/signature.mvir b/language/stdlib/modules/signature.mvir new file mode 100644 index 0000000000000..4d92baa589202 --- /dev/null +++ b/language/stdlib/modules/signature.mvir @@ -0,0 +1,3 @@ +module Signature { + native public ed25519_verify(signature: bytearray, public_key: bytearray, message: bytearray): bool; +} diff --git a/language/stdlib/modules/validator_set.mvir b/language/stdlib/modules/validator_set.mvir new file mode 100644 index 0000000000000..785547434afb3 --- /dev/null +++ b/language/stdlib/modules/validator_set.mvir @@ -0,0 +1,76 @@ +module ValidatorSet { + resource T { + // size of the validator set. Currently serialization format will enforce the lexicographic ordering by field + // name, thus we need to carefully choose the name for the deserialization order from the validator side. + // Validators need to interpret it and use the first `array_size` fields as the keys and ignore the rest. + // TODO: Replace this with generic array once Move supports generic collections. + array_size: u64, + // Pubkey list. Since we don't have array for now, use hard coded fields instead. + key0: V#Self.ValidatorPublicKeys, + key1: V#Self.ValidatorPublicKeys, + key2: V#Self.ValidatorPublicKeys, + key3: V#Self.ValidatorPublicKeys, + key4: V#Self.ValidatorPublicKeys, + key5: V#Self.ValidatorPublicKeys, + key6: V#Self.ValidatorPublicKeys, + key7: V#Self.ValidatorPublicKeys, + key8: V#Self.ValidatorPublicKeys, + key9: V#Self.ValidatorPublicKeys, + } + + struct ValidatorPublicKeys { + account_address: address, + consensus_public_key: bytearray, + network_identity_public_key: bytearray, + network_signing_public_key: bytearray, + } + + publish_validator_set( + size: u64, + key0: V#Self.ValidatorPublicKeys, + key1: V#Self.ValidatorPublicKeys, + key2: V#Self.ValidatorPublicKeys, + key3: V#Self.ValidatorPublicKeys, + key4: V#Self.ValidatorPublicKeys, + key5: V#Self.ValidatorPublicKeys, + key6: V#Self.ValidatorPublicKeys, + key7: V#Self.ValidatorPublicKeys, + key8: V#Self.ValidatorPublicKeys, + key9: V#Self.ValidatorPublicKeys + ) { + let set: R#Self.T; + // We only support at most 10 validator keys for now. Will get rid of of that + assert(copy(size) <= 10, 42); + set = T { + array_size: move(size), + key0: move(key0), + key1: move(key1), + key2: move(key2), + key3: move(key3), + key4: move(key4), + key5: move(key5), + key6: move(key6), + key7: move(key7), + key8: move(key8), + key9: move(key9), + }; + move_to_sender(move(set)); + return; + } + + make_new_validator_key( + account_address: address, + consensus_public_key: bytearray, + network_signing_public_key: bytearray, + network_identity_public_key: bytearray + ): V#Self.ValidatorPublicKeys { + let key: V#Self.ValidatorPublicKeys; + key = ValidatorPublicKeys { + account_address: move(account_address), + consensus_public_key: move(consensus_public_key), + network_signing_public_key: move(network_signing_public_key), + network_identity_public_key: move(network_identity_public_key), + }; + return move(key); + } +} diff --git a/language/stdlib/natives/Cargo.toml b/language/stdlib/natives/Cargo.toml new file mode 100644 index 0000000000000..3418024a26514 --- /dev/null +++ b/language/stdlib/natives/Cargo.toml @@ -0,0 +1,14 @@ +[package] +name = "move_ir_natives" +version = "0.1.0" +authors = ["Libra Association "] +license = "Apache-2.0" +publish = false +edition = "2018" + +[dependencies] +bitcoin_hashes = "0.3.0" +nextgen_crypto = { path = "../../../crypto/nextgen_crypto" } +failure = { path = "../../../common/failure_ext", package = "failure_ext" } +tiny-keccak = "1.4.2" +types = { path = "../../../types" } diff --git a/language/stdlib/natives/src/dispatch.rs b/language/stdlib/natives/src/dispatch.rs new file mode 100644 index 0000000000000..085bd9818efd5 --- /dev/null +++ b/language/stdlib/natives/src/dispatch.rs @@ -0,0 +1,65 @@ +// Copyright (c) The Libra Core Contributors +// SPDX-License-Identifier: Apache-2.0 + +use crate::{hash, signature}; +pub use failure::Error; +use failure::*; +use types::byte_array::ByteArray; +pub type Result = ::std::result::Result; + +pub enum NativeReturnType { + ByteArray(ByteArray), + Bool(bool), +} + +pub struct CostedReturnType { + cost: u64, + return_value: NativeReturnType, +} + +impl CostedReturnType { + pub fn new(cost: u64, return_value: NativeReturnType) -> Self { + CostedReturnType { cost, return_value } + } + + pub fn cost(&self) -> u64 { + self.cost + } + + pub fn get_return_value(self) -> NativeReturnType { + self.return_value + } +} + +pub trait StackAccessor { + fn get_byte_array(&mut self) -> Result; +} + +pub fn dispatch_native_call( + accessor: T, + module_name: &str, + function_name: &str, +) -> Result { + match module_name { + "Hash" => match function_name { + "keccak256" => hash::native_keccak_256(accessor), + "ripemd160" => hash::native_ripemd_160(accessor), + "sha2_256" => hash::native_sha2_256(accessor), + "sha3_256" => hash::native_sha3_256(accessor), + &_ => bail!( + "Unknown native function `{}.{}'", + module_name, + function_name + ), + }, + "Signature" => match function_name { + "ed25519_verify" => signature::native_ed25519_signature_verification(accessor), + &_ => bail!( + "Unknown native function `{}.{}'", + module_name, + function_name + ), + }, + &_ => bail!("Unknown native module {}", module_name), + } +} diff --git a/language/stdlib/natives/src/hash.rs b/language/stdlib/natives/src/hash.rs new file mode 100644 index 0000000000000..5066c85ae75e6 --- /dev/null +++ b/language/stdlib/natives/src/hash.rs @@ -0,0 +1,62 @@ +// Copyright (c) The Libra Core Contributors +// SPDX-License-Identifier: Apache-2.0 + +use crate::dispatch::{CostedReturnType, NativeReturnType, Result, StackAccessor}; +use bitcoin_hashes::{hash160, sha256, Hash}; +use std::borrow::Borrow; +use tiny_keccak::Keccak; +use types::byte_array::ByteArray; + +const HASH_LENGTH: usize = 32; +const KECCAK_COST: u64 = 30; +const RIPEMD_COST: u64 = 35; +const SHA2_COST: u64 = 30; +const SHA3_COST: u64 = 30; + +pub fn native_keccak_256(mut accessor: T) -> Result { + let mut hash = [0u8; HASH_LENGTH]; + let mut keccak = Keccak::new_keccak256(); + + let hash_arg = accessor.get_byte_array()?; + let native_cost = KECCAK_COST * hash_arg.len() as u64; + + keccak.update(hash_arg.as_bytes()); + keccak.finalize(&mut hash); + + let native_return = NativeReturnType::ByteArray(ByteArray::new(hash.to_vec())); + Ok(CostedReturnType::new(native_cost, native_return)) +} + +pub fn native_ripemd_160(mut accessor: T) -> Result { + let hash_arg = accessor.get_byte_array()?; + let native_cost = RIPEMD_COST * hash_arg.len() as u64; + let hash = hash160::Hash::hash(hash_arg.as_bytes()); + let hash_ref: &[u8] = hash.borrow(); + let native_return = NativeReturnType::ByteArray(ByteArray::new(hash_ref.to_vec())); + + Ok(CostedReturnType::new(native_cost, native_return)) +} + +pub fn native_sha2_256(mut accessor: T) -> Result { + let hash_arg = accessor.get_byte_array()?; + let native_cost = SHA2_COST * hash_arg.len() as u64; + let hash = sha256::Hash::hash(hash_arg.as_bytes()); + let hash_ref: &[u8] = hash.borrow(); + let native_return = NativeReturnType::ByteArray(ByteArray::new(hash_ref.to_vec())); + + Ok(CostedReturnType::new(native_cost, native_return)) +} + +pub fn native_sha3_256(mut accessor: T) -> Result { + let mut hash = [0u8; HASH_LENGTH]; + let mut keccak = Keccak::new_sha3_256(); + + let hash_arg = accessor.get_byte_array()?; + let native_cost = SHA3_COST * hash_arg.len() as u64; + + keccak.update(hash_arg.as_bytes()); + keccak.finalize(&mut hash); + + let native_return = NativeReturnType::ByteArray(ByteArray::new(hash.to_vec())); + Ok(CostedReturnType::new(native_cost, native_return)) +} diff --git a/language/stdlib/natives/src/lib.rs b/language/stdlib/natives/src/lib.rs new file mode 100644 index 0000000000000..14413b7616514 --- /dev/null +++ b/language/stdlib/natives/src/lib.rs @@ -0,0 +1,6 @@ +// Copyright (c) The Libra Core Contributors +// SPDX-License-Identifier: Apache-2.0 + +pub mod dispatch; +pub mod hash; +pub mod signature; diff --git a/language/stdlib/natives/src/signature.rs b/language/stdlib/natives/src/signature.rs new file mode 100644 index 0000000000000..68d8725fe198c --- /dev/null +++ b/language/stdlib/natives/src/signature.rs @@ -0,0 +1,33 @@ +// Copyright (c) The Libra Core Contributors +// SPDX-License-Identifier: Apache-2.0 + +use crate::dispatch::{CostedReturnType, NativeReturnType, Result, StackAccessor}; +use nextgen_crypto::{ed25519, traits::*}; +use std::convert::TryFrom; + +// TODO: Talk to Crypto to determine these costs +const ED25519_COST: u64 = 35; + +pub fn native_ed25519_signature_verification( + mut accessor: T, +) -> Result { + let signature = accessor.get_byte_array()?; + let pubkey = accessor.get_byte_array()?; + let msg = accessor.get_byte_array()?; + + let native_cost = ED25519_COST * msg.len() as u64; + + let sig = ed25519::Ed25519Signature::try_from(signature.as_bytes())?; + let pk = ed25519::Ed25519PublicKey::try_from(pubkey.as_bytes())?; + + match sig.verify_arbitrary_msg(msg.as_bytes(), &pk) { + Ok(()) => Ok(CostedReturnType::new( + native_cost, + NativeReturnType::Bool(true), + )), + Err(_) => Ok(CostedReturnType::new( + native_cost, + NativeReturnType::Bool(false), + )), + } +} diff --git a/language/stdlib/src/lib.rs b/language/stdlib/src/lib.rs new file mode 100644 index 0000000000000..7bdc93adb16a7 --- /dev/null +++ b/language/stdlib/src/lib.rs @@ -0,0 +1,5 @@ +// Copyright (c) The Libra Core Contributors +// SPDX-License-Identifier: Apache-2.0 + +pub mod stdlib; +pub mod transaction_scripts; diff --git a/language/stdlib/src/mod.rs b/language/stdlib/src/mod.rs new file mode 100644 index 0000000000000..4d22869503ad0 --- /dev/null +++ b/language/stdlib/src/mod.rs @@ -0,0 +1,6 @@ +// Copyright (c) The Libra Core Contributors +// SPDX-License-Identifier: Apache-2.0 + +#![allow(clippy::module_inception)] +pub mod stdlib; +pub mod transaction_scripts; diff --git a/language/stdlib/src/stdlib.rs b/language/stdlib/src/stdlib.rs new file mode 100644 index 0000000000000..864bb08e27696 --- /dev/null +++ b/language/stdlib/src/stdlib.rs @@ -0,0 +1,45 @@ +// Copyright (c) The Libra Core Contributors +// SPDX-License-Identifier: Apache-2.0 + +use compiler::parser::{ast::ModuleDefinition, parse_module}; +use lazy_static::lazy_static; + +macro_rules! make_module_definition { + ($source_path: literal) => {{ + let struct_body = include_str!($source_path); + parse_module(struct_body).unwrap() + }}; +} + +lazy_static! { + static ref ACCOUNT_MODULE: ModuleDefinition = + make_module_definition!("../modules/libra_account.mvir"); + static ref COIN_MODULE: ModuleDefinition = + make_module_definition!("../modules/libra_coin.mvir"); + static ref NATIVE_HASH_MODULE: ModuleDefinition = + make_module_definition!("../modules/hash.mvir"); + static ref SIGNATURE_MODULE: ModuleDefinition = + make_module_definition!("../modules/signature.mvir"); + static ref VALIDATOR_SET_MODULE: ModuleDefinition = + make_module_definition!("../modules/validator_set.mvir"); +} + +pub fn account_module() -> ModuleDefinition { + ACCOUNT_MODULE.clone() +} + +pub fn coin_module() -> ModuleDefinition { + COIN_MODULE.clone() +} + +pub fn native_hash_module() -> ModuleDefinition { + NATIVE_HASH_MODULE.clone() +} + +pub fn signature_module() -> ModuleDefinition { + SIGNATURE_MODULE.clone() +} + +pub fn validator_set_module() -> ModuleDefinition { + VALIDATOR_SET_MODULE.clone() +} diff --git a/language/stdlib/src/transaction_scripts.rs b/language/stdlib/src/transaction_scripts.rs new file mode 100644 index 0000000000000..41ca9fe3182a9 --- /dev/null +++ b/language/stdlib/src/transaction_scripts.rs @@ -0,0 +1,33 @@ +// Copyright (c) The Libra Core Contributors +// SPDX-License-Identifier: Apache-2.0 + +use compiler::parser::{ast::Program, parse_program}; +use lazy_static::lazy_static; + +lazy_static! { + pub static ref PEER_TO_PEER_TRANSFER_TXN_BODY: Program = { + let txn_body = include_str!("../transaction_scripts/peer_to_peer_transfer.mvir"); + parse_program(txn_body).unwrap() + }; +} + +lazy_static! { + pub static ref CREATE_ACCOUNT_TXN_BODY: Program = { + let txn_body = include_str!("../transaction_scripts/create_account.mvir"); + parse_program(txn_body).unwrap() + }; +} + +lazy_static! { + pub static ref ROTATE_AUTHENTICATION_KEY_TXN_BODY: Program = { + let txn_body = include_str!("../transaction_scripts/rotate_authentication_key.mvir"); + parse_program(txn_body).unwrap() + }; +} + +lazy_static! { + pub static ref MINT_TXN_BODY: Program = { + let txn_body = include_str!("../transaction_scripts/mint.mvir"); + parse_program(txn_body).unwrap() + }; +} diff --git a/language/stdlib/transaction_scripts/create_account.mvir b/language/stdlib/transaction_scripts/create_account.mvir new file mode 100644 index 0000000000000..7800374a668dc --- /dev/null +++ b/language/stdlib/transaction_scripts/create_account.mvir @@ -0,0 +1,5 @@ +import 0x0.LibraAccount; +main (fresh_address: address, initial_amount: u64) { + LibraAccount.create_new_account(move(fresh_address), move(initial_amount)); + return; +} diff --git a/language/stdlib/transaction_scripts/mint.mvir b/language/stdlib/transaction_scripts/mint.mvir new file mode 100644 index 0000000000000..014ffc4b42dfb --- /dev/null +++ b/language/stdlib/transaction_scripts/mint.mvir @@ -0,0 +1,6 @@ +import 0x0.LibraAccount; +import 0x0.LibraCoin; +main(payee: address, amount: u64) { + LibraAccount.mint_to_address(move(payee), move(amount)); + return; +} diff --git a/language/stdlib/transaction_scripts/peer_to_peer_transfer.mvir b/language/stdlib/transaction_scripts/peer_to_peer_transfer.mvir new file mode 100644 index 0000000000000..a5f35e244040d --- /dev/null +++ b/language/stdlib/transaction_scripts/peer_to_peer_transfer.mvir @@ -0,0 +1,5 @@ +import 0x0.LibraAccount; +main (payee: address, amount: u64) { + LibraAccount.pay_from_sender(move(payee), move(amount)); + return; +} diff --git a/language/stdlib/transaction_scripts/placeholder_script.mvir b/language/stdlib/transaction_scripts/placeholder_script.mvir new file mode 100644 index 0000000000000..7527970d3d838 --- /dev/null +++ b/language/stdlib/transaction_scripts/placeholder_script.mvir @@ -0,0 +1,3 @@ +main() { + return; +} diff --git a/language/stdlib/transaction_scripts/rotate_authentication_key.mvir b/language/stdlib/transaction_scripts/rotate_authentication_key.mvir new file mode 100644 index 0000000000000..0d713060b9783 --- /dev/null +++ b/language/stdlib/transaction_scripts/rotate_authentication_key.mvir @@ -0,0 +1,5 @@ +import 0x0.LibraAccount; +main (new_key: bytearray) { + LibraAccount.rotate_authentication_key(move(new_key)); + return; +} diff --git a/language/test.sh b/language/test.sh new file mode 100755 index 0000000000000..a914d2c3f02fa --- /dev/null +++ b/language/test.sh @@ -0,0 +1,17 @@ +#!/bin/bash -e + +# Copyright (c) The Libra Core Contributors +# SPDX-License-Identifier: Apache-2.0 + +# This script runs `cargo test` for each crate in the subdir + +base_cmd="cargo test" +dir="$( cd "$( dirname "${BASH_SOURCE[0]}" )" >/dev/null 2>&1 && pwd )" +while read line; do + echo $line; + dirline=$(realpath $(dirname $line)); + cmd="cd $dirline && $base_cmd" + echo $cmd; + # Run the cmd in a subshell since it switches directories. + (eval $cmd) +done < <(find "$dir" -name 'Cargo.toml') diff --git a/language/vm/Cargo.toml b/language/vm/Cargo.toml new file mode 100644 index 0000000000000..16fdbd376877e --- /dev/null +++ b/language/vm/Cargo.toml @@ -0,0 +1,17 @@ +[package] +name = "vm" +version = "0.1.0" +authors = ["Libra Association "] +license = "Apache-2.0" +publish = false +edition = "2018" + +[dependencies] +byteorder = "1.3.1" +hex = "0.3.2" +proptest = "0.9" +proptest-derive = "0.1.1" +crypto = { path = "../../crypto/legacy_crypto" } +failure = { path = "../../common/failure_ext", package = "failure_ext" } +proptest_helpers = { path = "../../common/proptest_helpers" } +types = { path = "../../types" } diff --git a/language/vm/README.md b/language/vm/README.md new file mode 100644 index 0000000000000..2582cfefe6450 --- /dev/null +++ b/language/vm/README.md @@ -0,0 +1,58 @@ +--- +id: vm +title: Virtual Machine +custom_edit_url: https://github.com/libra/libra/edit/master/language/vm/README.md +--- + +# MoveVM Core + +The MoveVM executes transactions expressed in the Move bytecode. There are +two main crates: the core VM and the VM runtime. The VM core contains the low-level +data type for the VM - mostly the file format and abstraction over it. A gas +metering logical abstraction is also defined there. + +## Overview + +The MoveVM is a stack machine with a static type system. The MoveVM honors +the specification of the Move language through a mix of file format, +verification (for reference [bytcode verifier README](https://github.com/libra/libra/blob/master/language/bytecode_verifier/README.md)) +and runtime constraints. The structure of the file format allows the +definition of modules, types (resources and unrestricted types), and +functions. Code is expressed via bytecode instructions, which may have +references to external functions and types. The file format also imposes +certain invariants of the language such as opaque types and private fields. +From the file format definition it should be clear that modules define a +scope/namespace for functions and types. Types are opaque given all fields +are private, and types carry no functions or methods. + +## Implementation Details + +The MoveVM core crate provides the definition of the file format and all +utilities related to the file format: +* A simple Rust abstraction over the file format + (`libra/language/vm/src/file_format.rs`) and the bytecodes. These Rust + structures are widely used in the code base. +* Serialization and deserialization of the file format. These define the + on-chain binary representation of the code. +* Some pretty printing functionalities. +* A proptest infrastructure for the file format. +* The gas cost/synthesis infrastructure. + +The `CompiledModule` and `CompiledScript` definitions in +`libra/language/vm/src/file_format.rs` are the top-level structs for a Move +*Module* or *Transaction Script*, respectively. These structs provide a +simple abstraction over the file format. Additionally, a set of +[*Views*](https://github.com/libra/libra/blob/master/language/vm/src/views.rs) are defined to easily navigate and inspect +`CompiledModule`s and `CompiledScript`s. + +## Folder Structure + +``` +. +β”œβ”€β”€ cost_synthesis # Infrastructure for gas cost synthesis +β”œβ”€β”€ src # VM core files +β”œβ”€β”€ tests # Proptests +β”œβ”€β”€ vm_genesis # Helpers to generate a genesis block, the initial state of the blockchain +└── vm_runtime # Interpreter and runtime data types (see README in that folder) +``` + diff --git a/language/vm/cost_synthesis/Cargo.toml b/language/vm/cost_synthesis/Cargo.toml new file mode 100644 index 0000000000000..4b8f2815a6dbe --- /dev/null +++ b/language/vm/cost_synthesis/Cargo.toml @@ -0,0 +1,24 @@ +[package] +name = "cost_synthesis" +version = "0.1.0" +authors = ["Libra Association "] +license = "Apache-2.0" +publish = false +edition = "2018" + +[dependencies] +rand = "0.6.5" +lazy_static = "1.3.0" + +types = { path = "../../../types" } +vm = { path = "../" } +vm_runtime = { path = "../vm_runtime" } +vm_genesis = { path = "../vm_genesis" } +vm_cache_map = { path = "../vm_runtime/vm_cache_map" } +move_ir_natives = { path = "../../stdlib/natives" } + +[dev-dependencies] +hex = "0.3.2" + +[features] +default = ["vm_runtime/instruction_synthesis"] diff --git a/language/vm/cost_synthesis/src/bin/main.rs b/language/vm/cost_synthesis/src/bin/main.rs new file mode 100644 index 0000000000000..e7e6fcac90efb --- /dev/null +++ b/language/vm/cost_synthesis/src/bin/main.rs @@ -0,0 +1,187 @@ +// Copyright (c) The Libra Core Contributors +// SPDX-License-Identifier: Apache-2.0 + +//! This performs instruction cost synthesis for the various bytecode instructions that we have. We +//! separate instructions into three sets: +//! * Global-memory independent instructions; +//! * Global-memory dependent instructions; and +//! * Native operations. +use cost_synthesis::{ + module_generator::ModuleGenerator, natives::StackAccessorMocker, + stack_generator::RandomStackGenerator, vm_runner::FakeDataCache, with_loaded_vm, +}; +use move_ir_natives::hash; +use std::{collections::HashMap, time::Instant}; +use vm::{ + errors::VMErrorKind, + file_format::{ + AddressPoolIndex, ByteArrayPoolIndex, Bytecode, FieldDefinitionIndex, + FunctionDefinitionIndex, FunctionHandleIndex, StringPoolIndex, StructDefinitionIndex, + }, + transaction_metadata::TransactionMetadata, +}; +use vm_cache_map::Arena; +use vm_genesis::STDLIB_MODULES; +use vm_runtime::{ + code_cache::module_cache::{ModuleCache, VMModuleCache}, + loaded_data::function::{FunctionRef, FunctionReference}, + txn_executor::TransactionExecutor, +}; + +const MAX_STACK_SIZE: u64 = 100; +const NUM_ITERS: u64 = 1000; + +fn stack_instructions() { + use Bytecode::*; + let stack_opcodes: Vec = vec![ + ReadRef, + WriteRef, + ReleaseRef, + FreezeRef, + BorrowField(FieldDefinitionIndex::new(0)), + CopyLoc(0), + MoveLoc(0), + BorrowLoc(0), + StLoc(0), + Unpack(StructDefinitionIndex::new(0)), + Pack(StructDefinitionIndex::new(0)), + Call(FunctionHandleIndex::new(0)), + CreateAccount, + Sub, + Ret, + Add, + Mul, + Mod, + Div, + BitOr, + BitAnd, + Xor, + Or, + And, + Eq, + Neq, + Lt, + Gt, + Le, + Ge, + Assert, + LdFalse, + LdTrue, + LdConst(0), + LdStr(StringPoolIndex::new(0)), + LdByteArray(ByteArrayPoolIndex::new(0)), + LdAddr(AddressPoolIndex::new(0)), + BrFalse(0), + BrTrue(0), + Branch(0), + Pop, + GetTxnGasUnitPrice, + GetTxnMaxGasUnits, + GetGasRemaining, + GetTxnSenderAddress, + GetTxnSequenceNumber, + GetTxnPublicKey, + ]; + + let mod_gen: ModuleGenerator = ModuleGenerator::new(NUM_ITERS as u16, 3); + with_loaded_vm! (mod_gen => vm, loaded_module, module_cache); + let costs: HashMap = stack_opcodes + .into_iter() + .map(|instruction| { + println!("Running: {:?}", instruction); + let stack_gen = RandomStackGenerator::new( + &loaded_module, + &module_cache, + &instruction, + MAX_STACK_SIZE, + NUM_ITERS, + ); + let instr_cost: u128 = stack_gen + .map(|stack_state| { + let instr = RandomStackGenerator::stack_transition( + &mut vm.execution_stack, + stack_state, + ); + let before = Instant::now(); + let ignore = vm.execute_block(&[instr], 0); + let time = before.elapsed().as_nanos(); + // Check to make sure we didn't error. Need to special case the assertion + // bytecode. + if instruction != Bytecode::Assert { + // We want any errors here to bubble up to us with the actual VM error. + ignore.unwrap().unwrap(); + } else { + // In the case of the Assert bytecode we want to only make sure that we + // don't have a VMInvariantViolation error, and then make sure that the any + // error generated was an assertion failure. + match ignore.unwrap() { + Ok(_) => (), + Err(err) => match err.err { + VMErrorKind::AssertionFailure(_) => (), + _ => panic!("Assertion bytecode failed"), + }, + } + } + time + }) + .sum(); + let average_time = instr_cost / u128::from(NUM_ITERS); + (instruction, average_time) + }) + .collect(); + + println!("---------------------------------------------------------------------------"); + for (instr, cost) in costs { + println!("{:?}: {}", instr, cost); + } + println!("---------------------------------------------------------------------------"); +} + +macro_rules! bench_native { + ($name:expr, $function:path, $table:ident) => { + let mut stack_access = StackAccessorMocker::new(); + let time_byte_mapping = (1..512) + .map(|i| { + stack_access.set_hash_length(i); + let time = (0..NUM_ITERS).fold(0, |acc, _| { + stack_access.next_bytearray(); + let before = Instant::now(); + let _ = $function(&mut stack_access).unwrap(); + acc + before.elapsed().as_nanos() + }); + let time = time / u128::from(NUM_ITERS); + (time, i as u64) + }) + .collect::>(); + let time_per_byte = time_byte_mapping + .into_iter() + .fold(0, |acc, (time, bytes)| acc + (time / u128::from(bytes))) + / 512; + $table.insert($name, time_per_byte); + }; +} + +fn natives() { + let mut cost_table = HashMap::new(); + bench_native!("keccak_256", hash::native_keccak_256, cost_table); + bench_native!("ripemd_160", hash::native_ripemd_160, cost_table); + bench_native!("native_sha2_256", hash::native_sha2_256, cost_table); + bench_native!("native_sha3_256", hash::native_sha3_256, cost_table); + println!("------------------------ NATIVES ------------------------------------------"); + for (instr, cost) in cost_table { + println!("{:?}: {}", instr, cost); + } + println!("---------------------------------------------------------------------------"); +} + +pub fn main() { + natives(); + stack_instructions(); +} + +// Instructions left to implement: +// BorrowGlobal(StructDefinitionIndex), +// Exists(StructDefinitionIndex), +// MoveFrom(StructDefinitionIndex), +// MoveToSender(StructDefinitionIndex), +// EmitEvent, <- Not yet/until it's implemented diff --git a/language/vm/cost_synthesis/src/bytecode_specifications/frame_transition_info.rs b/language/vm/cost_synthesis/src/bytecode_specifications/frame_transition_info.rs new file mode 100644 index 0000000000000..6f3871d33cc1f --- /dev/null +++ b/language/vm/cost_synthesis/src/bytecode_specifications/frame_transition_info.rs @@ -0,0 +1,43 @@ +// Copyright (c) The Libra Core Contributors +// SPDX-License-Identifier: Apache-2.0 + +//! Frame transition rules for the execution stack. +use vm::file_format::{Bytecode, FunctionDefinitionIndex}; +use vm_runtime::{ + code_cache::module_cache::ModuleCache, + execution_stack::ExecutionStack, + loaded_data::{ + function::{FunctionRef, FunctionReference}, + loaded_module::LoadedModule, + }, +}; + +fn should_push_frame(instr: &Bytecode) -> bool { + *instr == Bytecode::Ret +} + +/// Certain instructions require specific frame configurations. In particular, Ret requires that +/// there be at least one frame on the stack, and for others (e.g. Call) we expect a specific frame +/// (that we have chosen previously) to be at the top of the frame stack. This function makes sure +/// that the execution stack has the number and/or requested frames at the top. +pub(crate) fn frame_transitions<'alloc, 'txn, P>( + stk: &mut ExecutionStack<'alloc, 'txn, P>, + instr: &Bytecode, + module_info: (&'txn LoadedModule, Option), +) where + 'alloc: 'txn, + P: ModuleCache<'alloc>, +{ + let module = module_info.0; + if should_push_frame(instr) { + let empty_frame = FunctionRef::new(module, FunctionDefinitionIndex::new(0)) + .expect("[Frame Transition] Unable to build dummy function reference."); + stk.push_frame(empty_frame) + } + + if let Some(function_idx) = module_info.1 { + let frame = FunctionRef::new(module, function_idx) + .expect("[Frame Transition] Unable to build specified function reference."); + stk.push_frame(frame); + } +} diff --git a/language/vm/cost_synthesis/src/bytecode_specifications/mod.rs b/language/vm/cost_synthesis/src/bytecode_specifications/mod.rs new file mode 100644 index 0000000000000..cb6ef491e5c34 --- /dev/null +++ b/language/vm/cost_synthesis/src/bytecode_specifications/mod.rs @@ -0,0 +1,5 @@ +// Copyright (c) The Libra Core Contributors +// SPDX-License-Identifier: Apache-2.0 + +pub mod frame_transition_info; +pub mod stack_transition_info; diff --git a/language/vm/cost_synthesis/src/bytecode_specifications/stack_transition_info.rs b/language/vm/cost_synthesis/src/bytecode_specifications/stack_transition_info.rs new file mode 100644 index 0000000000000..f74aed0071bef --- /dev/null +++ b/language/vm/cost_synthesis/src/bytecode_specifications/stack_transition_info.rs @@ -0,0 +1,271 @@ +// Copyright (c) The Libra Core Contributors +// SPDX-License-Identifier: Apache-2.0 + +//! Encodes the typed value-stack transitions for the VM bytecode. +//! +//! This file encodes the typed stack transitions for the bytecode according to +//! the rules that are outlined in the bytecode file format. +//! Each instructions type transition is encoded as +//! [(source_tys_list, output_tys_list), ..., (source_tys_list, output_tys_list)] +//! We encode each instruction with the number of arguments (and the valid types +//! that these arguments can take), and the number of (type) outputs from that instruction. +use lazy_static::lazy_static; +use std::{boxed::Box, u8::MAX}; +use vm::file_format::{Bytecode, SignatureToken, StructHandleIndex, TypeSignature}; + +const MAX_INSTRUCTION_LEN: u8 = MAX; + +/// Represents a type that a given position can be -- fixed, or variable. +/// +/// We want to be able to represent that a given set of type positions can have multiple types +/// irrespective of the other types around it possibly (i.e. that the type can vary independently). +/// These are represented as `Variable` types. Likewise, we want to be able to also say that a type +/// is `Fixed`. +#[derive(Clone, Debug, Eq, Hash, PartialEq)] +pub enum SignatureTy { + Fixed(TypeSignature), + Variable(Vec, u8), +} + +/// Holds the type details (input/output) for a bytecode instruction. +pub struct CallDetails { + /// The types of the input arguments expected on the value stack. + pub in_args: Vec, + + /// The types of the outputs placed at the top of the value stack after the execution of the + /// instruction. + pub out_args: Vec, +} + +impl SignatureTy { + /// This method returns the underlying type(s) of the `SignatureTy`. It discards the + /// information on whether the types returned were `Fixed` or `Variable`. + pub fn underlying(self) -> Vec { + match self { + SignatureTy::Fixed(t) => vec![t], + SignatureTy::Variable(tys, _) => tys, + } + } + + /// A predicate returning whether the type is fixed. + pub fn is_fixed(&self) -> bool { + match self { + SignatureTy::Fixed(_) => true, + SignatureTy::Variable(_, _) => false, + } + } + + /// A predicate returning whether the type is variable. + pub fn is_variable(&self) -> bool { + !self.is_fixed() + } +} + +// The base types in our system. +// +// For any instruction that can take a stack value of any base type +// we represent it as a variable type over all base types for the bytecode. +lazy_static! { + static ref BASE_SIG_TOKENS: Vec = vec![ + SignatureToken::Bool, + SignatureToken::U64, + SignatureToken::String, + SignatureToken::ByteArray, + SignatureToken::Address, + // Bogus struct handle index, but it's fine since we disregard this in the generation of + // instruction arguments. + SignatureToken::Struct(StructHandleIndex::new(0)), + ]; +} + +fn variable_ty_of_sig_tok(tok: Vec, len: u8) -> SignatureTy { + let typs = tok.into_iter().map(TypeSignature).collect(); + SignatureTy::Variable(typs, len) +} + +fn ty_of_sig_tok(tok: SignatureToken) -> SignatureTy { + SignatureTy::Fixed(TypeSignature(tok)) +} + +fn simple_ref_of_sig_tok(tok: SignatureToken) -> SignatureTy { + SignatureTy::Fixed(TypeSignature(SignatureToken::Reference(Box::new(tok)))) +} + +fn simple_ref_of_sig_toks(toks: Vec) -> SignatureTy { + let len = toks.len() as u8; + let types = toks + .into_iter() + .map(|tok| TypeSignature(SignatureToken::Reference(Box::new(tok)))) + .collect(); + SignatureTy::Variable(types, len) +} + +fn ref_values(num: u64) -> Vec { + (0..num) + .map(|_| simple_ref_of_sig_toks(BASE_SIG_TOKENS.clone())) + .collect() +} + +fn ref_resources(num: u64) -> Vec { + (0..num) + .map(|_| simple_ref_of_sig_tok(SignatureToken::Struct(StructHandleIndex::new(0)))) + .collect() +} + +fn bools(num: u64) -> Vec { + (0..num) + .map(|_| ty_of_sig_tok(SignatureToken::Bool)) + .collect() +} + +fn u64s(num: u64) -> Vec { + (0..num) + .map(|_| ty_of_sig_tok(SignatureToken::U64)) + .collect() +} + +fn simple_addrs(num: u64) -> Vec { + (0..num) + .map(|_| ty_of_sig_tok(SignatureToken::Address)) + .collect() +} + +fn strs(num: u64) -> Vec { + (0..num) + .map(|_| ty_of_sig_tok(SignatureToken::String)) + .collect() +} + +fn byte_arrays(num: u64) -> Vec { + (0..num) + .map(|_| ty_of_sig_tok(SignatureToken::ByteArray)) + .collect() +} + +fn values(num: u64) -> Vec { + (0..num) + .map(|_| variable_ty_of_sig_tok(BASE_SIG_TOKENS.clone(), BASE_SIG_TOKENS.len() as u8)) + .collect() +} + +fn resources(num: u64) -> Vec { + (0..num) + .map(|_| ty_of_sig_tok(SignatureToken::Struct(StructHandleIndex::new(0)))) + .collect() +} + +fn empty() -> Vec { + vec![] +} + +fn type_transitions(args: Vec<(Vec, Vec)>) -> Vec { + args.into_iter() + .map(|(in_args, out_args)| CallDetails { in_args, out_args }) + .collect() +} + +macro_rules! type_transition { + ($($e1:expr => $e2:expr),+) => { + type_transitions(vec![ $(($e1,$e2)),+ ]) + }; +} + +/// Given an instruction `op` return back the type-level stack only call details. +pub fn call_details(op: &Bytecode) -> Vec { + match op { + Bytecode::Add + | Bytecode::Sub + | Bytecode::Mul + | Bytecode::Mod + | Bytecode::Div + | Bytecode::BitOr + | Bytecode::BitAnd + | Bytecode::Xor => type_transition! { u64s(2) => u64s(1) }, + Bytecode::Eq | Bytecode::Neq => type_transition! { values(2) => bools(1) }, + Bytecode::Pop => type_transition! { + values(1) => empty(), + ref_values(1) => empty(), + ref_resources(1) => empty() + }, + Bytecode::LdConst(_) => type_transition! { empty() => u64s(1) }, + Bytecode::LdAddr(_) => type_transition! { empty() => simple_addrs(1) }, + Bytecode::LdByteArray(_) => type_transition! { empty() => byte_arrays(1) }, + Bytecode::LdStr(_) => type_transition! { empty() => strs(1) }, + Bytecode::LdFalse | Bytecode::LdTrue => type_transition! { empty() => bools(1) }, + Bytecode::BrTrue(_) | Bytecode::BrFalse(_) => { + type_transition! { bools(1) => empty() } + } + Bytecode::Assert => { + let mut arg_tys = u64s(1); + arg_tys.append(&mut bools(1)); + type_transition! { arg_tys => empty() } + } + Bytecode::Branch(_) => type_transition! { empty() => empty() }, + Bytecode::StLoc(_) => type_transition! { resources(1) => empty(), values(1) => empty() }, + Bytecode::CopyLoc(_) => type_transition! { + empty() => values(1), + empty() => ref_values(1), + empty() => ref_resources(1) + }, + Bytecode::MoveLoc(_) => type_transition! { + empty() => values(1), + empty() => resources(1), + empty() => ref_values(1), + empty() => ref_resources(1) + }, + Bytecode::BorrowLoc(_) | Bytecode::BorrowField(_) => { + type_transition! { empty() => ref_values(1), empty() => ref_resources(1) } + } + Bytecode::ReadRef => type_transition! { ref_values(1) => values(1) }, + Bytecode::WriteRef => { + let mut input_tys = values(1); + input_tys.append(&mut ref_values(1)); + type_transition! { + input_tys => empty() + } + } + Bytecode::Lt | Bytecode::Gt | Bytecode::Le | Bytecode::Ge => { + type_transition! { u64s(2) => bools(1) } + } + Bytecode::And | Bytecode::Or => type_transition! { bools(2) => bools(1) }, + Bytecode::Not => type_transition! { bools(1) => bools(1) }, + Bytecode::Ret => type_transition! { + values(1) => empty(), + resources(1) => empty(), + ref_values(1) => empty(), + ref_resources(1) => empty() + }, + Bytecode::Pack(_) | Bytecode::Call(_) => { + let possible_tys = BASE_SIG_TOKENS.clone(); + type_transition! { + vec![variable_ty_of_sig_tok( + possible_tys.clone(), + MAX_INSTRUCTION_LEN, + )] => vec![variable_ty_of_sig_tok(possible_tys, 1)] + } + } + Bytecode::Unpack(_) => { + let possible_tys = BASE_SIG_TOKENS.clone(); + type_transition! { + vec![variable_ty_of_sig_tok( + possible_tys.clone(), + 1, + )] => vec![variable_ty_of_sig_tok(possible_tys, MAX_INSTRUCTION_LEN)] + } + } + Bytecode::GetTxnGasUnitPrice + | Bytecode::GetTxnSequenceNumber + | Bytecode::GetTxnMaxGasUnits + | Bytecode::GetGasRemaining => type_transition! { empty() => u64s(1) }, + Bytecode::GetTxnSenderAddress => type_transition! { empty() => simple_addrs(1) }, + Bytecode::Exists(_) => type_transition! { simple_addrs(1) => bools(1) }, + Bytecode::BorrowGlobal(_) => type_transition! { simple_addrs(1) => ref_values(1) }, + Bytecode::ReleaseRef => type_transition! { ref_values(1) => empty() }, + Bytecode::MoveFrom(_) => type_transition! { simple_addrs(1) => values(1) }, + Bytecode::MoveToSender(_) => type_transition! { values(1) => empty() }, + Bytecode::CreateAccount => type_transition! { simple_addrs(1) => empty() }, + Bytecode::GetTxnPublicKey => type_transition! { empty() => byte_arrays(1) }, + Bytecode::FreezeRef => type_transition! { ref_values(1) => ref_values(1) }, + Bytecode::EmitEvent => unimplemented!(), + } +} diff --git a/language/vm/cost_synthesis/src/common.rs b/language/vm/cost_synthesis/src/common.rs new file mode 100644 index 0000000000000..8256cd8f3d042 --- /dev/null +++ b/language/vm/cost_synthesis/src/common.rs @@ -0,0 +1,30 @@ +// Copyright (c) The Libra Core Contributors +// SPDX-License-Identifier: Apache-2.0 + +//! Defines constants and types that are used throughout cost synthesis. +use vm::file_format::TableIndex; +use vm_runtime::value::Local; + +/// The maximum number of fields that will be generated for any struct. +pub const MAX_FIELDS: usize = 10; + +/// The maximum size that generated byte arrays can be. +pub const BYTE_ARRAY_MAX_SIZE: usize = 64; + +/// The maximum size that a generated string can be. +pub const MAX_STRING_SIZE: usize = 32; + +/// The maximumm number of locals that can be defined within a generated function definition. +pub const MAX_NUM_LOCALS: usize = 10; + +/// The maximum number of arguments to generated function definitions. +pub const MAX_FUNCTION_CALL_SIZE: usize = 8; + +/// The default index to use when we need to have a frame on the execution stack. +/// +/// We are always guaranteed to have at least one function definition in a generated module. We can +/// therefore always count on having a function definition at index 0. +pub const DEFAULT_FUNCTION_IDX: TableIndex = 0; + +/// The type of the value stack. +pub type Stack = Vec; diff --git a/language/vm/cost_synthesis/src/lib.rs b/language/vm/cost_synthesis/src/lib.rs new file mode 100644 index 0000000000000..8592eb7441ffb --- /dev/null +++ b/language/vm/cost_synthesis/src/lib.rs @@ -0,0 +1,12 @@ +// Copyright (c) The Libra Core Contributors +// SPDX-License-Identifier: Apache-2.0 + +#![feature(box_syntax, box_patterns)] +#[macro_use] + +pub mod module_generator; +mod bytecode_specifications; +mod common; +pub mod natives; +pub mod stack_generator; +pub mod vm_runner; diff --git a/language/vm/cost_synthesis/src/module_generator.rs b/language/vm/cost_synthesis/src/module_generator.rs new file mode 100644 index 0000000000000..ac71f93893786 --- /dev/null +++ b/language/vm/cost_synthesis/src/module_generator.rs @@ -0,0 +1,385 @@ +// Copyright (c) The Libra Core Contributors +// SPDX-License-Identifier: Apache-2.0 + +//! Logic for random valid module and module universe generation. +//! +//! This module contains the logic for generating random valid modules and valid (rooted) module +//! universes. Note that we do not generate valid function bodies for the functions that are +//! generated -- any function bodies that are generated are simply non-semantic sequences of +//! instructions to check BrTrue, BrFalse, and Branch instructions. +use crate::common::*; +use rand::{rngs::StdRng, Rng, SeedableRng}; +use std::collections::HashMap; +use types::{account_address::AccountAddress, byte_array::ByteArray, language_storage::CodeKey}; +use vm::{ + access::*, + file_format::{ + AddressPoolIndex, Bytecode, CodeUnit, CompiledModule, FieldDefinition, + FieldDefinitionIndex, FunctionDefinition, FunctionHandle, FunctionHandleIndex, + FunctionSignature, FunctionSignatureIndex, LocalsSignature, LocalsSignatureIndex, + MemberCount, ModuleHandle, ModuleHandleIndex, SignatureToken, StringPoolIndex, + StructDefinition, StructHandle, StructHandleIndex, TableIndex, TypeSignature, + TypeSignatureIndex, + }, +}; + +/// A wrapper around a `CompiledModule` containing information needed for generation. +/// +/// Contains a source of pseudo-randomness along with a table of the modules that are known and can +/// be called into -- these are all modules that have previously been generated by the same +/// instance of the `ModuleBuilder`. +/// +/// The call graph between generated modules forms a rooted DAG based at the current +/// `CompiledModule` being generated. +pub struct ModuleBuilder { + /// The source of randomness used across the modules that we generate. + gen: StdRng, + + /// The current module being built. + module: CompiledModule, + + /// The minimum size of the tables in the generated module. + table_size: TableIndex, + + /// Other modules that we know, and that we can generate calls type references into. Indexed by + /// their address and name (i.e. the module's `CodeKey`). + known_modules: HashMap, +} + +impl ModuleBuilder { + /// Create a new module builder with generated module tables of size `table_size`. + pub fn new(table_size: TableIndex) -> Self { + let seed: [u8; 32] = [0; 32]; + Self { + gen: StdRng::from_seed(seed), + module: Self::default_module_with_types(), + table_size, + known_modules: HashMap::new(), + } + } + + /// Display the current module being generated. + pub fn display(&self) { + println!("{:#?}", self.module) + } + + fn with_account_addresses(&mut self) { + let mut addrs = (0..self.table_size) + .map(|_| AccountAddress::random()) + .collect(); + self.module.address_pool.append(&mut addrs); + } + + fn with_strings(&mut self) { + let mut strs = (0..self.table_size) + .map(|_| { + let len = self.gen.gen_range(1, MAX_STRING_SIZE); + (0..len).map(|_| self.gen.gen::()).collect() + }) + .collect(); + self.module.string_pool.append(&mut strs); + } + + fn with_bytearrays(&mut self) { + self.module.byte_array_pool = (0..self.table_size) + .map(|_| { + let len = self.gen.gen_range(1, BYTE_ARRAY_MAX_SIZE); + let bytes = (0..len).map(|_| self.gen.gen::()).collect(); + ByteArray::new(bytes) + }) + .collect(); + } + + // Add the functions with locals given by the first part of the tuple, and with function + // signature `FunctionSignature`. + fn with_functions(&mut self, sigs: Vec<(Vec, FunctionSignature)>) { + let mut names: Vec = sigs + .iter() + .enumerate() + .map(|(i, _)| format!("func{}", i)) + .collect(); + // Grab the offset before adding the generated names to the string pool; we'll need this + // later on when we generate the function handles in order to know where we should have the + // functions point to for their name. + let offset = self.module.string_pool.len(); + let function_sig_offset = self.module.function_signatures.len(); + self.module.string_pool.append(&mut names); + + self.module.function_defs = sigs + .iter() + .enumerate() + .map(|(i, _)| FunctionDefinition { + function: FunctionHandleIndex::new(i as u16), + flags: CodeUnit::PUBLIC, + code: CodeUnit { + max_stack_size: 20, + locals: LocalsSignatureIndex(i as u16), + // Random nonsense to pad this out. We won't look at this at all, just + // non-empty is all that matters. + code: vec![Bytecode::Sub, Bytecode::Sub, Bytecode::Add, Bytecode::Ret], + }, + }) + .collect(); + + self.module.function_handles = sigs + .iter() + .enumerate() + .map(|(i, _)| FunctionHandle { + name: StringPoolIndex::new((i + offset) as u16), + signature: FunctionSignatureIndex::new((i + function_sig_offset) as u16), + module: ModuleHandleIndex::new(0), + }) + .collect(); + let (local_sigs, mut function_sigs): (Vec<_>, Vec<_>) = sigs.into_iter().unzip(); + self.module.function_signatures.append(&mut function_sigs); + self.module + .locals_signatures + .append(&mut local_sigs.into_iter().map(LocalsSignature).collect()); + } + + // Generate `table_size` number of structs. Note that this will not generate nested structs. + // The overall logic of this function follows very similarly to that for function generation. + fn with_structs(&mut self) { + // Generate struct names. + let mut names: Vec = (0..self.table_size) + .map(|i| format!("struct{}", i)) + .collect(); + let offset = self.module.string_pool.len() as TableIndex; + self.module.string_pool.append(&mut names); + + // Generate the field definitions and struct definitions at the same time + for struct_idx in 0..self.table_size { + // Generate a random amount of fields for each struct. Each struct must have at least + // one field. + let num_fields = self.gen.gen_range(1, MAX_FIELDS); + + // Generate the struct def. This generates pointers into the module's `field_defs` that + // are not generated just yet -- we do this beforehand so that we can grab the starting + // index into the module's `field_defs` table before we generate the struct's fields. + let struct_def = StructDefinition { + struct_handle: StructHandleIndex(struct_idx), + field_count: num_fields as MemberCount, + fields: FieldDefinitionIndex::new(self.module.field_defs.len() as TableIndex), + }; + self.module.struct_defs.push(struct_def); + + // Generate the fields for the struct. + for _ in 0..num_fields { + let struct_handle_idx = StructHandleIndex::new(struct_idx); + // Pick a random base type (non-reference) + let typ_idx = TypeSignatureIndex::new( + self.gen + .gen_range(0, self.module.type_signatures.len() as TableIndex), + ); + // Pick a random name. + let str_pool_idx = StringPoolIndex::new( + self.gen + .gen_range(0, self.module.string_pool.len() as TableIndex), + ); + let field_def = FieldDefinition { + struct_: struct_handle_idx, + name: str_pool_idx, + signature: typ_idx, + }; + self.module.field_defs.push(field_def); + } + } + + // Generate the struct handles. This needs to be in sync with the names that we generated + // earlier at the start of this function. + self.module.struct_handles = (0..self.table_size) + .map(|struct_idx| StructHandle { + module: ModuleHandleIndex::new(0), + name: StringPoolIndex::new((struct_idx + offset) as TableIndex), + is_resource: self.gen.gen_bool(1.0 / 2.0), + }) + .collect(); + } + + // Generate `table_size` number of functions in the underlying module. This does this by + // generating a bunch of random locals type signatures (Vec) and the + // FunctionSignatures. We then call `with_functions` with this generated type info. + fn with_random_functions(&mut self) { + use SignatureToken::*; + // The base signature tokens that we can use for our types. + let sig_toks = vec![Bool, U64, String, ByteArray, Address]; + // Generate a bunch of random function signatures over these types. + let functions = (0..self.table_size) + .map(|_| { + let num_locals = self.gen.gen_range(1, MAX_NUM_LOCALS); + let num_args = self.gen.gen_range(1, MAX_FUNCTION_CALL_SIZE); + + let locals = (0..num_locals) + .map(|_| { + let index = self.gen.gen_range(0, sig_toks.len()); + sig_toks[index].clone() + }) + .collect(); + + let args = (0..num_args) + .map(|_| { + let index = self.gen.gen_range(0, sig_toks.len()); + sig_toks[index].clone() + }) + .collect(); + + // Generate the function signature. We don't care about the return type of the + // function, so we don't generate any types, and default to saying that it returns + // the unit type. + let function_sig = FunctionSignature { + arg_types: args, + return_types: vec![], + }; + + (locals, function_sig) + }) + .collect(); + + self.with_cross_calls(); + self.with_functions(functions); + } + + fn with_cross_calls(&mut self) { + let module_table_size = self.module.module_handles.len(); + if module_table_size < 2 { + return; + } + + // We have half/half inter- and intra-module calls. + let number_of_cross_calls = self.table_size; + for _ in 0..number_of_cross_calls { + let non_self_module_handle_idx = self.gen.gen_range(1, module_table_size) as TableIndex; + let callee_module_handle = self + .module + .module_handle_at(ModuleHandleIndex::new(non_self_module_handle_idx)); + let address = *self.module.address_at(callee_module_handle.address); + let name = self.module.string_at(callee_module_handle.name); + let code_key = CodeKey::new(address, name.to_string()); + let callee_module = self + .known_modules + .get(&code_key) + .expect("[Module Lookup] Unable to get module from known_modules."); + + let callee_function_handle_idx = + self.gen.gen_range(0, callee_module.function_handles.len()) as TableIndex; + let callee_function_handle = callee_module + .function_handle_at(FunctionHandleIndex::new(callee_function_handle_idx)); + let callee_type_sig = callee_module + .function_signature_at(callee_function_handle.signature) + .clone(); + let callee_name = callee_module + .string_at(callee_function_handle.name) + .to_string(); + let callee_name_idx = self.module.string_pool.len() as TableIndex; + let callee_type_sig_idx = self.module.function_signatures.len() as TableIndex; + let func_handle = FunctionHandle { + module: ModuleHandleIndex::new(non_self_module_handle_idx), + name: StringPoolIndex::new(callee_name_idx), + signature: FunctionSignatureIndex::new(callee_type_sig_idx), + }; + + self.module.string_pool.push(callee_name); + self.module.function_signatures.push(callee_type_sig); + self.module.function_handles.push(func_handle); + } + } + + // Add the modules identitied by their code keys to the module handles of the underlying + // CompiledModule. + fn with_callee_modules(&mut self) { + // Add the SELF module + let module_name: String = (0..10).map(|_| self.gen.gen::()).collect(); + self.module.string_pool.insert(0, module_name); + self.module.address_pool.insert(0, AccountAddress::random()); + // Recall that we inserted the module name at index 0 in the string pool. + let self_module_handle = ModuleHandle { + address: AddressPoolIndex::new(0), + name: StringPoolIndex::new(0), + }; + self.module.module_handles.insert(0, self_module_handle); + + let (mut names, mut addresses) = self + .known_modules + .keys() + .map(|key| (key.name().clone(), key.address())) + .unzip(); + + let address_pool_offset = self.module.address_pool.len() as TableIndex; + let string_pool_offset = self.module.string_pool.len() as TableIndex; + // Add the strings and addresses to the pool + self.module.string_pool.append(&mut names); + self.module.address_pool.append(&mut addresses); + + let mut module_handles = (0..self.known_modules.len()) + .map(|i| { + let i = i as TableIndex; + ModuleHandle { + address: AddressPoolIndex::new(address_pool_offset + i), + name: StringPoolIndex::new(string_pool_offset + i), + } + }) + .collect(); + self.module.module_handles.append(&mut module_handles); + } + + /// This method builds and then materializes the underlying module skeleton. It then swaps in a + /// new module skeleton, adds the generated module to the `known_modules`, and returns + /// the generated module. + pub fn materialize(&mut self) -> CompiledModule { + self.with_callee_modules(); + self.with_account_addresses(); + self.with_strings(); + self.with_bytearrays(); + self.with_random_functions(); + self.with_structs(); + let module = std::mem::replace(&mut self.module, Self::default_module_with_types()); + self.known_modules + .insert(module.self_code_key(), module.clone()); + module + } + + // This method generates a default (empty) `CompiledModule` but with base types. This way we + // can point to them when generating structs/functions etc. + fn default_module_with_types() -> CompiledModule { + use SignatureToken::*; + let mut module = CompiledModule::default(); + module.type_signatures = vec![Bool, U64, String, ByteArray, Address] + .into_iter() + .map(TypeSignature) + .collect(); + module + } +} + +/// A wrapper around a `ModuleBuilder` for building module universes. +/// +/// The `ModuleBuilder` is already designed to build module universes but the size of this universe +/// is unspecified and un-iterable. This is a simple wrapper around the builder that allows +/// the implemenation of the `Iterator` trait over it. +pub struct ModuleGenerator { + module_builder: ModuleBuilder, + iters: u64, +} + +impl ModuleGenerator { + /// Create a new `ModuleGenerator` where each generated module has at least `table_size` + /// elements in each table, and where `iters` many modules are generated. + pub fn new(table_size: TableIndex, iters: u64) -> Self { + Self { + module_builder: ModuleBuilder::new(table_size), + iters, + } + } +} + +impl Iterator for ModuleGenerator { + type Item = CompiledModule; + fn next(&mut self) -> Option { + if self.iters == 0 { + None + } else { + self.iters -= 1; + Some(self.module_builder.materialize()) + } + } +} diff --git a/language/vm/cost_synthesis/src/natives.rs b/language/vm/cost_synthesis/src/natives.rs new file mode 100644 index 0000000000000..3f6ab26cc3e64 --- /dev/null +++ b/language/vm/cost_synthesis/src/natives.rs @@ -0,0 +1,62 @@ +// Copyright (c) The Libra Core Contributors +// SPDX-License-Identifier: Apache-2.0 + +//! Logic for generating valid stack states for native function calls. +//! +//! This implements a `StackAccessor` that generates random bytearrays of a user defined length. We +//! then use this ability to run the native functions with different bytearray lengths in the +//! generated synthesis binary. +use move_ir_natives::dispatch::{Result, StackAccessor}; +use rand::{rngs::StdRng, Rng, SeedableRng}; +use types::byte_array::ByteArray; + +/// A wrapper around data used to generate random valid bytearrays that replicate the semantics of +/// a `StackAccessor`. +pub struct StackAccessorMocker { + gen: StdRng, + /// The current byte array. Set to `None` if there is no bytearray. + curr_byte_array: Option, + /// The length of bytearrays that will be generated. + pub hash_length: usize, +} + +impl StackAccessorMocker { + /// Builds a new stack accessor mocker. User is responsible for later setting the length of, + /// and generating the underlying bytearray. + pub fn new() -> Self { + let seed: [u8; 32] = [0; 32]; + Self { + gen: StdRng::from_seed(seed), + hash_length: 1, + curr_byte_array: None, + } + } + + /// Set the bytearray length that will be generated in calls to `next_bytearray`. + pub fn set_hash_length(&mut self, len: usize) { + self.hash_length = len; + } + + /// Set the `curr_byte_array` field to a freshly generated bytearray. + /// + /// The user is responsible for calling this between subsequent calls to `get_byte_array` in + /// the `StackAccessor` trait. + pub fn next_bytearray(&mut self) { + let bytes: Vec = (0..self.hash_length) + .map(|_| self.gen.gen::()) + .collect(); + self.curr_byte_array = Some(ByteArray::new(bytes)) + } +} + +impl Default for StackAccessorMocker { + fn default() -> Self { + Self::new() + } +} + +impl StackAccessor for &mut StackAccessorMocker { + fn get_byte_array(&mut self) -> Result { + Ok(std::mem::replace(&mut self.curr_byte_array, None).unwrap()) + } +} diff --git a/language/vm/cost_synthesis/src/stack_generator.rs b/language/vm/cost_synthesis/src/stack_generator.rs new file mode 100644 index 0000000000000..d4008a1f72c15 --- /dev/null +++ b/language/vm/cost_synthesis/src/stack_generator.rs @@ -0,0 +1,735 @@ +// Copyright (c) The Libra Core Contributors +// SPDX-License-Identifier: Apache-2.0 + +//! Random valid stack state generation. +//! +//! This module encodes the stack evolution logic that is used to evolve the execution stack from +//! one iteration to the next while synthesizing the cost of an instruction. +use crate::{ + bytecode_specifications::{frame_transition_info::frame_transitions, stack_transition_info::*}, + common::*, +}; +use rand::{rngs::StdRng, Rng, SeedableRng}; +use std::collections::HashMap; +use types::{account_address::AccountAddress, byte_array::ByteArray, language_storage::CodeKey}; +use vm::{ + access::*, + assert_ok, + file_format::{ + AddressPoolIndex, ByteArrayPoolIndex, Bytecode, CodeOffset, FieldDefinitionIndex, + FunctionDefinition, FunctionDefinitionIndex, FunctionHandleIndex, LocalIndex, ModuleHandle, + SignatureToken, StringPoolIndex, StructDefinition, StructDefinitionIndex, + StructHandleIndex, TableIndex, + }, + internals::ModuleIndex, +}; +use vm_runtime::{ + code_cache::module_cache::ModuleCache, execution_stack::ExecutionStack, + loaded_data::loaded_module::LoadedModule, value::*, +}; + +/// Specifies the data to be applied to the execution stack for the next valid stack state. +/// +/// This contains directives and information that is used to evolve the execution stack of the VM +/// in the `stack_transition_function`. +pub struct StackState<'txn> { + /// The loaded module that should be loaded as the "current module" in the VM and the function + /// index for the frame that we expect to be at the top of the execution stack. Set to `None` + /// if no particular frame is required by the instruction. + pub module_info: (&'txn LoadedModule, Option), + + /// The value stack that is required for the instruction. + pub stack: Stack, + + /// A copy of the instruction. This will later be used to call into the VM. + pub instr: Bytecode, + + /// For certain instructions the cost is variable on the size of the data being loaded. This + /// holds the size of data that was generated so this can be taken into account when + /// determining the cost per byte. + pub size: u64, + + /// A sparse mapping of local index to local value for the current function frame. This will + /// be applied to the execution stack later on in the `stack_transition_function`. + pub local_mapping: HashMap, +} + +impl<'txn> StackState<'txn> { + /// Create a new stack state with the passed-in information. + pub fn new( + module_info: (&'txn LoadedModule, Option), + stack: Stack, + instr: Bytecode, + size: u64, + local_mapping: HashMap, + ) -> Self { + Self { + module_info, + stack, + instr, + size, + local_mapping, + } + } +} + +/// A wrapper around the instruction being synthesized. Holds the internal state that is +/// used to generate random valid stack states. +pub struct RandomStackGenerator<'alloc, 'txn> +where + 'alloc: 'txn, +{ + /// The number of iterations that this instruction will be run for. Used to implement the + /// `Iterator` trait. + iters: u64, + + /// The source of pseudo-randomness. + gen: StdRng, + + /// Certain instructions require indices into the various tables within the module. + /// We store a reference to the loaded module context that we are currently in so that we can + /// generate valid references into these tables. When generating a module universe this is the + /// root module that has pointers to all other modules. + root_module: &'txn LoadedModule, + + /// The module cache for all of the other modules in the universe. We need this in order to + /// resolve struct and function handles to other modules other then the root module. + module_cache: &'txn ModuleCache<'alloc>, + + /// The bytecode instruction for which stack states are generated. + op: Bytecode, + + /// The maximum size of the generated value stack. + max_stack_size: u64, + + /// Cursor into the string pool. Used for the generation of random strings. + string_pool_index: u64, + + /// Cursor into the address pool. Used for the generation of random addresses. We use this + /// since we need the addresses to be unique for e.g. CreateAccount, and we don't want a + /// mutable reference into the underlying `root_module`. + address_pool_index: u64, + + /// A reverse lookup table to find the struct definition for a struct handle. Needed for + /// generating an inhabitant for a struct SignatureToken. This is lazily populated. + struct_handle_table: HashMap>, + + /// A reverse lookup table for each code module that allows us to resolve function handles to + /// function definitions. Also lazily populated. + function_handle_table: HashMap>, +} + +impl<'alloc, 'txn> RandomStackGenerator<'alloc, 'txn> +where + 'alloc: 'txn, +{ + /// Create a new random stack state generator. + /// + /// It initializes each of the internal resolution tables for structs and function handles to + /// be empty. + pub fn new( + root_module: &'txn LoadedModule, + module_cache: &'txn ModuleCache<'alloc>, + op: &Bytecode, + max_stack_size: u64, + iters: u64, + ) -> Self { + let seed: [u8; 32] = [0; 32]; + Self { + gen: StdRng::from_seed(seed), + op: op.clone(), + max_stack_size, + root_module, + module_cache, + iters, + string_pool_index: iters, + address_pool_index: iters, + struct_handle_table: HashMap::new(), + function_handle_table: HashMap::new(), + } + } + + fn to_code_key(&self, module_handle: &ModuleHandle) -> CodeKey { + let address = *self.root_module.module.address_at(module_handle.address); + let name = self.root_module.module.string_at(module_handle.name); + CodeKey::new(address, name.to_string()) + } + + // Determines if the instruction gets its type/instruction info from the stack type + // transitions, or from the type signatures available in the module(s). + fn is_module_specific_op(&self) -> bool { + use Bytecode::*; + match self.op { + Unpack(_) | Pack(_) | Call(_) => true, + CopyLoc(_) | MoveLoc(_) | StLoc(_) | BorrowLoc(_) | BorrowField(_) => true, + _ => false, + } + } + + // Certain operations are only valid if their values come from module-specific data. In + // particular, CreateLibraAccount. But, they may eventually be more of these as well. + fn points_to_module_data(&self) -> bool { + use Bytecode::*; + match self.op { + CreateAccount => true, + _ => false, + } + } + + fn next_int(&mut self, stk: &[Local]) -> u64 { + if self.op == Bytecode::Sub && !stk.is_empty() { + let peek: Option = stk.last().expect("[Next Integer] The impossible happened: the value stack became empty while still full.").clone().value().expect("[Next Integer] Invalid integer stack value encountered when peeking at the generated stack.").into(); + self.gen.gen_range( + 0, + peek.expect("[Next Integer] Unable to cast peeked stack value to an integer."), + ) + } else { + u64::from(self.gen.gen_range(0, u32::max_value())) + } + } + + fn next_bool(&mut self) -> bool { + // Flip a coin + self.gen.gen_bool(1.0 / 2.0) + } + + fn next_bytearray(&mut self) -> ByteArray { + let len: usize = self.gen.gen_range(1, BYTE_ARRAY_MAX_SIZE); + let bytes: Vec = (0..len).map(|_| self.gen.gen::()).collect(); + ByteArray::new(bytes) + } + + // Strings and addresses are already randomly generated in the module that we create these + // pools from so we simply pop off from them. This assumes that the module was generated with + // at least `self.iters` number of strings and addresses. In the case where we are just padding + // the stack, or where the instructions semantics don't require having an address in the + // address pool, we don't waste our pools and generate a random value. + fn next_str(&mut self, is_padding: bool) -> String { + if !self.points_to_module_data() || is_padding { + let len: usize = self.gen.gen_range(1, MAX_STRING_SIZE); + (0..len).map(|_| self.gen.gen::()).collect::() + } else { + let string = + self.root_module.module.string_pool[self.string_pool_index as usize].clone(); + self.string_pool_index = self + .string_pool_index + .checked_sub(1) + .expect("Exhausted strings in string pool"); + string + } + } + + fn next_addr(&mut self, is_padding: bool) -> AccountAddress { + if !self.points_to_module_data() || is_padding { + AccountAddress::random() + } else { + let address = self.root_module.module.address_pool[self.address_pool_index as usize]; + self.address_pool_index = self + .address_pool_index + .checked_sub(1) + .expect("Exhausted account addresses in address pool"); + address + } + } + + fn next_bounded_index(&mut self, bound: TableIndex) -> TableIndex { + self.gen.gen_range(1, bound) + } + + fn next_string_idx(&mut self) -> StringPoolIndex { + let len = self.root_module.module.string_pool.len(); + StringPoolIndex::new(self.gen.gen_range(0, len) as TableIndex) + } + + fn next_address_idx(&mut self) -> AddressPoolIndex { + let len = self.root_module.module.address_pool.len(); + AddressPoolIndex::new(self.gen.gen_range(0, len) as TableIndex) + } + + fn next_bytearray_idx(&mut self) -> ByteArrayPoolIndex { + let len = self.root_module.module.byte_array_pool.len(); + ByteArrayPoolIndex::new(self.gen.gen_range(0, len) as TableIndex) + } + + fn next_function_handle_idx(&mut self) -> FunctionHandleIndex { + let table_idx = + self.next_bounded_index(self.root_module.module.function_handles.len() as TableIndex); + FunctionHandleIndex::new(table_idx) + } + + fn next_stack_value(&mut self, stk: &[Local], is_padding: bool) -> Local { + match self.gen.gen_range(0, 5) { + 0 => Local::u64(self.next_int(stk)), + 1 => Local::bool(self.next_bool()), + 2 => Local::string(self.next_str(is_padding)), + 3 => Local::bytearray(self.next_bytearray()), + _ => Local::address(self.next_addr(is_padding)), + } + } + + // Pick a random function, and random local within that function. Then generate an inhabitant + // for that local's type. + fn next_local_state( + &mut self, + ) -> ( + &'txn LoadedModule, + LocalIndex, + FunctionDefinitionIndex, + Local, + ) { + // We pick a random function from the module in which to store the local + let function_handle_idx = self.next_function_handle_idx(); + let (module, function_definition, function_def_idx) = + self.resolve_function_handle(function_handle_idx); + let type_sig = &module + .module + .locals_signature_at(function_definition.code.locals) + .0; + // Pick a random local within that function in which we'll store the local + let local_index = self.gen.gen_range(0, type_sig.len()); + let type_tok = type_sig[local_index].clone(); + let stack_local = self.resolve_to_value(type_tok, &[]); + ( + module, + local_index as LocalIndex, + function_def_idx, + stack_local, + ) + } + + fn fill_instruction_arg(&mut self) -> Bytecode { + use Bytecode::*; + // For branching we need to know the size of the code within the top frame on the execution + // stack (the frame that the instruction will be executing in) so that we don't jump off + // the end of the function. Because of this, we need to get the frame that we're in first. + // Since we only generate one instruction at a time, for branching instructions we know + // that we won't be pushing any (non-default) frames on to the execution stack -- and + // therefore that the function at `DEFAULT_FUNCTION_IDX` will be the top frame on the + // execution stack. Because of this, we can safely pick the default function as our frame + // here. + let function_idx = FunctionDefinitionIndex::new(DEFAULT_FUNCTION_IDX); + let frame_len = self + .root_module + .module + .function_def_at(function_idx) + .code + .code + .len(); + match self.op { + BrTrue(_) => { + let index = self.next_bounded_index(frame_len as TableIndex); + BrTrue(index as CodeOffset) + } + BrFalse(_) => { + let index = self.next_bounded_index(frame_len as TableIndex); + BrFalse(index as CodeOffset) + } + Branch(_) => { + let index = self.next_bounded_index(frame_len as TableIndex); + Branch(index as CodeOffset) + } + LdConst(_) => LdConst(self.next_int(&[])), + LdStr(_) => LdStr(self.next_string_idx()), + LdByteArray(_) => LdByteArray(self.next_bytearray_idx()), + LdAddr(_) => LdAddr(self.next_address_idx()), + _ => self.op.clone(), + } + } + + fn resolve_struct_handle( + &mut self, + struct_handle_index: StructHandleIndex, + ) -> ( + &'txn LoadedModule, + &'txn StructDefinition, + StructDefinitionIndex, + ) { + let struct_handle = self + .root_module + .module + .struct_handle_at(struct_handle_index); + let struct_name = self.root_module.string_at(struct_handle.name); + let module_handle = self + .root_module + .module + .module_handle_at(struct_handle.module); + let code_key = self.to_code_key(module_handle); + let module = self + .module_cache + .get_loaded_module(&code_key) + .expect("[Module Lookup] Error while looking up module") + .expect("[Module Lookup] Unable to find module"); + let struct_def_idx = if self.struct_handle_table.contains_key(&code_key) { + self.struct_handle_table + .get(&code_key) + .expect("[Struct Definition Lookup] Unable to get struct handles for module") + .get(struct_name) + } else { + let entry = self.struct_handle_table.entry(code_key).or_insert_with(|| { + module + .module + .struct_defs() + .enumerate() + .map(|(struct_def_index, struct_def)| { + let handle = module.module.struct_handle_at(struct_def.struct_handle); + let name = module.module.string_at(handle.name).to_string(); + ( + name, + StructDefinitionIndex::new(struct_def_index as TableIndex), + ) + }) + .collect() + }); + entry.get(struct_name) + } + .expect("[Struct Definition Lookup] Unable to get struct definition for struct handle"); + + let struct_def = module.module.struct_def_at(*struct_def_idx); + (module, struct_def, *struct_def_idx) + } + + fn resolve_function_handle( + &mut self, + function_handle_index: FunctionHandleIndex, + ) -> ( + &'txn LoadedModule, + &'txn FunctionDefinition, + FunctionDefinitionIndex, + ) { + let function_handle = self + .root_module + .module + .function_handle_at(function_handle_index); + let function_name = self.root_module.string_at(function_handle.name); + let module_handle = self + .root_module + .module + .module_handle_at(function_handle.module); + let code_key = self.to_code_key(module_handle); + let module = self + .module_cache + .get_loaded_module(&code_key) + .expect("[Module Lookup] Error while looking up module") + .expect("[Module Lookup] Unable to find module"); + let function_def_idx = if self.function_handle_table.contains_key(&code_key) { + *self + .function_handle_table + .get(&code_key) + .expect("[Function Definition Lookup] Unable to get function handles for module") + .get(function_name) + .expect( + "[Function Definition Lookup] Unable to get function definition for struct handle", + ) + } else { + let entry = self + .function_handle_table + .entry(code_key) + .or_insert_with(|| { + module + .module + .function_defs() + .enumerate() + .map(|(function_def_index, function_def)| { + let handle = module.module.function_handle_at(function_def.function); + let name = module.module.string_at(handle.name).to_string(); + ( + name, + FunctionDefinitionIndex::new(function_def_index as TableIndex), + ) + }) + .collect() + }); + *entry.get(function_name).expect("FOO") + }; + + let function_def = module.module.function_def_at(function_def_idx); + (module, function_def, function_def_idx) + } + + // Build an inhabitant of the type given by `sig_token`. We pass the current stack state in + // since for certain instructions (...Sub) we need to generate number pairs that when + // subtracted from each other do not cause overflow. + fn resolve_to_value(&mut self, sig_token: SignatureToken, stk: &[Local]) -> Local { + match sig_token { + SignatureToken::Bool => Local::bool(self.next_bool()), + SignatureToken::U64 => Local::u64(self.next_int(stk)), + SignatureToken::String => Local::string(self.next_str(false)), + SignatureToken::Address => Local::address(self.next_addr(false)), + SignatureToken::Reference(box sig) | SignatureToken::MutableReference(box sig) => { + let underlying_value = self.resolve_to_value(sig, stk); + underlying_value + .borrow_local() + .expect("Unable to generate valid reference value") + } + SignatureToken::ByteArray => Local::bytearray(self.next_bytearray()), + SignatureToken::Struct(struct_handle_idx) => { + assert!(self.root_module.module.struct_defs().len() > 1); + let struct_definition = self + .root_module + .module + .struct_def_at(self.resolve_struct_handle(struct_handle_idx).2); + let num_fields = struct_definition.field_count as usize; + let index = struct_definition.fields.into_index(); + let fields = &self.root_module.module.field_defs[index..index + num_fields]; + let mutvals = fields + .iter() + .map(|field| { + self.resolve_to_value( + self.root_module + .module + .type_signature_at(field.signature) + .0 + .clone(), + stk, + ) + .value() + .expect("[Struct Generation] Unable to get underlying value for generated struct field.") + }) + .collect(); + Local::struct_(mutvals) + } + } + } + + // Generate starting state of the stack based upon the type transition in the call info table. + fn generate_from_type(&mut self, typ: SignatureTy, stk: &[Local]) -> Local { + // If the underlying type is a variable type, then we can choose any type that we want. + let typ = if typ.is_variable() { + let underlying_type = typ.underlying(); + let index = self.gen.gen_range(0, underlying_type.len()); + underlying_type[index].clone() + } else { + typ.underlying() + .first() + .expect("Unable to get underlying type for sigty in generate_from_type") + .clone() + }; + self.resolve_to_value(typ.0, stk) + } + + // Certain instructions require specific stack states; e.g. Pack() requires the correct number + // and type of locals to already exist at the top of the value stack when the instruction is + // encountered. We therefore need to generate the stack state for certain operations _not_ on + // their call info, but on the possible calls that we could have in the module/other modules + // that we are aware of. + fn generate_from_module_info(&mut self) -> StackState<'txn> { + use Bytecode::*; + match self.op { + Call(_) => { + let function_handle_idx = self.next_function_handle_idx(); + let function_idx = self.resolve_function_handle(function_handle_idx).2; + let function_handle = self + .root_module + .module + .function_handle_at(function_handle_idx); + let function_sig = self + .root_module + .module + .function_signature_at(function_handle.signature); + let stack = function_sig.arg_types.clone().into_iter().fold( + Vec::new(), + |mut acc, sig_tok| { + acc.push(self.resolve_to_value(sig_tok, &acc)); + acc + }, + ); + let size = stack.iter().fold(0, |acc, local| local.size() + acc); + StackState::new( + (self.root_module, Some(function_idx)), + self.random_pad(stack), + Call(function_handle_idx), + size, + HashMap::new(), + ) + } + Pack(_struct_def_idx) => { + let struct_def_bound = self.root_module.module.struct_defs.len() as TableIndex; + let random_struct_idx = + StructDefinitionIndex::new(self.next_bounded_index(struct_def_bound)); + let struct_definition = self.root_module.module.struct_def_at(random_struct_idx); + let num_fields = struct_definition.field_count as usize; + let index = struct_definition.fields.into_index(); + let fields = &self.root_module.module.field_defs[index..index + num_fields]; + let stack: Stack = fields + .iter() + .map(|field| { + let ty = self + .root_module + .module + .type_signature_at(field.signature) + .0 + .clone(); + self.resolve_to_value(ty, &[]) + }) + .collect(); + let size = stack.iter().fold(0, |acc, local| local.size() + acc); + StackState::new( + (self.root_module, None), + self.random_pad(stack), + Pack(random_struct_idx), + size, + HashMap::new(), + ) + } + Unpack(_struct_def_idx) => { + let struct_def_bound = self.root_module.module.struct_defs.len() as TableIndex; + let random_struct_idx = + StructDefinitionIndex::new(self.next_bounded_index(struct_def_bound)); + let struct_handle_idx = self + .root_module + .module + .struct_def_at(random_struct_idx) + .struct_handle; + let struct_stack = + self.resolve_to_value(SignatureToken::Struct(struct_handle_idx), &[]); + let size = struct_stack.size() as u64; + StackState::new( + (self.root_module, None), + self.random_pad(vec![struct_stack]), + Unpack(random_struct_idx), + size, + HashMap::new(), + ) + } + BorrowField(_) => { + // First grab a random struct + let struct_def_bound = self.root_module.module.struct_defs.len() as TableIndex; + let random_struct_idx = + StructDefinitionIndex::new(self.next_bounded_index(struct_def_bound)); + let struct_definition = self.root_module.module.struct_def_at(random_struct_idx); + let num_fields = struct_definition.field_count; + // Grab a random field within that struct to borrow + let field_index = self.gen.gen_range(0, num_fields); + let struct_stack = self.resolve_to_value( + SignatureToken::Reference(Box::new(SignatureToken::Struct( + struct_definition.struct_handle, + ))), + &[], + ); + let field_size = struct_stack + .borrow_field(u32::from(field_index)) + .expect("[BorrowField] Unable to borrow field of generated struct to get field size.") + .size(); + StackState::new( + (self.root_module, None), + self.random_pad(vec![struct_stack]), + BorrowField(FieldDefinitionIndex::new(field_index)), + field_size, + HashMap::new(), + ) + } + StLoc(_) => { + let (module, local_idx, function_idx, stack_local) = self.next_local_state(); + let size = stack_local.size(); + StackState::new( + (module, Some(function_idx)), + self.random_pad(vec![stack_local]), + StLoc(local_idx as LocalIndex), + size, + HashMap::new(), + ) + } + CopyLoc(_) | MoveLoc(_) | BorrowLoc(_) => { + let (module, local_idx, function_idx, frame_local) = self.next_local_state(); + let size = frame_local.size(); + let mut locals_mapping = HashMap::new(); + locals_mapping.insert(local_idx as LocalIndex, frame_local); + StackState::new( + (module, Some(function_idx)), + self.random_pad(Vec::new()), + CopyLoc(local_idx as LocalIndex), + size, + locals_mapping, + ) + } + _ => unimplemented!(), + } + } + + // Take the stack, and then randomly pad it up to the given stack limit + fn random_pad(&mut self, mut stk: Stack) -> Stack { + // max amount we can pad while being legal + let len = stk.len() as u64; + let max_pad_amt = if len > self.max_stack_size { + 1 + } else { + self.max_stack_size - (stk.len() as u64) + }; + let rand_len = self.gen.gen_range(1, max_pad_amt); + // Generate the random stack prefix + let mut stk_prefix: Vec<_> = (0..rand_len) + .map(|_| self.next_stack_value(&stk, true)) + .collect(); + stk_prefix.append(&mut stk); + stk_prefix + } + + /// Generate a new valid random stack state. Return `None` if `iters` many stacks have been + /// produced with this instance. + pub fn next_stack(&mut self) -> Option> { + if self.iters == 0 { + return None; + } + self.iters -= 1; + Some(if self.is_module_specific_op() { + self.generate_from_module_info() + } else { + let info = call_details(&self.op); + // Pick a random input/output argument configuration for the opcode + let index = self.gen.gen_range(0, info.len()); + let args = info[index].in_args.clone(); + let starting_stack: Stack = args.into_iter().fold(Vec::new(), |mut acc, x| { + // We pass in a context since we need to enforce certain relationships between + // stack values for valid execution. + acc.push(self.generate_from_type(x, &acc)); + acc + }); + let size = starting_stack.iter().fold(0, |acc, x| acc + x.size()); + StackState::new( + (self.root_module, None), + self.random_pad(starting_stack), + self.fill_instruction_arg(), + size, + HashMap::new(), + ) + }) + } + + /// Applies the `stack_state` to the VM's execution stack. + /// + /// We don't use the `instr` in the stack state within this function. We therefore pull it out + /// since we are grabbing ownership of the other fields of the struct and return it to be + /// used elsewhere. + pub fn stack_transition