From 05318544014d15dd02f90ff720eb8757cc3d3a1a Mon Sep 17 00:00:00 2001 From: Vaivaswatha N Date: Fri, 14 Apr 2023 21:15:20 +0530 Subject: [PATCH] memcpyopt: Ensure source local isn't clobbered before the new memcpy (#4422) ## Description When combining a load+store into a `memcpy`, the optimization was missing a check that the source isn't clobbered (stored to) in between the load and store. The optimization thus is enabled for copy types too. Issue #4345 ## Checklist - [x] I have linked to any relevant issues. - [x] I have commented my code, particularly in hard-to-understand areas. - [x] I have updated the documentation where relevant (API docs, the reference, and the Sway book). - [x] I have added tests that prove my fix is effective or that my feature works. - [x] I have added (or requested a maintainer to add) the necessary `Breaking*` or `New Feature` labels where relevant. - [x] I have done my best to ensure that my PR adheres to [the Fuel Labs Code Review Standards](https://github.com/FuelLabs/rfcs/blob/master/text/code-standards/external-contributors.md). - [x] I have requested a review from the relevant team or maintainers. --------- Co-authored-by: Mohammad Fawaz --- sway-ir/src/block.rs | 13 +- sway-ir/src/instruction.rs | 8 +- sway-ir/src/optimize/arg_demotion.rs | 60 +++-- sway-ir/src/optimize/memcpyopt.rs | 217 ++++++++++++------ sway-ir/src/optimize/misc_demotion.rs | 21 +- sway-ir/src/verify.rs | 16 +- .../array_of_structs_caller/src/main.sw | 2 +- .../call_basic_storage/src/main.sw | 2 +- .../nested_struct_args_caller/src/main.sw | 2 +- .../storage_access_caller/src/main.sw | 2 +- .../ir_generation/tests/simple_contract.sw | 4 +- test/src/ir_generation/tests/smo.sw | 5 +- 12 files changed, 243 insertions(+), 109 deletions(-) diff --git a/sway-ir/src/block.rs b/sway-ir/src/block.rs index dba996f7c8e..ba677e5b36c 100644 --- a/sway-ir/src/block.rs +++ b/sway-ir/src/block.rs @@ -41,7 +41,7 @@ pub struct BlockContent { pub preds: FxHashSet, } -#[derive(Debug, Clone, DebugWithContext)] +#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash, DebugWithContext)] pub struct BlockArgument { /// The block of which this is an argument. pub block: Block, @@ -138,6 +138,17 @@ impl Block { idx } + pub fn set_arg(&self, context: &mut Context, arg: Value) { + match context.values[arg.0].value { + ValueDatum::Argument(BlockArgument { block, idx, ty: _ }) + if block == *self && idx < context.blocks[self.0].args.len() => + { + context.blocks[self.0].args[idx] = arg; + } + _ => panic!("Inconsistent block argument being set"), + } + } + /// Add a block argument, asserts that `arg` is suitable here. pub fn add_arg(&self, context: &mut Context, arg: Value) { match context.values[arg.0].value { diff --git a/sway-ir/src/instruction.rs b/sway-ir/src/instruction.rs index fe45b19d192..9a445caed6b 100644 --- a/sway-ir/src/instruction.rs +++ b/sway-ir/src/instruction.rs @@ -442,7 +442,7 @@ impl Instruction { indices.iter_mut().for_each(replace); } Instruction::IntToPtr(value, _) => replace(value), - Instruction::Load(_) => (), + Instruction::Load(ptr) => replace(ptr), Instruction::MemCopyBytes { dst_val_ptr, src_val_ptr, @@ -461,8 +461,12 @@ impl Instruction { Instruction::Nop => (), Instruction::PtrToInt(value, _) => replace(value), Instruction::Ret(ret_val, _) => replace(ret_val), - Instruction::Store { stored_val, .. } => { + Instruction::Store { + stored_val, + dst_val_ptr, + } => { replace(stored_val); + replace(dst_val_ptr); } Instruction::FuelVm(fuel_vm_instr) => match fuel_vm_instr { diff --git a/sway-ir/src/optimize/arg_demotion.rs b/sway-ir/src/optimize/arg_demotion.rs index 27e1226065d..5976c132f75 100644 --- a/sway-ir/src/optimize/arg_demotion.rs +++ b/sway-ir/src/optimize/arg_demotion.rs @@ -89,17 +89,6 @@ fn fn_arg_demotion(context: &mut Context, function: Function) -> Result { - if let ValueDatum::Argument(BlockArgument { ty, .. }) = - &mut $context.values[$arg_val.0].value - { - *ty = $new_ty - } - }; -} - fn demote_fn_signature(context: &mut Context, function: &Function, arg_idcs: &[(usize, Type)]) { // Change the types of the arg values in place to their pointer counterparts. let entry_block = function.get_entry_block(context); @@ -108,18 +97,27 @@ fn demote_fn_signature(context: &mut Context, function: &Function, arg_idcs: &[( .map(|(arg_idx, arg_ty)| { let ptr_ty = Type::new_ptr(context, *arg_ty); - // Update the function signature. - let fn_args = &context.functions[function.0].arguments; - let (_name, fn_arg_val) = &fn_args[*arg_idx]; - set_arg_type!(context, fn_arg_val, ptr_ty); - - // Update the entry block signature. + // Create a new block arg, same as the old one but with a different type. let blk_arg_val = entry_block .get_arg(context, *arg_idx) .expect("Entry block args should be mirror of function args."); - set_arg_type!(context, blk_arg_val, ptr_ty); + let ValueDatum::Argument(block_arg) = context.values[blk_arg_val.0].value else { + panic!("Block argument is not of right Value kind"); + }; + let new_blk_arg_val = Value::new_argument( + context, + BlockArgument { + ty: ptr_ty, + ..block_arg + }, + ); + + // Set both function and block arg to the new one. + entry_block.set_arg(context, new_blk_arg_val); + let (_name, fn_arg_val) = &mut context.functions[function.0].arguments[*arg_idx]; + *fn_arg_val = new_blk_arg_val; - *fn_arg_val + (blk_arg_val, new_blk_arg_val) }) .collect::>(); @@ -127,12 +125,12 @@ fn demote_fn_signature(context: &mut Context, function: &Function, arg_idcs: &[( let arg_val_pairs = old_arg_vals .into_iter() .rev() - .map(|old_arg_val| { - let new_arg_val = Value::new_instruction(context, Instruction::Load(old_arg_val)); + .map(|(old_arg_val, new_arg_val)| { + let load_from_new_arg = Value::new_instruction(context, Instruction::Load(new_arg_val)); context.blocks[entry_block.0] .instructions - .insert(0, new_arg_val); - (old_arg_val, new_arg_val) + .insert(0, load_from_new_arg); + (old_arg_val, load_from_new_arg) }) .collect::>(); @@ -237,9 +235,21 @@ fn demote_block_signature(context: &mut Context, function: &Function, block: Blo .rev() .map(|(_arg_idx, arg_val, arg_ty)| { let ptr_ty = Type::new_ptr(context, *arg_ty); - set_arg_type!(context, arg_val, ptr_ty); - let load_val = Value::new_instruction(context, Instruction::Load(*arg_val)); + // Create a new block arg, same as the old one but with a different type. + let ValueDatum::Argument(block_arg) = context.values[arg_val.0].value else { + panic!("Block argument is not of right Value kind"); + }; + let new_blk_arg_val = Value::new_argument( + context, + BlockArgument { + ty: ptr_ty, + ..block_arg + }, + ); + block.set_arg(context, new_blk_arg_val); + + let load_val = Value::new_instruction(context, Instruction::Load(new_blk_arg_val)); let block_instrs = &mut context.blocks[block.0].instructions; block_instrs.insert(0, load_val); diff --git a/sway-ir/src/optimize/memcpyopt.rs b/sway-ir/src/optimize/memcpyopt.rs index 33764ea2a93..3013190bf72 100644 --- a/sway-ir/src/optimize/memcpyopt.rs +++ b/sway-ir/src/optimize/memcpyopt.rs @@ -4,8 +4,8 @@ use rustc_hash::{FxHashMap, FxHashSet}; use crate::{ - AnalysisResults, Block, Context, Function, Instruction, IrError, LocalVar, Pass, - PassMutability, ScopedPass, Value, ValueDatum, + AnalysisResults, Block, BlockArgument, Context, Function, Instruction, IrError, LocalVar, Pass, + PassMutability, ScopedPass, Type, Value, ValueDatum, }; pub const MEMCPYOPT_NAME: &str = "memcpyopt"; @@ -31,6 +31,37 @@ pub fn mem_copy_opt( Ok(modified) } +#[derive(Eq, PartialEq, Copy, Clone, Hash)] +enum Symbol { + Local(LocalVar), + Arg(BlockArgument), +} + +impl Symbol { + pub fn get_type(&self, context: &Context) -> Type { + match self { + Symbol::Local(l) => l.get_type(context), + Symbol::Arg(ba) => ba.ty, + } + } + + pub fn _get_name(&self, context: &Context, function: Function) -> String { + match self { + Symbol::Local(l) => function.lookup_local_name(context, l).unwrap().clone(), + Symbol::Arg(ba) => format!("{}[{}]", ba.block.get_label(context), ba.idx), + } + } +} + +fn get_symbol(context: &Context, val: Value) -> Option { + match context.values[val.0].value { + ValueDatum::Instruction(Instruction::GetLocal(local)) => Some(Symbol::Local(local)), + ValueDatum::Instruction(Instruction::GetElemPtr { base, .. }) => get_symbol(context, base), + ValueDatum::Argument(b) => Some(Symbol::Arg(b)), + _ => None, + } +} + struct InstInfo { // The block in which an instruction is block: Block, @@ -42,25 +73,17 @@ struct InstInfo { /// a data-flow analysis. Until then, we do a safe approximation, /// restricting to when every related instruction is in the same block. fn local_copy_prop(context: &mut Context, function: Function) -> Result { - let mut loads_map = FxHashMap::>::default(); - let mut stores_map = FxHashMap::>::default(); + let mut loads_map = FxHashMap::>::default(); + let mut stores_map = FxHashMap::>::default(); let mut instr_info_map = FxHashMap::::default(); - let mut asm_uses = FxHashSet::::default(); - - fn get_local(context: &Context, val: Value) -> Option { - match val.get_instruction(context) { - Some(Instruction::GetLocal(local)) => Some(*local), - Some(Instruction::GetElemPtr { base, .. }) => get_local(context, *base), - _ => None, - } - } + let mut asm_uses = FxHashSet::::default(); for (pos, (block, inst)) in function.instruction_iter(context).enumerate() { let info = || InstInfo { block, pos }; let inst_e = inst.get_instruction(context).unwrap(); match inst_e { Instruction::Load(src_val_ptr) => { - if let Some(local) = get_local(context, *src_val_ptr) { + if let Some(local) = get_symbol(context, *src_val_ptr) { loads_map .entry(local) .and_modify(|loads| loads.push(inst)) @@ -69,7 +92,7 @@ fn local_copy_prop(context: &mut Context, function: Function) -> Result { - if let Some(local) = get_local(context, *dst_val_ptr) { + if let Some(local) = get_symbol(context, *dst_val_ptr) { stores_map .entry(local) .and_modify(|stores| stores.push(inst)) @@ -80,7 +103,7 @@ fn local_copy_prop(context: &mut Context, function: Function) -> Result { for arg in args { if let Some(arg) = arg.initializer { - if let Some(local) = get_local(context, arg) { + if let Some(local) = get_symbol(context, arg) { asm_uses.insert(local); } } @@ -91,7 +114,7 @@ fn local_copy_prop(context: &mut Context, function: Function) -> Result::default(); - let candidates: FxHashMap = function + let candidates: FxHashMap = function .instruction_iter(context) .enumerate() .filter_map(|(pos, (block, instr_val))| { @@ -104,7 +127,7 @@ fn local_copy_prop(context: &mut Context, function: Function) -> Result Result Result Result, - src_local: &LocalVar, - ) -> Option { + fn closure(candidates: &FxHashMap, src_local: &Symbol) -> Option { candidates .get(src_local) .map(|replace_with| closure(candidates, replace_with).unwrap_or(*replace_with)) } + + // If the source is an Arg, we replace uses of destination with Arg. + // otherwise (`get_local`), we replace the local symbol in-place. + enum ReplaceWith { + InPlaceLocal(LocalVar), + Value(Value), + } + // Because we can't borrow context for both iterating and replacing, do it in 2 steps. let replaces: Vec<_> = function .instruction_iter(context) .filter_map(|(_block, value)| match value.get_instruction(context) { Some(Instruction::GetLocal(local)) => { - closure(&candidates, local).map(|replace_with| (value, *local, replace_with)) + closure(&candidates, &Symbol::Local(*local)).map(|replace_with| { + ( + value, + match replace_with { + Symbol::Local(local) => ReplaceWith::InPlaceLocal(local), + Symbol::Arg(ba) => { + ReplaceWith::Value(ba.block.get_arg(context, ba.idx).unwrap()) + } + }, + ) + }) } _ => None, }) .collect(); - for (value, redundant_var, replacement_var) in replaces.into_iter() { - // Be sure to propagate the mutability of the original local variable to the copy. - if redundant_var.is_mutable(context) { - replacement_var.set_mutable(context, true); + + let mut value_replace = FxHashMap::::default(); + for (value, replace_with) in replaces.into_iter() { + match replace_with { + ReplaceWith::InPlaceLocal(replacement_var) => { + let Some(Instruction::GetLocal(redundant_var)) = value.get_instruction(context) else { + panic!("earlier match now fails"); + }; + if redundant_var.is_mutable(context) { + replacement_var.set_mutable(context, true); + } + value.replace( + context, + ValueDatum::Instruction(Instruction::GetLocal(replacement_var)), + ) + } + ReplaceWith::Value(replace_with) => { + value_replace.insert(value, replace_with); + } } - value.replace( - context, - ValueDatum::Instruction(Instruction::GetLocal(replacement_var)), - ); } + function.replace_values(context, &value_replace, None); // Delete stores to the replaced local. let blocks: Vec = function.block_iter(context).collect(); @@ -194,13 +246,63 @@ fn local_copy_prop(context: &mut Context, function: Function) -> Result bool { + let mut iter = store_block + .instruction_iter(context) + .rev() + .skip_while(|i| i != &store_val); + assert!(iter.next().unwrap() == store_val); + + // Scan backwards till we encounter load_val, checking if + // any store aliases with src_ptr. + let mut worklist: Vec<(Block, Box>)> = + vec![(store_block, Box::new(iter))]; + let mut visited = FxHashSet::default(); + 'next_job: while !worklist.is_empty() { + let (block, iter) = worklist.pop().unwrap(); + visited.insert(block); + for inst in iter { + if inst == load_val || inst == store_val { + // We don't need to go beyond either the source load or the candidate store. + continue 'next_job; + } + if let Some(Instruction::Store { + dst_val_ptr, + stored_val: _, + }) = inst.get_instruction(context) + { + if get_symbol(context, *dst_val_ptr) == get_symbol(context, src_ptr) { + return true; + } + } + } + for pred in block.pred_iter(context) { + if !visited.contains(pred) { + worklist.push(( + *pred, + Box::new(pred.instruction_iter(context).rev().skip_while(|_| false)), + )); + } + } + } + + false +} + fn load_store_to_memcopy(context: &mut Context, function: Function) -> Result { // Find any `store`s of `load`s. These can be replaced with `mem_copy` and are especially // important for non-copy types on architectures which don't support loading them. let candidates = function .instruction_iter(context) - .filter_map(|(block, instr_val)| { - instr_val + .filter_map(|(block, store_instr_val)| { + store_instr_val .get_instruction(context) .and_then(|instr| { // Is the instruction a Store? @@ -211,45 +313,32 @@ fn load_store_to_memcopy(context: &mut Context, function: Function) -> Result>(); @@ -257,7 +346,7 @@ fn load_store_to_memcopy(context: &mut Context, function: Function) -> Result Result Result Result<(), IrError> { for function in &module.functions { - self.verify_function(module, &self.functions[function.0])?; + self.verify_function(module, function)?; } Ok(()) } @@ -61,8 +61,20 @@ impl Context { fn verify_function( &self, cur_module: &ModuleContent, - function: &FunctionContent, + function: &Function, ) -> Result<(), IrError> { + let entry_block = function.get_entry_block(self); + // Ensure that the entry block arguments are same as function arguments. + if function.num_args(self) != entry_block.num_args(self) { + return Err(IrError::VerifyBlockArgMalformed); + } + for ((_, func_arg), block_arg) in function.args_iter(self).zip(entry_block.arg_iter(self)) { + if func_arg != block_arg { + return Err(IrError::VerifyBlockArgMalformed); + } + } + + let function = &self.functions[function.0]; for block in &function.blocks { self.verify_block(cur_module, function, &self.blocks[block.0])?; } diff --git a/test/src/e2e_vm_tests/test_programs/should_pass/require_contract_deployment/array_of_structs_caller/src/main.sw b/test/src/e2e_vm_tests/test_programs/should_pass/require_contract_deployment/array_of_structs_caller/src/main.sw index b77255b1cfe..ee6b0330d8c 100644 --- a/test/src/e2e_vm_tests/test_programs/should_pass/require_contract_deployment/array_of_structs_caller/src/main.sw +++ b/test/src/e2e_vm_tests/test_programs/should_pass/require_contract_deployment/array_of_structs_caller/src/main.sw @@ -4,7 +4,7 @@ use array_of_structs_abi::{Id, TestContract, Wrapper}; use std::hash::sha256; fn main() -> u64 { - let addr = abi(TestContract, 0x511edec57a18fe8fccb55dd5668b45ab51a8ecf638107243c1b831ba96184714); + let addr = abi(TestContract, 0xa5d354e58efd316c2eb3f4b273a2143e7d534952fae5e191e057b27403ae829e); let input = [Wrapper { id: Id { diff --git a/test/src/e2e_vm_tests/test_programs/should_pass/require_contract_deployment/call_basic_storage/src/main.sw b/test/src/e2e_vm_tests/test_programs/should_pass/require_contract_deployment/call_basic_storage/src/main.sw index 969397734cb..657f158c191 100644 --- a/test/src/e2e_vm_tests/test_programs/should_pass/require_contract_deployment/call_basic_storage/src/main.sw +++ b/test/src/e2e_vm_tests/test_programs/should_pass/require_contract_deployment/call_basic_storage/src/main.sw @@ -2,7 +2,7 @@ script; use basic_storage_abi::{BasicStorage, Quad}; fn main() -> u64 { - let addr = abi(BasicStorage, 0x2fdecddd593b29cab5760ce8333979f341248a4e89d257cb869a89acd74201fa); + let addr = abi(BasicStorage, 0xc98246b75472af9196be66b65a979ebbe5cd5975d1331f18ce2cda66a9819379); let key = 0xffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffff; let value = 4242; diff --git a/test/src/e2e_vm_tests/test_programs/should_pass/require_contract_deployment/nested_struct_args_caller/src/main.sw b/test/src/e2e_vm_tests/test_programs/should_pass/require_contract_deployment/nested_struct_args_caller/src/main.sw index b1613711d7f..93a5b044d6f 100644 --- a/test/src/e2e_vm_tests/test_programs/should_pass/require_contract_deployment/nested_struct_args_caller/src/main.sw +++ b/test/src/e2e_vm_tests/test_programs/should_pass/require_contract_deployment/nested_struct_args_caller/src/main.sw @@ -3,7 +3,7 @@ script; use nested_struct_args_abi::*; fn main() -> bool { - let contract_id = 0xf17db9ebfbf5470fb3955d6b86c038658e4a6016a28f8c1d64957fdee891001b; + let contract_id = 0xc07c133be5867020f483c34e045c9162867d6b40accc022525e50c048d17d679; let caller = abi(NestedStructArgs, contract_id); let param_one = StructOne { diff --git a/test/src/e2e_vm_tests/test_programs/should_pass/require_contract_deployment/storage_access_caller/src/main.sw b/test/src/e2e_vm_tests/test_programs/should_pass/require_contract_deployment/storage_access_caller/src/main.sw index e76ffb4fef8..5acc117878c 100644 --- a/test/src/e2e_vm_tests/test_programs/should_pass/require_contract_deployment/storage_access_caller/src/main.sw +++ b/test/src/e2e_vm_tests/test_programs/should_pass/require_contract_deployment/storage_access_caller/src/main.sw @@ -4,7 +4,7 @@ use storage_access_abi::*; use std::hash::sha256; fn main() -> bool { - let contract_id = 0x215f26a1817d55ecff703908053bd4379ea9e78052395bc42429078046e816fc; + let contract_id = 0x9ec7ac3cce15b1b75adabd598a374e96b606471a56229486a2d0ca26c08ed994; let caller = abi(StorageAccess, contract_id); // Test initializers diff --git a/test/src/ir_generation/tests/simple_contract.sw b/test/src/ir_generation/tests/simple_contract.sw index 29b466c9600..175c79609ff 100644 --- a/test/src/ir_generation/tests/simple_contract.sw +++ b/test/src/ir_generation/tests/simple_contract.sw @@ -33,8 +33,8 @@ impl Test for Contract { // ::check-ir:: // check: contract { -// check: fn get_b256<42123b96>($ID $MD: ptr b256) -> ptr b256, -// check: fn get_s($ID $MD: u64, $ID $MD: ptr b256) -> ptr { u64, b256 } +// check: fn get_b256<42123b96>($ID: ptr b256) -> ptr b256, +// check: fn get_s($ID $MD: u64, $ID: ptr b256) -> ptr { u64, b256 } // check: fn get_u64<9890aef4>($ID $MD: u64) -> u64 // ::check-asm:: diff --git a/test/src/ir_generation/tests/smo.sw b/test/src/ir_generation/tests/smo.sw index 1d3004a0451..c941d07c7fd 100644 --- a/test/src/ir_generation/tests/smo.sw +++ b/test/src/ir_generation/tests/smo.sw @@ -15,9 +15,8 @@ fn main() { // Match the first one where data is initialised. // check: get_local ptr u64, data -// Match the second one where we read it back. +// Match the second one where we read it back, as a mem_copy_val later on // check: $(data_ptr=$VAL) = get_local ptr u64, data -// check: $(data_val=$VAL) = load $data_ptr // check: $(temp_ptr=$VAL) = get_local ptr { b256, u64, u64 }, $(=__anon_\d+) @@ -33,7 +32,7 @@ fn main() { // check: $(idx_2=$VAL) = const u64 2 // check: $(field_2_ptr=$VAL) = get_elem_ptr $temp_ptr, ptr u64, $idx_2 -// check: store $data_val to $field_2_ptr +// check: mem_copy_val $field_2_ptr, $data_ptr // check: $(oi_ptr=$VAL) = get_local ptr u64, output_index // check: $(oi=$VAL) = load $oi_ptr