@@ -2,11 +2,11 @@ use crate::consts::{
2
2
constant, Constant ,
3
3
Constant :: { F32 , F64 } ,
4
4
} ;
5
- use crate :: utils:: * ;
5
+ use crate :: utils:: { span_lint_and_sugg , sugg } ;
6
6
use if_chain:: if_chain;
7
7
use rustc:: ty;
8
8
use rustc_errors:: Applicability ;
9
- use rustc_hir:: * ;
9
+ use rustc_hir:: { BinOpKind , Expr , ExprKind , UnOp } ;
10
10
use rustc_lint:: { LateContext , LateLintPass } ;
11
11
use rustc_session:: { declare_lint_pass, declare_tool_lint} ;
12
12
use std:: f32:: consts as f32_consts;
@@ -39,6 +39,7 @@ declare_clippy_lint! {
39
39
/// let _ = (1.0 + a).ln();
40
40
/// let _ = a.exp() - 1.0;
41
41
/// let _ = a.powf(2.0);
42
+ /// let _ = a * 2.0 + 4.0;
42
43
/// ```
43
44
///
44
45
/// is better expressed as
@@ -57,6 +58,7 @@ declare_clippy_lint! {
57
58
/// let _ = a.ln_1p();
58
59
/// let _ = a.exp_m1();
59
60
/// let _ = a.powi(2);
61
+ /// let _ = a.mul_add(2.0, 4.0);
60
62
/// ```
61
63
pub SUBOPTIMAL_FLOPS ,
62
64
nursery,
@@ -211,12 +213,12 @@ fn check_powf(cx: &LateContext<'_, '_>, expr: &Expr<'_>, args: &[Expr<'_>]) {
211
213
let ( help, suggestion) = if F32 ( 1.0 / 2.0 ) == value || F64 ( 1.0 / 2.0 ) == value {
212
214
(
213
215
"square-root of a number can be computed more efficiently and accurately" ,
214
- format ! ( "{}.sqrt()" , Sugg :: hir( cx, & args[ 0 ] , ".." ) )
216
+ format ! ( "{}.sqrt()" , Sugg :: hir( cx, & args[ 0 ] , ".." ) ) ,
215
217
)
216
218
} else if F32 ( 1.0 / 3.0 ) == value || F64 ( 1.0 / 3.0 ) == value {
217
219
(
218
220
"cube-root of a number can be computed more accurately" ,
219
- format ! ( "{}.cbrt()" , Sugg :: hir( cx, & args[ 0 ] , ".." ) )
221
+ format ! ( "{}.cbrt()" , Sugg :: hir( cx, & args[ 0 ] , ".." ) ) ,
220
222
)
221
223
} else if let Some ( exponent) = get_integer_from_float_constant ( & value) {
222
224
(
@@ -225,7 +227,7 @@ fn check_powf(cx: &LateContext<'_, '_>, expr: &Expr<'_>, args: &[Expr<'_>]) {
225
227
"{}.powi({})" ,
226
228
Sugg :: hir( cx, & args[ 0 ] , ".." ) ,
227
229
format_numeric_literal( & exponent. to_string( ) , None , false )
228
- )
230
+ ) ,
229
231
)
230
232
} else {
231
233
return ;
@@ -272,6 +274,52 @@ fn check_expm1(cx: &LateContext<'_, '_>, expr: &Expr<'_>) {
272
274
}
273
275
}
274
276
277
+ fn is_float_mul_expr < ' a > ( cx : & LateContext < ' _ , ' _ > , expr : & ' a Expr < ' a > ) -> Option < ( & ' a Expr < ' a > , & ' a Expr < ' a > ) > {
278
+ if_chain ! {
279
+ if let ExprKind :: Binary ( op, ref lhs, ref rhs) = & expr. kind;
280
+ if let BinOpKind :: Mul = op. node;
281
+ if cx. tables. expr_ty( lhs) . is_floating_point( ) ;
282
+ if cx. tables. expr_ty( rhs) . is_floating_point( ) ;
283
+ then {
284
+ return Some ( ( lhs, rhs) ) ;
285
+ }
286
+ }
287
+
288
+ None
289
+ }
290
+
291
+ // TODO: Fix rust-lang/rust-clippy#4735
292
+ fn check_fma ( cx : & LateContext < ' _ , ' _ > , expr : & Expr < ' _ > ) {
293
+ if_chain ! {
294
+ if let ExprKind :: Binary ( op, lhs, rhs) = & expr. kind;
295
+ if let BinOpKind :: Add = op. node;
296
+ then {
297
+ let ( recv, arg1, arg2) = if let Some ( ( inner_lhs, inner_rhs) ) = is_float_mul_expr( cx, lhs) {
298
+ ( inner_lhs, inner_rhs, rhs)
299
+ } else if let Some ( ( inner_lhs, inner_rhs) ) = is_float_mul_expr( cx, rhs) {
300
+ ( inner_lhs, inner_rhs, lhs)
301
+ } else {
302
+ return ;
303
+ } ;
304
+
305
+ span_lint_and_sugg(
306
+ cx,
307
+ SUBOPTIMAL_FLOPS ,
308
+ expr. span,
309
+ "multiply and add expressions can be calculated more efficiently and accurately" ,
310
+ "consider using" ,
311
+ format!(
312
+ "{}.mul_add({}, {})" ,
313
+ prepare_receiver_sugg( cx, recv) ,
314
+ Sugg :: hir( cx, arg1, ".." ) ,
315
+ Sugg :: hir( cx, arg2, ".." ) ,
316
+ ) ,
317
+ Applicability :: MachineApplicable ,
318
+ ) ;
319
+ }
320
+ }
321
+ }
322
+
275
323
impl < ' a , ' tcx > LateLintPass < ' a , ' tcx > for FloatingPointArithmetic {
276
324
fn check_expr ( & mut self , cx : & LateContext < ' a , ' tcx > , expr : & ' tcx Expr < ' _ > ) {
277
325
if let ExprKind :: MethodCall ( ref path, _, args) = & expr. kind {
@@ -287,6 +335,7 @@ impl<'a, 'tcx> LateLintPass<'a, 'tcx> for FloatingPointArithmetic {
287
335
}
288
336
} else {
289
337
check_expm1 ( cx, expr) ;
338
+ check_fma ( cx, expr) ;
290
339
}
291
340
}
292
341
}
0 commit comments