Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Prepare to call jit from block machine. #2098

Merged
merged 31 commits into from
Dec 12, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
31 commits
Select commit Hold shift + click to select a range
fbdabca
Use contiguous data for finalized rows.
chriseth Nov 13, 2024
71d2ebd
Remove last row.
chriseth Nov 13, 2024
f66c5c2
Fix extend.
chriseth Nov 14, 2024
934ebb1
clippy
chriseth Nov 14, 2024
19d544f
remove unused functiion
chriseth Nov 14, 2024
6b83811
Allow non-sorted column IDs.
chriseth Nov 14, 2024
80c24f9
Jit driver and machine.
chriseth Nov 14, 2024
5528a4f
reserve block
chriseth Nov 15, 2024
4d4111c
revie
chriseth Nov 15, 2024
0076d2c
review
chriseth Nov 15, 2024
3c577c1
review
chriseth Nov 15, 2024
4e83830
Add "remove_last_row"
chriseth Nov 15, 2024
9f3130a
Re-added one case.
chriseth Nov 15, 2024
1aadfe0
Merge branch 'improve_finalizable' into call_jit_from_block
chriseth Nov 15, 2024
056a904
jitedijit
chriseth Nov 15, 2024
f2fee46
jitjit
chriseth Nov 15, 2024
9415cc3
Remove try_remove_last_row.
chriseth Nov 15, 2024
ec5d110
compact data ref.
chriseth Nov 15, 2024
7cbbb8d
Remove unwrap.
chriseth Nov 15, 2024
4490dcd
Cleanup
chriseth Nov 15, 2024
7ffdad9
Merge remote-tracking branch 'origin/improve_finalizable' into call_j…
chriseth Nov 15, 2024
09863e8
cleanup
chriseth Nov 15, 2024
cf51774
cleanup
chriseth Nov 15, 2024
825d0eb
Merge remote-tracking branch 'origin/main' into call_jit_from_block
chriseth Nov 15, 2024
d470fd7
Remove type constraints.
chriseth Nov 22, 2024
b8864c8
Avoid dyn iter.
chriseth Nov 22, 2024
ca38415
Correct typo.
chriseth Nov 22, 2024
d47f695
Update executor/src/witgen/data_structures/finalizable_data.rs
chriseth Dec 9, 2024
d482821
Review comments.
chriseth Dec 9, 2024
68be580
Merge remote-tracking branch 'origin/main' into call_jit_from_block
chriseth Dec 10, 2024
46860a1
Merge remote-tracking branch 'origin/main' into call_jit_from_block
chriseth Dec 12, 2024
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions executor/Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@ powdr-parser-util.workspace = true
powdr-pil-analyzer.workspace = true
powdr-jit-compiler.workspace = true

auto_enums = "0.8.5"
itertools = "0.13"
log = { version = "0.4.17" }
rayon = "1.7.0"
Expand Down
121 changes: 112 additions & 9 deletions executor/src/witgen/data_structures/finalizable_data.rs
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@ use std::{
ops::{Index, IndexMut},
};

use auto_enums::auto_enum;
use bit_vec::BitVec;
use itertools::Itertools;
use powdr_ast::analyzed::{PolyID, PolynomialType};
Expand All @@ -13,7 +14,7 @@ use crate::witgen::rows::Row;
/// Sequence of rows of field elements, stored in a compact form.
/// Optimized for contiguous column IDs, but works with any combination.
#[derive(Clone)]
struct CompactData<T: FieldElement> {
pub struct CompactData<T> {
/// The ID of the first column used in the table.
first_column_id: u64,
/// The length of a row in the table.
Expand All @@ -26,7 +27,7 @@ struct CompactData<T: FieldElement> {

impl<T: FieldElement> CompactData<T> {
/// Creates a new empty compact data storage.
fn new(column_ids: &[PolyID]) -> Self {
pub fn new(column_ids: &[PolyID]) -> Self {
let col_id_range = column_ids.iter().map(|id| id.id).minmax();
let (first_column_id, last_column_id) = col_id_range.into_option().unwrap();
Self {
Expand All @@ -37,28 +38,28 @@ impl<T: FieldElement> CompactData<T> {
}
}

fn is_empty(&self) -> bool {
pub fn is_empty(&self) -> bool {
self.data.is_empty()
}

/// Returns the number of stored rows.
fn len(&self) -> usize {
pub fn len(&self) -> usize {
self.data.len() / self.column_count
}

/// Truncates the data to `len` rows.
fn truncate(&mut self, len: usize) {
pub fn truncate(&mut self, len: usize) {
self.data.truncate(len * self.column_count);
self.known_cells.truncate(len * self.column_count);
}

fn clear(&mut self) {
pub fn clear(&mut self) {
self.data.clear();
self.known_cells.clear();
}

/// Appends a non-finalized row to the data, turning it into a finalized row.
fn push(&mut self, row: Row<T>) {
pub fn push(&mut self, row: Row<T>) {
self.data.reserve(self.column_count);
self.known_cells.reserve(self.column_count);
for col_id in self.first_column_id..(self.first_column_id + self.column_count as u64) {
Expand All @@ -75,11 +76,69 @@ impl<T: FieldElement> CompactData<T> {
}
}

fn get(&self, row: usize, col: u64) -> (T, bool) {
pub fn append_new_rows(&mut self, count: usize) {
self.data
.resize(self.data.len() + count * self.column_count, T::zero());
self.known_cells.grow(count * self.column_count, false);
}

fn index(&self, row: usize, col: u64) -> usize {
let col = col - self.first_column_id;
let idx = row * self.column_count + col as usize;
row * self.column_count + col as usize
}

pub fn get(&self, row: usize, col: u64) -> (T, bool) {
let idx = self.index(row, col);
(self.data[idx], self.known_cells[idx])
}

pub fn set(&mut self, row: usize, col: u64, value: T) {
let idx = self.index(row, col);
assert!(!self.known_cells[idx] || self.data[idx] == value);
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Does this mean that we sometimes set the same value twice?

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Or that there's a default value?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It's OK to set the value twice if it's the same value. We can maybe make this more strict later, though.

The "default" value is "not known", internally it will be represented by zero.

self.data[idx] = value;
self.known_cells.set(idx, true);
}

pub fn known_values_in_row(&self, row: usize) -> impl Iterator<Item = (u64, &T)> {
(0..self.column_count).filter_map(move |i| {
let col = self.first_column_id + i as u64;
let idx = self.index(row, col);
self.known_cells[idx].then(|| {
let col_id = self.first_column_id + i as u64;
(col_id, &self.data[idx])
})
})
}
}

/// A mutable reference into CompactData that is meant to be used
/// only for a certain block of rows, starting from row index zero.
/// It allows negative row indices as well.
pub struct CompactDataRef<'a, T> {
data: &'a mut CompactData<T>,
row_offset: usize,
}

impl<'a, T: FieldElement> CompactDataRef<'a, T> {
/// Creates a new reference to the data, supplying the offset of the row
/// that is supposed to be "row zero".
pub fn new(data: &'a mut CompactData<T>, row_offset: usize) -> Self {
Self { data, row_offset }
}

pub fn get(&self, row: i32, col: u32) -> T {
Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Not sure if the 32 bit values here provide a performance advantage. Any opinions?

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Instead of get and set, can this implement Index and IndexMut?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I actually don't think we will use those functions later on. The current interface used in #2071 uses direct memory slices.

let (v, known) = self.data.get(self.inner_row(row), col as u64);
assert!(known);
v
}

pub fn set(&mut self, row: i32, col: u32, value: T) {
self.data.set(self.inner_row(row), col as u64, value);
}

fn inner_row(&self, row: i32) -> usize {
(row + self.row_offset as i32) as usize
}
}

/// A data structure that stores witness data.
Expand Down Expand Up @@ -215,6 +274,38 @@ impl<T: FieldElement> FinalizableData<T> {
}
}

/// Returns an iterator over the values known in that row together with the PolyIDs.
#[auto_enum(Iterator)]
georgwiese marked this conversation as resolved.
Show resolved Hide resolved
pub fn known_values_in_row(&self, row: usize) -> impl Iterator<Item = (PolyID, T)> + '_ {
match self.location_of_row(row) {
Location::PreFinalized(local) => {
let row = &self.pre_finalized_data[local];
self.column_ids
.iter()
.filter_map(move |id| row.value(id).map(|v| (*id, v)))
}
Location::Finalized(local) => {
self.finalized_data
.known_values_in_row(local)
.map(|(id, v)| {
(
PolyID {
id,
ptype: PolynomialType::Committed,
},
*v,
)
})
}
Location::PostFinalized(local) => {
let row = &self.post_finalized_data[local];
self.column_ids
.iter()
.filter_map(move |id| row.value(id).map(|v| (*id, v)))
}
}
}

pub fn last(&self) -> Option<&Row<T>> {
match self.location_of_last_row()? {
Location::PreFinalized(local) => self.pre_finalized_data.get(local),
Expand Down Expand Up @@ -283,6 +374,18 @@ impl<T: FieldElement> FinalizableData<T> {
}
}

/// Appends a given amount of new finalized rows set to zero and "unknown".
/// Returns a `CompactDataRef` that is built so that its "row zero" is the
/// first newly appended row.
///
/// Panics if there are any non-finalized rows at the end.
pub fn append_new_finalized_rows(&mut self, count: usize) -> CompactDataRef<'_, T> {
assert!(self.post_finalized_data.is_empty());
let row_zero = self.finalized_data.len();
self.finalized_data.append_new_rows(count);
CompactDataRef::new(&mut self.finalized_data, row_zero)
}

/// Takes all data out of the [FinalizableData] and returns it as a list of columns.
/// Columns are represented as a tuple of:
/// - A list of values
Expand Down
63 changes: 63 additions & 0 deletions executor/src/witgen/jit/jit_processor.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,63 @@
use bit_vec::BitVec;
use powdr_number::FieldElement;

use crate::witgen::{
data_structures::finalizable_data::CompactDataRef,
machines::{LookupCell, MachineParts},
util::try_to_simple_poly,
EvalError, FixedData, MutableState, QueryCallback,
};

pub struct JitProcessor<'a, T: FieldElement> {
_fixed_data: &'a FixedData<'a, T>,
parts: MachineParts<'a, T>,
_block_size: usize,
latch_row: usize,
}

impl<'a, T: FieldElement> JitProcessor<'a, T> {
pub fn new(
fixed_data: &'a FixedData<'a, T>,
parts: MachineParts<'a, T>,
block_size: usize,
latch_row: usize,
) -> Self {
JitProcessor {
_fixed_data: fixed_data,
parts,
_block_size: block_size,
latch_row,
}
}

pub fn can_answer_lookup(&self, _identity_id: u64, _known_inputs: &BitVec) -> bool {
// TODO call the JIT compiler here.
false
}

pub fn process_lookup_direct<'c, 'd, Q: QueryCallback<T>>(
&self,
_mutable_state: &MutableState<'a, T, Q>,
connection_id: u64,
values: Vec<LookupCell<'c, T>>,
mut data: CompactDataRef<'d, T>,
) -> Result<bool, EvalError<T>> {
// Transfer inputs.
let right = self.parts.connections[&connection_id].right;
for (e, v) in right.expressions.iter().zip(&values) {
match v {
LookupCell::Input(&v) => {
let col = try_to_simple_poly(e).unwrap();
data.set(self.latch_row as i32, col.poly_id.id as u32, v);
}
LookupCell::Output(_) => {}
}
}

// Just some code here to avoid "unused" warnings.
// This code will not be called as long as `can_answer_lookup` returns false.
data.get(self.latch_row as i32, 0);

unimplemented!();
}
}
1 change: 1 addition & 0 deletions executor/src/witgen/jit/mod.rs
Original file line number Diff line number Diff line change
@@ -1,2 +1,3 @@
mod affine_symbolic_expression;
pub(crate) mod jit_processor;
mod symbolic_expression;
56 changes: 47 additions & 9 deletions executor/src/witgen/machines/block_machine.rs
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@ use crate::witgen::block_processor::BlockProcessor;
use crate::witgen::data_structures::finalizable_data::FinalizableData;
use crate::witgen::data_structures::multiplicity_counter::MultiplicityCounter;
use crate::witgen::data_structures::mutable_state::MutableState;
use crate::witgen::jit::jit_processor::JitProcessor;
use crate::witgen::processor::{OuterQuery, Processor, SolverState};
use crate::witgen::rows::{Row, RowIndex, RowPair};
use crate::witgen::sequence_iterator::{
Expand Down Expand Up @@ -72,6 +73,9 @@ pub struct BlockMachine<'a, T: FieldElement> {
/// Cache that states the order in which to evaluate identities
/// to make progress most quickly.
processing_sequence_cache: ProcessingSequenceCache,
/// The JIT processor for this machine, i.e. the component that tries to generate
/// witgen code based on which elements of the connection are known.
jit_processor: JitProcessor<'a, T>,
name: String,
multiplicity_counter: MultiplicityCounter,
}
Expand Down Expand Up @@ -132,6 +136,7 @@ impl<'a, T: FieldElement> BlockMachine<'a, T> {
latch_row,
parts.identities.len(),
),
jit_processor: JitProcessor::new(fixed_data, parts.clone(), block_size, latch_row),
})
}
}
Expand Down Expand Up @@ -356,12 +361,6 @@ impl<'a, T: FieldElement> BlockMachine<'a, T> {
RowIndex::from_i64(self.rows() as i64 - 1, self.degree)
}

fn get_row(&self, row: RowIndex) -> &Row<T> {
// The first block is a dummy block corresponding to rows (-block_size, 0),
// so we have to add the block size to the row index.
&self.data[(row + self.block_size).into()]
}

fn process_plookup_internal<'b, Q: QueryCallback<T>>(
&mut self,
mutable_state: &MutableState<'a, T, Q>,
Expand All @@ -372,8 +371,18 @@ impl<'a, T: FieldElement> BlockMachine<'a, T> {

log::trace!("Start processing block machine '{}'", self.name());
log::trace!("Left values of lookup:");
for l in &outer_query.left {
log::trace!(" {}", l);
if log::log_enabled!(log::Level::Trace) {
for l in &outer_query.left {
log::trace!(" {}", l);
}
}

let known_inputs = outer_query.left.iter().map(|e| e.is_constant()).collect();
if self
.jit_processor
.can_answer_lookup(identity_id, &known_inputs)
{
return self.process_lookup_via_jit(mutable_state, identity_id, outer_query);
}

// TODO this assumes we are always using the same lookup for this machine.
Expand Down Expand Up @@ -431,6 +440,35 @@ impl<'a, T: FieldElement> BlockMachine<'a, T> {
}
}

fn process_lookup_via_jit<'b, Q: QueryCallback<T>>(
&mut self,
mutable_state: &MutableState<'a, T, Q>,
identity_id: u64,
outer_query: OuterQuery<'a, 'b, T>,
) -> EvalResult<'a, T> {
let mut input_output_data = vec![T::zero(); outer_query.left.len()];
let values = outer_query.prepare_for_direct_lookup(&mut input_output_data);

assert!(
(self.rows() + self.block_size as DegreeType) < self.degree,
"Block machine is full (this should have been checked before)"
);
self.data
.finalize_range(self.first_in_progress_row..self.data.len());
self.first_in_progress_row = self.data.len() + self.block_size;
//TODO can we properly access the last row of the dummy block?
let data = self.data.append_new_finalized_rows(self.block_size);

let success =
self.jit_processor
.process_lookup_direct(mutable_state, identity_id, values, data)?;
assert!(success);

Ok(outer_query
.direct_lookup_to_eval_result(input_output_data)?
.report_side_effect())
}

fn process<'b, Q: QueryCallback<T>>(
&self,
mutable_state: &MutableState<'a, T, Q>,
Expand Down Expand Up @@ -481,7 +519,7 @@ impl<'a, T: FieldElement> BlockMachine<'a, T> {
new_block
.get_mut(0)
.unwrap()
.merge_with(self.get_row(self.last_row_index()))
.merge_with_values(self.data.known_values_in_row(self.data.len() - 1))
.map_err(|_| {
EvalError::Generic(
"Block machine overwrites existing value with different value!".to_string(),
Expand Down
6 changes: 3 additions & 3 deletions executor/src/witgen/machines/fixed_lookup_machine.rs
Original file line number Diff line number Diff line change
Expand Up @@ -234,7 +234,7 @@ impl<'a, T: FieldElement> FixedLookup<'a, T> {

fn process_range_check(
&self,
rows: &RowPair<'_, 'a, T>,
rows: &RowPair<'_, '_, T>,
lhs: &AffineExpression<AlgebraicVariable<'a>, T>,
rhs: AlgebraicVariable<'a>,
) -> EvalResult<'a, T> {
Expand Down Expand Up @@ -317,9 +317,9 @@ impl<'a, T: FieldElement> Machine<'a, T> for FixedLookup<'a, T> {
self.process_plookup_internal(mutable_state, identity_id, caller_rows, outer_query, right)
}

fn process_lookup_direct<'b, 'c, Q: QueryCallback<T>>(
fn process_lookup_direct<'c, Q: QueryCallback<T>>(
&mut self,
_mutable_state: &'b MutableState<'a, T, Q>,
_mutable_state: &MutableState<'a, T, Q>,
identity_id: u64,
values: &mut [LookupCell<'c, T>],
) -> Result<bool, EvalError<T>> {
Expand Down
Loading
Loading