Skip to content

Commit

Permalink
cargo fmt
Browse files Browse the repository at this point in the history
  • Loading branch information
morganthomas committed Apr 2, 2024
1 parent 1767364 commit 1d428c0
Show file tree
Hide file tree
Showing 4 changed files with 50 additions and 30 deletions.
4 changes: 1 addition & 3 deletions basic/tests/test_static_data.rs
Original file line number Diff line number Diff line change
Expand Up @@ -5,8 +5,7 @@ use p3_fri::{TwoAdicFriPcs, TwoAdicFriPcsConfig};
use valida_alu_u32::add::{Add32Instruction, MachineWithAdd32Chip};
use valida_basic::BasicMachine;
use valida_cpu::{
BneInstruction, Imm32Instruction, Load32Instruction,
MachineWithCpuChip, StopInstruction,
BneInstruction, Imm32Instruction, Load32Instruction, MachineWithCpuChip, StopInstruction,
};
use valida_machine::{
FixedAdviceProvider, Instruction, InstructionWord, Machine, MachineProof, Operands, ProgramROM,
Expand All @@ -32,7 +31,6 @@ use rand::thread_rng;
use valida_machine::StarkConfigImpl;
use valida_machine::__internal::p3_commit::ExtensionMmcs;


#[test]
fn prove_static_data() {
// _start:
Expand Down
40 changes: 25 additions & 15 deletions derive/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,10 @@ impl Parse for MachineFields {
}
}

#[proc_macro_derive(Machine, attributes(machine_fields, bus, chip, static_data_chip, instruction))]
#[proc_macro_derive(
Machine,
attributes(machine_fields, bus, chip, static_data_chip, instruction)
)]
pub fn machine_derive(input: TokenStream) -> TokenStream {
let ast = syn::parse(input).unwrap();
impl_machine(&ast)
Expand Down Expand Up @@ -61,12 +64,15 @@ fn impl_machine(machine: &syn::DeriveInput) -> TokenStream {
.expect("Invalid machine_fields attribute, expected #[machine_fields(<Val>)]");
let val = &machine_fields.val;

let static_data_chip: Option<Ident> =
chips
.iter()
.filter(|f| f.attrs.iter().any(|a| a.path.is_ident("static_data_chip")))
.map(|f| f.ident.clone().expect("static data chip requires an identifier"))
.next();
let static_data_chip: Option<Ident> = chips
.iter()
.filter(|f| f.attrs.iter().any(|a| a.path.is_ident("static_data_chip")))
.map(|f| {
f.ident
.clone()
.expect("static data chip requires an identifier")
})
.next();

let name = &machine.ident;
let run = run_method(machine, &instructions, &val, &static_data_chip);
Expand Down Expand Up @@ -134,7 +140,12 @@ fn chip_methods(chip: &Field) -> TokenStream2 {
}
}

fn run_method(machine: &syn::DeriveInput, instructions: &[&Field], val: &Ident, static_data_chip: &Option<Ident>) -> TokenStream2 {
fn run_method(
machine: &syn::DeriveInput,
instructions: &[&Field],
val: &Ident,
static_data_chip: &Option<Ident>,
) -> TokenStream2 {
let name = &machine.ident;
let (_, ty_generics, _) = machine.generics.split_for_impl();

Expand All @@ -150,13 +161,12 @@ fn run_method(machine: &syn::DeriveInput, instructions: &[&Field], val: &Ident,
})
.collect::<TokenStream2>();

let init_static_data: TokenStream2 =
match static_data_chip {
Some(static_data_chip) => quote!{
self.initialize_memory();
},
None => quote!{},
};
let init_static_data: TokenStream2 = match static_data_chip {
Some(static_data_chip) => quote! {
self.initialize_memory();
},
None => quote! {},
};

quote! {
fn run<Adv: ::valida_machine::AdviceProvider>(&mut self, program: &ProgramROM<i32>, advice: &mut Adv) {
Expand Down
26 changes: 17 additions & 9 deletions memory/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -127,7 +127,8 @@ where
// // than the length of the table (capped at 2^29)
// Self::insert_dummy_reads(&mut ops);

let mut rows = self.static_data
let mut rows = self
.static_data
.iter()
.enumerate()
.map(|(n, (addr, value))| self.static_data_to_row(n, *addr, *value))
Expand All @@ -140,7 +141,7 @@ where
let ops_rows = ops
.par_iter()
.enumerate()
.map(|(n, (clk, op))| self.op_to_row(n0+n, *clk as usize, *op))
.map(|(n, (clk, op))| self.op_to_row(n0 + n, *clk as usize, *op))
.collect::<Vec<_>>();
rows.extend(ops_rows.clone());

Expand All @@ -150,8 +151,10 @@ where
// Make sure the table length is a power of two
rows.resize(rows.len().next_power_of_two(), padding_row);

let trace =
RowMajorMatrix::new(rows.clone().into_iter().flatten().collect::<Vec<_>>(), NUM_MEM_COLS);
let trace = RowMajorMatrix::new(
rows.clone().into_iter().flatten().collect::<Vec<_>>(),
NUM_MEM_COLS,
);

trace
}
Expand Down Expand Up @@ -225,7 +228,12 @@ impl MemoryChip {
row
}

fn static_data_to_row<F: PrimeField>(&self, n: usize, addr: u32, value: Word<u8>) -> [F; NUM_MEM_COLS] {
fn static_data_to_row<F: PrimeField>(
&self,
n: usize,
addr: u32,
value: Word<u8>,
) -> [F; NUM_MEM_COLS] {
let mut row = [F::zero(); NUM_MEM_COLS];
let cols: &mut MemoryCols<F> = unsafe { transmute(&mut row) };
cols.is_static_initial = F::one();
Expand Down Expand Up @@ -352,14 +360,14 @@ impl MemoryChip {

// Set trace values
for i in 0..(ops.len() - 1) {
rows[i0+i][MEM_COL_MAP.diff] = diff[i];
rows[i0+i][MEM_COL_MAP.diff_inv] = diff_inv[i];
rows[i0+i][MEM_COL_MAP.counter_mult] = mult[i];
rows[i0 + i][MEM_COL_MAP.diff] = diff[i];
rows[i0 + i][MEM_COL_MAP.diff_inv] = diff_inv[i];
rows[i0 + i][MEM_COL_MAP.counter_mult] = mult[i];

let addr = ops[i].1.get_address();
let addr_next = ops[i + 1].1.get_address();
if addr_next - addr != 0 {
rows[i0+i][MEM_COL_MAP.addr_not_equal] = F::one();
rows[i0 + i][MEM_COL_MAP.addr_not_equal] = F::one();
}
}

Expand Down
10 changes: 7 additions & 3 deletions static_data/src/lib.rs
Original file line number Diff line number Diff line change
@@ -1,4 +1,3 @@

extern crate alloc;

use crate::columns::{StaticDataCols, NUM_STATIC_DATA_COLS, STATIC_DATA_COL_MAP};
Expand Down Expand Up @@ -53,7 +52,9 @@ where
SC: StarkConfig,
{
fn generate_trace(&self, machine: &M) -> RowMajorMatrix<SC::Val> {
let mut rows = self.cells.iter()
let mut rows = self
.cells
.iter()
.map(|(addr, value)| {
let mut row = [SC::Val::zero(); NUM_STATIC_DATA_COLS];
let cols: &mut StaticDataCols<SC::Val> = unsafe { transmute(&mut row) };
Expand All @@ -65,7 +66,10 @@ where
})
.flatten()
.collect::<Vec<_>>();
rows.resize(rows.len().next_power_of_two() * NUM_STATIC_DATA_COLS, SC::Val::zero());
rows.resize(
rows.len().next_power_of_two() * NUM_STATIC_DATA_COLS,
SC::Val::zero(),
);
RowMajorMatrix::new(rows, NUM_STATIC_DATA_COLS)
}

Expand Down

0 comments on commit 1d428c0

Please sign in to comment.