Skip to content

Commit 03bf899

Browse files
committed
Fix an NRE in UnreachableCaseInspector
Also refactors the class a bit and fixes several possible multiple enumeration issues.
1 parent b24cfcb commit 03bf899

File tree

1 file changed

+59
-52
lines changed

1 file changed

+59
-52
lines changed

Rubberduck.CodeAnalysis/Inspections/Concrete/UnreachableCaseInspection/UnreachableCaseInspector.cs

Lines changed: 59 additions & 52 deletions
Original file line numberDiff line numberDiff line change
@@ -11,7 +11,6 @@ public interface IUnreachableCaseInspector
1111
{
1212
void InspectForUnreachableCases();
1313
string SelectExpressionTypeName { get; }
14-
Func<string, ParserRuleContext, string> GetVariableDeclarationTypeName { set; get; }
1514
List<ParserRuleContext> UnreachableCases { get; }
1615
List<ParserRuleContext> InherentlyUnreachableCases { get; }
1716
List<ParserRuleContext> MismatchTypeCases { get; }
@@ -24,23 +23,22 @@ public class UnreachableCaseInspector : IUnreachableCaseInspector
2423
private readonly IEnumerable<VBAParser.CaseClauseContext> _caseClauses;
2524
private readonly ParserRuleContext _caseElseContext;
2625
private readonly IParseTreeValueFactory _valueFactory;
26+
private readonly Func<string, ParserRuleContext, string> _getVariableDeclarationTypeName;
2727
private IParseTreeValue _selectExpressionValue;
2828

2929
public UnreachableCaseInspector(VBAParser.SelectCaseStmtContext selectCaseContext,
3030
IParseTreeVisitorResults inspValues,
3131
IParseTreeValueFactory valueFactory,
32-
Func<string,ParserRuleContext,string> GetVariableTypeName = null)
32+
Func<string,ParserRuleContext,string> getVariableTypeName = null)
3333
{
3434
_valueFactory = valueFactory;
3535
_caseClauses = selectCaseContext.caseClause();
3636
_caseElseContext = selectCaseContext.caseElseClause();
37-
GetVariableDeclarationTypeName = GetVariableTypeName;
37+
_getVariableDeclarationTypeName = getVariableTypeName;
3838
ParseTreeValueResults = inspValues;
39-
SetSelectExpressionTypeName(selectCaseContext as ParserRuleContext, inspValues);
39+
SetSelectExpressionTypeName(selectCaseContext, inspValues);
4040
}
4141

42-
public Func<string, ParserRuleContext, string> GetVariableDeclarationTypeName { set; get; }
43-
4442
public List<ParserRuleContext> UnreachableCases { set; get; } = new List<ParserRuleContext>();
4543

4644
public List<ParserRuleContext> MismatchTypeCases { set; get; } = new List<ParserRuleContext>();
@@ -53,7 +51,7 @@ public UnreachableCaseInspector(VBAParser.SelectCaseStmtContext selectCaseContex
5351

5452
public string SelectExpressionTypeName { private set; get; } = string.Empty;
5553

56-
private IParseTreeVisitorResults ParseTreeValueResults { set; get; }
54+
private IParseTreeVisitorResults ParseTreeValueResults { get; }
5755

5856
public void InspectForUnreachableCases()
5957
{
@@ -71,7 +69,9 @@ public void InspectForUnreachableCases()
7169
foreach ( var range in caseClause.rangeClause())
7270
{
7371
var childResults = ParseTreeValueResults.GetChildResults(range);
74-
var childValues = childResults.Select(ch => ParseTreeValueResults.GetValue(ch));
72+
var childValues = childResults
73+
.Select(ch => ParseTreeValueResults.GetValue(ch))
74+
.ToList();
7575
if (childValues.Any(chr => chr.IsMismatchExpression))
7676
{
7777
containsMismatch = true;
@@ -136,16 +136,16 @@ private IExpressionFilter BuildRangeClauseFilter(IEnumerable<VBAParser.CaseClaus
136136
{
137137
var rangeClauseFilter = ExpressionFilterFactory.Create(SelectExpressionTypeName);
138138

139-
if (!(GetVariableDeclarationTypeName is null))
139+
if (!(_getVariableDeclarationTypeName is null))
140140
{
141141
foreach (var caseClause in caseClauses)
142142
{
143143
foreach (var rangeClause in caseClause.rangeClause())
144144
{
145145
var expression = GetRangeClauseExpression(rangeClause);
146-
if (!expression.LHS.ParsesToConstantValue)
146+
if (!expression?.LHS?.ParsesToConstantValue ?? false)
147147
{
148-
var typeName = GetVariableDeclarationTypeName(expression.LHS.Token, rangeClause);
148+
var typeName = _getVariableDeclarationTypeName(expression.LHS.Token, rangeClause);
149149
rangeClauseFilter.AddComparablePredicateFilter(expression.LHS.Token, typeName);
150150
}
151151
}
@@ -157,12 +157,12 @@ private IExpressionFilter BuildRangeClauseFilter(IEnumerable<VBAParser.CaseClaus
157157
private void SetSelectExpressionTypeName(ParserRuleContext context, IParseTreeVisitorResults inspValues)
158158
{
159159
var selectStmt = (VBAParser.SelectCaseStmtContext)context;
160-
if (TryDetectTypeHint(selectStmt.selectExpression().GetText(), out string typeName)
160+
if (TryDetectTypeHint(selectStmt.selectExpression().GetText(), out var typeName)
161161
&& InspectableTypes.Contains(typeName))
162162
{
163163
SelectExpressionTypeName = typeName;
164164
}
165-
else if (inspValues.TryGetValue(selectStmt.selectExpression(), out IParseTreeValue result)
165+
else if (inspValues.TryGetValue(selectStmt.selectExpression(), out var result)
166166
&& InspectableTypes.Contains(result.ValueType))
167167
{
168168
_selectExpressionValue = result;
@@ -181,34 +181,38 @@ private string DeriveTypeFromCaseClauses(IParseTreeVisitorResults inspValues, VB
181181
{
182182
foreach (var range in caseClause.rangeClause())
183183
{
184-
if (TryDetectTypeHint(range.GetText(), out string hintTypeName))
184+
if (TryDetectTypeHint(range.GetText(), out var hintTypeName))
185185
{
186186
caseClauseTypeNames.Add(hintTypeName);
187187
}
188188
else
189189
{
190-
var typeNames = from context in range.children
191-
where context is ParserRuleContext
192-
&& IsResultContext(context)
193-
select inspValues.GetValueType(context as ParserRuleContext);
190+
var typeNames = range.children
191+
.OfType<ParserRuleContext>()
192+
.Where(IsResultContext)
193+
.Select(inspValues.GetValueType);
194194

195195
caseClauseTypeNames.AddRange(typeNames);
196196
caseClauseTypeNames.RemoveAll(tp => !InspectableTypes.Contains(tp));
197197
}
198198
}
199199
}
200200

201-
if (TryGetSelectExpressionTypeNameFromTypes(caseClauseTypeNames, out string evalTypeName))
201+
if (TryGetSelectExpressionTypeNameFromTypes(caseClauseTypeNames, out var evalTypeName))
202202
{
203203
return evalTypeName;
204204
}
205+
205206
return string.Empty;
206207
}
207208

208-
private static bool TryGetSelectExpressionTypeNameFromTypes(IEnumerable<string> typeNames, out string typeName)
209+
private static bool TryGetSelectExpressionTypeNameFromTypes(ICollection<string> typeNames, out string typeName)
209210
{
210211
typeName = string.Empty;
211-
if (!typeNames.Any()) { return false; }
212+
if (!typeNames.Any())
213+
{
214+
return false;
215+
}
212216

213217
//If everything is declared as a Variant , we do not attempt to inspect the selectStatement
214218
if (typeNames.All(tn => tn.Equals(Tokens.Variant)))
@@ -229,7 +233,7 @@ private static bool TryGetSelectExpressionTypeNameFromTypes(IEnumerable<string>
229233
return true;
230234
}
231235

232-
//Mix of Integertypes and rational number types will be evaluated using Double or Currency
236+
//Mix of Integer types and rational number types will be evaluated using Double or Currency
233237
if (typeNames.All(tn => new List<string>() { Tokens.Long, Tokens.Integer, Tokens.Byte, Tokens.Single, Tokens.Double, Tokens.Currency }.Contains(tn)))
234238
{
235239
typeName = typeNames.Any(tk => tk.Equals(Tokens.Currency)) ? Tokens.Currency : Tokens.Double;
@@ -246,7 +250,7 @@ private static bool TryDetectTypeHint(string content, out string typeName)
246250
return false;
247251
}
248252

249-
if (SymbolList.TypeHintToTypeName.Keys.Any(th => content.EndsWith(th)))
253+
if (SymbolList.TypeHintToTypeName.Keys.Any(content.EndsWith))
250254
{
251255
var lastChar = content.Substring(content.Length - 1);
252256
typeName = SymbolList.TypeHintToTypeName[lastChar];
@@ -257,9 +261,10 @@ private static bool TryDetectTypeHint(string content, out string typeName)
257261

258262
private IRangeClauseExpression GetRangeClauseExpression(VBAParser.RangeClauseContext rangeClause)
259263
{
260-
var resultContexts = from ctxt in rangeClause.children
261-
where ctxt is ParserRuleContext && IsResultContext(ctxt)
262-
select ctxt as ParserRuleContext;
264+
var resultContexts = rangeClause.children
265+
.OfType<ParserRuleContext>()
266+
.Where(IsResultContext)
267+
.ToList();
263268

264269
if (!resultContexts.Any())
265270
{
@@ -272,41 +277,43 @@ private IRangeClauseExpression GetRangeClauseExpression(VBAParser.RangeClauseCon
272277
var rangeEndValue = ParseTreeValueResults.GetValue(rangeClause.GetChild<VBAParser.SelectEndValueContext>());
273278
return new RangeOfValuesExpression((rangeStartValue, rangeEndValue));
274279
}
275-
else if (rangeClause.IS() != null)
280+
281+
if (rangeClause.IS() != null)
276282
{
277-
var clauseValue = ParseTreeValueResults.GetValue(resultContexts.First());
283+
var isClauseValue = ParseTreeValueResults.GetValue(resultContexts.First());
278284
var opSymbol = rangeClause.GetChild<VBAParser.ComparisonOperatorContext>().GetText();
279-
return new IsClauseExpression(clauseValue, opSymbol);
285+
return new IsClauseExpression(isClauseValue, opSymbol);
280286
}
281-
else if (TryGetLogicSymbol(resultContexts.First(), out string symbol))
287+
288+
if (!TryGetLogicSymbol(resultContexts.First(), out string symbol))
282289
{
283-
var resultContext = resultContexts.First();
284-
var clauseValue = ParseTreeValueResults.GetValue(resultContext);
285-
if (clauseValue.ParsesToConstantValue)
286-
{
287-
return new ValueExpression(clauseValue);
288-
}
290+
return new ValueExpression(ParseTreeValueResults.GetValue(resultContexts.First()));
291+
}
289292

290-
if (resultContext is VBAParser.LogicalNotOpContext)
291-
{
293+
var resultContext = resultContexts.First();
294+
var clauseValue = ParseTreeValueResults.GetValue(resultContext);
295+
if (clauseValue.ParsesToConstantValue)
296+
{
297+
return new ValueExpression(clauseValue);
298+
}
299+
300+
switch (resultContext)
301+
{
302+
case VBAParser.LogicalNotOpContext _:
292303
return new UnaryExpression(clauseValue, symbol);
293-
}
294-
else if (resultContext is VBAParser.RelationalOpContext
295-
|| resultContext is VBAParser.LogicalEqvOpContext
296-
|| resultContext is VBAParser.LogicalImpOpContext)
304+
case VBAParser.RelationalOpContext _:
305+
case VBAParser.LogicalEqvOpContext _:
306+
case VBAParser.LogicalImpOpContext _:
297307
{
298-
(IParseTreeValue lhs, IParseTreeValue rhs) = CreateLogicPair(clauseValue, symbol, _valueFactory);
308+
var (lhs, rhs) = CreateLogicPair(clauseValue, symbol, _valueFactory);
299309
if (symbol.Equals(Tokens.Like))
300310
{
301311
return new LikeExpression(lhs, rhs);
302312
}
303313
return new BinaryExpression(lhs, rhs, symbol);
304314
}
305-
return null;
306-
}
307-
else
308-
{
309-
return new ValueExpression(ParseTreeValueResults.GetValue(resultContexts.First()));
315+
default:
316+
return null;
310317
}
311318
}
312319

@@ -323,8 +330,8 @@ private static bool TryGetLogicSymbol(ParserRuleContext context, out string opSy
323330
private static (IParseTreeValue lhs, IParseTreeValue rhs)
324331
CreateLogicPair(IParseTreeValue value, string opSymbol, IParseTreeValueFactory factory)
325332
{
326-
var operands = value.Token.Split(new string[] { opSymbol }, StringSplitOptions.None);
327-
if (operands.Count() == 2)
333+
var operands = value.Token.Split(new [] { opSymbol }, StringSplitOptions.None);
334+
if (operands.Length == 2)
328335
{
329336
var lhs = factory.Create(operands[0].Trim());
330337
var rhs = factory.Create(operands[1].Trim());
@@ -335,7 +342,7 @@ private static (IParseTreeValue lhs, IParseTreeValue rhs)
335342
return (lhs, rhs);
336343
}
337344

338-
if (operands.Count() == 1)
345+
if (operands.Length == 1)
339346
{
340347
var lhs = factory.Create(operands[0].Trim());
341348
return (lhs, null);
@@ -358,7 +365,7 @@ private static bool IsResultContext<TContext>(TContext context)
358365
|| context is VBAParser.SelectEndValueContext;
359366
}
360367

361-
private static List<string> InspectableTypes = new List<string>()
368+
private static readonly IReadOnlyList<string> InspectableTypes = new List<string>
362369
{
363370
Tokens.Byte,
364371
Tokens.Integer,

0 commit comments

Comments
 (0)