diff --git a/lib/comment-directive-parser.js b/lib/comment-directive-parser.js index 51dd24a6..b9326b62 100644 --- a/lib/comment-directive-parser.js +++ b/lib/comment-directive-parser.js @@ -68,7 +68,6 @@ class CommentDirectiveParser { isRuleEnabled(line, ruleId) { return this.ruleStore.isRuleEnabled(line, ruleId); } - } diff --git a/lib/common/tree-traversing.js b/lib/common/tree-traversing.js index 15a84446..602d6bf6 100644 --- a/lib/common/tree-traversing.js +++ b/lib/common/tree-traversing.js @@ -87,4 +87,15 @@ TreeTraversing.hasMethodCalls = function (ctx, methodNames) { }; +TreeTraversing.findPropertyInParents = function (ctx, property) { + let curCtx = ctx; + + while (curCtx !== null && !curCtx[property]) { + curCtx = curCtx.parentCtx; + } + + return curCtx && curCtx[property]; +}; + + module.exports = TreeTraversing; \ No newline at end of file diff --git a/lib/rules/security/reentrancy.js b/lib/rules/security/reentrancy.js index 28a8015c..01c73b43 100644 --- a/lib/rules/security/reentrancy.js +++ b/lib/rules/security/reentrancy.js @@ -1,6 +1,7 @@ const BaseChecker = require('./../base-checker'); const _ = require('lodash'); -const { hasMethodCalls } = require('./../../common/tree-traversing'); +const TreeTraversing = require('./../../common/tree-traversing'); +const { typeOf, hasMethodCalls, findPropertyInParents } = TreeTraversing; class ReentrancyChecker extends BaseChecker { @@ -9,8 +10,17 @@ class ReentrancyChecker extends BaseChecker { super(reporter); } + enterContractDefinition(ctx) { + ctx.stateDeclarationScope = new StateDeclarationScope(); + const scope = ctx.stateDeclarationScope; + + new ContractDefinition(ctx) + .stateDefinitions() + .forEach(i => scope.trackStateDeclaration(i)); + } + enterFunctionDefinition(ctx) { - ctx.effects = new Effects(); + ctx.effects = new Effects(StateDeclarationScope.of(ctx)); } enterExpression(ctx) { @@ -21,7 +31,7 @@ class ReentrancyChecker extends BaseChecker { _checkAssignment(ctx) { const effects = Effects.of(ctx); - if (isAssignOperator(ctx) && effects && !effects.isAllowedAssign()) { + if (isAssignOperator(ctx) && effects && !effects.isAllowedAssign(ctx)) { this._warn(ctx); } } @@ -38,29 +48,96 @@ class ReentrancyChecker extends BaseChecker { } -class Effects { +class StateDeclarationScope { static of (ctx) { - let curCtx = ctx; + return findPropertyInParents(ctx, 'stateDeclarationScope'); + } - while (curCtx !== null && !curCtx.effects) { - curCtx = curCtx.parentCtx; - } + constructor () { + this.states = []; + } - return curCtx.effects; + trackStateDeclaration (stateDefinition) { + const stateName = stateDefinition.stateName(); + this.states.push(stateName); } +} - constructor () { + +class ContractDefinition { + + constructor (ctx) { + this.ctx = ctx; + } + + stateDefinitions() { + return this.ctx + .children + .map(i => new ContractPart(i)) + .filter(i => i.isStateDefinition()) + .map(i => i.getStateDefinition()); + } +} + + +class ContractPart { + + constructor (ctx) { + this.ctx = ctx; + } + + isStateDefinition() { + return typeOf(this._firstChild()) === 'stateVariableDeclaration'; + } + + getStateDefinition () { + return new StateDefinition(this._firstChild()); + } + + _firstChild () { + return _.first(this.ctx.children); + } +} + +class StateDefinition { + + constructor(ctx) { + this.ctx = ctx; + } + + stateName() { + return _(this.ctx.children) + .find(i => typeOf(i) === 'identifier') + .getText(); + } +} + + +class Effects { + + static of (ctx) { + return findPropertyInParents(ctx, 'effects'); + } + + constructor (statesScope) { + this.states = statesScope && statesScope.states; this.hasTransfer = false; } - isAllowedAssign() { - return !this.hasTransfer; + isAllowedAssign(ctx) { + const assignee = ctx.children[0].getText(); + + return !(this.hasTransfer && this._isContainsStateName(assignee)); } trackTransfer() { this.hasTransfer = true; } + + _isContainsStateName(expressionText) { + return this.states.some(i => expressionText.includes(i)); + } } diff --git a/test/security-rules.js b/test/security-rules.js index de2c3fdf..38a3c347 100644 --- a/test/security-rules.js +++ b/test/security-rules.js @@ -185,15 +185,23 @@ describe('Linter - SecurityRules', function() { describe('Reentrancy', function () { const REENTRANCY_ERROR = [ - funcWith(` - uint amount = shares[msg.sender]; - bool a = msg.sender.send(amount); - if (a) { shares[msg.sender] = 0; } + contractWith(` + mapping(address => uint) private shares; + + function b() external { + uint amount = shares[msg.sender]; + bool a = msg.sender.send(amount); + if (a) { shares[msg.sender] = 0; } + } `), - funcWith(` - uint amount = shares[msg.sender]; - msg.sender.transfer(amount); - shares[msg.sender] = 0; + contractWith(` + mapping(address => uint) private shares; + + function b() external { + uint amount = shares[msg.sender]; + msg.sender.transfer(amount); + shares[msg.sender] = 0; + } `) ]; @@ -207,15 +215,29 @@ describe('Linter - SecurityRules', function() { ); const NO_REENTRANCY_ERRORS = [ - funcWith(` - uint amount = shares[msg.sender]; - user.test(amount); - shares[msg.sender] = 0; + contractWith(` + mapping(address => uint) private shares; + + function b() external { + uint amount = shares[msg.sender]; + shares[msg.sender] = 0; + msg.sender.transfer(amount); + } + `), + contractWith(` + mapping(address => uint) private shares; + + function b() external { + uint amount = shares[msg.sender]; + user.test(amount); + shares[msg.sender] = 0; + } `), funcWith(` + uint[] shares; uint amount = shares[msg.sender]; - shares[msg.sender] = 0; msg.sender.transfer(amount); + shares[msg.sender] = 0; `) ];