diff --git a/sheriff/.juspay/sheriffRules.yaml b/sheriff/.juspay/sheriffRules.yaml index e1ebeb6..4138872 100644 --- a/sheriff/.juspay/sheriffRules.yaml +++ b/sheriff/.juspay/sheriffRules.yaml @@ -185,4 +185,19 @@ rules: - customerEmail db_rule_fixes: - "You might want to include an indexed column in the `where` clause of the query." - db_rule_exceptions: [] \ No newline at end of file + db_rule_exceptions: [] + + - where_clause_rule_name: "DefaultWhereClauseRule" + where_clause_rule_fixes: + - "You might want to include a mandatory column in the `where` clause of the query." + where_clause_rule_ignore_modules: [] + where_clause_rule_check_modules: + - "*" + query_functions_to_check: + - findOneRow + - findOne + - findAll + - findAllRows + - findAllRowsWithLimit + - findAllWithLimit + - findAllWithLimitAndOffset \ No newline at end of file diff --git a/sheriff/src/Sheriff/Plugin.hs b/sheriff/src/Sheriff/Plugin.hs index 733cc3d..8ca99ce 100644 --- a/sheriff/src/Sheriff/Plugin.hs +++ b/sheriff/src/Sheriff/Plugin.hs @@ -15,7 +15,6 @@ import Sheriff.Rules import Sheriff.Types import Sheriff.TypesUtils import Sheriff.Utils - -- GHC imports import Control.Applicative ((<|>)) import Control.Monad (foldM, when) @@ -31,9 +30,9 @@ import Data.Function (on) import qualified Data.HashMap.Strict as HM import Data.List (nub, sortBy, groupBy, find, isInfixOf, isSuffixOf, isPrefixOf) import Data.List.Extra (splitOn) -import Data.Maybe (catMaybes, fromMaybe) +import Data.Maybe (catMaybes, fromMaybe, mapMaybe) import Data.Yaml -import Debug.Trace (traceShowId, trace) +import Debug.Trace (traceShowId, trace, traceM) import GHC hiding (exprType) import Prelude hiding (id, writeFile, appendFile) import System.Directory (createDirectoryIfMissing, getHomeDirectory) @@ -142,7 +141,6 @@ sheriff opts modSummary tcEnv = do when failOnFileNotFound $ addErr (mkInvalidYamlFileErr (show err)) pure [] Right (YamlTables tables) -> pure $ (map yamlToDbRule tables) - -- Check the parsed rules yaml file. If failed, throw file error if configured. configuredRules <- case parsedRulesYaml of Left err -> do @@ -169,6 +167,7 @@ sheriff opts modSummary tcEnv = do infRule = case find isInfiniteRecursionRule globalRules of Just (InfiniteRecursionRuleT r) -> r _ -> defaultInfiniteRecursionRuleT + -- liftIO $ putStrLn $ "finalSheriffRules: " <> show finalSheriffRules when logDebugInfo $ liftIO $ print globalRules when logDebugInfo $ liftIO $ print globalExceptionRules @@ -201,6 +200,7 @@ sheriff opts modSummary tcEnv = do return tcEnv + --------------------------- Infinite Recursion Detection Logic --------------------------- {- @@ -423,6 +423,7 @@ getBadFnCalls rules (FunBind{fun_matches = matches}) = do -- use childrenBi and then repeated children usage as per use case -- (exprs :: [LHsExpr GhcTc]) = traverseConditionalUni (noWhereClauseExpansion) (childrenBi match :: [LHsExpr GhcTc]) (exprs :: [LHsExpr GhcTc]) = traverseAstConditionally match noWhereClauseExpansion + -- liftIO $ putStrLn ("exprs: " ++ showS exprs) concat <$> mapM (isBadExpr rules) exprs getBadFnCalls _ _ = pure [] @@ -465,6 +466,34 @@ isBadExpr rules ap = pure [] isBadExprHelper :: (HasPluginOpts PluginOpts) => Rules -> LHsExpr GhcTc -> TcM [(LHsExpr GhcTc, Violation)] isBadExprHelper rules ap = concat <$> mapM (\rule -> checkAndApplyRule rule ap) rules +-- Collect all function names in the given expression +collectFnNames :: LHsExpr GhcTc -> [String] +collectFnNames expr = case expr of + L _ (HsVar _ v) -> + let fnName = showS v + in trace ("collectFnNames: Detected HsVar with name = " <> fnName) [fnName] -- Log detected function name + L _ (HsApp _ funl funr) -> + trace "collectFnNames: Detected HsApp, traversing both sides" $ + collectFnNames funl ++ collectFnNames funr -- Traverse both sides of the application + L _ (HsPar _ expr) -> + trace "collectFnNames: Detected HsPar, traversing inner expression" $ + collectFnNames expr -- Handle parentheses + L _ (OpApp _ funl op funr) -> + let opName = showS op + in trace ("collectFnNames: Detected OpApp with operator = " <> opName) $ + collectFnNames funl ++ [opName] ++ collectFnNames funr -- Handle operators + L _ (HsAppType _ expr _) -> + trace "collectFnNames: Detected HsAppType, traversing inner expression" $ + collectFnNames expr -- Handle type applications + L loc (PatHsWrap _ innerExpr) -> + trace "collectFnNames: Detected PatHsWrap, traversing wrapped expression" $ + collectFnNames (L loc innerExpr) -- Handle wrapped patterns + L _ (HsConLikeOut _ cl) -> + let conName = showS cl + in trace ("collectFnNames: Detected HsConLikeOut with name = " <> conName) [conName] -- Collect constructor-like outputs + _ -> + trace "collectFnNames: No matching case found, returning empty list" [] -- For other cases, return an empty list + -- Check if a particular rule applies to given expr checkAndApplyRule :: (HasPluginOpts PluginOpts) => Rule -> LHsExpr GhcTc -> TcM ([(LHsExpr GhcTc, Violation)]) checkAndApplyRule ruleT ap = case ruleT of @@ -473,8 +502,21 @@ checkAndApplyRule ruleT ap = case ruleT of (L _ (PatExplicitList (TyConApp ty [_, tblName]) exprs)) -> do case (showS ty == "Clause" && showS tblName == (ruleTableName <> "T")) of True -> validateDBRule rule (showS tblName) exprs ap - False -> pure [] + False -> pure [] _ -> pure [] + WhereClauseRuleT rule -> + case ap of + (L _ (PatExplicitList (TyConApp ty [_, tblName]) exprs)) -> do + if showS ty == "Clause" then do + let res = getFnNameWithAllArgs ap + case res of + Nothing -> pure [] + Just (fnLocatedVar, args) -> do + let fnName = getLocatedVarNameWithModuleName fnLocatedVar + fnLHsExpr = mkLHsVar fnLocatedVar + validateWhereClauseRule rule (showS tblName) exprs fnName + else pure [] + _ -> pure [] FunctionRuleT rule@(FunctionRule {fn_name = ruleFnNames, arg_no}) -> do let res = getFnNameWithAllArgs ap case res of @@ -487,7 +529,6 @@ checkAndApplyRule ruleT ap = case ruleT of Nothing -> pure [] InfiniteRecursionRuleT rule -> pure [] --TODO: Add handling of infinite recursion rule GeneralRuleT rule -> pure [] --TODO: Add handling of general rule - --------------------------- Function Rule Validation Logic --------------------------- {- @@ -649,7 +690,7 @@ Part-2 Validation -- Function to check if given DB rules is violated or not -- TODO: Fix this, keep two separate options for - 1. Match All Fields in AND 2. Use 1st column matching or all columns matching for composite key validateDBRule :: (HasPluginOpts PluginOpts) => DBRule -> String -> [LHsExpr GhcTc] -> LHsExpr GhcTc -> TcM ([(LHsExpr GhcTc, Violation)]) -validateDBRule rule@(DBRule {db_rule_name = ruleName, table_name = ruleTableName, indexed_cols_names = ruleColNames}) tableName clauses expr = do +validateDBRule rule@(DBRule {db_rule_name = ruleName, table_name = ruleTableName, indexed_cols_names = ruleColNames}) tableName clauses expr = do simplifiedExprs <- trfWhereToSOP clauses let checkDBViolation = case (matchAllInsideAnd . pluginOpts $ ?pluginOpts) of True -> checkDBViolationMatchAll @@ -657,7 +698,7 @@ validateDBRule rule@(DBRule {db_rule_name = ruleName, table_name = ruleTableName violations <- catMaybes <$> mapM checkDBViolation simplifiedExprs pure violations where - -- Since we need all columns to be indexed, we need to check for the columns in the order of composite key + -- Since we need all columns to be indexed, we need to check for the columns in the order of composite ke checkDBViolationMatchAll :: [SimplifiedIsClause] -> TcM (Maybe (LHsExpr GhcTc, Violation)) checkDBViolationMatchAll sop = do let isDbViolation (cls, colName, tableName) = (ruleTableName == tableName) && not (doesMatchColNameInDbRuleWithComposite colName ruleColNames (map (\(_, col, _) -> col) sop)) @@ -675,6 +716,49 @@ validateDBRule rule@(DBRule {db_rule_name = ruleName, table_name = ruleTableName [] -> pure Nothing ((clause, colName, tableName) : _) -> pure $ Just (clause, NonIndexedDBColumn colName tableName rule) +validateWhereClauseRule :: (HasPluginOpts PluginOpts) => WhereClauseRule -> String -> [LHsExpr GhcTc] -> String -> TcM ([(LHsExpr GhcTc, Violation)]) +validateWhereClauseRule rule tableName clauses fnName = do + liftIO $ putStrLn $ "Validating WhereClauseRule for table: " <> tableName <> " with clauses: " <> showS clauses + simplifiedExprs <- trfWhereToSOP clauses + let matched = concatMap (map (\(_, colName, _) -> colName)) simplifiedExprs -- Extract column names from all clauses + liftIO $ putStrLn $ "tableKey: " <> tableName <> ", Matched columns: " <> show matched <> ", fnName: " <> fnName + jsonData <- liftIO $ Char8.readFile "tables_and_fields_with_types.json" + case A.decode jsonData :: Maybe (HM.HashMap String (HM.HashMap String String)) of + Just jsonMap -> do + let tableKey = tableName + case HM.lookup tableKey jsonMap of + Just fieldsMap -> do + let fieldNames = HM.keys fieldsMap + let isBoolField field = case HM.lookup field fieldsMap of + Just value -> "Bool" `isInfixOf` value + _ -> False + let matchingFields = filter (\field -> ("disable" `isInfixOf` field || "enable" `isInfixOf` field) && isBoolField field) fieldNames + + liftIO $ putStrLn $ "Matching fields for " <> tableKey <> ": " <> show matchingFields + violation <- checkWhereClauseViolation tableKey matchingFields simplifiedExprs -- Pass all simplifiedExprs at once + pure $ maybe [] (\v -> [v]) violation + Nothing -> do + liftIO $ putStrLn $ "No fields found for table: " <> tableKey + pure [] + Nothing -> do + liftIO $ putStrLn "Failed to parse JSON file." + pure [] + where + checkWhereClauseViolation :: String -> [String] -> [[SimplifiedIsClause]] -> TcM (Maybe (LHsExpr GhcTc, Violation)) + checkWhereClauseViolation tableKey matchingFields sopList = do + let matched = concatMap (map (\(_, colName, _) -> colName)) sopList + let isWhereClauseViolation = not (null matchingFields) && all (\field -> field `notElem` matched) matchingFields + liftIO $ putStrLn $ "tableKey: " <> tableKey <> ", Matched columns: " <> show matched <> ", Is violation: " <> show isWhereClauseViolation + if null matchingFields + then pure Nothing + else case concat sopList of + [] -> pure Nothing + ((clause, _, _) : _) -> + if isWhereClauseViolation + -- then pure $ Just (clause, WhereClauseViolationDetected tableName matchingFields rule) + then pure Nothing + else pure Nothing + -- Check only for the ordering of the columns of the composite key doesMatchColNameInDbRuleWithComposite :: String -> [YamlTableKeys] -> [String] -> Bool doesMatchColNameInDbRuleWithComposite _ [] _ = False @@ -707,22 +791,37 @@ trfWhereToSOP [] = pure [[]] trfWhereToSOP (clause : ls) = do let res = getWhereClauseFnNameWithAllArgs clause (fnName, args) = fromMaybe ("NA", []) res + liftIO $ putStrLn $ "Processing clause: " <> showS clause <> ", Function name: " <> fnName <> ", Args: " <> showS args case (fnName, args) of ("And", [(L _ (PatExplicitList _ arg))]) -> do + liftIO $ putStrLn "Detected 'And' clause." curr <- trfWhereToSOP arg rem <- trfWhereToSOP ls pure [x <> y | x <- curr, y <- rem] ("Or", [(L _ (PatExplicitList _ arg))]) -> do + liftIO $ putStrLn "Detected 'Or' clause." curr <- foldM (\r cls -> fmap (<> r) $ trfWhereToSOP [cls]) [] arg rem <- trfWhereToSOP ls pure [x <> y | x <- curr, y <- rem] ("$WIs", [arg1, arg2]) -> do + liftIO $ putStrLn "Detected 'Is' clause." curr <- getIsClauseData arg1 arg2 clause rem <- trfWhereToSOP ls case curr of - Nothing -> pure rem - Just (tblName, colName) -> pure $ fmap (\lst -> (clause, tblName, colName) : lst) rem - (fn, _) -> when ((logWarnInfo . pluginOpts $ ?pluginOpts)) (liftIO $ print $ "Invalid/unknown clause in `where` clause : " <> fn <> " at " <> (showS . getLoc2 $ clause)) >> trfWhereToSOP ls + Nothing -> do + liftIO $ putStrLn "Failed to extract 'Is' clause data." + pure rem + Just (tblName, colName) -> do + liftIO $ putStrLn $ "Extracted 'Is' clause data: Table = " <> tblName <> ", Column = " <> colName + pure $ fmap (\lst -> (clause, tblName, colName) : lst) rem + ("getField", a ) -> do + -- Handle `getField @"columnName"` syntax + liftIO $ putStrLn $ "Detected 'getField' clause for column: " <> showS a + -- rem <- trfWhereToSOP [] + pure [] + (fn, _) -> do + liftIO $ putStrLn $ "Invalid/unknown clause in `where` clause: " <> fn <> " at " <> (showS . getLoc2 $ clause) + trfWhereToSOP ls -- Get table field name and table name for the `Se.Is` clause -- Patterns to match 'getField`, `recordDot`, `overloadedRecordDot` (ghc > 9), selector (duplicate record fields), rec fields (ghc 9), lens @@ -741,12 +840,19 @@ getIsClauseData fieldArg _comp _clause = do ("$sel" : colName : tableName : []) -> pure $ Just (colName, tableName) _ -> when ((logWarnInfo . pluginOpts $ ?pluginOpts)) (liftIO $ print "Invalid pattern for Selector way") >> pure Nothing RecordDot -> do + liftIO $ putStrLn $ "Debugging RecordDot: AST structure of fieldArg = " <> showS fieldArg + let allNodes = traverseAst fieldArg :: [HsExpr GhcTc] + liftIO $ putStrLn $ "All nodes in AST: " <> showS allNodes let tyApps = filter (\x -> case x of (HsApp _ (L _ (HsAppType _ _ fldName)) tableVar) -> True (PatHsWrap (WpCompose (WpEvApp (EvExpr _hasFld)) (WpCompose (WpTyApp _fldType) (WpTyApp tableVar))) (HsAppType _ _ fldName)) -> True + (HsApp _ (L _ fldName) tableVar) -> True -- Added fallback for simpler AST structure + (HsVar _ fldName) -> True -- Handle HsVar for field names + (HsOverLit _ (OverLit {ol_val = HsIsString _ colName})) -> True -- Handle string literals _ -> False - ) $ (traverseAst fieldArg :: [HsExpr GhcTc]) - if length tyApps > 0 + ) allNodes + liftIO $ putStrLn $ "Filtered tyApps: " <> showS tyApps + if not (null tyApps) then case head tyApps of (HsApp _ (L _ (HsAppType _ _ fldName)) tableVar) -> do @@ -762,6 +868,76 @@ getIsClauseData fieldArg _comp _clause = do TyConApp ty1 _ -> showS ty1 ty -> showS ty in pure $ Just (getStrFromHsWildCardBndrs fldName, take (length tblName' - 1) tblName') + (HsApp _ (L _ fldName) tableVar) -> do -- Handle simpler AST structure + typ <- getHsExprType (logTypeDebugging . pluginOpts $ ?pluginOpts) tableVar + let tblName' = case typ of + AppTy ty1 _ -> showS ty1 + TyConApp ty1 _ -> showS ty1 + ty -> showS ty + pure $ Just (showS fldName, take (length tblName' - 1) tblName') + (HsVar _ fldName) -> do -- Handle HsVar for field names + let fldNameStr = occNameString (nameOccName (idName (unLoc fldName))) -- Extract the name as a String + liftIO $ putStrLn $ "Detected HsVar: " <> fldNameStr + case fldNameStr of + "getField" -> do + let colName = case allNodes of + (HsOverLit _ (OverLit {ol_val = HsIsString _ colName})) : _ -> + let extractedColName = unpackFS colName + in trace ("Matched HsOverLit (HsIsString): Extracted column name = " <> extractedColName) extractedColName + + (HsVar _ directFldName) : _ -> + let directFldNameStr = occNameString (nameOccName (idName (unLoc directFldName))) + in trace ("Matched HsVar directly: Field name = " <> directFldNameStr) directFldNameStr + + (HsPar _ expr) : _ -> + -- Helper function to recursively unwrap expressions + let recUnwrap :: LHsExpr GhcTc -> LHsExpr GhcTc + recUnwrap expr = case unLoc expr of + HsPar _ innerExpr -> recUnwrap innerExpr + other -> expr + + exprStr = showS expr + debugExpr = showS (unLoc expr) -- Log the structure of unLoc expr + innerColName = (case unLoc (recUnwrap expr) of + -- Match HsOverLit directly + HsOverLit _ (OverLit {ol_val = HsIsString _ colName}) -> + let extractedColName = unpackFS colName + in trace ("Inner Matched HsOverLit (HsIsString): Extracted column name = " <> extractedColName) extractedColName + + -- Match HsLit directly + HsLit _ (HsString _ colName) -> + let extractedColName = unpackFS colName + in trace ("Inner Matched HsLit (HsString): Extracted column name = " <> extractedColName) extractedColName + + -- Match HsVar directly + HsVar _ name -> + let extractedColName = occNameString (nameOccName (idName (unLoc name))) + in trace ("Inner Matched HsVar: Extracted column name = " <> extractedColName) extractedColName + + -- Match HsApp (function application) + HsApp _ func arg -> + let funcStr = showS func + argStr = showS arg + in case unLoc (recUnwrap func) of + HsVar _ name | occNameString (nameOccName (idName (unLoc name))) == "getField" -> + case unLoc (recUnwrap arg) of + HsLit _ (HsString _ colName) -> + let extractedColName = unpackFS colName + in trace ("Inner Matched HsApp (getField): Extracted column name = " <> extractedColName) extractedColName + _ -> trace ("Inner Argument of getField is not a string literal. Defaulting to UnknownColumn") "UnknownColumn" + _ -> trace ("Inner Function is not getField. Defaulting to UnknownColumn") "UnknownColumn" + + -- Default case for unmatched patterns + _ -> trace ("Inner No matching case found for unLoc expr: " <> debugExpr <> ". Defaulting to UnknownColumn") "UnknownColumn" + ) + in trace ("Matched HsPar: Expression = " <> exprStr <> ", Inner column name = " <> innerColName) innerColName + _ -> trace ("No matching case found. Nodes in AST: " <> showS allNodes <> ". Defaulting to UnknownColumn") "UnknownColumn7" + liftIO $ putStrLn $ "Extracted column name: " <> colName + pure $ Just (colName, "AuthenticationAccountT") -- Replace with actual table name if available + _ -> pure $ Just (fldNameStr, "UnknownTable") + (HsOverLit _ (OverLit {ol_val = HsIsString _ colName})) -> do -- Handle string literals + liftIO $ putStrLn $ "Detected string literal: " <> unpackFS colName + pure $ Just (unpackFS colName, "UnknownTable") _ -> when ((logWarnInfo . pluginOpts $ ?pluginOpts)) (liftIO $ putStrLn "HsAppType not present. Should never be the case as we already filtered.") >> pure Nothing else when ((logWarnInfo . pluginOpts $ ?pluginOpts)) (liftIO $ putStrLn "HsAppType not present after filtering. Should never reach as already deduced RecordDot.") >> pure Nothing Lens -> do @@ -878,32 +1054,68 @@ getFnNameAndTypeableExprWithAllArgs _ = Nothing -- TODO: Verify the correctness of this function before moving it to utils -- Get function name with all it's arguments + getFnNameWithAllArgs :: LHsExpr GhcTc -> Maybe (Located Var, [LHsExpr GhcTc]) -getFnNameWithAllArgs (L loc (HsVar _ v)) = Just (getLocated v loc, []) -getFnNameWithAllArgs (L _ (HsConLikeOut _ cl)) = (\clId -> (noExprLoc clId, [])) <$> conLikeWrapId cl -getFnNameWithAllArgs (L _ (HsAppType _ expr _)) = getFnNameWithAllArgs expr -getFnNameWithAllArgs (L _ (HsApp _ (L loc (HsVar _ v)) funr)) = Just (getLocated v loc, [funr]) -getFnNameWithAllArgs (L _ (HsPar _ expr)) = getFnNameWithAllArgs expr -getFnNameWithAllArgs (L _ (HsApp _ funl funr)) = do - let res = getFnNameWithAllArgs funl - case res of - Nothing -> Nothing - Just (fnName, ls) -> Just (fnName, ls ++ [funr]) -getFnNameWithAllArgs (L loc (OpApp _ funl op funr)) = do - case showS op of - "($)" -> getFnNameWithAllArgs $ (L loc (HsApp noExtFieldOrAnn funl funr)) - _ -> Nothing -getFnNameWithAllArgs (L loc ap@(PatHsWrap _ expr)) = getFnNameWithAllArgs (L loc expr) +getFnNameWithAllArgs expr = case expr of + L loc (HsVar _ v) -> + -- trace ("getFnNameWithAllArgs: Detected HsVar with name = " <> showS v) $ + Just (getLocated v loc, []) + + L _ (HsConLikeOut _ cl) -> + -- trace "getFnNameWithAllArgs: Detected HsConLikeOut" $ + (\clId -> (noExprLoc clId, [])) <$> conLikeWrapId cl + + L _ (HsAppType _ expr _) -> + -- trace "getFnNameWithAllArgs: Detected HsAppType" $ + getFnNameWithAllArgs expr + + L _ (HsApp _ (L loc (HsVar _ v)) funr) -> + -- trace ("getFnNameWithAllArgs: Detected HsApp with HsVar function name = " <> showS v) $ + Just (getLocated v loc, [funr]) + + L _ (HsPar _ expr) -> + -- trace "getFnNameWithAllArgs: Detected HsPar" $ + getFnNameWithAllArgs expr + + L _ (HsApp _ funl funr) -> + -- trace "getFnNameWithAllArgs: Detected HsApp with nested function application" $ + case getFnNameWithAllArgs funl of + Nothing -> trace "getFnNameWithAllArgs: No function name found in nested HsApp" Nothing + Just (fnName, ls) -> + trace ("getFnNameWithAllArgs: Found function name = " <> showS fnName <> " with arguments = " <> showS ls) $ + Just (fnName, ls ++ [funr]) + + L loc (OpApp _ funl op funr) -> + -- trace ("getFnNameWithAllArgs: Detected OpApp with operator = " <> showS op) $ + case showS op of + "($)" -> + -- trace "getFnNameWithAllArgs: Detected ($) operator, treating as HsApp" $ + getFnNameWithAllArgs (L loc (HsApp noExtFieldOrAnn funl funr)) + _ -> trace "getFnNameWithAllArgs: Unsupported operator, returning Nothing" Nothing + + L loc ap@(PatHsWrap _ expr) -> + -- trace "getFnNameWithAllArgs: Detected PatHsWrap" $ + getFnNameWithAllArgs (L loc expr) + #if __GLASGOW_HASKELL__ >= 900 -getFnNameWithAllArgs (L loc ap@(PatHsExpansion orig expanded)) = - case (orig, expanded) of - ((OpApp _ _ op _), (HsApp _ (L _ (HsApp _ op' funl)) funr)) -> case showS op of - "($)" -> getFnNameWithAllArgs (L loc (HsApp noExtFieldOrAnn funl funr)) - _ -> getFnNameWithAllArgs (L loc expanded) - _ -> getFnNameWithAllArgs (L loc expanded) + L loc ap@(PatHsExpansion orig expanded) -> + trace "getFnNameWithAllArgs: Detected PatHsExpansion" $ + case (orig, expanded) of + ((OpApp _ _ op _), (HsApp _ (L _ (HsApp _ op' funl)) funr)) -> + case showS op of + "($)" -> + trace "getFnNameWithAllArgs: Detected ($) operator in PatHsExpansion, treating as HsApp" $ + getFnNameWithAllArgs (L loc (HsApp noExtFieldOrAnn funl funr)) + _ -> + trace "getFnNameWithAllArgs: Unsupported operator in PatHsExpansion, processing expanded expression" $ + getFnNameWithAllArgs (L loc expanded) + _ -> + trace "getFnNameWithAllArgs: Processing expanded expression in PatHsExpansion" $ + getFnNameWithAllArgs (L loc expanded) #endif -getFnNameWithAllArgs _ = Nothing + _ -> + trace "getFnNameWithAllArgs: No matching case found, returning Nothing" Nothing --------------------------- Sheriff Plugin Utils --------------------------- -- Transform the FnBlockedInArg Violation with correct expression trfViolationErrorInfo :: (HasPluginOpts PluginOpts) => Violation -> LHsExpr GhcTc -> LHsExpr GhcTc -> TcM Violation diff --git a/sheriff/src/Sheriff/Types.hs b/sheriff/src/Sheriff/Types.hs index 24fd04c..1615490 100644 --- a/sheriff/src/Sheriff/Types.hs +++ b/sheriff/src/Sheriff/Types.hs @@ -191,6 +191,35 @@ instance FromJSON FunctionRule where fn_rule_ignore_functions <- o .:? "fn_rule_ignore_functions" .!= (fn_rule_ignore_functions defaultFunctionRule) return FunctionRule {..} +data WhereClauseRule = + WhereClauseRule + { + where_clause_rule_name :: String, + where_clause_rule_fixes :: Suggestions, + where_clause_rule_ignore_modules :: Modules, + where_clause_rule_check_modules :: Modules, + query_functions_to_check :: [String] + } + deriving (Show, Eq) + +defaultWhereClauseRule :: WhereClauseRule +defaultWhereClauseRule = WhereClauseRule { + where_clause_rule_name = "WhereClauseRule", + where_clause_rule_fixes = ["You Might want to include an mandatory column in the `where` clause of the query."], + where_clause_rule_ignore_modules = [], + where_clause_rule_check_modules = ["*"], + query_functions_to_check = ["find", "findOne", "findoneRow" , "findAllRows"] + } + +instance FromJSON WhereClauseRule where + parseJSON = withObject "WhereClauseRule" $ \o -> do + where_clause_rule_name <- o .: "where_clause_rule_name" + where_clause_rule_fixes <- o .: "where_clause_rule_fixes" + where_clause_rule_ignore_modules <- o .:? "where_clause_rule_ignore_modules" .!= (where_clause_rule_ignore_modules defaultWhereClauseRule) + where_clause_rule_check_modules <- o .:? "where_clause_rule_check_modules" .!= (where_clause_rule_check_modules defaultWhereClauseRule) + query_functions_to_check <- o .:? "query_functions_to_check" .!= (query_functions_to_check defaultWhereClauseRule) + return WhereClauseRule {..} + data InfiniteRecursionRule = InfiniteRecursionRule { @@ -316,11 +345,12 @@ data Rule = DBRuleT DBRule | FunctionRuleT FunctionRule | InfiniteRecursionRuleT InfiniteRecursionRule + | WhereClauseRuleT WhereClauseRule | GeneralRuleT GeneralRule deriving (Show, Eq) instance FromJSON Rule where - parseJSON str = (DBRuleT <$> parseJSON str) <|> (FunctionRuleT <$> parseJSON str) <|> (InfiniteRecursionRuleT <$> parseJSON str) <|> (GeneralRuleT <$> parseJSON str) <|> (fail $ "Invalid Rule: " <> show str) + parseJSON str = (DBRuleT <$> parseJSON str) <|> (FunctionRuleT <$> parseJSON str) <|> (InfiniteRecursionRuleT <$> parseJSON str) <|> (WhereClauseRuleT <$> parseJSON str) <|>(GeneralRuleT <$> parseJSON str) <|> (fail $ "Invalid Rule: " <> show str) data LocalVar = FnArg Var | FnWhere Var | FnLocal Var deriving (Eq) @@ -344,15 +374,17 @@ data Violation = | FnUseBlocked String FunctionRule | FnSigBlocked String String FunctionRule | InfiniteRecursionDetected InfiniteRecursionRule + | WhereClauseViolationDetected String [String] WhereClauseRule | NoViolation deriving (Eq) instance Show Violation where show violation = case violation of - (ArgTypeBlocked typ exprTy ruleFnName rule) -> "Use of '" <> ruleFnName <> "' on '" <> typ <> "' is not allowed in the overall expression type '" <> exprTy <> "'." - (FnBlockedInArg (fnName, typ) ruleFnName _ rule) -> "Use of '" <> fnName <> "' on type '" <> typ <> "' inside argument of '" <> ruleFnName <> "' is not allowed." - (FnUseBlocked ruleFnName rule) -> "Use of '" <> ruleFnName <> "' in the code is not allowed." - (FnSigBlocked ruleFnName ruleFnSig rule) -> "Use of '" <> ruleFnName <> "' with signature '" <> ruleFnSig <> "' is not allowed in the code." - (NonIndexedDBColumn colName tableName _) -> "Querying on non-indexed column '" <> colName <> "' of table '" <> (tableName) <> "' is not allowed." - (InfiniteRecursionDetected _) -> "Infinite recursion detected in expression" - NoViolation -> "NoViolation" \ No newline at end of file + (ArgTypeBlocked typ exprTy ruleFnName rule) -> "Use of '" <> ruleFnName <> "' on '" <> typ <> "' is not allowed in the overall expression type '" <> exprTy <> "'." + (FnBlockedInArg (fnName, typ) ruleFnName _ rule) -> "Use of '" <> fnName <> "' on type '" <> typ <> "' inside argument of '" <> ruleFnName <> "' is not allowed." + (FnUseBlocked ruleFnName rule) -> "Use of '" <> ruleFnName <> "' in the code is not allowed." + (FnSigBlocked ruleFnName ruleFnSig rule) -> "Use of '" <> ruleFnName <> "' with signature '" <> ruleFnSig <> "' is not allowed in the code." + (NonIndexedDBColumn colName tableName _) -> "Querying on non-indexed column '" <> colName <> "' of table '" <> (tableName) <> "' is not allowed." + (InfiniteRecursionDetected _) -> "Infinite recursion detected in expression" + (WhereClauseViolationDetected tableName colNames _) -> "Where clause rule violation: Missing mandatory querying fields on table '" <> tableName <> "' for column '" <> show colNames <> "'." + NoViolation -> "NoViolation" \ No newline at end of file diff --git a/sheriff/src/Sheriff/TypesUtils.hs b/sheriff/src/Sheriff/TypesUtils.hs index 50f5555..f7a55e5 100644 --- a/sheriff/src/Sheriff/TypesUtils.hs +++ b/sheriff/src/Sheriff/TypesUtils.hs @@ -38,6 +38,7 @@ getViolationSuggestions v = case v of FnSigBlocked _ _ r -> fn_rule_fixes r NonIndexedDBColumn _ _ r -> db_rule_fixes r InfiniteRecursionDetected r -> infinite_recursion_rule_fixes r + WhereClauseViolationDetected _ _ r -> where_clause_rule_fixes r NoViolation -> [] getViolationType :: Violation -> String @@ -48,6 +49,7 @@ getViolationType v = case v of FnSigBlocked _ _ _ -> "FnSigBlocked" NonIndexedDBColumn _ _ _ -> "NonIndexedDBColumn" InfiniteRecursionDetected _ -> "InfiniteRecursionDetected" + WhereClauseViolationDetected _ _ _-> "WhereClauseRule" NoViolation -> "NoViolation" getViolationRule :: Violation -> Rule @@ -58,6 +60,7 @@ getViolationRule v = case v of FnSigBlocked _ _ r -> FunctionRuleT r NonIndexedDBColumn _ _ r -> DBRuleT r InfiniteRecursionDetected r -> InfiniteRecursionRuleT r + WhereClauseViolationDetected _ _ r -> WhereClauseRuleT r NoViolation -> defaultRule getViolationRuleName :: Violation -> String @@ -68,6 +71,7 @@ getViolationRuleName v = case v of FnSigBlocked _ _ r -> fn_rule_name r NonIndexedDBColumn _ _ r -> db_rule_name r InfiniteRecursionDetected r -> infinite_recursion_rule_name r + WhereClauseViolationDetected _ _ r -> where_clause_rule_name r NoViolation -> "NA" getViolationRuleExceptions :: Violation -> Rules