diff --git a/src/main/scala-3/com/sksamuel/scapegoat/Inspection.scala b/src/main/scala-3/com/sksamuel/scapegoat/Inspection.scala index 0070927e..3087b591 100644 --- a/src/main/scala-3/com/sksamuel/scapegoat/Inspection.scala +++ b/src/main/scala-3/com/sksamuel/scapegoat/Inspection.scala @@ -11,7 +11,7 @@ abstract class Inspection( val explanation: String ) extends InspectionBase { - val self: Inspection = this + implicit val self: Inspection = this def inspect(feedback: Feedback[SourcePosition], tree: tpd.Tree)(using Context): Unit diff --git a/src/main/scala-3/com/sksamuel/scapegoat/InspectionTraverser.scala b/src/main/scala-3/com/sksamuel/scapegoat/InspectionTraverser.scala index 21a9b9b8..fbc147d5 100644 --- a/src/main/scala-3/com/sksamuel/scapegoat/InspectionTraverser.scala +++ b/src/main/scala-3/com/sksamuel/scapegoat/InspectionTraverser.scala @@ -1,10 +1,46 @@ package com.sksamuel.scapegoat +import dotty.tools.dotc.ast.Trees +import dotty.tools.dotc.ast.tpd import dotty.tools.dotc.ast.tpd.* +import dotty.tools.dotc.core.Constants import dotty.tools.dotc.core.Contexts.Context +import dotty.tools.dotc.core.Names +import dotty.tools.dotc.core.Symbols +import dotty.tools.dotc.core.Symbols.requiredClass import dotty.tools.dotc.util.NoSource -abstract class InspectionTraverser extends TreeTraverser { +abstract class InspectionTraverser(using inspection: Inspection) extends TreeTraverser { + + override protected def traverseChildren(tree: tpd.Tree)(using Context): Unit = { + if (!isSuppressed(tree)) { + super.traverseChildren(tree) + } + } + + private def isSuppressed(t: tpd.Tree)(using Context): Boolean = { + val symbol = t.symbol + val annotation = symbol.getAnnotation(requiredClass("java.lang.SuppressWarnings")) + val arg = annotation.flatMap(_.argument(0)).map(extractArg) + val inspectionName = inspection.getClass.getSimpleName + arg match { + case Some( + Apply(Apply(TypeApply(Select(Ident(array), _), _), List(Typed(SeqLiteral(args, _), _))), _) + ) => + args.exists { + case Literal(value) if value.tag == Constants.StringTag => + value.stringValue == "all" || value.stringValue == inspectionName + case _ => false + } + case _ => false + } + } + + // Scala 3.3 doesn't insert NamedArg, skip it while trying to match + private def extractArg(t: tpd.Tree): tpd.Tree = t match { + case NamedArg(_, tree) => tree + case _ => t + } extension (tree: Tree)(using Context) def asSnippet: Option[String] = tree.source match @@ -12,3 +48,7 @@ abstract class InspectionTraverser extends TreeTraverser { case _ => Some(tree.source.content().slice(tree.sourcePos.start, tree.sourcePos.end).mkString) } + +object InspectionTraverser { + val array = Names.termName("Array") +} diff --git a/src/test/scala-3/com/sksamuel/scapegoat/InspectionTraverserTest.scala b/src/test/scala-3/com/sksamuel/scapegoat/InspectionTraverserTest.scala new file mode 100644 index 00000000..0aa5b3b4 --- /dev/null +++ b/src/test/scala-3/com/sksamuel/scapegoat/InspectionTraverserTest.scala @@ -0,0 +1,68 @@ +package com.sksamuel.scapegoat + +import com.sksamuel.scapegoat.inspections.option.OptionGet + +class InspectionTraverserTest extends InspectionTest(classOf[OptionGet]) { + "InspectionTraverser" - { + "should ignore all inspection based on SuppressWarnings on class" in { + val code = """ + @SuppressWarnings(Array("all")) + class Test { + val o = Option("sammy") + o.get + }""".stripMargin + + val feedback = runner.compileCodeSnippet(code) + feedback.errors.assertable shouldEqual Seq.empty + } + + "should ignore specific inspection based on SuppressWarnings on class" in { + val code = """ + @SuppressWarnings(Array("OptionGet")) + class Test { + val o = Option("sammy") + o.get + }""".stripMargin + + val feedback = runner.compileCodeSnippet(code) + feedback.errors.assertable shouldEqual Seq.empty + } + + "should ignore specific inspection based on SuppressWarnings on class (Different warning)" in { + val code = """ + @SuppressWarnings(Array("AvoidRequire")) + class Test { + val o = Option("sammy") + o.get + }""".stripMargin + + val feedback = runner.compileCodeSnippet(code) + feedback.errors.assertable shouldEqual Seq( + warning(4, Levels.Error, Some("o.get")) + ) + } + + "should ignore all inspection based on SuppressWarnings on method" in { + val code = """ + class Test { + @SuppressWarnings(Array("all")) + def func(): String = { + // ignored violation + val o = Option("sammy") + o.get + } + + // violation + val o2 = Option("sammy") + o2.get + + func() + }""".stripMargin + + val feedback = runner.compileCodeSnippet(code) + feedback.errors.assertable shouldEqual Seq( + warning(11, Levels.Error, Some("o2.get")) + ) + } + } +}