diff --git a/sway-core/src/asm_generation/data_section.rs b/sway-core/src/asm_generation/data_section.rs index 283a1531af4..5a0c8319d6e 100644 --- a/sway-core/src/asm_generation/data_section.rs +++ b/sway-core/src/asm_generation/data_section.rs @@ -1,6 +1,6 @@ use crate::asm_generation::from_ir::ir_type_size_in_bytes; -use sway_ir::{AggregateContent, Constant, ConstantValue, Context, Type}; +use sway_ir::{Constant, ConstantValue, Context}; use std::fmt::{self, Write}; @@ -50,28 +50,25 @@ impl Entry { let size = Some(ir_type_size_in_bytes(context, &constant.ty) as usize); // Is this constant a tagged union? - if let Type::Struct(struct_agg) = &constant.ty { - if let AggregateContent::FieldTypes(field_tys) = struct_agg.get_content(context) { - if field_tys.len() == 2 - && matches!( - (field_tys[0], field_tys[1]), - (Type::Uint(_), Type::Union(_)) - ) - { - // OK, this looks very much like a tagged union enum, which is the only place - // we use unions (otherwise we should be generalising this a bit more). - if let ConstantValue::Struct(els) = &constant.value { - if els.len() == 2 { - let tag_entry = Entry::from_constant(context, &els[0]); - - // Here's the special case. We need to get the size of the union and - // attach it to this constant entry which will be one of the variants. - let mut val_entry = Entry::from_constant(context, &els[1]); - val_entry.size = ir_type_size_in_bytes(context, &field_tys[1]) as usize; - - // Return here from our special case. - return Entry::new_collection(vec![tag_entry, val_entry], size); - } + if constant.ty.is_struct(context) { + let field_tys = constant.ty.get_field_types(context); + if field_tys.len() == 2 + && field_tys[0].is_uint(context) + && field_tys[1].is_union(context) + { + // OK, this looks very much like a tagged union enum, which is the only place + // we use unions (otherwise we should be generalising this a bit more). + if let ConstantValue::Struct(els) = &constant.value { + if els.len() == 2 { + let tag_entry = Entry::from_constant(context, &els[0]); + + // Here's the special case. We need to get the size of the union and + // attach it to this constant entry which will be one of the variants. + let mut val_entry = Entry::from_constant(context, &els[1]); + val_entry.size = ir_type_size_in_bytes(context, &field_tys[1]) as usize; + + // Return here from our special case. + return Entry::new_collection(vec![tag_entry, val_entry], size); } } } diff --git a/sway-core/src/asm_generation/evm/evm_asm_builder.rs b/sway-core/src/asm_generation/evm/evm_asm_builder.rs index 4a42904acef..5c1ea73eda3 100644 --- a/sway-core/src/asm_generation/evm/evm_asm_builder.rs +++ b/sway-core/src/asm_generation/evm/evm_asm_builder.rs @@ -480,7 +480,7 @@ impl<'ir> EvmAsmBuilder<'ir> { &mut self, instr_val: &Value, array: &Value, - ty: &Aggregate, + ty: &Type, index_val: &Value, ) { todo!(); @@ -506,7 +506,7 @@ impl<'ir> EvmAsmBuilder<'ir> { &mut self, instr_val: &Value, array: &Value, - ty: &Aggregate, + ty: &Type, value: &Value, index_val: &Value, ) { diff --git a/sway-core/src/asm_generation/from_ir.rs b/sway-core/src/asm_generation/from_ir.rs index b42765c427d..c91fc013e8b 100644 --- a/sway-core/src/asm_generation/from_ir.rs +++ b/sway-core/src/asm_generation/from_ir.rs @@ -196,40 +196,26 @@ pub enum StateAccessType { } pub(crate) fn ir_type_size_in_bytes(context: &Context, ty: &Type) -> u64 { - match ty { - Type::Unit | Type::Bool | Type::Uint(_) => 8, - Type::Slice => 16, - Type::B256 => 32, - Type::String(n) => size_bytes_round_up_to_word_alignment!(n), - Type::Array(aggregate) => { - if let AggregateContent::ArrayType(el_ty, cnt) = aggregate.get_content(context) { - cnt * ir_type_size_in_bytes(context, el_ty) - } else { - unreachable!("Wrong content for array.") - } - } - Type::Struct(aggregate) => { - if let AggregateContent::FieldTypes(field_tys) = aggregate.get_content(context) { - // Sum up all the field sizes. - field_tys - .iter() - .map(|field_ty| ir_type_size_in_bytes(context, field_ty)) - .sum() - } else { - unreachable!("Wrong content for struct.") - } + match ty.get_content(context) { + TypeContent::Unit | TypeContent::Bool | TypeContent::Uint(_) => 8, + TypeContent::Slice => 16, + TypeContent::B256 => 32, + TypeContent::String(n) => size_bytes_round_up_to_word_alignment!(*n), + TypeContent::Array(el_ty, cnt) => cnt * ir_type_size_in_bytes(context, el_ty), + TypeContent::Struct(field_tys) => { + // Sum up all the field sizes. + field_tys + .iter() + .map(|field_ty| ir_type_size_in_bytes(context, field_ty)) + .sum() } - Type::Union(aggregate) => { - if let AggregateContent::FieldTypes(field_tys) = aggregate.get_content(context) { - // Find the max size for field sizes. - field_tys - .iter() - .map(|field_ty| ir_type_size_in_bytes(context, field_ty)) - .max() - .unwrap_or(0) - } else { - unreachable!("Wrong content for union.") - } + TypeContent::Union(field_tys) => { + // Find the max size for field sizes. + field_tys + .iter() + .map(|field_ty| ir_type_size_in_bytes(context, field_ty)) + .max() + .unwrap_or(0) } } } @@ -240,44 +226,40 @@ pub(crate) fn aggregate_idcs_to_field_layout( ty: &Type, idcs: &[u64], ) -> ((u64, u64), Type) { - idcs.iter() - .fold(((0, 0), *ty), |((offs, _), ty), idx| match ty { - Type::Struct(aggregate) => { - let idx = *idx as usize; - let field_types = &aggregate.get_content(context).field_types(); - let field_type = field_types[idx]; - let field_offs_in_bytes = field_types - .iter() - .take(idx) - .map(|field_ty| ir_type_size_in_bytes(context, field_ty)) - .sum::(); - let field_size_in_bytes = ir_type_size_in_bytes(context, &field_type); - + idcs.iter().fold(((0, 0), *ty), |((offs, _), ty), idx| { + if ty.is_struct(context) { + let idx = *idx as usize; + let field_types = ty.get_field_types(context); + let field_type = field_types[idx]; + let field_offs_in_bytes = field_types + .iter() + .take(idx) + .map(|field_ty| ir_type_size_in_bytes(context, field_ty)) + .sum::(); + let field_size_in_bytes = ir_type_size_in_bytes(context, &field_type); + + ( ( - ( - offs + size_bytes_in_words!(field_offs_in_bytes), - field_size_in_bytes, - ), - field_type, - ) - } - - Type::Union(aggregate) => { - let idx = *idx as usize; - let field_type = aggregate.get_content(context).field_types()[idx]; - let union_size_in_bytes = ir_type_size_in_bytes(context, &ty); - let field_size_in_bytes = ir_type_size_in_bytes(context, &field_type); - - // The union fields are at offset (union_size - variant_size) due to left padding. + offs + size_bytes_in_words!(field_offs_in_bytes), + field_size_in_bytes, + ), + field_type, + ) + } else if ty.is_union(context) { + let idx = *idx as usize; + let field_type = ty.get_field_types(context)[idx]; + let union_size_in_bytes = ir_type_size_in_bytes(context, &ty); + let field_size_in_bytes = ir_type_size_in_bytes(context, &field_type); + // The union fields are at offset (union_size - variant_size) due to left padding. + ( ( - ( - offs + size_bytes_in_words!(union_size_in_bytes - field_size_in_bytes), - field_size_in_bytes, - ), - field_type, - ) - } - - _otherwise => panic!("Attempt to access field in non-aggregate."), - }) + offs + size_bytes_in_words!(union_size_in_bytes - field_size_in_bytes), + field_size_in_bytes, + ), + field_type, + ) + } else { + panic!("Attempt to access field in non-aggregate.") + } + }) } diff --git a/sway-core/src/asm_generation/fuel/fuel_asm_builder.rs b/sway-core/src/asm_generation/fuel/fuel_asm_builder.rs index 3cd3caebad5..39d0f7a81b8 100644 --- a/sway-core/src/asm_generation/fuel/fuel_asm_builder.rs +++ b/sway-core/src/asm_generation/fuel/fuel_asm_builder.rs @@ -483,7 +483,7 @@ impl<'ir> FuelAsmBuilder<'ir> { fn compile_bitcast(&mut self, instr_val: &Value, bitcast_val: &Value, to_type: &Type) { let val_reg = self.value_to_register(bitcast_val); - let reg = if let Type::Bool = to_type { + let reg = if to_type.is_bool(self.context) { // This may not be necessary if we just treat a non-zero value as 'true'. let res_reg = self.reg_seqr.next(); self.cur_bytecode.push(Op { @@ -704,7 +704,7 @@ impl<'ir> FuelAsmBuilder<'ir> { &mut self, instr_val: &Value, array: &Value, - ty: &Aggregate, + ty: &Type, index_val: &Value, ) { // Base register should pointer to some stack allocated memory. @@ -719,7 +719,7 @@ impl<'ir> FuelAsmBuilder<'ir> { let instr_reg = self.reg_seqr.next(); let owning_span = self.md_mgr.val_to_span(self.context, *instr_val); - let elem_type = ty.get_elem_type(self.context).unwrap(); + let elem_type = ty.get_array_elem_type(self.context).unwrap(); let elem_size = ir_type_size_in_bytes(self.context, &elem_type); if self.is_copy_type(&elem_type) { self.cur_bytecode.push(Op { @@ -992,7 +992,7 @@ impl<'ir> FuelAsmBuilder<'ir> { &mut self, instr_val: &Value, array: &Value, - ty: &Aggregate, + ty: &Type, value: &Value, index_val: &Value, ) { @@ -1006,7 +1006,7 @@ impl<'ir> FuelAsmBuilder<'ir> { let owning_span = self.md_mgr.val_to_span(self.context, *instr_val); - let elem_type = ty.get_elem_type(self.context).unwrap(); + let elem_type = ty.get_array_elem_type(self.context).unwrap(); let elem_size = ir_type_size_in_bytes(self.context, &elem_type); if self.is_copy_type(&elem_type) { self.cur_bytecode.push(Op { @@ -1105,7 +1105,7 @@ impl<'ir> FuelAsmBuilder<'ir> { // Account for the padding if the final field type is a union and the value we're trying to // insert is smaller than the size of the union (i.e. we're inserting a small variant). - if matches!(field_type, Type::Union(_)) { + if field_type.is_union(self.context) { let field_size_in_words = size_bytes_in_words!(field_size_in_bytes); assert!(field_size_in_words >= value_size_in_words); insert_offs += field_size_in_words - value_size_in_words; @@ -1224,7 +1224,7 @@ impl<'ir> FuelAsmBuilder<'ir> { } Storage::Stack(word_offs) => { let base_reg = self.locals_base_reg().clone(); - if self.is_copy_type(local_var.get_type(self.context)) { + if self.is_copy_type(&local_var.get_type(self.context)) { // Value can fit in a register, so we load the value. if word_offs > compiler_constants::TWELVE_BITS { let offs_reg = self.reg_seqr.next(); @@ -1407,7 +1407,7 @@ impl<'ir> FuelAsmBuilder<'ir> { fn compile_ret_from_entry(&mut self, instr_val: &Value, ret_val: &Value, ret_type: &Type) { let owning_span = self.md_mgr.val_to_span(self.context, *instr_val); - if ret_type.eq(self.context, &Type::Unit) { + if ret_type.is_unit(self.context) { // Unit returns should always be zero, although because they can be omitted from // functions, the register is sometimes uninitialized. Manually return zero in this // case. @@ -1430,7 +1430,7 @@ impl<'ir> FuelAsmBuilder<'ir> { } else { // If the type is not a copy type then we use RETD to return data. let size_reg = self.reg_seqr.next(); - if ret_type.eq(self.context, &Type::Slice) { + if ret_type.is_slice(self.context) { // If this is a slice then return what it points to. self.cur_bytecode.push(Op { opcode: Either::Left(VirtualOp::LW( @@ -1555,8 +1555,8 @@ impl<'ir> FuelAsmBuilder<'ir> { access_type: StateAccessType, ) -> CompileResult<()> { // Make sure that both val and key are pointers to B256. - assert!(matches!(val.get_type(self.context), Some(Type::B256))); - assert!(matches!(key.get_type(self.context), Some(Type::B256))); + assert!(val.get_type(self.context).is(Type::is_b256, self.context)); + assert!(key.get_type(self.context).is(Type::is_b256, self.context)); let owning_span = self.md_mgr.val_to_span(self.context, *instr_val); let key_var = self.resolve_ptr(key); @@ -1567,7 +1567,7 @@ impl<'ir> FuelAsmBuilder<'ir> { // Not expecting an offset here nor a pointer cast assert!(offset == 0); - assert!(var_ty.eq(self.context, &Type::B256)); + assert!(var_ty.is_b256(self.context)); let val_reg = if matches!( val.get_instruction(self.context), @@ -1585,7 +1585,7 @@ impl<'ir> FuelAsmBuilder<'ir> { } let (local_val, local_val_ty, _offset) = local_val.value.unwrap(); // Expect the ptr_ty for val to also be B256 - assert!(local_val_ty.eq(self.context, &Type::B256)); + assert!(local_val_ty.is_b256(self.context)); match self.ptr_map.get(&local_val) { Some(Storage::Stack(val_offset)) => { let base_reg = self.locals_base_reg().clone(); @@ -1629,7 +1629,7 @@ impl<'ir> FuelAsmBuilder<'ir> { fn compile_state_load_word(&mut self, instr_val: &Value, key: &Value) -> CompileResult<()> { // Make sure that the key is a pointers to B256. - assert!(matches!(key.get_type(self.context), Some(Type::B256))); + assert!(key.get_type(self.context).is(Type::is_b256, self.context)); let key_var = self.resolve_ptr(key); if key_var.value.is_none() { @@ -1639,7 +1639,7 @@ impl<'ir> FuelAsmBuilder<'ir> { // Not expecting an offset here nor a pointer cast assert!(offset == 0); - assert!(var_ty.eq(self.context, &Type::B256)); + assert!(var_ty.is_b256(self.context)); let load_reg = self.reg_seqr.next(); let owning_span = self.md_mgr.val_to_span(self.context, *instr_val); @@ -1678,13 +1678,12 @@ impl<'ir> FuelAsmBuilder<'ir> { key: &Value, ) -> CompileResult<()> { // Make sure that key is a pointer to B256. - assert!(matches!(key.get_type(self.context), Some(Type::B256))); + assert!(key.get_type(self.context).is(Type::is_b256, self.context)); // Make sure that store_val is a U64 value. - assert!(matches!( - store_val.get_type(self.context), - Some(Type::Uint(64)) - )); + assert!(store_val + .get_type(self.context) + .is(Type::is_uint64, self.context)); let store_reg = self.value_to_register(store_val); // Expect the get_ptr here to have type b256 and offset = 0??? @@ -1699,7 +1698,7 @@ impl<'ir> FuelAsmBuilder<'ir> { // Not expecting an offset here nor a pointer cast assert!(offset == 0); - assert!(key_var_ty.eq(self.context, &Type::B256)); + assert!(key_var_ty.is_b256(self.context)); let owning_span = self.md_mgr.val_to_span(self.context, *instr_val); match self.ptr_map.get(&key_var) { @@ -1742,15 +1741,16 @@ impl<'ir> FuelAsmBuilder<'ir> { let word_offs = *word_offs; let store_type = local_var.get_type(self.context); let store_size_in_words = - size_bytes_in_words!(ir_type_size_in_bytes(self.context, store_type)); - if self.is_copy_type(store_type) { + size_bytes_in_words!(ir_type_size_in_bytes(self.context, &store_type)); + if self.is_copy_type(&store_type) { let base_reg = self.locals_base_reg().clone(); // A single word can be stored with SW. - let is_aggregate_var = matches!( - local_var.get_type(self.context), - Type::Array(_) | Type::Struct(_) | Type::Union(_) - ); + let local_var_ty = local_var.get_type(self.context); + let is_aggregate_var = local_var_ty.is_array(self.context) + || local_var_ty.is_struct(self.context) + || local_var_ty.is_union(self.context); + let stored_reg = if !is_aggregate_var { // stored_reg is a value. stored_reg @@ -1873,7 +1873,7 @@ impl<'ir> FuelAsmBuilder<'ir> { } pub(crate) fn is_copy_type(&self, ty: &Type) -> bool { - matches!(ty, Type::Unit | Type::Bool | Type::Uint(_)) + ty.is_unit(self.context) || ty.is_bool(self.context) | ty.is_uint(self.context) } fn resolve_ptr(&mut self, ptr_val: &Value) -> CompileResult<(LocalVar, Type, u64)> { @@ -1882,7 +1882,7 @@ impl<'ir> FuelAsmBuilder<'ir> { match ptr_val.get_instruction(self.context) { // Return the local variable with its type and an offset of 0. Some(Instruction::GetLocal(local_var)) => ok( - (*local_var, *local_var.get_type(self.context), 0), + (*local_var, local_var.get_type(self.context), 0), warnings, errors, ), diff --git a/sway-core/src/asm_generation/fuel/functions.rs b/sway-core/src/asm_generation/fuel/functions.rs index 48d81196689..315b13b5d98 100644 --- a/sway-core/src/asm_generation/fuel/functions.rs +++ b/sway-core/src/asm_generation/fuel/functions.rs @@ -596,32 +596,34 @@ impl<'ir> FuelAsmBuilder<'ir> { .insert_data_value(Entry::from_constant(self.context, constant)); self.ptr_map.insert(*ptr, Storage::Data(data_id)); } else { - match ptr.get_type(self.context) { - Type::Unit | Type::Bool | Type::Uint(_) => { + let ptr_ty = ptr.get_type(self.context); + match ptr_ty.get_content(self.context) { + TypeContent::Unit | TypeContent::Bool | TypeContent::Uint(_) => { self.ptr_map.insert(*ptr, Storage::Stack(stack_base)); stack_base += 1; } - Type::Slice => { + TypeContent::Slice => { self.ptr_map.insert(*ptr, Storage::Stack(stack_base)); stack_base += 2; } - Type::B256 => { + TypeContent::B256 => { // XXX Like strings, should we just reserve space for a pointer? self.ptr_map.insert(*ptr, Storage::Stack(stack_base)); stack_base += 4; } - Type::String(n) => { + TypeContent::String(n) => { // Strings are always constant and used by reference, so we only store the // pointer on the stack. self.ptr_map.insert(*ptr, Storage::Stack(stack_base)); stack_base += size_bytes_round_up_to_word_alignment!(n) } - ty @ (Type::Array(_) | Type::Struct(_) | Type::Union(_)) => { + TypeContent::Array(..) | TypeContent::Struct(_) | TypeContent::Union(_) => { // Store this aggregate at the current stack base. self.ptr_map.insert(*ptr, Storage::Stack(stack_base)); // Reserve space by incrementing the base. - stack_base += size_bytes_in_words!(ir_type_size_in_bytes(self.context, ty)); + stack_base += + size_bytes_in_words!(ir_type_size_in_bytes(self.context, &ptr_ty)); } }; } diff --git a/sway-core/src/ir_generation/compile.rs b/sway-core/src/ir_generation/compile.rs index dc31600dedc..4da121e11cb 100644 --- a/sway-core/src/ir_generation/compile.rs +++ b/sway-core/src/ir_generation/compile.rs @@ -509,7 +509,7 @@ fn compile_fn_with_args( // Need to copy ref-type return values to the 'out' parameter. ret_val = compiler.compile_copy_to_last_arg(context, ret_val, None); } - if ret_type.eq(context, &Type::Unit) { + if ret_type.is_unit(context) { ret_val = Constant::get_unit(context); } compiler.current_block.ins(context).ret(ret_val, ret_type); diff --git a/sway-core/src/ir_generation/const_eval.rs b/sway-core/src/ir_generation/const_eval.rs index 5c09c4e898c..8b4d83cc948 100644 --- a/sway-core/src/ir_generation/const_eval.rs +++ b/sway-core/src/ir_generation/const_eval.rs @@ -184,7 +184,7 @@ fn const_eval_typed_expr( expr: &ty::TyExpression, ) -> Result, CompileError> { Ok(match &expr.expression { - ty::TyExpressionVariant::Literal(l) => Some(convert_literal_to_constant(l)), + ty::TyExpressionVariant::Literal(l) => Some(convert_literal_to_constant(lookup.context, l)), ty::TyExpressionVariant::FunctionApplication { arguments, function_decl_id, @@ -252,10 +252,16 @@ fn const_eval_typed_expr( // We couldn't evaluate all fields to a constant. return Ok(None); } - get_aggregate_for_types(lookup.type_engine, lookup.context, &field_typs) - .map_or(None, |aggregate| { - Some(Constant::new_struct(&aggregate, field_vals)) - }) + get_aggregate_for_types(lookup.type_engine, lookup.context, &field_typs).map_or( + None, + |struct_ty| { + Some(Constant::new_struct( + lookup.context, + struct_ty.get_field_types(lookup.context), + field_vals, + )) + }, + ) } ty::TyExpressionVariant::Tuple { fields } => { let (mut field_typs, mut field_vals): (Vec<_>, Vec<_>) = (vec![], vec![]); @@ -270,10 +276,16 @@ fn const_eval_typed_expr( // We couldn't evaluate all fields to a constant. return Ok(None); } - create_tuple_aggregate(lookup.type_engine, lookup.context, field_typs) - .map_or(None, |aggregate| { - Some(Constant::new_struct(&aggregate, field_vals)) - }) + create_tuple_aggregate(lookup.type_engine, lookup.context, field_typs).map_or( + None, + |tuple_ty| { + Some(Constant::new_struct( + lookup.context, + tuple_ty.get_field_types(lookup.context), + field_vals, + )) + }, + ) } ty::TyExpressionVariant::Array { contents } => { let (mut element_typs, mut element_vals): (Vec<_>, Vec<_>) = (vec![], vec![]); @@ -305,8 +317,12 @@ fn const_eval_typed_expr( element_type_id, element_typs.len().try_into().unwrap(), ) - .map_or(None, |aggregate| { - Some(Constant::new_array(&aggregate, element_vals)) + .map_or(None, |array_ty| { + Some(Constant::new_array( + lookup.context, + array_ty.get_array_elem_type(lookup.context).unwrap(), + element_vals, + )) }) } ty::TyExpressionVariant::EnumInstantiation { @@ -317,11 +333,11 @@ fn const_eval_typed_expr( } => { let aggregate = create_enum_aggregate(lookup.type_engine, lookup.context, &enum_decl.variants); - if let Ok(aggregate) = aggregate { - let tag_value = Constant::new_uint(64, *tag as u64); + if let Ok(enum_ty) = aggregate { + let tag_value = Constant::new_uint(lookup.context, 64, *tag as u64); let mut fields: Vec = vec![tag_value]; match contents { - None => fields.push(Constant::new_unit()), + None => fields.push(Constant::new_unit(lookup.context)), Some(subexpr) => { let eval_expr = const_eval_typed_expr(lookup, known_consts, subexpr)?; eval_expr.into_iter().for_each(|enum_val| { @@ -329,7 +345,11 @@ fn const_eval_typed_expr( }) } } - Some(Constant::new_struct(&aggregate, fields)) + Some(Constant::new_struct( + lookup.context, + enum_ty.get_field_types(lookup.context), + fields, + )) } else { None } diff --git a/sway-core/src/ir_generation/convert.rs b/sway-core/src/ir_generation/convert.rs index 5d8b9f5d41f..335f6809973 100644 --- a/sway-core/src/ir_generation/convert.rs +++ b/sway-core/src/ir_generation/convert.rs @@ -7,7 +7,7 @@ use crate::{ use super::types::{create_enum_aggregate, create_tuple_aggregate}; use sway_error::error::CompileError; -use sway_ir::{Aggregate, Constant, Context, Type, Value}; +use sway_ir::{Constant, Context, Type, Value}; use sway_types::span::Span; pub(super) fn convert_literal_to_value(context: &mut Context, ast_literal: &Literal) -> Value { @@ -29,17 +29,20 @@ pub(super) fn convert_literal_to_value(context: &mut Context, ast_literal: &Lite } } -pub(super) fn convert_literal_to_constant(ast_literal: &Literal) -> Constant { +pub(super) fn convert_literal_to_constant( + context: &mut Context, + ast_literal: &Literal, +) -> Constant { match ast_literal { // All integers are `u64`. See comment above. - Literal::U8(n) => Constant::new_uint(64, *n as u64), - Literal::U16(n) => Constant::new_uint(64, *n as u64), - Literal::U32(n) => Constant::new_uint(64, *n as u64), - Literal::U64(n) => Constant::new_uint(64, *n), - Literal::Numeric(n) => Constant::new_uint(64, *n), - Literal::String(s) => Constant::new_string(s.as_str().as_bytes().to_vec()), - Literal::Boolean(b) => Constant::new_bool(*b), - Literal::B256(bs) => Constant::new_b256(*bs), + Literal::U8(n) => Constant::new_uint(context, 64, *n as u64), + Literal::U16(n) => Constant::new_uint(context, 64, *n as u64), + Literal::U32(n) => Constant::new_uint(context, 64, *n as u64), + Literal::U64(n) => Constant::new_uint(context, 64, *n), + Literal::Numeric(n) => Constant::new_uint(context, 64, *n), + Literal::String(s) => Constant::new_string(context, s.as_str().as_bytes().to_vec()), + Literal::Boolean(b) => Constant::new_bool(context, *b), + Literal::B256(bs) => Constant::new_b256(context, *bs), } } @@ -89,11 +92,11 @@ fn convert_resolved_type( Ok(match ast_type { // All integers are `u64`, see comment in convert_literal_to_value() above. - TypeInfo::UnsignedInteger(_) => Type::Uint(64), - TypeInfo::Numeric => Type::Uint(64), - TypeInfo::Boolean => Type::Bool, - TypeInfo::B256 => Type::B256, - TypeInfo::Str(n) => Type::String(n.val() as u64), + TypeInfo::UnsignedInteger(_) => Type::get_uint64(context), + TypeInfo::Numeric => Type::get_uint64(context), + TypeInfo::Boolean => Type::get_bool(context), + TypeInfo::B256 => Type::get_b256(context), + TypeInfo::Str(n) => Type::new_string(context, n.val() as u64), TypeInfo::Struct { fields, .. } => super::types::get_aggregate_for_types( type_engine, context, @@ -102,33 +105,28 @@ fn convert_resolved_type( .map(|field| field.type_id) .collect::>() .as_slice(), - ) - .map(Type::Struct)?, + )?, TypeInfo::Enum { variant_types, .. } => { - create_enum_aggregate(type_engine, context, variant_types).map(Type::Struct)? + create_enum_aggregate(type_engine, context, variant_types)? } TypeInfo::Array(elem_type, length) => { let elem_type = convert_resolved_typeid(type_engine, context, &elem_type.type_id, span)?; - Type::Array(Aggregate::new_array( - context, - elem_type, - length.val() as u64, - )) + Type::new_array(context, elem_type, length.val() as u64) } TypeInfo::Tuple(fields) => { if fields.is_empty() { // XXX We've removed Unit from the core compiler, replaced with an empty Tuple. // Perhaps the same should be done for the IR, although it would use an empty // aggregate which might not make as much sense as a dedicated Unit type. - Type::Unit + Type::get_unit(context) } else { let new_fields = fields.iter().map(|x| x.type_id).collect(); - create_tuple_aggregate(type_engine, context, new_fields).map(Type::Struct)? + create_tuple_aggregate(type_engine, context, new_fields)? } } - TypeInfo::RawUntypedPtr => Type::Uint(64), - TypeInfo::RawUntypedSlice => Type::Slice, + TypeInfo::RawUntypedPtr => Type::get_uint64(context), + TypeInfo::RawUntypedSlice => Type::get_slice(context), // Unsupported types which shouldn't exist in the AST after type checking and // monomorphisation. diff --git a/sway-core/src/ir_generation/function.rs b/sway-core/src/ir_generation/function.rs index 4b4cf95229c..d087c7a0e87 100644 --- a/sway-core/src/ir_generation/function.rs +++ b/sway-core/src/ir_generation/function.rs @@ -360,7 +360,7 @@ impl<'eng> FnCompiler<'eng> { self.compile_intrinsic_function(context, md_mgr, kind, ast_expr.span.clone()) } ty::TyExpressionVariant::AbiName(_) => { - Ok(Value::new_constant(context, Constant::new_unit())) + Ok(Value::new_constant(context, Constant::new_unit(context))) } ty::TyExpressionVariant::UnsafeDowncast { exp, variant } => { self.compile_unsafe_downcast(context, md_mgr, exp, variant) @@ -445,7 +445,7 @@ impl<'eng> FnCompiler<'eng> { // Local variable for the key let key_var = compiler .function - .new_local_var(context, alias_key_name, Type::B256, None) + .new_local_var(context, alias_key_name, Type::get_b256(context), None) .map_err(|ir_error| { CompileError::InternalOwned(ir_error.to_string(), Span::dummy()) })?; @@ -640,11 +640,12 @@ impl<'eng> FnCompiler<'eng> { self.compile_expression(context, md_mgr, &number_of_slots_exp)?; let span_md_idx = md_mgr.span_to_md(context, &span); let key_var = store_key_in_local_mem(self, context, key_value, span_md_idx)?; + let b256_ty = Type::get_b256(context); // For quad word, the IR instructions take in a pointer rather than a raw u64. let val_ptr = self .current_block .ins(context) - .int_to_ptr(val_value, Type::B256) + .int_to_ptr(val_value, b256_ty) .add_metadatum(context, span_md_idx); match kind { Intrinsic::StateLoadQuad => Ok(self @@ -767,9 +768,13 @@ impl<'eng> FnCompiler<'eng> { // - The first field is a `b256` that contains the `recipient` // - The second field is a `u64` that contains the message ID // - The third field contains the actual user data - let field_types = [Type::B256, Type::Uint(64), user_message_type]; + let field_types = [ + Type::get_b256(context), + Type::get_uint64(context), + user_message_type, + ]; let recipient_and_message_aggregate = - Aggregate::new_struct(context, field_types.to_vec()); + Type::new_struct(context, field_types.to_vec()); // Step 3: construct a local pointer for the recipient and message data struct let recipient_and_message_aggregate_local_name = self.lexical_map.insert_anon(); @@ -778,7 +783,7 @@ impl<'eng> FnCompiler<'eng> { .new_local_var( context, recipient_and_message_aggregate_local_name, - Type::Struct(recipient_and_message_aggregate), + recipient_and_message_aggregate, None, ) .map_err(|ir_error| { @@ -938,9 +943,11 @@ impl<'eng> FnCompiler<'eng> { let merge_val_arg_idx = final_block.new_arg( context, - lhs_val - .get_type(context) - .unwrap_or_else(|| rhs_val.get_type(context).unwrap_or(Type::Unit)), + lhs_val.get_type(context).unwrap_or_else(|| { + rhs_val + .get_type(context) + .unwrap_or_else(|| Type::get_unit(context)) + }), false, ); @@ -999,7 +1006,7 @@ impl<'eng> FnCompiler<'eng> { 1 => { // The single arg doesn't need to be put into a struct. let arg0 = compiled_args[0]; - + let u64_ty = Type::get_uint64(context); if self .type_engine .get(ast_args[0].1.return_type) @@ -1007,7 +1014,7 @@ impl<'eng> FnCompiler<'eng> { { self.current_block .ins(context) - .bitcast(arg0, Type::Uint(64)) + .bitcast(arg0, u64_ty) .add_metadatum(context, span_md_idx) } else { // Copy this value to a new location. This is quite inefficient but we need to @@ -1039,7 +1046,7 @@ impl<'eng> FnCompiler<'eng> { .iter() .filter_map(|val| val.get_type(context)) .collect::>(); - let user_args_struct_aggregate = Aggregate::new_struct(context, field_types); + let user_args_struct_aggregate = Type::new_struct(context, field_types); // New local pointer for the struct to hold all user arguments let user_args_struct_local_name = self @@ -1050,7 +1057,7 @@ impl<'eng> FnCompiler<'eng> { .new_local_var( context, user_args_struct_local_name, - Type::Struct(user_args_struct_aggregate), + user_args_struct_aggregate, None, ) .map_err(|ir_error| { @@ -1088,9 +1095,14 @@ impl<'eng> FnCompiler<'eng> { // Now handle the contract address and the selector. The contract address is just // as B256 while the selector is a [u8; 4] which we have to convert to a U64. - let ra_struct_aggregate = Aggregate::new_struct( + let ra_struct_aggregate = Type::new_struct( context, - [Type::B256, Type::Uint(64), Type::Uint(64)].to_vec(), + [ + Type::get_b256(context), + Type::get_uint64(context), + Type::get_uint64(context), + ] + .to_vec(), ); let ra_struct_var = self @@ -1098,7 +1110,7 @@ impl<'eng> FnCompiler<'eng> { .new_local_var( context, self.lexical_map.insert_anon(), - Type::Struct(ra_struct_aggregate), + ra_struct_aggregate, None, ) .map_err(|ir_error| CompileError::InternalOwned(ir_error.to_string(), Span::dummy()))?; @@ -1358,9 +1370,11 @@ impl<'eng> FnCompiler<'eng> { // Add a single argument to merge_block that merges true_value and false_value. let merge_val_arg_idx = merge_block.new_arg( context, - true_value - .get_type(context) - .unwrap_or_else(|| false_value.get_type(context).unwrap_or(Type::Unit)), + true_value.get_type(context).unwrap_or_else(|| { + false_value + .get_type(context) + .unwrap_or_else(|| Type::get_unit(context)) + }), false, ); if !true_block_end.is_terminated(context) { @@ -1392,7 +1406,7 @@ impl<'eng> FnCompiler<'eng> { &exp.return_type, &exp.span, )? { - Type::Struct(aggregate) => aggregate, + ty if ty.is_struct(context) => ty, _ => { return Err(CompileError::Internal( "Enum type for `unsafe downcast` is not an enum.", @@ -1423,7 +1437,7 @@ impl<'eng> FnCompiler<'eng> { &exp.return_type, &exp.span, )? { - Type::Struct(aggregate) => aggregate, + ty if ty.is_struct(context) => ty, _ => { return Err(CompileError::Internal("Expected enum type here.", exp.span)); } @@ -1540,7 +1554,9 @@ impl<'eng> FnCompiler<'eng> { name: &str, span_md_idx: Option, ) -> Result { - let need_to_load = |ty: &Type| matches!(ty, Type::Unit | Type::Bool | Type::Uint(_)); + let need_to_load = |ty: &Type, context: &Context| { + ty.is_unit(context) || ty.is_bool(context) || ty.is_uint(context) + }; // We need to check the symbol map first, in case locals are shadowing the args, other // locals or even constants. @@ -1558,7 +1574,7 @@ impl<'eng> FnCompiler<'eng> { .is_copy_type() && fn_param.unwrap().is_reference && fn_param.unwrap().is_mutable; - if !is_ref_primitive && need_to_load(var.get_type(context)) { + if !is_ref_primitive && need_to_load(&var.get_type(context), context) { Ok(self .current_block .ins(context) @@ -1631,7 +1647,7 @@ impl<'eng> FnCompiler<'eng> { // We can have empty aggregates, especially arrays, which shouldn't be initialised, but // otherwise use a store. - let var_ty = *local_var.get_type(context); + let var_ty = local_var.get_type(context); if ir_type_size_in_bytes(context, &var_ty) > 0 { let local_val = self .current_block @@ -1681,7 +1697,7 @@ impl<'eng> FnCompiler<'eng> { // We can have empty aggregates, especially arrays, which shouldn't be initialised, but // otherwise use a store. - let var_ty = *local_var.get_type(context); + let var_ty = local_var.get_type(context); if ir_type_size_in_bytes(context, &var_ty) > 0 { let local_val = self .current_block @@ -1754,7 +1770,7 @@ impl<'eng> FnCompiler<'eng> { } let ty = match val.get_type(context).unwrap() { - Type::Array(aggregate) => aggregate, + ty if ty.is_array(context) => ty, _otherwise => { let spans = ast_reassignment .lhs_indices @@ -1797,7 +1813,7 @@ impl<'eng> FnCompiler<'eng> { )?; let ty = match val.get_type(context).unwrap() { - Type::Struct(aggregate) => aggregate, + ty if ty.is_struct(context) => ty, _otherwise => { let spans = ast_reassignment .lhs_indices @@ -1873,17 +1889,17 @@ impl<'eng> FnCompiler<'eng> { // A zero length array is a pointer to nothing, which is still supported by Sway. // We're unable to get the type though it's irrelevant because it can't be indexed, so // we'll just use Unit. - Type::Unit + Type::get_unit(context) } else { convert_resolved_typeid_no_span(self.type_engine, context, &contents[0].return_type)? }; - let aggregate = Aggregate::new_array(context, elem_type, contents.len() as u64); + let aggregate = Type::new_array(context, elem_type, contents.len() as u64); // Compile each element and insert it immediately. let temp_name = self.lexical_map.insert_anon(); let array_var = self .function - .new_local_var(context, temp_name, Type::Array(aggregate), None) + .new_local_var(context, temp_name, aggregate, None) .map_err(|ir_error| CompileError::InternalOwned(ir_error.to_string(), Span::dummy()))?; let mut array_value = self .current_block @@ -1932,12 +1948,14 @@ impl<'eng> FnCompiler<'eng> { array_expr_span, ) }) - } else if let Some((Type::Array(agg), _)) = array_val.get_argument_type_and_byref(context) { + } else if let Some((agg, _)) = array_val + .get_argument_type_and_byref(context) + .filter(|(ty, _)| ty.is_array(context)) + { Ok(agg) - } else if let Some(Constant { - ty: Type::Array(agg), - .. - }) = array_val.get_constant(context) + } else if let Some(Constant { ty: agg, .. }) = array_val + .get_constant(context) + .filter(|c| c.ty.is_array(context)) { Ok(*agg) } else { @@ -1961,11 +1979,11 @@ impl<'eng> FnCompiler<'eng> { Some(self), index_expr, ) { - let (_, count) = aggregate.get_content(context).array_type(); - if constant_value >= *count { + let count = aggregate.get_array_len(context).unwrap(); + if constant_value >= count { return Err(CompileError::ArrayOutOfBounds { index: constant_value, - count: *count, + count, span: index_expr_span, }); } @@ -2015,7 +2033,7 @@ impl<'eng> FnCompiler<'eng> { let temp_name = self.lexical_map.insert_anon(); let struct_var = self .function - .new_local_var(context, temp_name, Type::Struct(aggregate), None) + .new_local_var(context, temp_name, aggregate, None) .map_err(|ir_error| CompileError::InternalOwned(ir_error.to_string(), Span::dummy()))?; let agg_value = self .current_block @@ -2051,13 +2069,14 @@ impl<'eng> FnCompiler<'eng> { "Unsupported instruction as struct value for field expression. {instruction:?}"), ast_struct_expr_span) }) - } else if let Some((Type::Struct(agg), _)) = struct_val.get_argument_type_and_byref(context) + } else if let Some((agg, _)) = struct_val + .get_argument_type_and_byref(context) + .filter(|(ty, _)| ty.is_struct(context)) { Ok(agg) - } else if let Some(Constant { - ty: Type::Struct(agg), - .. - }) = struct_val.get_constant(context) + } else if let Some(Constant { ty: agg, .. }) = struct_val + .get_constant(context) + .filter(|c| c.ty.is_struct(context)) { Ok(*agg) } else { @@ -2121,7 +2140,7 @@ impl<'eng> FnCompiler<'eng> { let temp_name = self.lexical_map.insert_anon(); let enum_var = self .function - .new_local_var(context, temp_name, Type::Struct(aggregate), None) + .new_local_var(context, temp_name, aggregate, None) .map_err(|ir_error| CompileError::InternalOwned(ir_error.to_string(), Span::dummy()))?; let enum_val = self .current_block @@ -2137,26 +2156,22 @@ impl<'eng> FnCompiler<'eng> { // If the struct representing the enum has only one field, then that field is basically the // tag and all the variants must have unit types, hence the absence of the union. // Therefore, there is no need for another `insert_value` instruction here. - match aggregate.get_content(context) { - AggregateContent::FieldTypes(field_tys) => { - Ok(if field_tys.len() == 1 { - agg_value - } else { - match &contents { - None => agg_value, - Some(te) => { - // Insert the value too. - let contents_value = self.compile_expression(context, md_mgr, te)?; - self.current_block - .ins(context) - .insert_value(agg_value, aggregate, contents_value, vec![1]) - .add_metadatum(context, span_md_idx) - } - } - }) + let field_tys = aggregate.get_field_types(context); + Ok(if field_tys.len() == 1 { + agg_value + } else { + match &contents { + None => agg_value, + Some(te) => { + // Insert the value too. + let contents_value = self.compile_expression(context, md_mgr, te)?; + self.current_block + .ins(context) + .insert_value(agg_value, aggregate, contents_value, vec![1]) + .add_metadatum(context, span_md_idx) + } } - _ => unreachable!("Wrong content for struct."), - } + }) } fn compile_tuple_expr( @@ -2187,11 +2202,11 @@ impl<'eng> FnCompiler<'eng> { init_types.push(init_type); } - let aggregate = Aggregate::new_struct(context, init_types); + let aggregate = Type::new_struct(context, init_types); let temp_name = self.lexical_map.insert_anon(); let tuple_var = self .function - .new_local_var(context, temp_name, Type::Struct(aggregate), None) + .new_local_var(context, temp_name, aggregate, None) .map_err(|ir_error| { CompileError::InternalOwned(ir_error.to_string(), Span::dummy()) })?; @@ -2223,14 +2238,13 @@ impl<'eng> FnCompiler<'eng> { span: Span, ) -> Result { let tuple_value = self.compile_expression(context, md_mgr, tuple)?; - if let Type::Struct(aggregate) = - convert_resolved_typeid(self.type_engine, context, &tuple_type, &span)? - { + let ty = convert_resolved_typeid(self.type_engine, context, &tuple_type, &span)?; + if ty.is_struct(context) { let span_md_idx = md_mgr.span_to_md(context, &span); Ok(self .current_block .ins(context) - .extract_value(tuple_value, aggregate, vec![idx as u64]) + .extract_value(tuple_value, ty, vec![idx as u64]) .add_metadatum(context, span_md_idx)) } else { Err(CompileError::Internal( @@ -2333,11 +2347,11 @@ impl<'eng> FnCompiler<'eng> { span_md_idx: Option, ) -> Result { match ty { - Type::Struct(aggregate) => { + ty if ty.is_struct(context) => { let temp_name = self.lexical_map.insert_anon(); let struct_var = self .function - .new_local_var(context, temp_name, Type::Struct(*aggregate), None) + .new_local_var(context, temp_name, *ty, None) .map_err(|ir_error| { CompileError::InternalOwned(ir_error.to_string(), Span::dummy()) })?; @@ -2347,7 +2361,7 @@ impl<'eng> FnCompiler<'eng> { .get_local(struct_var) .add_metadatum(context, span_md_idx); - let fields = aggregate.get_content(context).field_types().clone(); + let fields = ty.get_field_types(context); for (field_idx, field_type) in fields.into_iter().enumerate() { let field_idx = field_idx as u64; @@ -2368,7 +2382,7 @@ impl<'eng> FnCompiler<'eng> { struct_val = self .current_block .ins(context) - .insert_value(struct_val, *aggregate, val_to_insert, vec![field_idx]) + .insert_value(struct_val, *ty, val_to_insert, vec![field_idx]) .add_metadatum(context, span_md_idx); } Ok(struct_val) @@ -2386,7 +2400,7 @@ impl<'eng> FnCompiler<'eng> { // Local pointer for the key let key_var = self .function - .new_local_var(context, alias_key_name, Type::B256, None) + .new_local_var(context, alias_key_name, Type::get_b256(context), None) .map_err(|ir_error| { CompileError::InternalOwned(ir_error.to_string(), Span::dummy()) })?; @@ -2409,31 +2423,32 @@ impl<'eng> FnCompiler<'eng> { .store(key_val, const_key) .add_metadatum(context, span_md_idx); - match ty { - Type::Array(_) => Err(CompileError::Internal( + match ty.get_content(context) { + TypeContent::Array(..) => Err(CompileError::Internal( "Arrays in storage have not been implemented yet.", Span::dummy(), )), - Type::Slice => Err(CompileError::Internal( + TypeContent::Slice => Err(CompileError::Internal( "Slices in storage have not been implemented yet.", Span::dummy(), )), - Type::B256 => { + TypeContent::B256 => { self.compile_b256_storage_read(context, ix, indices, &key_val, span_md_idx) } - Type::Bool | Type::Uint(_) => { + TypeContent::Bool | TypeContent::Uint(_) => { self.compile_uint_or_bool_storage_read(context, &key_val, ty, span_md_idx) } - Type::String(_) | Type::Union(_) => self.compile_union_or_string_storage_read( - context, - ix, - indices, - &key_val, - ty, - span_md_idx, - ), - Type::Struct(_) => unreachable!("structs are already handled!"), - Type::Unit => { + TypeContent::String(_) | TypeContent::Union(_) => self + .compile_union_or_string_storage_read( + context, + ix, + indices, + &key_val, + ty, + span_md_idx, + ), + TypeContent::Struct(_) => unreachable!("structs are already handled!"), + TypeContent::Unit => { Ok(Constant::get_unit(context).add_metadatum(context, span_md_idx)) } } @@ -2453,8 +2468,8 @@ impl<'eng> FnCompiler<'eng> { span_md_idx: Option, ) -> Result<(), CompileError> { match ty { - Type::Struct(aggregate) => { - let fields = aggregate.get_content(context).field_types().clone(); + ty if ty.is_struct(context) => { + let fields = ty.get_field_types(context); for (field_idx, field_type) in fields.into_iter().enumerate() { let field_idx = field_idx as u64; @@ -2466,7 +2481,7 @@ impl<'eng> FnCompiler<'eng> { let rhs = self .current_block .ins(context) - .extract_value(rhs, *aggregate, vec![field_idx]) + .extract_value(rhs, *ty, vec![field_idx]) .add_metadatum(context, span_md_idx); self.compile_storage_write( @@ -2494,7 +2509,7 @@ impl<'eng> FnCompiler<'eng> { // Local pointer for the key let key_var = self .function - .new_local_var(context, alias_key_name, Type::B256, None) + .new_local_var(context, alias_key_name, Type::get_b256(context), None) .map_err(|ir_error| { CompileError::InternalOwned(ir_error.to_string(), Span::dummy()) })?; @@ -2517,16 +2532,16 @@ impl<'eng> FnCompiler<'eng> { .store(key_val, const_key) .add_metadatum(context, span_md_idx); - match ty { - Type::Array(_) => Err(CompileError::Internal( + match ty.get_content(context) { + TypeContent::Array(..) => Err(CompileError::Internal( "Arrays in storage have not been implemented yet.", Span::dummy(), )), - Type::Slice => Err(CompileError::Internal( + TypeContent::Slice => Err(CompileError::Internal( "Slices in storage have not been implemented yet.", Span::dummy(), )), - Type::B256 => self.compile_b256_storage_write( + TypeContent::B256 => self.compile_b256_storage_write( context, ix, indices, @@ -2534,20 +2549,21 @@ impl<'eng> FnCompiler<'eng> { rhs, span_md_idx, ), - Type::Bool | Type::Uint(_) => { + TypeContent::Bool | TypeContent::Uint(_) => { self.compile_uint_or_bool_storage_write(context, &key_val, rhs, span_md_idx) } - Type::String(_) | Type::Union(_) => self.compile_union_or_string_storage_write( - context, - ix, - indices, - &key_val, - ty, - rhs, - span_md_idx, - ), - Type::Struct(_) => unreachable!("structs are already handled!"), - Type::Unit => Ok(()), + TypeContent::String(_) | TypeContent::Union(_) => self + .compile_union_or_string_storage_write( + context, + ix, + indices, + &key_val, + ty, + rhs, + span_md_idx, + ), + TypeContent::Struct(_) => unreachable!("structs are already handled!"), + TypeContent::Unit => Ok(()), } } } @@ -2582,12 +2598,13 @@ impl<'eng> FnCompiler<'eng> { rhs: Value, span_md_idx: Option, ) -> Result<(), CompileError> { + let u64_ty = Type::get_uint64(context); // `state_store_word` requires a `u64`. Cast the value to store to // `u64` first before actually storing. let rhs_u64 = self .current_block .ins(context) - .bitcast(rhs, Type::Uint(64)) + .bitcast(rhs, u64_ty) .add_metadatum(context, span_md_idx); self.current_block .ins(context) @@ -2615,7 +2632,7 @@ impl<'eng> FnCompiler<'eng> { // Local pointer to hold the B256 let local_var = self .function - .new_local_var(context, alias_value_name, Type::B256, None) + .new_local_var(context, alias_value_name, Type::get_b256(context), None) .map_err(|ir_error| CompileError::InternalOwned(ir_error.to_string(), Span::dummy()))?; // Convert the local pointer created to a value using get_ptr @@ -2653,7 +2670,7 @@ impl<'eng> FnCompiler<'eng> { // Local pointer to hold the B256 let local_var = self .function - .new_local_var(context, alias_value_name, Type::B256, None) + .new_local_var(context, alias_value_name, Type::get_b256(context), None) .map_err(|ir_error| CompileError::InternalOwned(ir_error.to_string(), Span::dummy()))?; // Convert the local pointer created to a value using get_ptr @@ -2703,11 +2720,7 @@ impl<'eng> FnCompiler<'eng> { // Create an array of `b256` that will hold the value to store into storage // or the value loaded from storage. The array has to fit the whole type. let number_of_elements = (ir_type_size_in_bytes(context, r#type) + 31) / 32; - let b256_array_type = Type::Array(Aggregate::new_array( - context, - Type::B256, - number_of_elements, - )); + let b256_array_type = Type::new_array(context, Type::get_b256(context), number_of_elements); // Local pointer to hold the array of b256s let local_var = self @@ -2726,6 +2739,7 @@ impl<'eng> FnCompiler<'eng> { .ins(context) .cast_ptr(local_val, *r#type, 0) .add_metadatum(context, span_md_idx); + let b256_ty = Type::get_b256(context); if number_of_elements > 0 { // Get the b256 from the array at index iter @@ -2737,7 +2751,7 @@ impl<'eng> FnCompiler<'eng> { let indexed_value_val_b256 = self .current_block .ins(context) - .cast_ptr(value_val_b256, Type::B256, 0) + .cast_ptr(value_val_b256, b256_ty, 0) .add_metadatum(context, span_md_idx); let count_value = convert_literal_to_value(context, &Literal::U64(number_of_elements)); @@ -2776,11 +2790,7 @@ impl<'eng> FnCompiler<'eng> { // Create an array of `b256` that will hold the value to store into storage // or the value loaded from storage. The array has to fit the whole type. let number_of_elements = (ir_type_size_in_bytes(context, r#type) + 31) / 32; - let b256_array_type = Type::Array(Aggregate::new_array( - context, - Type::B256, - number_of_elements, - )); + let b256_array_type = Type::new_array(context, Type::get_b256(context), number_of_elements); // Local pointer to hold the array of b256s let local_var = self @@ -2807,6 +2817,7 @@ impl<'eng> FnCompiler<'eng> { .store(final_val, rhs) .add_metadatum(context, span_md_idx); + let b256_ty = Type::get_b256(context); if number_of_elements > 0 { // Get the b256 from the array at index iter let value_ptr_val_b256 = self @@ -2817,7 +2828,7 @@ impl<'eng> FnCompiler<'eng> { let indexed_value_ptr_val_b256 = self .current_block .ins(context) - .cast_ptr(value_ptr_val_b256, Type::B256, 0) + .cast_ptr(value_ptr_val_b256, b256_ty, 0) .add_metadatum(context, span_md_idx); // Finally, just call state_load_quad_word/state_store_quad_word diff --git a/sway-core/src/ir_generation/storage.rs b/sway-core/src/ir_generation/storage.rs index 94585ce1815..df698d429b9 100644 --- a/sway-core/src/ir_generation/storage.rs +++ b/sway-core/src/ir_generation/storage.rs @@ -9,7 +9,7 @@ use crate::{ use sway_ir::{ constant::{Constant, ConstantValue}, context::Context, - irtype::{AggregateContent, Type}, + irtype::Type, }; use sway_types::state::StateIndex; @@ -63,13 +63,13 @@ pub fn serialize_to_storage_slots( ty: &Type, indices: &[usize], ) -> Vec { - match (&ty, &constant.value) { - (_, ConstantValue::Undef) => vec![], - (Type::Unit, ConstantValue::Unit) => vec![StorageSlot::new( + match &constant.value { + ConstantValue::Undef => vec![], + ConstantValue::Unit if ty.is_unit(context) => vec![StorageSlot::new( get_storage_key(ix, indices), Bytes32::new([0; 32]), )], - (Type::Bool, ConstantValue::Bool(b)) => { + ConstantValue::Bool(b) if ty.is_bool(context) => { vec![StorageSlot::new( get_storage_key(ix, indices), Bytes32::new( @@ -84,7 +84,7 @@ pub fn serialize_to_storage_slots( ), )] } - (Type::Uint(_), ConstantValue::Uint(n)) => { + ConstantValue::Uint(n) if ty.is_uint(context) => { vec![StorageSlot::new( get_storage_key(ix, indices), Bytes32::new( @@ -98,39 +98,36 @@ pub fn serialize_to_storage_slots( ), )] } - (Type::B256, ConstantValue::B256(b)) => { + ConstantValue::B256(b) if ty.is_b256(context) => { vec![StorageSlot::new( get_storage_key(ix, indices), Bytes32::new(*b), )] } - (Type::Array(_), ConstantValue::Array(_a)) => { + ConstantValue::Array(_a) if ty.is_array(context) => { unimplemented!("Arrays in storage have not been implemented yet.") } - (Type::Struct(aggregate), ConstantValue::Struct(vec)) => { - match aggregate.get_content(context) { - AggregateContent::FieldTypes(field_tys) => vec - .iter() - .zip(field_tys.iter()) - .enumerate() - .flat_map(|(i, (f, ty))| { - serialize_to_storage_slots( - f, - context, - ix, - ty, - &indices - .iter() - .cloned() - .chain(vec![i].iter().cloned()) - .collect::>(), - ) - }) - .collect(), - _ => unreachable!("Wrong content for struct."), - } + ConstantValue::Struct(vec) if ty.is_struct(context) => { + let field_tys = ty.get_field_types(context); + vec.iter() + .zip(field_tys.iter()) + .enumerate() + .flat_map(|(i, (f, ty))| { + serialize_to_storage_slots( + f, + context, + ix, + ty, + &indices + .iter() + .cloned() + .chain(vec![i].iter().cloned()) + .collect::>(), + ) + }) + .collect() } - (Type::Union(_), _) | (Type::String(_), _) => { + _ if ty.is_string(context) || ty.is_union(context) => { // Serialize the constant data in words and add zero words until the number of words // is a multiple of 4. This is useful because each storage slot is 4 words. let mut packed = serialize_to_words(constant, context, ty); @@ -164,10 +161,10 @@ pub fn serialize_to_storage_slots( /// words and add left padding up to size of `ty`. /// pub fn serialize_to_words(constant: &Constant, context: &Context, ty: &Type) -> Vec { - match (&ty, &constant.value) { - (_, ConstantValue::Undef) => vec![], - (Type::Unit, ConstantValue::Unit) => vec![Bytes8::new([0; 8])], - (Type::Bool, ConstantValue::Bool(b)) => { + match &constant.value { + ConstantValue::Undef => vec![], + ConstantValue::Unit if ty.is_unit(context) => vec![Bytes8::new([0; 8])], + ConstantValue::Bool(b) if ty.is_bool(context) => { vec![Bytes8::new( [0; 7] .iter() @@ -178,15 +175,15 @@ pub fn serialize_to_words(constant: &Constant, context: &Context, ty: &Type) -> .unwrap(), )] } - (Type::Uint(_), ConstantValue::Uint(n)) => { + ConstantValue::Uint(n) if ty.is_uint(context) => { vec![Bytes8::new(n.to_be_bytes())] } - (Type::B256, ConstantValue::B256(b)) => Vec::from_iter( + ConstantValue::B256(b) if ty.is_b256(context) => Vec::from_iter( (0..4) .into_iter() .map(|i| Bytes8::new(b[8 * i..8 * i + 8].try_into().unwrap())), ), - (Type::String(_), ConstantValue::String(s)) => { + ConstantValue::String(s) if ty.is_string(context) => { // Turn the bytes into serialized words (Bytes8). let mut s = s.clone(); s.extend(vec![0; ((s.len() + 7) / 8) * 8 - s.len()]); @@ -202,20 +199,17 @@ pub fn serialize_to_words(constant: &Constant, context: &Context, ty: &Type) -> ) })) } - (Type::Array(_), ConstantValue::Array(_)) => { + ConstantValue::Array(_) if ty.is_array(context) => { unimplemented!("Arrays in storage have not been implemented yet.") } - (Type::Struct(aggregate), ConstantValue::Struct(vec)) => { - match aggregate.get_content(context) { - AggregateContent::FieldTypes(field_tys) => vec - .iter() - .zip(field_tys.iter()) - .flat_map(|(f, ty)| serialize_to_words(f, context, ty)) - .collect(), - _ => unreachable!("Wrong content for struct."), - } + ConstantValue::Struct(vec) if ty.is_struct(context) => { + let field_tys = ty.get_field_types(context); + vec.iter() + .zip(field_tys.iter()) + .flat_map(|(f, ty)| serialize_to_words(f, context, ty)) + .collect() } - (Type::Union(_), _) => { + _ if ty.is_union(context) => { let value_size_in_words = ir_type_size_in_bytes(context, ty) / 8; let constant_size_in_words = ir_type_size_in_bytes(context, &constant.ty) / 8; assert!(value_size_in_words >= constant_size_in_words); diff --git a/sway-core/src/ir_generation/types.rs b/sway-core/src/ir_generation/types.rs index d63daa8a8bf..308e645595b 100644 --- a/sway-core/src/ir_generation/types.rs +++ b/sway-core/src/ir_generation/types.rs @@ -7,14 +7,14 @@ use crate::{ use super::convert::convert_resolved_typeid_no_span; use sway_error::error::CompileError; -use sway_ir::{Aggregate, Context, Type}; +use sway_ir::{Context, Type}; use sway_types::span::Spanned; pub(super) fn create_enum_aggregate( type_engine: &TypeEngine, context: &mut Context, variants: &[ty::TyEnumVariant], -) -> Result { +) -> Result { // Create the enum aggregate first. NOTE: single variant enums don't need an aggregate but are // getting one here anyway. They don't need to be a tagged union either. let field_types: Vec<_> = variants @@ -25,11 +25,12 @@ pub(super) fn create_enum_aggregate( // Enums where all the variants are unit types don't really need the union. Only a tag is // needed. For consistency, and to keep enums as reference types, we keep the tag in an // Aggregate. - Ok(if field_types.iter().all(|f| matches!(f, Type::Unit)) { - Aggregate::new_struct(context, vec![Type::Uint(64)]) + Ok(if field_types.iter().all(|f| f.is_unit(context)) { + Type::new_struct(context, vec![Type::get_uint64(context)]) } else { - let enum_aggregate = Aggregate::new_struct(context, field_types); - Aggregate::new_struct(context, vec![Type::Uint(64), Type::Union(enum_aggregate)]) + let u64_ty = Type::get_uint64(context); + let union_ty = Type::new_union(context, field_types); + Type::new_struct(context, vec![u64_ty, union_ty]) }) } @@ -37,13 +38,13 @@ pub(super) fn create_tuple_aggregate( type_engine: &TypeEngine, context: &mut Context, fields: Vec, -) -> Result { +) -> Result { let field_types = fields .into_iter() .map(|ty_id| convert_resolved_typeid_no_span(type_engine, context, &ty_id)) .collect::, CompileError>>()?; - Ok(Aggregate::new_struct(context, field_types)) + Ok(Type::new_struct(context, field_types)) } pub(super) fn create_array_aggregate( @@ -51,21 +52,21 @@ pub(super) fn create_array_aggregate( context: &mut Context, element_type_id: TypeId, count: u64, -) -> Result { +) -> Result { let element_type = convert_resolved_typeid_no_span(type_engine, context, &element_type_id)?; - Ok(Aggregate::new_array(context, element_type, count)) + Ok(Type::new_array(context, element_type, count)) } pub(super) fn get_aggregate_for_types( type_engine: &TypeEngine, context: &mut Context, type_ids: &[TypeId], -) -> Result { +) -> Result { let types = type_ids .iter() .map(|ty_id| convert_resolved_typeid_no_span(type_engine, context, ty_id)) .collect::, CompileError>>()?; - Ok(Aggregate::new_struct(context, types)) + Ok(Type::new_struct(context, types)) } pub(super) fn get_struct_name_field_index_and_type( diff --git a/sway-core/src/lib.rs b/sway-core/src/lib.rs index c013ef41366..113a53e4ed8 100644 --- a/sway-core/src/lib.rs +++ b/sway-core/src/lib.rs @@ -29,7 +29,7 @@ use std::collections::HashMap; use std::path::{Path, PathBuf}; use std::sync::Arc; use sway_error::handler::{ErrorEmitted, Handler}; -use sway_ir::{call_graph, Context, Function, Instruction, Kind, Module, Type, Value}; +use sway_ir::{call_graph, Context, Function, Instruction, Kind, Module, Value}; pub use semantic_analysis::namespace::{self, Namespace}; pub mod types; @@ -595,7 +595,7 @@ pub fn inline_function_calls( arg_val .get_argument_type_and_byref(ctx) .map(|(ty, by_ref)| { - by_ref || !matches!(ty, Type::Unit | Type::Bool | Type::Uint(_)) + by_ref || !(ty.is_unit(ctx) | ty.is_bool(ctx) | ty.is_uint(ctx)) }) .unwrap_or(false) }) { diff --git a/sway-ir/src/constant.rs b/sway-ir/src/constant.rs index 86d02544017..7bb29383366 100644 --- a/sway-ir/src/constant.rs +++ b/sway-ir/src/constant.rs @@ -1,11 +1,6 @@ //! [`Constant`] is a typed constant value. -use crate::{ - context::Context, - irtype::{Aggregate, Type}, - pretty::DebugWithContext, - value::Value, -}; +use crate::{context::Context, irtype::Type, pretty::DebugWithContext, value::Value}; /// A [`Type`] and constant value, including [`ConstantValue::Undef`] for uninitialized constants. #[derive(Debug, Clone, DebugWithContext)] @@ -28,51 +23,51 @@ pub enum ConstantValue { } impl Constant { - pub fn new_unit() -> Self { + pub fn new_unit(context: &Context) -> Self { Constant { - ty: Type::Unit, + ty: Type::get_unit(context), value: ConstantValue::Unit, } } - pub fn new_bool(b: bool) -> Self { + pub fn new_bool(context: &Context, b: bool) -> Self { Constant { - ty: Type::Bool, + ty: Type::get_bool(context), value: ConstantValue::Bool(b), } } - pub fn new_uint(nbits: u8, n: u64) -> Self { + pub fn new_uint(context: &mut Context, nbits: u8, n: u64) -> Self { Constant { - ty: Type::Uint(nbits), + ty: Type::new_uint(context, nbits), value: ConstantValue::Uint(n), } } - pub fn new_b256(bytes: [u8; 32]) -> Self { + pub fn new_b256(context: &Context, bytes: [u8; 32]) -> Self { Constant { - ty: Type::B256, + ty: Type::get_b256(context), value: ConstantValue::B256(bytes), } } - pub fn new_string(string: Vec) -> Self { + pub fn new_string(context: &mut Context, string: Vec) -> Self { Constant { - ty: Type::String(string.len() as u64), + ty: Type::new_string(context, string.len() as u64), value: ConstantValue::String(string), } } - pub fn new_array(aggregate: &Aggregate, elems: Vec) -> Self { + pub fn new_array(context: &mut Context, elm_ty: Type, elems: Vec) -> Self { Constant { - ty: Type::Array(*aggregate), + ty: Type::new_array(context, elm_ty, elems.len() as u64), value: ConstantValue::Array(elems), } } - pub fn new_struct(aggregate: &Aggregate, fields: Vec) -> Self { + pub fn new_struct(context: &mut Context, field_tys: Vec, fields: Vec) -> Self { Constant { - ty: Type::Struct(*aggregate), + ty: Type::new_struct(context, field_tys), value: ConstantValue::Struct(fields), } } @@ -85,46 +80,39 @@ impl Constant { } pub fn get_unit(context: &mut Context) -> Value { - Value::new_constant(context, Constant::new_unit()) + let new_const = Constant::new_unit(context); + Value::new_constant(context, new_const) } pub fn get_bool(context: &mut Context, value: bool) -> Value { - Value::new_constant(context, Constant::new_bool(value)) + let new_const = Constant::new_bool(context, value); + Value::new_constant(context, new_const) } pub fn get_uint(context: &mut Context, nbits: u8, value: u64) -> Value { - Value::new_constant(context, Constant::new_uint(nbits, value)) + let new_const = Constant::new_uint(context, nbits, value); + Value::new_constant(context, new_const) } pub fn get_b256(context: &mut Context, value: [u8; 32]) -> Value { - Value::new_constant(context, Constant::new_b256(value)) + let new_const = Constant::new_b256(context, value); + Value::new_constant(context, new_const) } pub fn get_string(context: &mut Context, value: Vec) -> Value { - Value::new_constant(context, Constant::new_string(value)) + let new_const = Constant::new_string(context, value); + Value::new_constant(context, new_const) } /// `value` must be created as an array constant first, using [`Constant::new_array()`]. pub fn get_array(context: &mut Context, value: Constant) -> Value { - assert!(matches!( - value, - Constant { - ty: Type::Array(_), - .. - } - )); + assert!(value.ty.is_array(context)); Value::new_constant(context, value) } /// `value` must be created as a struct constant first, using [`Constant::new_struct()`]. pub fn get_struct(context: &mut Context, value: Constant) -> Value { - assert!(matches!( - value, - Constant { - ty: Type::Struct(_), - .. - } - )); + assert!(value.ty.is_struct(context)); Value::new_constant(context, value) } diff --git a/sway-ir/src/context.rs b/sway-ir/src/context.rs index 064d8af311c..7b1625d5aca 100644 --- a/sway-ir/src/context.rs +++ b/sway-ir/src/context.rs @@ -7,31 +7,51 @@ //! It is passed around as a mutable reference to many of the Sway-IR APIs. use generational_arena::Arena; +use rustc_hash::FxHashMap; use crate::{ - asm::AsmBlockContent, block::BlockContent, function::FunctionContent, irtype::AggregateContent, + asm::AsmBlockContent, block::BlockContent, function::FunctionContent, local_var::LocalVarContent, metadata::Metadatum, module::ModuleContent, module::ModuleIterator, - value::ValueContent, + value::ValueContent, Type, TypeContent, }; /// The main IR context handle. /// /// Every module, function, block and value is stored here. Some aggregate metadata is also /// managed by the context. -#[derive(Default)] pub struct Context { pub(crate) modules: Arena, pub(crate) functions: Arena, pub(crate) blocks: Arena, pub(crate) values: Arena, pub(crate) local_vars: Arena, - pub(crate) aggregates: Arena, + pub(crate) types: Arena, + pub(crate) type_map: FxHashMap, pub(crate) asm_blocks: Arena, pub(crate) metadata: Arena, next_unique_sym_tag: u64, } +impl Default for Context { + fn default() -> Self { + let mut def = Self { + modules: Default::default(), + functions: Default::default(), + blocks: Default::default(), + values: Default::default(), + local_vars: Default::default(), + types: Default::default(), + type_map: Default::default(), + asm_blocks: Default::default(), + metadata: Default::default(), + next_unique_sym_tag: Default::default(), + }; + Type::create_basic_types(&mut def); + def + } +} + impl Context { /// Return an interator for every module in this context. pub fn module_iter(&self) -> ModuleIterator { diff --git a/sway-ir/src/instruction.rs b/sway-ir/src/instruction.rs index a54db7cb9d4..a6c11355292 100644 --- a/sway-ir/src/instruction.rs +++ b/sway-ir/src/instruction.rs @@ -15,7 +15,7 @@ use crate::{ block::Block, context::Context, function::Function, - irtype::{Aggregate, Type}, + irtype::Type, local_var::LocalVar, pretty::DebugWithContext, value::{Value, ValueDatum}, @@ -74,13 +74,13 @@ pub enum Instruction { /// Reading a specific element from an array. ExtractElement { array: Value, - ty: Aggregate, + ty: Type, index_val: Value, }, /// Reading a specific field from (nested) structs. ExtractValue { aggregate: Value, - ty: Aggregate, + ty: Type, indices: Vec, }, /// Umbrella instruction variant for FuelVM-specific instructions @@ -90,14 +90,14 @@ pub enum Instruction { /// Writing a specific value to an array. InsertElement { array: Value, - ty: Aggregate, + ty: Type, value: Value, index_val: Value, }, /// Writing a specific value to a (nested) struct field. InsertValue { aggregate: Value, - ty: Aggregate, + ty: Type, value: Value, indices: Vec, }, @@ -226,21 +226,25 @@ impl Instruction { /// `Ret` do not have a type. pub fn get_type(&self, context: &Context) -> Option { match self { - Instruction::AddrOf(_) => Some(Type::Uint(64)), + Instruction::AddrOf(_) => Some(Type::get_uint64(context)), Instruction::AsmBlock(asm_block, _) => Some(asm_block.get_type(context)), Instruction::BinaryOp { arg1, .. } => arg1.get_type(context), Instruction::BitCast(_, ty) => Some(*ty), Instruction::Call(function, _) => Some(context.functions[function.0].return_type), Instruction::CastPtr(_val, ty, _offs) => Some(*ty), - Instruction::Cmp(..) => Some(Type::Bool), + Instruction::Cmp(..) => Some(Type::get_bool(context)), Instruction::ContractCall { return_type, .. } => Some(*return_type), - Instruction::ExtractElement { ty, .. } => ty.get_elem_type(context), - Instruction::ExtractValue { ty, indices, .. } => ty.get_field_type(context, indices), - Instruction::FuelVm(FuelVmInstruction::GetStorageKey) => Some(Type::B256), - Instruction::FuelVm(FuelVmInstruction::Gtf { .. }) => Some(Type::Uint(64)), - Instruction::FuelVm(FuelVmInstruction::Log { .. }) => Some(Type::Unit), - Instruction::FuelVm(FuelVmInstruction::ReadRegister(_)) => Some(Type::Uint(64)), - Instruction::FuelVm(FuelVmInstruction::StateLoadWord(_)) => Some(Type::Uint(64)), + Instruction::ExtractElement { ty, .. } => ty.get_array_elem_type(context), + Instruction::ExtractValue { ty, indices, .. } => ty.get_indexed_type(context, indices), + Instruction::FuelVm(FuelVmInstruction::GetStorageKey) => Some(Type::get_b256(context)), + Instruction::FuelVm(FuelVmInstruction::Gtf { .. }) => Some(Type::get_uint64(context)), + Instruction::FuelVm(FuelVmInstruction::Log { .. }) => Some(Type::get_unit(context)), + Instruction::FuelVm(FuelVmInstruction::ReadRegister(_)) => { + Some(Type::get_uint64(context)) + } + Instruction::FuelVm(FuelVmInstruction::StateLoadWord(_)) => { + Some(Type::get_uint64(context)) + } Instruction::InsertElement { array, .. } => array.get_type(context), Instruction::InsertValue { aggregate, .. } => aggregate.get_type(context), Instruction::Load(ptr_val) => match &context.values[ptr_val.0].value { @@ -250,7 +254,7 @@ impl Instruction { }, // These can be recursed to via Load, so we return the pointer type. - Instruction::GetLocal(local_var) => Some(*local_var.get_type(context)), + Instruction::GetLocal(local_var) => Some(local_var.get_type(context)), // Used to re-interpret an integer as a pointer to some type so return the pointer type. Instruction::IntToPtr(_, ty) => Some(*ty), @@ -261,12 +265,18 @@ impl Instruction { Instruction::FuelVm(FuelVmInstruction::Revert(..)) => None, Instruction::Ret(..) => None, - Instruction::FuelVm(FuelVmInstruction::Smo { .. }) => Some(Type::Unit), - Instruction::FuelVm(FuelVmInstruction::StateLoadQuadWord { .. }) => Some(Type::Unit), - Instruction::FuelVm(FuelVmInstruction::StateStoreQuadWord { .. }) => Some(Type::Unit), - Instruction::FuelVm(FuelVmInstruction::StateStoreWord { .. }) => Some(Type::Unit), - Instruction::MemCopy { .. } => Some(Type::Unit), - Instruction::Store { .. } => Some(Type::Unit), + Instruction::FuelVm(FuelVmInstruction::Smo { .. }) => Some(Type::get_unit(context)), + Instruction::FuelVm(FuelVmInstruction::StateLoadQuadWord { .. }) => { + Some(Type::get_unit(context)) + } + Instruction::FuelVm(FuelVmInstruction::StateStoreQuadWord { .. }) => { + Some(Type::get_unit(context)) + } + Instruction::FuelVm(FuelVmInstruction::StateStoreWord { .. }) => { + Some(Type::get_unit(context)) + } + Instruction::MemCopy { .. } => Some(Type::get_unit(context)), + Instruction::Store { .. } => Some(Type::get_unit(context)), // No-op is also no-type. Instruction::Nop => None, @@ -274,37 +284,20 @@ impl Instruction { } /// Some [`Instruction`]s may have struct arguments. Return it if so for this instruction. - pub fn get_aggregate(&self, context: &Context) -> Option { - match self { - Instruction::Call(func, _args) => match &context.functions[func.0].return_type { - Type::Array(aggregate) => Some(*aggregate), - Type::Struct(aggregate) => Some(*aggregate), - _otherwise => None, - }, - Instruction::GetLocal(local_var) => match local_var.get_type(context) { - Type::Array(aggregate) => Some(*aggregate), - Type::Struct(aggregate) => Some(*aggregate), - _otherwise => None, - }, - Instruction::ExtractElement { ty, .. } => { - ty.get_elem_type(context).and_then(|ty| match ty { - Type::Array(nested_aggregate) => Some(nested_aggregate), - Type::Struct(nested_aggregate) => Some(nested_aggregate), - _otherwise => None, - }) - } - Instruction::ExtractValue { ty, indices, .. } => { - // This array is a field in a struct or element in an array. - ty.get_field_type(context, indices).and_then(|ty| match ty { - Type::Array(nested_aggregate) => Some(nested_aggregate), - Type::Struct(nested_aggregate) => Some(nested_aggregate), - _otherwise => None, - }) + pub fn get_aggregate(&self, context: &Context) -> Option { + let ty = match self { + Instruction::Call(func, _args) => Some(context.functions[func.0].return_type), + Instruction::GetLocal(local_var) => Some(local_var.get_type(context)), + Instruction::ExtractElement { ty, .. } => ty.get_array_elem_type(context), + Instruction::ExtractValue { ty, indices, .. } => + // This array is a field in a struct or element in an array. + { + ty.get_indexed_type(context, indices) } - // Unknown aggregate instruction. Adding these as we come across them... _otherwise => None, - } + }; + ty.filter(|ty| ty.is_array(context) || ty.is_struct(context)) } pub fn get_operands(&self) -> Vec { @@ -784,7 +777,7 @@ impl<'a> InstructionInserter<'a> { ) } - pub fn extract_element(self, array: Value, ty: Aggregate, index_val: Value) -> Value { + pub fn extract_element(self, array: Value, ty: Type, index_val: Value) -> Value { make_instruction!( self, Instruction::ExtractElement { @@ -795,7 +788,7 @@ impl<'a> InstructionInserter<'a> { ) } - pub fn extract_value(self, aggregate: Value, ty: Aggregate, indices: Vec) -> Value { + pub fn extract_value(self, aggregate: Value, ty: Type, indices: Vec) -> Value { make_instruction!( self, Instruction::ExtractValue { @@ -821,13 +814,7 @@ impl<'a> InstructionInserter<'a> { make_instruction!(self, Instruction::GetLocal(local_var)) } - pub fn insert_element( - self, - array: Value, - ty: Aggregate, - value: Value, - index_val: Value, - ) -> Value { + pub fn insert_element(self, array: Value, ty: Type, value: Value, index_val: Value) -> Value { make_instruction!( self, Instruction::InsertElement { @@ -842,7 +829,7 @@ impl<'a> InstructionInserter<'a> { pub fn insert_value( self, aggregate: Value, - ty: Aggregate, + ty: Type, value: Value, indices: Vec, ) -> Value { diff --git a/sway-ir/src/irtype.rs b/sway-ir/src/irtype.rs index 3b9e57a5e4e..73fbada7770 100644 --- a/sway-ir/src/irtype.rs +++ b/sway-ir/src/irtype.rs @@ -11,190 +11,315 @@ use crate::{context::Context, pretty::DebugWithContext}; -#[derive(Debug, Clone, Copy, DebugWithContext)] -pub enum Type { +#[derive(Clone, Copy, Debug, Eq, PartialEq, Hash, DebugWithContext)] +pub struct Type(pub generational_arena::Index); + +#[derive(Debug, Clone, DebugWithContext, Hash, PartialEq, Eq)] +pub enum TypeContent { Unit, Bool, Uint(u8), B256, String(u64), - Array(Aggregate), - Union(Aggregate), - Struct(Aggregate), + Array(Type, u64), + Union(Vec), + Struct(Vec), Slice, } impl Type { + fn get_or_create_unique_type(context: &mut Context, t: TypeContent) -> Type { + // Trying to avoiding cloning t unless we're creating a new type. + #[allow(clippy::map_entry)] + if !context.type_map.contains_key(&t) { + let new_type = Type(context.types.insert(t.clone())); + context.type_map.insert(t, new_type); + new_type + } else { + context.type_map.get(&t).copied().unwrap() + } + } + + /// Get Type if it already exists. + pub fn get_type(context: &Context, t: &TypeContent) -> Option { + context.type_map.get(t).copied() + } + + pub fn create_basic_types(context: &mut Context) { + Self::get_or_create_unique_type(context, TypeContent::Unit); + Self::get_or_create_unique_type(context, TypeContent::Bool); + Self::get_or_create_unique_type(context, TypeContent::Uint(8)); + Self::get_or_create_unique_type(context, TypeContent::Uint(32)); + Self::get_or_create_unique_type(context, TypeContent::Uint(64)); + Self::get_or_create_unique_type(context, TypeContent::B256); + Self::get_or_create_unique_type(context, TypeContent::Slice); + } + + /// Get the content for this Type. + pub fn get_content<'a>(&self, context: &'a Context) -> &'a TypeContent { + &context.types[self.0] + } + + /// Get unit type + pub fn get_unit(context: &Context) -> Type { + Self::get_type(context, &TypeContent::Unit).expect("create_basic_types not called") + } + + /// Get bool type + pub fn get_bool(context: &Context) -> Type { + Self::get_type(context, &TypeContent::Bool).expect("create_basic_types not called") + } + + /// New unsigned integer type + pub fn new_uint(context: &mut Context, width: u8) -> Type { + Self::get_or_create_unique_type(context, TypeContent::Uint(width)) + } + + /// New u8 type + pub fn get_uint8(context: &Context) -> Type { + Self::get_type(context, &TypeContent::Uint(8)).expect("create_basic_types not called") + } + + /// New u32 type + pub fn get_uint32(context: &Context) -> Type { + Self::get_type(context, &TypeContent::Uint(32)).expect("create_basic_types not called") + } + + /// New u64 type + pub fn get_uint64(context: &Context) -> Type { + Self::get_type(context, &TypeContent::Uint(64)).expect("create_basic_types not called") + } + + /// Get unsigned integer type + pub fn get_uint(context: &Context, width: u8) -> Option { + Self::get_type(context, &TypeContent::Uint(width)) + } + + /// Get B256 type + pub fn get_b256(context: &Context) -> Type { + Self::get_type(context, &TypeContent::B256).expect("create_basic_types not called") + } + + /// Get string type + pub fn new_string(context: &mut Context, len: u64) -> Type { + Self::get_or_create_unique_type(context, TypeContent::String(len)) + } + + /// Get array type + pub fn new_array(context: &mut Context, elm_ty: Type, len: u64) -> Type { + Self::get_or_create_unique_type(context, TypeContent::Array(elm_ty, len)) + } + + /// Get union type + pub fn new_union(context: &mut Context, fields: Vec) -> Type { + Self::get_or_create_unique_type(context, TypeContent::Union(fields)) + } + + /// Get struct type + pub fn new_struct(context: &mut Context, fields: Vec) -> Type { + Self::get_or_create_unique_type(context, TypeContent::Struct(fields)) + } + + /// Get pointer type + pub fn get_slice(context: &mut Context) -> Type { + Self::get_type(context, &TypeContent::Slice).expect("create_basic_types not called") + } + /// Return a string representation of type, used for printing. pub fn as_string(&self, context: &Context) -> String { - let sep_types_str = |agg_content: &AggregateContent, sep: &str| { + let sep_types_str = |agg_content: &Vec, sep: &str| { agg_content - .field_types() .iter() .map(|ty| ty.as_string(context)) .collect::>() .join(sep) }; - match self { - Type::Unit => "()".into(), - Type::Bool => "bool".into(), - Type::Uint(nbits) => format!("u{}", nbits), - Type::B256 => "b256".into(), - Type::String(n) => format!("string<{}>", n), - Type::Array(agg) => { - let (ty, cnt) = &context.aggregates[agg.0].array_type(); + match self.get_content(context) { + TypeContent::Unit => "()".into(), + TypeContent::Bool => "bool".into(), + TypeContent::Uint(nbits) => format!("u{}", nbits), + TypeContent::B256 => "b256".into(), + TypeContent::String(n) => format!("string<{}>", n), + TypeContent::Array(ty, cnt) => { format!("[{}; {}]", ty.as_string(context), cnt) } - Type::Union(agg) => { - let agg_content = &context.aggregates[agg.0]; - format!("( {} )", sep_types_str(agg_content, " | ")) + TypeContent::Union(agg) => { + format!("( {} )", sep_types_str(agg, " | ")) } - Type::Struct(agg) => { - let agg_content = &context.aggregates[agg.0]; - format!("{{ {} }}", sep_types_str(agg_content, ", ")) + TypeContent::Struct(agg) => { + format!("{{ {} }}", sep_types_str(agg, ", ")) } - Type::Slice => "slice".into(), + TypeContent::Slice => "slice".into(), } } - /// Compare a type to this one for equivalence. We're unable to use `PartialEq` as we need the - /// `Context` to compare structs and arrays. + /// Compare a type to this one for equivalence. + /// `PartialEq` does not take into account the special case for Unions below. pub fn eq(&self, context: &Context, other: &Type) -> bool { - match (self, other) { - (Type::Unit, Type::Unit) => true, - (Type::Bool, Type::Bool) => true, - (Type::Uint(l), Type::Uint(r)) => l == r, - (Type::B256, Type::B256) => true, - (Type::String(l), Type::String(r)) => l == r, - - (Type::Array(l), Type::Array(r)) => l.is_equivalent(context, r), - (Type::Struct(l), Type::Struct(r)) => l.is_equivalent(context, r), + match (self.get_content(context), other.get_content(context)) { + (TypeContent::Unit, TypeContent::Unit) => true, + (TypeContent::Bool, TypeContent::Bool) => true, + (TypeContent::Uint(l), TypeContent::Uint(r)) => l == r, + (TypeContent::B256, TypeContent::B256) => true, + (TypeContent::String(l), TypeContent::String(r)) => l == r, + (TypeContent::Array(l, llen), TypeContent::Array(r, rlen)) => { + llen == rlen && l.eq(context, r) + } + (TypeContent::Struct(l), TypeContent::Struct(r)) + | (TypeContent::Union(l), TypeContent::Union(r)) => { + l.len() == r.len() && l.iter().zip(r.iter()).all(|(l, r)| l.eq(context, r)) + } // Unions are special. We say unions are equivalent to any of their variant types. - (Type::Union(l), Type::Union(r)) => l.is_equivalent(context, r), - (l, r @ Type::Union(_)) => r.eq(context, l), - (Type::Union(l), r) => context.aggregates[l.0] - .field_types() - .iter() - .any(|field_ty| r.eq(context, field_ty)), + (_, TypeContent::Union(_)) => other.eq(context, self), + (TypeContent::Union(l), _) => l.iter().any(|field_ty| other.eq(context, field_ty)), - (Type::Slice, Type::Slice) => true, + (TypeContent::Slice, TypeContent::Slice) => true, _ => false, } } -} -/// A collection of [`Type`]s. -/// -/// XXX I've added Array as using Aggregate in the hope ExtractValue could be used just like with -/// struct aggregates, but it turns out we need ExtractElement (which takes an index Value). So -/// Aggregate can be a 'struct' or 'array' but you only ever use them with Struct and Array types -/// and with ExtractValue and ExtractElement... so they're orthogonal and we can simplify aggregate -/// again to be only for structs. -/// -/// But also to keep Type as Copy we need to put the Array meta into another copy type (rather than -/// recursing with Box, effectively a different Aggregate. This could be OK though, still -/// simpler that what we have here. -/// -/// NOTE: `Aggregate` derives `Eq` (and `PartialEq`) so that it can also derive `Hash`. But we must -/// be careful not to use `==` or `!=` to compare `Aggregate` for equivalency -- i.e., to check -/// that they represent the same collection of types. Instead the `is_equivalent()` method is -/// provided. XXX Perhaps `Hash` should be impl'd directly without `Eq` if possible? + /// Is bool type + pub fn is_bool(&self, context: &Context) -> bool { + matches!(*self.get_content(context), TypeContent::Bool) + } -#[derive(Clone, Copy, Debug, Eq, PartialEq, Hash, DebugWithContext)] -pub struct Aggregate(#[in_context(aggregates)] pub generational_arena::Index); + /// Is unit type + pub fn is_unit(&self, context: &Context) -> bool { + matches!(*self.get_content(context), TypeContent::Unit) + } -#[doc(hidden)] -#[derive(Debug, Clone, DebugWithContext)] -pub enum AggregateContent { - ArrayType(Type, u64), - FieldTypes(Vec), -} + /// Is unsigned integer type + pub fn is_uint(&self, context: &Context) -> bool { + matches!(*self.get_content(context), TypeContent::Uint(_)) + } -impl Aggregate { - /// Return a new struct specific aggregate. - pub fn new_struct(context: &mut Context, field_types: Vec) -> Self { - Aggregate( - context - .aggregates - .insert(AggregateContent::FieldTypes(field_types)), - ) + /// Is u8 type + pub fn is_uint8(&self, context: &Context) -> bool { + matches!(*self.get_content(context), TypeContent::Uint(8)) } - /// Returna new array specific aggregate. - pub fn new_array(context: &mut Context, element_type: Type, count: u64) -> Self { - Aggregate( - context - .aggregates - .insert(AggregateContent::ArrayType(element_type, count)), - ) + /// Is u32 type + pub fn is_uint32(&self, context: &Context) -> bool { + matches!(*self.get_content(context), TypeContent::Uint(32)) } - /// Tests whether an aggregate has the same sub-types. - pub fn is_equivalent(&self, context: &Context, other: &Aggregate) -> bool { - context.aggregates[self.0].eq(context, &context.aggregates[other.0]) + /// Is u64 type + pub fn is_uint64(&self, context: &Context) -> bool { + matches!(*self.get_content(context), TypeContent::Uint(64)) } - /// Get a reference to the [`AggregateContent`] for this aggregate. - pub fn get_content<'a>(&self, context: &'a Context) -> &'a AggregateContent { - &context.aggregates[self.0] + /// Is unsigned integer type of specific width + pub fn is_uint_of(&self, context: &Context, width: u8) -> bool { + matches!(*self.get_content(context), TypeContent::Uint(width_) if width == width_) + } + + /// Is B256 type + pub fn is_b256(&self, context: &Context) -> bool { + matches!(*self.get_content(context), TypeContent::B256) + } + + /// Is string type + pub fn is_string(&self, context: &Context) -> bool { + matches!(*self.get_content(context), TypeContent::String(_)) + } + + /// Is array type + pub fn is_array(&self, context: &Context) -> bool { + matches!(*self.get_content(context), TypeContent::Array(..)) + } + + /// Is union type + pub fn is_union(&self, context: &Context) -> bool { + matches!(*self.get_content(context), TypeContent::Union(_)) + } + + /// Is struct type + pub fn is_struct(&self, context: &Context) -> bool { + matches!(*self.get_content(context), TypeContent::Struct(_)) + } + + /// Returns true if this is a slice type. + pub fn is_slice(&self, context: &Context) -> bool { + matches!(*self.get_content(context), TypeContent::Slice) + } + + /// Get width of an integer type. + pub fn get_uint_width(&self, context: &Context) -> Option { + if let TypeContent::Uint(width) = self.get_content(context) { + Some(*width) + } else { + None + } } - /// Get the type of (nested) aggregate fields, if found. If an index is into a `Union` then it - /// will get the type of the indexed variant. - pub fn get_field_type(&self, context: &Context, indices: &[u64]) -> Option { - indices.iter().fold(Some(Type::Struct(*self)), |ty, idx| { - ty.and_then(|ty| match ty { - Type::Struct(agg) | Type::Union(agg) => context.aggregates[agg.0] - .field_types() - .get(*idx as usize) - .cloned(), + /// What's the type of the struct value indexed by indices. + pub fn get_indexed_type(&self, context: &Context, indices: &[u64]) -> Option { + if indices.is_empty() { + return None; + } - // Trying to index a non-aggregate. - _otherwise => None, + indices.iter().fold(Some(*self), |ty, idx| { + ty.and_then(|ty| { + ty.get_field_type(context, *idx) + .or_else(|| ty.get_array_elem_type(context)) }) }) } + pub fn get_field_type(&self, context: &Context, idx: u64) -> Option { + if let TypeContent::Struct(agg) | TypeContent::Union(agg) = self.get_content(context) { + agg.get(idx as usize).cloned() + } else { + // Trying to index a non-aggregate. + None + } + } + /// Get the type of the array element, if applicable. - pub fn get_elem_type(&self, context: &Context) -> Option { - if let AggregateContent::ArrayType(ty, _) = context.aggregates[self.0] { + pub fn get_array_elem_type(&self, context: &Context) -> Option { + if let TypeContent::Array(ty, _) = *self.get_content(context) { Some(ty) } else { None } } -} -impl AggregateContent { - pub fn field_types(&self) -> &Vec { - match self { - AggregateContent::FieldTypes(types) => types, - AggregateContent::ArrayType(..) => panic!("Getting field types from array aggregate."), + /// Get the length of the array , if applicable. + pub fn get_array_len(&self, context: &Context) -> Option { + if let TypeContent::Array(_, n) = *self.get_content(context) { + Some(n) + } else { + None } } - pub fn array_type(&self) -> (&Type, &u64) { - match self { - AggregateContent::FieldTypes(..) => panic!("Getting array type from fields aggregate."), - AggregateContent::ArrayType(ty, cnt) => (ty, cnt), + /// Get the length of a string + pub fn get_string_len(&self, context: &Context) -> Option { + if let TypeContent::String(n) = *self.get_content(context) { + Some(n) + } else { + None } } - /// Tests whether an aggregate has the same sub-types. - pub fn eq(&self, context: &Context, other: &AggregateContent) -> bool { - match (self, other) { - (AggregateContent::FieldTypes(l_tys), AggregateContent::FieldTypes(r_tys)) => l_tys - .iter() - .zip(r_tys.iter()) - .all(|(l, r)| l.eq(context, r)), - ( - AggregateContent::ArrayType(l_ty, l_cnt), - AggregateContent::ArrayType(r_ty, r_cnt), - ) => l_cnt == r_cnt && l_ty.eq(context, r_ty), - - _ => false, + /// Get the type of each field of a struct Type. Empty vector otherwise. + pub fn get_field_types(&self, context: &Context) -> Vec { + match self.get_content(context) { + TypeContent::Struct(fields) | TypeContent::Union(fields) => fields.clone(), + _ => vec![], } } } + +/// A helper to check if an Option value is of a particular Type. +pub trait TypeOption { + fn is(&self, pred: fn(&Type, &Context) -> bool, context: &Context) -> bool; +} + +impl TypeOption for Option { + fn is(&self, pred: fn(&Type, &Context) -> bool, context: &Context) -> bool { + self.filter(|ty| pred(ty, context)).is_some() + } +} diff --git a/sway-ir/src/local_var.rs b/sway-ir/src/local_var.rs index 4c15fb03eb5..f97d0546795 100644 --- a/sway-ir/src/local_var.rs +++ b/sway-ir/src/local_var.rs @@ -22,8 +22,8 @@ impl LocalVar { } /// Return the type of this local variable. - pub fn get_type<'a>(&self, context: &'a Context) -> &'a Type { - &context.local_vars[self.0].ty + pub fn get_type(&self, context: &Context) -> Type { + context.local_vars[self.0].ty } /// Return the initializer for this local variable. diff --git a/sway-ir/src/optimize/constants.rs b/sway-ir/src/optimize/constants.rs index 180c7ac568c..9bd4dc18443 100644 --- a/sway-ir/src/optimize/constants.rs +++ b/sway-ir/src/optimize/constants.rs @@ -117,7 +117,7 @@ fn combine_cmp(context: &mut Context, function: &Function) -> bool { // Replace this `cmp` instruction with a constant. inst_val.replace( context, - ValueDatum::Constant(Constant::new_bool(cn_replace)), + ValueDatum::Constant(Constant::new_bool(context, cn_replace)), ); block.remove_instruction(context, inst_val); true diff --git a/sway-ir/src/optimize/inline.rs b/sway-ir/src/optimize/inline.rs index 0c631bba44c..e23e83abcc0 100644 --- a/sway-ir/src/optimize/inline.rs +++ b/sway-ir/src/optimize/inline.rs @@ -89,28 +89,22 @@ pub fn is_small_fn( ) -> impl Fn(&Context, &Function, &Value) -> bool { fn count_type_elements(context: &Context, ty: &Type) -> usize { // This is meant to just be a heuristic rather than be super accurate. - match ty { - Type::Unit - | Type::Bool - | Type::Uint(_) - | Type::B256 - | Type::String(_) - | Type::Slice => 1, - Type::Array(aggregate) => { - let (ty, sz) = context.aggregates[aggregate.0].array_type(); - count_type_elements(context, ty) * *sz as usize - } - Type::Union(aggregate) => context.aggregates[aggregate.0] - .field_types() + if ty.is_array(context) { + count_type_elements(context, &ty.get_array_elem_type(context).unwrap()) + * ty.get_array_len(context).unwrap() as usize + } else if ty.is_union(context) { + ty.get_field_types(context) .iter() .map(|ty| count_type_elements(context, ty)) .max() - .unwrap_or(1), - Type::Struct(aggregate) => context.aggregates[aggregate.0] - .field_types() + .unwrap_or(1) + } else if ty.is_struct(context) { + ty.get_field_types(context) .iter() .map(|ty| count_type_elements(context, ty)) - .sum(), + .sum() + } else { + 1 } } @@ -125,7 +119,7 @@ pub fn is_small_fn( .map(|max_stack_size_count| { function .locals_iter(context) - .map(|(_name, ptr)| count_type_elements(context, ptr.get_type(context))) + .map(|(_name, ptr)| count_type_elements(context, &ptr.get_type(context))) .sum::() <= max_stack_size_count }) diff --git a/sway-ir/src/optimize/mem2reg.rs b/sway-ir/src/optimize/mem2reg.rs index 6f7989e8795..3b204723206 100644 --- a/sway-ir/src/optimize/mem2reg.rs +++ b/sway-ir/src/optimize/mem2reg.rs @@ -33,10 +33,11 @@ fn filter_usable_locals(context: &mut Context, function: &Function) -> HashSet = function .locals_iter(context) - .filter(|(_, var)| match var.get_type(context) { - Type::Unit | Type::Bool => true, - Type::Uint(n) => *n <= 64, - _ => false, + .filter(|(_, var)| { + let ty = var.get_type(context); + ty.is_unit(context) + || ty.is_bool(context) + || (ty.is_uint(context) && ty.get_uint_width(context).unwrap() <= 64) }) .map(|(name, _)| name.clone()) .collect(); @@ -153,7 +154,7 @@ pub fn promote_to_registers(context: &mut Context, function: &Function) -> Resul { match get_validate_local_var(context, function, &dst_val) { Some((local, var)) if safe_locals.contains(&local) => { - worklist.push((local, *var.get_type(context), block)); + worklist.push((local, var.get_type(context), block)); } _ => (), } diff --git a/sway-ir/src/parser.rs b/sway-ir/src/parser.rs index a70b058772d..8c7cf2bb4b9 100644 --- a/sway-ir/src/parser.rs +++ b/sway-ir/src/parser.rs @@ -599,7 +599,7 @@ mod ir_builder { error::IrError, function::Function, instruction::{Instruction, Predicate, Register}, - irtype::{Aggregate, Type}, + irtype::Type, local_var::LocalVar, metadata::{MetadataIndex, Metadatum}, module::{Kind, Module}, @@ -783,29 +783,22 @@ mod ir_builder { impl IrAstTy { fn to_ir_type(&self, context: &mut Context) -> Type { match self { - IrAstTy::Unit => Type::Unit, - IrAstTy::Bool => Type::Bool, - IrAstTy::U64 => Type::Uint(64), - IrAstTy::B256 => Type::B256, - IrAstTy::String(n) => Type::String(*n), - IrAstTy::Array(..) => Type::Array(self.to_ir_aggregate_type(context)), - IrAstTy::Union(_) => Type::Union(self.to_ir_aggregate_type(context)), - IrAstTy::Struct(_) => Type::Struct(self.to_ir_aggregate_type(context)), - } - } - - fn to_ir_aggregate_type(&self, context: &mut Context) -> Aggregate { - match self { + IrAstTy::Unit => Type::get_unit(context), + IrAstTy::Bool => Type::get_bool(context), + IrAstTy::U64 => Type::get_uint64(context), + IrAstTy::B256 => Type::get_b256(context), + IrAstTy::String(n) => Type::new_string(context, *n), IrAstTy::Array(el_ty, count) => { let el_ty = el_ty.to_ir_type(context); - Aggregate::new_array(context, el_ty, *count) + Type::new_array(context, el_ty, *count) } - IrAstTy::Struct(tys) | IrAstTy::Union(tys) => { + IrAstTy::Union(tys) => { let tys = tys.iter().map(|ty| ty.to_ir_type(context)).collect(); - Aggregate::new_struct(context, tys) + Type::new_union(context, tys) } - _otherwise => { - unreachable!("Converting non aggregate IR AST type to IR aggregate type.") + IrAstTy::Struct(tys) => { + let tys = tys.iter().map(|ty| ty.to_ir_type(context)).collect(); + Type::new_struct(context, tys) } } } @@ -1127,7 +1120,7 @@ mod ir_builder { .add_metadatum(context, opt_metadata) } IrAstOperation::ExtractElement(aval, ty, idx) => { - let ir_ty = ty.to_ir_aggregate_type(context); + let ir_ty = ty.to_ir_type(context); block .ins(context) .extract_element( @@ -1138,7 +1131,7 @@ mod ir_builder { .add_metadatum(context, opt_metadata) } IrAstOperation::ExtractValue(val, ty, idcs) => { - let ir_ty = ty.to_ir_aggregate_type(context); + let ir_ty = ty.to_ir_type(context); block .ins(context) .extract_value(*val_map.get(&val).unwrap(), ir_ty, idcs) @@ -1157,7 +1150,7 @@ mod ir_builder { .gtf(*val_map.get(&index).unwrap(), tx_field_id) .add_metadatum(context, opt_metadata), IrAstOperation::InsertElement(aval, ty, val, idx) => { - let ir_ty = ty.to_ir_aggregate_type(context); + let ir_ty = ty.to_ir_type(context); block .ins(context) .insert_element( @@ -1169,7 +1162,7 @@ mod ir_builder { .add_metadatum(context, opt_metadata) } IrAstOperation::InsertValue(aval, ty, ival, idcs) => { - let ir_ty = ty.to_ir_aggregate_type(context); + let ir_ty = ty.to_ir_type(context); block .ins(context) .insert_value( diff --git a/sway-ir/src/printer.rs b/sway-ir/src/printer.rs index f388d9d1b08..abedcbb60fc 100644 --- a/sway-ir/src/printer.rs +++ b/sway-ir/src/printer.rs @@ -13,7 +13,6 @@ use crate::{ context::Context, function::{Function, FunctionContent}, instruction::{FuelVmInstruction, Instruction, Predicate, Register}, - irtype::Type, metadata::{MetadataIndex, Metadatum}, module::{Kind, ModuleContent}, value::{Value, ValueContent, ValueDatum}, @@ -495,7 +494,7 @@ fn instruction_to_doc<'a>( "{} = extract_element {}, {}, {}", namer.name(context, ins_value), namer.name(context, array), - Type::Array(*ty).as_string(context), + ty.as_string(context), namer.name(context, index_val), )) .append(md_namer.md_idx_to_doc(context, metadata)), @@ -509,7 +508,7 @@ fn instruction_to_doc<'a>( "{} = extract_value {}, {}, ", namer.name(context, ins_value), namer.name(context, aggregate), - Type::Struct(*ty).as_string(context), + ty.as_string(context), )) .append(Doc::list_sep( indices @@ -687,7 +686,7 @@ fn instruction_to_doc<'a>( "{} = insert_element {}, {}, {}, {}", namer.name(context, ins_value), namer.name(context, array), - Type::Array(*ty).as_string(context), + ty.as_string(context), namer.name(context, value), namer.name(context, index_val), )) @@ -705,7 +704,7 @@ fn instruction_to_doc<'a>( "{} = insert_value {}, {}, {}, ", namer.name(context, ins_value), namer.name(context, aggregate), - Type::Struct(*ty).as_string(context), + ty.as_string(context), namer.name(context, value), )) .append(Doc::list_sep( diff --git a/sway-ir/src/verify.rs b/sway-ir/src/verify.rs index c7e85add8ea..079f78e1a9b 100644 --- a/sway-ir/src/verify.rs +++ b/sway-ir/src/verify.rs @@ -9,12 +9,12 @@ use crate::{ error::IrError, function::{Function, FunctionContent}, instruction::{FuelVmInstruction, Instruction, Predicate}, - irtype::{Aggregate, Type}, + irtype::Type, local_var::LocalVar, metadata::{MetadataIndex, Metadatum}, module::ModuleContent, value::{Value, ValueDatum}, - BinaryOpKind, BlockArgument, BranchToWithArgs, + BinaryOpKind, BlockArgument, BranchToWithArgs, TypeOption, }; impl Context { @@ -206,7 +206,7 @@ impl<'a> InstructionVerifier<'a> { number_of_slots, } => self.verify_state_load_store( dst_val, - &Type::B256, + Type::get_b256(self.context), key, number_of_slots, )?, @@ -294,7 +294,7 @@ impl<'a> InstructionVerifier<'a> { let arg2_ty = arg2 .get_type(self.context) .ok_or(IrError::VerifyBinaryOpIncorrectArgType)?; - if !arg1_ty.eq(self.context, &arg2_ty) || !matches!(arg1_ty, Type::Uint(_)) { + if !arg1_ty.eq(self.context, &arg2_ty) || !arg1_ty.is_uint(self.context) { return Err(IrError::VerifyBinaryOpIncorrectArgType); } @@ -356,12 +356,15 @@ impl<'a> InstructionVerifier<'a> { } fn verify_cast_ptr(&self, val: &Value, ty: &Type) -> Result<(), IrError> { - if matches!( - val.get_type(self.context), - Some(Type::Unit | Type::Bool | Type::Uint(_)) - ) { + let non_pointer_type = |ty: &Type, context: &Context| { + ty.is_unit(context) | ty.is_bool(context) | ty.is_uint(context) + }; + if val + .get_type(self.context) + .is(non_pointer_type, self.context) + { Err(IrError::VerifyPtrCastFromNonPointer) - } else if matches!(ty, Type::Unit | Type::Bool | Type::Uint(_)) { + } else if non_pointer_type(ty, self.context) { Err(IrError::VerifyPtrCastToNonPointer) } else { // Just going to throw this assert in here. `cast_ptr` is a temporary measure and this @@ -401,7 +404,10 @@ impl<'a> InstructionVerifier<'a> { true_block: &BranchToWithArgs, false_block: &BranchToWithArgs, ) -> Result<(), IrError> { - if !matches!(cond_val.get_type(self.context), Some(Type::Bool)) { + if !cond_val + .get_type(self.context) + .is(Type::is_bool, self.context) + { Err(IrError::VerifyConditionExprNotABool) } else if !self.cur_function.blocks.contains(&true_block.block) { Err(IrError::VerifyBranchToMissingBlock( @@ -428,23 +434,21 @@ impl<'a> InstructionVerifier<'a> { lhs_value.get_type(self.context), rhs_value.get_type(self.context), ) { - (Some(lhs_ty), Some(rhs_ty)) => match (lhs_ty, rhs_ty) { - (Type::Uint(lhs_nbits), Type::Uint(rhs_nbits)) => { - if lhs_nbits != rhs_nbits { - Err(IrError::VerifyCmpTypeMismatch( - lhs_ty.as_string(self.context), - rhs_ty.as_string(self.context), - )) - } else { - Ok(()) - } + (Some(lhs_ty), Some(rhs_ty)) => { + if !lhs_ty.eq(self.context, &rhs_ty) { + Err(IrError::VerifyCmpTypeMismatch( + lhs_ty.as_string(self.context), + rhs_ty.as_string(self.context), + )) + } else if lhs_ty.is_bool(self.context) || lhs_ty.is_uint(self.context) { + Ok(()) + } else { + Err(IrError::VerifyCmpBadTypes( + lhs_ty.as_string(self.context), + rhs_ty.as_string(self.context), + )) } - (Type::Bool, Type::Bool) => Ok(()), - _otherwise => Err(IrError::VerifyCmpBadTypes( - lhs_ty.as_string(self.context), - rhs_ty.as_string(self.context), - )), - }, + } _otherwise => Err(IrError::VerifyCmpUnknownTypes), } } @@ -460,36 +464,40 @@ impl<'a> InstructionVerifier<'a> { // user args. // - The coins and gas must be u64s. // - The asset_id must be a B256 - if let Some(Type::Struct(agg)) = params.get_type(self.context) { - let fields = self.context.aggregates[agg.0].field_types(); - if fields.len() != 3 - || !fields[0].eq(self.context, &Type::B256) - || !fields[1].eq(self.context, &Type::Uint(64)) - || !fields[2].eq(self.context, &Type::Uint(64)) - { - Err(IrError::VerifyContractCallBadTypes("params".to_owned())) - } else { - Ok(()) - } - } else { + let fields = params + .get_type(self.context) + .map_or_else(std::vec::Vec::new, |ty| ty.get_field_types(self.context)); + if fields.len() != 3 + || !fields[0].is_b256(self.context) + || !fields[1].is_uint64(self.context) + || !fields[2].is_uint64(self.context) + { Err(IrError::VerifyContractCallBadTypes("params".to_owned())) + } else { + Ok(()) } .and_then(|_| { - if let Some(Type::Uint(64)) = coins.get_type(self.context) { + if coins + .get_type(self.context) + .is(Type::is_uint64, self.context) + { Ok(()) } else { Err(IrError::VerifyContractCallBadTypes("coins".to_owned())) } }) .and_then(|_| { - if let Some(Type::B256) = asset_id.get_type(self.context) { + if asset_id + .get_type(self.context) + .is(Type::is_b256, self.context) + { Ok(()) } else { Err(IrError::VerifyContractCallBadTypes("asset_id".to_owned())) } }) .and_then(|_| { - if let Some(Type::Uint(64)) = gas.get_type(self.context) { + if gas.get_type(self.context).is(Type::is_uint64, self.context) { Ok(()) } else { Err(IrError::VerifyContractCallBadTypes("gas".to_owned())) @@ -500,14 +508,17 @@ impl<'a> InstructionVerifier<'a> { fn verify_extract_element( &self, array: &Value, - ty: &Aggregate, + ty: &Type, index_val: &Value, ) -> Result<(), IrError> { match array.get_type(self.context) { - Some(Type::Array(ary_ty)) => { - if !ary_ty.is_equivalent(self.context, ty) { + Some(ary_ty) if ary_ty.is_array(self.context) => { + if !ary_ty.eq(self.context, ty) { Err(IrError::VerifyAccessElementInconsistentTypes) - } else if !matches!(index_val.get_type(self.context), Some(Type::Uint(_))) { + } else if !index_val + .get_type(self.context) + .is(Type::is_uint, self.context) + { Err(IrError::VerifyAccessElementNonIntIndex) } else { Ok(()) @@ -520,14 +531,14 @@ impl<'a> InstructionVerifier<'a> { fn verify_extract_value( &self, aggregate: &Value, - ty: &Aggregate, + ty: &Type, indices: &[u64], ) -> Result<(), IrError> { match aggregate.get_type(self.context) { - Some(Type::Struct(agg_ty)) | Some(Type::Union(agg_ty)) => { - if !agg_ty.is_equivalent(self.context, ty) { + Some(agg_ty) if agg_ty.is_struct(self.context) || agg_ty.is_union(self.context) => { + if !agg_ty.eq(self.context, ty) { Err(IrError::VerifyAccessValueInconsistentTypes) - } else if ty.get_field_type(self.context, indices).is_none() { + } else if ty.get_indexed_type(self.context, indices).is_none() { Err(IrError::VerifyAccessValueInvalidIndices) } else { Ok(()) @@ -552,7 +563,7 @@ impl<'a> InstructionVerifier<'a> { fn verify_gtf(&self, index: &Value, _tx_field_id: &u64) -> Result<(), IrError> { // We should perhaps verify that _tx_field_id fits in a twelve bit immediate - if !matches!(index.get_type(self.context), Some(Type::Uint(_))) { + if !index.get_type(self.context).is(Type::is_uint, self.context) { Err(IrError::VerifyInvalidGtfIndexType) } else { Ok(()) @@ -562,20 +573,23 @@ impl<'a> InstructionVerifier<'a> { fn verify_insert_element( &self, array: &Value, - ty: &Aggregate, + ty: &Type, value: &Value, index_val: &Value, ) -> Result<(), IrError> { match array.get_type(self.context) { - Some(Type::Array(ary_ty)) => { - if !ary_ty.is_equivalent(self.context, ty) { + Some(ary_ty) if ary_ty.is_array(self.context) => { + if !ary_ty.eq(self.context, ty) { Err(IrError::VerifyAccessElementInconsistentTypes) } else if self.opt_ty_not_eq( - &ty.get_elem_type(self.context), + &ty.get_array_elem_type(self.context), &value.get_type(self.context), ) { Err(IrError::VerifyInsertElementOfIncorrectType) - } else if !matches!(index_val.get_type(self.context), Some(Type::Uint(_))) { + } else if !index_val + .get_type(self.context) + .is(Type::is_uint, self.context) + { Err(IrError::VerifyAccessElementNonIntIndex) } else { Ok(()) @@ -588,16 +602,16 @@ impl<'a> InstructionVerifier<'a> { fn verify_insert_value( &self, aggregate: &Value, - ty: &Aggregate, + ty: &Type, value: &Value, idcs: &[u64], ) -> Result<(), IrError> { match aggregate.get_type(self.context) { - Some(Type::Struct(str_ty)) => { - if !str_ty.is_equivalent(self.context, ty) { + Some(str_ty) if str_ty.is_struct(self.context) => { + if !str_ty.eq(self.context, ty) { Err(IrError::VerifyAccessValueInconsistentTypes) } else { - let field_ty = ty.get_field_type(self.context, idcs); + let field_ty = ty.get_indexed_type(self.context, idcs); if field_ty.is_none() { Err(IrError::VerifyAccessValueInvalidIndices) } else if self.opt_ty_not_eq(&field_ty, &value.get_type(self.context)) { @@ -617,7 +631,7 @@ impl<'a> InstructionVerifier<'a> { let val_ty = value .get_type(self.context) .ok_or(IrError::VerifyIntToPtrUnknownSourceType)?; - if !matches!(val_ty, Type::Uint(64)) { + if !val_ty.is_uint64(self.context) { return Err(IrError::VerifyIntToPtrFromNonIntegerType( val_ty.as_string(self.context), )); @@ -643,7 +657,10 @@ impl<'a> InstructionVerifier<'a> { } fn verify_log(&self, log_val: &Value, log_ty: &Type, log_id: &Value) -> Result<(), IrError> { - if !matches!(log_id.get_type(self.context), Some(Type::Uint(64))) { + if !log_id + .get_type(self.context) + .is(Type::is_uint64, self.context) + { return Err(IrError::VerifyLogId); } @@ -705,7 +722,7 @@ impl<'a> InstructionVerifier<'a> { } fn verify_revert(&self, val: &Value) -> Result<(), IrError> { - if !matches!(val.get_type(self.context), Some(Type::Uint(64))) { + if !val.get_type(self.context).is(Type::is_uint64, self.context) { Err(IrError::VerifyRevertCodeBadType) } else { Ok(()) @@ -721,9 +738,11 @@ impl<'a> InstructionVerifier<'a> { ) -> Result<(), IrError> { // Check that the first operand is a struct with the first field being a `b256` // representing the recipient address - if let Some(Type::Struct(agg)) = recipient_and_message.get_type(self.context) { - let fields = self.context.aggregates[agg.0].field_types(); - if fields.is_empty() || !fields[0].eq(self.context, &Type::B256) { + if let Some(fields) = recipient_and_message + .get_type(self.context) + .map(|ty| ty.get_field_types(self.context)) + { + if fields.is_empty() || !fields[0].is_b256(self.context) { return Err(IrError::VerifySmoRecipientBadType); } } else { @@ -731,17 +750,26 @@ impl<'a> InstructionVerifier<'a> { } // Check that the second operand is a `u64` representing the message size. - if !matches!(message_size.get_type(self.context), Some(Type::Uint(64))) { + if !message_size + .get_type(self.context) + .is(Type::is_uint64, self.context) + { return Err(IrError::VerifySmoMessageSize); } // Check that the third operand is a `u64` representing the output index. - if !matches!(output_index.get_type(self.context), Some(Type::Uint(64))) { + if !output_index + .get_type(self.context) + .is(Type::is_uint64, self.context) + { return Err(IrError::VerifySmoOutputIndex); } // Check that the fourth operand is a `u64` representing the amount of coins being sent. - if !matches!(coins.get_type(self.context), Some(Type::Uint(64))) { + if !coins + .get_type(self.context) + .is(Type::is_uint64, self.context) + { return Err(IrError::VerifySmoCoins); } @@ -751,17 +779,20 @@ impl<'a> InstructionVerifier<'a> { fn verify_state_load_store( &self, dst_val: &Value, - val_type: &Type, + val_type: Type, key: &Value, number_of_slots: &Value, ) -> Result<(), IrError> { - if self.opt_ty_not_eq(&dst_val.get_type(self.context), &Some(*val_type)) { + if self.opt_ty_not_eq(&dst_val.get_type(self.context), &Some(val_type)) { Err(IrError::VerifyStateDestBadType( val_type.as_string(self.context), )) - } else if !matches!(key.get_type(self.context), Some(Type::B256)) { + } else if !key.get_type(self.context).is(Type::is_b256, self.context) { Err(IrError::VerifyStateKeyBadType) - } else if !matches!(number_of_slots.get_type(self.context), Some(Type::Uint(_))) { + } else if !number_of_slots + .get_type(self.context) + .is(Type::is_uint, self.context) + { Err(IrError::VerifyStateAccessNumOfSlots) } else { Ok(()) @@ -769,7 +800,7 @@ impl<'a> InstructionVerifier<'a> { } fn verify_state_load_word(&self, key: &Value) -> Result<(), IrError> { - if !matches!(key.get_type(self.context), Some(Type::B256)) { + if !key.get_type(self.context).is(Type::is_b256, self.context) { Err(IrError::VerifyStateKeyBadType) } else { Ok(()) @@ -777,11 +808,14 @@ impl<'a> InstructionVerifier<'a> { } fn verify_state_store_word(&self, dst_val: &Value, key: &Value) -> Result<(), IrError> { - if !matches!(key.get_type(self.context), Some(Type::B256)) { + if !key.get_type(self.context).is(Type::is_b256, self.context) { Err(IrError::VerifyStateKeyBadType) - } else if !matches!(dst_val.get_type(self.context), Some(Type::Uint(64))) { + } else if !dst_val + .get_type(self.context) + .is(Type::is_uint, self.context) + { Err(IrError::VerifyStateDestBadType( - Type::Uint(64).as_string(self.context), + Type::get_uint64(self.context).as_string(self.context), )) } else { Ok(()) @@ -808,11 +842,14 @@ impl<'a> InstructionVerifier<'a> { // Typically we don't want to make assumptions about the size of types in the IR. This is // here until we reintroduce pointers and don't need to care about type sizes (and whether // they'd fit in a 64 bit register). - match ty { - Type::Unit | Type::Bool => Some(1), - Type::Uint(n) => Some(*n as usize), - Type::B256 => Some(256), - _ => None, + if ty.is_unit(self.context) || ty.is_bool(self.context) { + Some(1) + } else if ty.is_uint(self.context) { + Some(ty.get_uint_width(self.context).unwrap() as usize) + } else if ty.is_b256(self.context) { + Some(256) + } else { + None } }