Skip to content

Commit

Permalink
Add math & bitwise ops to U256 type 2 (FuelLabs#2445)
Browse files Browse the repository at this point in the history
* most changes

* tests

* visibility changes

* visibility changes 2

* requested changes

* fixes

* formatter

Co-authored-by: tyshkor <[email protected]>
Co-authored-by: Alex Hansen <[email protected]>
  • Loading branch information
3 people authored Aug 19, 2022
1 parent 75be099 commit 22a4cc6
Show file tree
Hide file tree
Showing 14 changed files with 476 additions and 2 deletions.
271 changes: 270 additions & 1 deletion sway-lib-std/src/u256.sw
Original file line number Diff line number Diff line change
@@ -1,8 +1,25 @@
library u256;

use core::num::*;

use ::result::Result;
use ::u128::U128;
use ::assert::assert;

/// Left shift a u64 and preserve the overflow amount if any
fn lsh_with_carry(word: u64, shift_amount: u64) -> (u64, u64) {
let right_shift_amount = 64 - shift_amount;
let carry = word >> right_shift_amount;
let shifted = word << shift_amount;
(shifted, carry)
}

/// Right shift a u64 and preserve the overflow amount if any
fn rsh_with_carry(word: u64, shift_amount: u64) -> (u64, u64) {
let left_shift_amount = 64 - shift_amount;
let carry = word << left_shift_amount;
let shifted = word >> shift_amount;
(shifted, carry)
}

/// The 256-bit unsigned integer type.
/// Represented as four 64-bit components: `(a, b, c, d)`, where `value = (a << 192) + (b << 128) + (c << 64) + d`.
Expand Down Expand Up @@ -89,4 +106,256 @@ impl U256 {
pub fn bits() -> u32 {
256
}

/// Get 4 64 bit words from a single U256 value.
fn decompose(self) -> (u64, u64, u64, u64) {
(self.a, self.b, self.c, self.d)
}
}

impl core::ops::Ord for U256 {
fn gt(self, other: Self) -> bool {
self.a > other.a || (self.a == other.a && self.b > other.b || (self.b == other.b && self.c > other.c || (self.c == other.c && self.d > other.d)))
}

fn lt(self, other: Self) -> bool {
self.a < other.a || (self.a == other.a && self.b < other.b || (self.b == other.b && self.c < other.c || (self.c == other.c && self.d < other.d)))
}
}

impl core::ops::BitwiseAnd for U256 {
fn binary_and(self, other: Self) -> Self {
let(value_word_1, value_word_2, value_word_3, value_word_4) = self.decompose();
let(other_word_1, other_word_2, other_word_3, other_word_4) = other.decompose();
let word_1 = value_word_1 & other_word_1;
let word_2 = value_word_2 & other_word_2;
let word_3 = value_word_3 & other_word_3;
let word_4 = value_word_4 & other_word_4;
~U256::from(word_1, word_2, word_3, word_4)
}
}

impl core::ops::BitwiseOr for U256 {
fn binary_or(self, other: Self) -> Self {
let(value_word_1, value_word_2, value_word_3, value_word_4) = self.decompose();
let(other_word_1, other_word_2, other_word_3, other_word_4) = other.decompose();
let word_1 = value_word_1 | other_word_1;
let word_2 = value_word_2 | other_word_2;
let word_3 = value_word_3 | other_word_3;
let word_4 = value_word_4 | other_word_4;
~U256::from(word_1, word_2, word_3, word_4)
}
}

impl core::ops::BitwiseXor for U256 {
fn binary_xor(self, other: Self) -> Self {
let(value_word_1, value_word_2, value_word_3, value_word_4) = self.decompose();
let(other_word_1, other_word_2, other_word_3, other_word_4) = other.decompose();
let word_1 = value_word_1 ^ other_word_1;
let word_2 = value_word_2 ^ other_word_2;
let word_3 = value_word_3 ^ other_word_3;
let word_4 = value_word_4 ^ other_word_4;
~U256::from(word_1, word_2, word_3, word_4)
}
}

impl core::ops::Shiftable for U256 {
fn lsh(self, shift_amount: u64) -> Self {
let(word_1, word_2, word_3, word_4) = self.decompose();
let mut w1 = 0;
let mut w2 = 0;
let mut w3 = 0;
let mut w4 = 0;

let w = shift_amount / 64; // num of whole words to shift in addition to b
let b = shift_amount % 64; // num of bits to shift within each word

if w == 0 {
let(shifted_2, carry_2) = lsh_with_carry(word_2, b);
w1 = (word_1 << b) + carry_2;
let(shifted_3, carry_3) = lsh_with_carry(word_3, b);
w2 = shifted_2 + carry_3;
let(shifted_4, carry_4) = lsh_with_carry(word_4, b);
w3 = shifted_3 + carry_4;
w4 = shifted_4;
} else if w == 1 {
let(shifted_3, carry_3) = lsh_with_carry(word_3, b);
w1 = (word_2 << b) + carry_3;
let(shifted_4, carry_4) = lsh_with_carry(word_4, b);
w2 = shifted_3 + carry_4;
w3 = shifted_4;
} else if w == 2 {
let(shifted_4, carry_4) = lsh_with_carry(word_4, b);
w1 = (word_3 << b) + carry_4;
w2 = shifted_4;
} else if w == 3 {
w1 = word_4 << b;
}

~U256::from(w1, w2, w3, w4)
}

fn rsh(self, shift_amount: u64) -> Self {
let(word_1, word_2, word_3, word_4) = self.decompose();
let mut w1 = 0;
let mut w2 = 0;
let mut w3 = 0;
let mut w4 = 0;

let w = shift_amount / 64; // num of whole words to shift in addition to b
let b = shift_amount % 64; // num of bits to shift within each word

if w == 0 {
let(shifted_3, carry_3) = rsh_with_carry(word_3, b);
w4 = (word_4 >> b) + carry_3;
let(shifted_2, carry_2) = rsh_with_carry(word_2, b);
w3 = shifted_3 + carry_2;
let(shifted_1, carry_1) = rsh_with_carry(word_1, b);
w2 = shifted_2 + carry_1;
w1 = shifted_1;
} else if w == 1 {
let(shifted_2, carry_2) = rsh_with_carry(word_2, b);
w4 = (word_3 >> b) + carry_2;
let(shifted_1, carry_1) = rsh_with_carry(word_1, b);
w3 = shifted_2 + carry_1;
w2 = shifted_1;
} else if w == 2 {
let(shifted_1, carry_1) = rsh_with_carry(word_1, b);
w4 = (word_2 >> b) + carry_1;
w3 = shifted_1;
} else if w == 3 {
w4 = word_1 >> b;
};

~U256::from(w1, w2, w3, w4)
}
}

impl core::ops::Add for U256 {
/// Add a U256 to a U256. Panics on overflow.
fn add(self, other: Self) -> Self {
let(word_1, word_2, word_3, word_4) = self.decompose();
let(other_word_1, other_word_2, other_word_3, other_word_4) = other.decompose();

let mut overflow = 0;
let mut local_res = ~U128::from(0, word_4) + ~U128::from(0, other_word_4);
let result_d = local_res.lower;
overflow = local_res.upper;

local_res = ~U128::from(0, word_3) + ~U128::from(0, other_word_3) + ~U128::from(0, overflow);
let result_c = local_res.lower;
overflow = local_res.upper;

local_res = ~U128::from(0, word_2) + ~U128::from(0, other_word_2) + ~U128::from(0, overflow);
let result_b = local_res.lower;
overflow = local_res.upper;

local_res = ~U128::from(0, word_1) + ~U128::from(0, other_word_1) + ~U128::from(0, overflow);
let result_a = local_res.lower;
// panic on overflow
assert(local_res.upper == 0);
~U256::from(result_a, result_b, result_c, result_d)
}
}

impl core::ops::Subtract for U256 {
/// Subtract a U256 from a U256. Panics of overflow.
fn subtract(self, other: Self) -> Self {
// If trying to subtract a larger number, panic.
assert(!(self < other));

let(word_1, word_2, word_3, word_4) = self.decompose();
let(other_word_1, other_word_2, other_word_3, other_word_4) = other.decompose();

let mut result_a = word_1 - other_word_1;

let mut result_b = 0;
if word_2 < other_word_2 {
result_b = ~u64::max() - (other_word_2 - word_2 - 1);
result_a -= 1;
} else {
result_b = word_2 - other_word_2;
}

let mut result_c = 0;
if word_3 < other_word_3 {
result_c = ~u64::max() - (other_word_3 - word_3 - 1);
result_b -= 1;
} else {
result_c = word_3 - other_word_3;
}

let mut result_d = 0;
if word_4 < other_word_4 {
result_d = ~u64::max() - (other_word_4 - word_4 - 1);
result_c -= 1;
} else {
result_d = word_4 - other_word_4;
}

~U256::from(result_a, result_b, result_c, result_d)
}
}

impl core::ops::Multiply for U256 {
/// Multiply a U256 with a U256. Panics on overflow.
fn multiply(self, other: Self) -> Self {
let zero = ~U256::from(0, 0, 0, 0);
let one = ~U256::from(0, 0, 0, 1);

let mut total = zero;

let mut i = 256 - 1;

while true {
total <<= 1;
if (other & (one << i)) != zero {
total = total + self;
}

if i == 0 {
break;
}

i -= 1;
}

total
}
}

impl core::ops::Divide for U256 {
/// Divide a U256 by a U256. Panics if divisor is zero.
fn divide(self, divisor: Self) -> Self {
let zero = ~U256::from(0, 0, 0, 0);
let one = ~U256::from(0, 0, 0, 1);

assert(divisor != zero);

let mut quotient = ~U256::from(0, 0, 0, 0);
let mut remainder = ~U256::from(0, 0, 0, 0);

let mut i = 256 - 1;

while true {
quotient <<= 1;
remainder <<= 1;

let m = self & (one << i);
remainder = remainder | ((self & (one << i)) >> i);
// TODO use >= once OrdEq can be implemented.
if remainder > divisor || remainder == divisor {
remainder -= divisor;
quotient = quotient | one;
}

if i == 0 {
break;
}

i -= 1;
}

quotient
}
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,2 @@
out
target
Original file line number Diff line number Diff line change
@@ -0,0 +1,14 @@
[[package]]
name = 'core'
source = 'path+from-root-9D3054CE9B0DD894'
dependencies = []

[[package]]
name = 'std'
source = 'path+from-root-9D3054CE9B0DD894'
dependencies = ['core']

[[package]]
name = 'u256_div_test'
source = 'root'
dependencies = ['std']
Original file line number Diff line number Diff line change
@@ -0,0 +1,8 @@
[project]
authors = ["Fuel Labs <[email protected]>"]
entry = "main.sw"
license = "Apache-2.0"
name = "u256_div_test"

[dependencies]
std = { path = "../../../../../../../sway-lib-std" }
Original file line number Diff line number Diff line change
@@ -0,0 +1,27 @@
script;

use std::assert::assert;
use std::u256::U256;
use core::num::*;

fn main() -> bool {
let zero = ~U256::from(0, 0, 0, 0);
let one = ~U256::from(0, 0, 0, 1);
let two = ~U256::from(0, 0, 0, 2);
let max_u64 = ~U256::from(0, 0, 0, ~u64::max());

let div_max_two = max_u64 / two;
assert(div_max_two.c == 0);
assert(div_max_two.d == ~u64::max() >> 1);

// Product of u64::MAX and u64::MAX.
let mut dividend = ~U256::from(0, 0, ~u64::max(), 1);
let mut res = dividend / max_u64;
assert(res == ~U256::from(0, 0, 1, 0));

dividend = ~U256::from(~u64::max(), 0, 0, 0);
let mut res = dividend / max_u64;
assert(res == ~U256::from(1, 0, 0, 0));

true
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,2 @@
out
target
Original file line number Diff line number Diff line change
@@ -0,0 +1,14 @@
[[package]]
name = 'core'
source = 'path+from-root-333FD703A60081C3'
dependencies = []

[[package]]
name = 'std'
source = 'path+from-root-333FD703A60081C3'
dependencies = ['core']

[[package]]
name = 'u256_mul_test'
source = 'root'
dependencies = ['std']
Original file line number Diff line number Diff line change
@@ -0,0 +1,8 @@
[project]
authors = ["Fuel Labs <[email protected]>"]
entry = "main.sw"
license = "Apache-2.0"
name = "u256_mul_test"

[dependencies]
std = { path = "../../../../../../../sway-lib-std" }
Loading

0 comments on commit 22a4cc6

Please sign in to comment.