Skip to content

Commit

Permalink
Implement OR match pattern (FuelLabs#4348)
Browse files Browse the repository at this point in the history
## Description

Allow users to specify match patters such as:

```sway
let x = 1;
match x {
  0 | 1 => true,
  _ => false
}
```

We also check that all patterns in a disjunction declare the same set of
variables (not doing so is an error).

Fix FuelLabs#769

This requires a change in the pattern matching analysis to remove the
assumption that a specialized matrix of a vector pattern is always a
vector, which is now no longer true because or patterns can generate
multiple branches and therefore multiple rows.

## Checklist

- [x] I have linked to any relevant issues.
- [x] I have commented my code, particularly in hard-to-understand
areas.
- [x] I have updated the documentation where relevant (API docs, the
reference, and the Sway book).
- [x] I have added tests that prove my fix is effective or that my
feature works.
- [x] I have added (or requested a maintainer to add) the necessary
`Breaking*` or `New Feature` labels where relevant.
- [x] I have done my best to ensure that my PR adheres to [the Fuel Labs
Code Review
Standards](https://github.com/FuelLabs/rfcs/blob/master/text/code-standards/external-contributors.md).
- [x] I have requested a review from the relevant team or maintainers.

Co-authored-by: Joshua Batty <[email protected]>
  • Loading branch information
IGI-111 and JoshuaBatty authored Apr 18, 2023
1 parent 2e494f3 commit 30f869a
Show file tree
Hide file tree
Showing 33 changed files with 643 additions and 125 deletions.
1 change: 1 addition & 0 deletions docs/book/src/basics/control_flow.md
Original file line number Diff line number Diff line change
Expand Up @@ -40,6 +40,7 @@ The basic syntax of a `match` statement is as follows:
let result = match expression {
pattern1 => code_to_execute_if_expression_matches_pattern1,
pattern2 => code_to_execute_if_expression_matches_pattern2,
pattern3 | pattern4 => code_to_execute_if_expression_matches_pattern3_or_pattern4
...
_ => code_to_execute_if_expression_matches_no_pattern,
}
Expand Down
1 change: 1 addition & 0 deletions docs/reference/src/code/language/control_flow/src/lib.sw
Original file line number Diff line number Diff line change
Expand Up @@ -103,6 +103,7 @@ fn simple_match() {
0 => 10,
1 => 20,
5 => 50,
6 | 7 => 60,
catch_all => 0,
};
// ANCHOR_END: simple_match
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -10,4 +10,6 @@ The left side of the arrow `=>` is the pattern that we are matching on and the r

We check each arm starting from `0` and make our way down until we either find a match on our pattern or we reach the `catch_all` case.

The `|` operator can be used to produce a pattern that is a disjuction of other patterns.

The `catch_all` case is equivalent to an `else` in [if expressions](../if-expressions.md) and it does not have to be called `catch_all`. Any pattern declared after a `catch_all` case will not be matched because once the compiler sees the first `catch_all` it stop performing further checks.
15 changes: 10 additions & 5 deletions examples/match_statements/src/main.sw
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@ fn on_odd(num: u64) {

fn main(num: u64) -> u64 {
// Match as an expression
let isEven = match num % 2 {
let is_even = match num % 2 {
0 => true,
_ => false,
};
Expand All @@ -29,17 +29,22 @@ fn main(num: u64) -> u64 {
Cloudy: (),
Snowy: (),
}
let currentWeather = Weather::Sunny;
let avgTemp = match currentWeather {
let current_weather = Weather::Sunny;
let avg_temp = match current_weather {
Weather::Sunny => 80,
Weather::Rainy => 50,
Weather::Cloudy => 60,
Weather::Snowy => 20,
};

let is_sunny = match current_weather {
Weather::Sunny => true,
Weather::Rainy | Weather::Cloudy | Weather::Snowy => false,
};

// match expression used for a return
let outsideTemp = Weather::Sunny;
match outsideTemp {
let outside_temp = Weather::Sunny;
match outside_temp {
Weather::Sunny => 80,
Weather::Rainy => 50,
Weather::Cloudy => 60,
Expand Down
10 changes: 10 additions & 0 deletions sway-ast/src/pattern.rs
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,11 @@ use crate::priv_prelude::*;

#[derive(Clone, Debug, Serialize)]
pub enum Pattern {
Or {
lhs: Box<Pattern>,
pipe_token: PipeToken,
rhs: Box<Pattern>,
},
Wildcard {
underscore_token: UnderscoreToken,
},
Expand All @@ -28,6 +33,11 @@ pub enum Pattern {
impl Spanned for Pattern {
fn span(&self) -> Span {
match self {
Pattern::Or {
lhs,
pipe_token,
rhs,
} => Span::join(Span::join(lhs.span(), pipe_token.span()), rhs.span()),
Pattern::Wildcard { underscore_token } => underscore_token.span(),
Pattern::Var {
reference,
Expand Down
7 changes: 6 additions & 1 deletion sway-core/src/language/parsed/expression/scrutinee.rs
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,10 @@ use sway_types::{ident::Ident, span::Span, Spanned};
#[allow(clippy::enum_variant_names)]
#[derive(Debug, Clone)]
pub enum Scrutinee {
Or {
elems: Vec<Scrutinee>,
span: Span,
},
CatchAll {
span: Span,
},
Expand Down Expand Up @@ -58,6 +62,7 @@ pub enum StructScrutineeField {
impl Spanned for Scrutinee {
fn span(&self) -> Span {
match self {
Scrutinee::Or { span, .. } => span.clone(),
Scrutinee::CatchAll { span } => span.clone(),
Scrutinee::Literal { span, .. } => span.clone(),
Scrutinee::Variable { span, .. } => span.clone(),
Expand Down Expand Up @@ -157,7 +162,7 @@ impl Scrutinee {
let value = value.gather_approximate_typeinfo_dependencies();
vec![name, value].concat()
}
Scrutinee::Tuple { elems, .. } => elems
Scrutinee::Tuple { elems, .. } | Scrutinee::Or { elems, .. } => elems
.iter()
.flat_map(|scrutinee| scrutinee.gather_approximate_typeinfo_dependencies())
.collect::<Vec<TypeInfo>>(),
Expand Down
2 changes: 1 addition & 1 deletion sway-core/src/language/ty/expression/match_expression.rs
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@ pub(crate) struct TyMatchExpression {

#[derive(Debug)]
pub(crate) struct TyMatchBranch {
pub(crate) conditions: MatchReqMap,
pub(crate) cnf: MatchReqMap,
pub(crate) result: TyExpression,
#[allow(dead_code)]
pub(crate) span: Span,
Expand Down
1 change: 1 addition & 0 deletions sway-core/src/language/ty/expression/scrutinee.rs
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@ pub struct TyScrutinee {

#[derive(Debug, Clone)]
pub enum TyScrutineeVariant {
Or(Vec<TyScrutinee>),
CatchAll,
Literal(Literal),
Variable(Ident),
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -320,12 +320,22 @@ impl ConstructorFactory {
)
}
Pattern::Tuple(elems) => Pattern::Tuple(PatStack::fill_wildcards(elems.len())),
Pattern::Or(_) => {
errors.push(CompileError::Unimplemented(
"or patterns are not supported",
span.clone(),
));
return err(warnings, errors);
Pattern::Or(elems) => {
let mut pat_stack = PatStack::empty();
for pat in elems.into_iter() {
pat_stack.push(check!(
self.create_pattern_not_present(engines, PatStack::from_pattern(pat), span),
return err(warnings, errors),
warnings,
errors
));
}
check!(
Pattern::from_pat_stack(pat_stack, span),
return err(warnings, errors),
warnings,
errors
)
}
};
ok(pat, warnings, errors)
Expand Down Expand Up @@ -391,13 +401,28 @@ impl ConstructorFactory {
) -> CompileResult<bool> {
let mut warnings = vec![];
let mut errors = vec![];

// flatten or patterns
let pat_stack = check!(
pat_stack.clone().serialize_multi_patterns(span),
return err(warnings, errors),
warnings,
errors
)
.into_iter()
.fold(PatStack::empty(), |mut acc, mut pats| {
acc.append(&mut pats);
acc
});

if pat_stack.is_empty() {
return ok(false, warnings, errors);
}
if pat_stack.contains(&Pattern::Wildcard) {
return ok(true, warnings, errors);
}
let (first, rest) = check!(

let (first, mut rest) = check!(
pat_stack.split_first(span),
return err(warnings, errors),
warnings,
Expand Down Expand Up @@ -568,12 +593,18 @@ impl ConstructorFactory {
));
err(warnings, errors)
}
Pattern::Or(_) => {
errors.push(CompileError::Unimplemented(
"or patterns are not supported",
span.clone(),
));
err(warnings, errors)
Pattern::Or(mut elems) => {
elems.append(&mut rest);
ok(
check!(
self.is_complete_signature(engines, &elems, span),
return err(warnings, errors),
warnings,
errors
),
warnings,
errors,
)
}
}
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -74,36 +74,6 @@ impl Matrix {
ok((self.rows.len(), n), warnings, errors)
}

/// Reports if the `Matrix` is equivalent to a vector (aka a single
/// `PatStack`).
pub(crate) fn is_a_vector(&self) -> bool {
self.rows.len() == 1
}

/// Checks to see if the `Matrix` is a vector, and if it is, returns the
/// single `PatStack` from its elements.
pub(crate) fn unwrap_vector(&self, span: &Span) -> CompileResult<PatStack> {
let warnings = vec![];
let mut errors = vec![];
if !self.is_a_vector() {
errors.push(CompileError::Internal(
"found invalid matrix size",
span.clone(),
));
return err(warnings, errors);
}
match self.rows.first() {
Some(first) => ok(first.clone(), warnings, errors),
None => {
errors.push(CompileError::Internal(
"found invalid matrix size",
span.clone(),
));
err(warnings, errors)
}
}
}

/// Computes Σ, where Σ is a `PatStack` containing the first element of
/// every row of the `Matrix`.
pub(crate) fn compute_sigma(&self, span: &Span) -> CompileResult<PatStack> {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -148,10 +148,6 @@ impl PatStack {
self.pats.iter()
}

pub(crate) fn into_iter(self) -> IntoIter<Pattern> {
self.pats.into_iter()
}

/// Flattens the contents of a `PatStack` into a `PatStack`.
pub(crate) fn flatten(&self) -> PatStack {
let mut flattened = PatStack::empty();
Expand Down Expand Up @@ -315,6 +311,14 @@ impl PatStack {
}
}

impl IntoIterator for PatStack {
type Item = Pattern;
type IntoIter = IntoIter<Pattern>;
fn into_iter(self) -> Self::IntoIter {
self.pats.into_iter()
}
}

impl From<Vec<Pattern>> for PatStack {
fn from(pats: Vec<Pattern>) -> Self {
PatStack { pats }
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -144,6 +144,18 @@ impl Pattern {
fields: new_fields,
})
}
ty::TyScrutineeVariant::Or(elems) => {
let mut new_elems = PatStack::empty();
for elem in elems.into_iter() {
new_elems.push(check!(
Pattern::from_scrutinee(elem),
return err(warnings, errors),
warnings,
errors
));
}
Pattern::Or(new_elems)
}
ty::TyScrutineeVariant::Tuple(elems) => {
let mut new_elems = PatStack::empty();
for elem in elems.into_iter() {
Expand Down Expand Up @@ -495,7 +507,31 @@ impl Pattern {
errors
)
}
Pattern::Or(_) => unreachable!(),
Pattern::Or(elems) => {
if elems.len() != args.len() {
errors.push(CompileError::Internal(
"malformed constructor request",
span.clone(),
));
return err(warnings, errors);
}
let pats: PatStack = check!(
args.serialize_multi_patterns(span),
return err(warnings, errors),
warnings,
errors
)
.into_iter()
.map(Pattern::Or)
.collect::<Vec<_>>()
.into();
check!(
Pattern::from_pat_stack(pats, span),
return err(warnings, errors),
warnings,
errors
)
}
};
ok(pat, warnings, errors)
}
Expand Down
Loading

0 comments on commit 30f869a

Please sign in to comment.