From fca5d1441778b75d62db799d9ff432bb73b0c382 Mon Sep 17 00:00:00 2001 From: Sailaja Date: Thu, 12 Jun 2025 22:06:03 +0530 Subject: [PATCH 01/15] First Changes --- sheriff/src/Sheriff/Plugin.hs | 358 ++++++++++++++++++++++++++++++++-- 1 file changed, 343 insertions(+), 15 deletions(-) diff --git a/sheriff/src/Sheriff/Plugin.hs b/sheriff/src/Sheriff/Plugin.hs index 733cc3d..76bf3f6 100644 --- a/sheriff/src/Sheriff/Plugin.hs +++ b/sheriff/src/Sheriff/Plugin.hs @@ -5,37 +5,51 @@ {-# LANGUAGE NamedFieldPuns #-} {-# LANGUAGE TypeFamilies #-} {-# LANGUAGE TypeSynonymInstances #-} +{-# LANGUAGE FlexibleContexts #-} module Sheriff.Plugin (plugin) where -- Sheriff imports +import Control.Exception (IOException, catch) import Sheriff.CommonTypes import Sheriff.Patterns import Sheriff.Rules import Sheriff.Types import Sheriff.TypesUtils import Sheriff.Utils - +import Text.Regex.TDFA ((=~)) -- GHC imports import Control.Applicative ((<|>)) import Control.Monad (foldM, when) import Control.Monad.IO.Class (MonadIO (..)) import Control.Monad.State -import Data.Aeson as A +import Control.Reference (biplateRef, (^?)) +import Data.Aeson as A +import qualified Data.Aeson.KeyMap as KM +import Data.Function (on, id) import Data.Aeson.Encode.Pretty (encodePretty) import Data.Bool (bool) -import Data.ByteString.Lazy (writeFile, appendFile) +import Data.ByteString.Lazy (writeFile, appendFile, readFile) import qualified Data.ByteString.Lazy.Char8 as Char8 import Data.Data import Data.Function (on) +import Data.Functor.Identity (runIdentity) import qualified Data.HashMap.Strict as HM import Data.List (nub, sortBy, groupBy, find, isInfixOf, isSuffixOf, isPrefixOf) import Data.List.Extra (splitOn) -import Data.Maybe (catMaybes, fromMaybe) -import Data.Yaml -import Debug.Trace (traceShowId, trace) +import qualified Data.Map as Map +import Data.Maybe (catMaybes, fromMaybe, listToMaybe) +import Data.Typeable (typeOf) +import Data.Yaml hiding (decode) +import qualified Data.Text as T +import Debug.Trace (traceShowId, trace, traceM) import GHC hiding (exprType) -import Prelude hiding (id, writeFile, appendFile) +import GHC.Types.TypeEnv (typeEnvElts) +import Prelude hiding (id, writeFile, appendFile, readFile) +import Text.Show.Pretty (ppShow) +import Data.Data (Data, toConstr, gmapQ) +import Data.Generics (everything, mkQ) +import Language.Haskell.Exts (parseFile, prettyPrint, ParseResult(..)) import System.Directory (createDirectoryIfMissing, getHomeDirectory) #if __GLASGOW_HASKELL__ >= 900 @@ -45,6 +59,7 @@ import GHC.Core.InstEnv import GHC.Core.TyCo.Rep import GHC.Data.Bag import GHC.HsToCore.Monad +import GHC.Hs.Expr (HsExpr(..)) import GHC.HsToCore.Expr import GHC.Plugins hiding ((<>), getHscEnv, purePlugin) import GHC.Tc.Types @@ -69,6 +84,8 @@ import TcType import TyCoRep #endif +-- type VarBindingMap = Map.Map String (LHsExpr GhcTc) + plugin :: Plugin plugin = defaultPlugin { typeCheckResultAction = sheriff @@ -78,6 +95,10 @@ plugin = defaultPlugin { purePlugin :: [CommandLineOption] -> IO PluginRecompile purePlugin _ = return NoForceRecompile + + + + --------------------------- Core Logic --------------------------- {- @@ -121,6 +142,7 @@ type SheriffTcM = StateT (HM.HashMap NameModuleValue NameModuleValue) TcM sheriff :: [CommandLineOption] -> ModSummary -> TcGblEnv -> TcM TcGblEnv sheriff opts modSummary tcEnv = do + -- STAGE-1 let moduleName' = moduleNameString $ moduleName $ ms_mod modSummary pluginOpts@PluginOpts{..} = decodeAndUpdateOpts opts defaultPluginOpts @@ -142,7 +164,8 @@ sheriff opts modSummary tcEnv = do when failOnFileNotFound $ addErr (mkInvalidYamlFileErr (show err)) pure [] Right (YamlTables tables) -> pure $ (map yamlToDbRule tables) - + + -- Check the parsed rules yaml file. If failed, throw file error if configured. configuredRules <- case parsedRulesYaml of Left err -> do @@ -178,7 +201,61 @@ sheriff opts modSummary tcEnv = do insts <- tcg_insts . env_gbl <$> getEnv let namesModTuple = concatMap (\inst -> let clsName = className (is_cls inst) in (is_dfun_name inst, clsName) : fmap (\clsMethod -> (varName clsMethod, clsName)) (classMethods $ is_cls inst)) insts nameModMap = foldr (\(name, clsName) r -> HM.insert (NMV_Name name) (NMV_ClassModule clsName (getModuleName clsName)) r) HM.empty namesModTuple + -- let binds = bagToList $ tcg_binds tcEnv + -- liftIO $ putStrLn ("AST of binds:\n" ++ ppShow binds) + -- liftIO $ putStrLn ("Extracted bind names: " ++ OP.showSDocUnsafe (OP.ppr binds)) + + -- let mymAction = do + -- results <- forM binds extractExprFromBind + -- return (results) + -- let (extractedAll, finalState) = runMyM mymAction + + -- let jsonFilePath = "/Users/sailaja.b/spider/tree.json" + + -- Read the JSON file + -- jsonData <- liftIO (Char8.readFile jsonFilePath `catch` handleReadError) + -- let parsedData = eitherDecode jsonData :: Either String FunctionMap + + -- case parsedData of + -- Right functionMap -> do + -- let filteredEntries = filterEntries functionMap + -- forM_ (HM.toList filteredEntries) $ \(key, functionEntry) -> do + -- let codeContent = code functionEntry + -- -- let codeContent = "res <- findOneRow\n dbConf meshConfig\n $ [Is DB.filterGroupId (Eq groupId),\n Is DB.dimensionValue (Eq dimension)]\n toDomainAll res parseTenantConfigFilter\n ! #function_name \"getTenantConfigFilterByGroupIdAndDimensionValue\"\n ! #parser_name \"parseTenantConfigFilter\"" :: String + + -- -- let regex1 = "findOneRow" :: String + -- -- let regex2 = "findOneRow[[:space:]]+([^[:space:]]+)" :: String + -- -- let regex3 = "findOneRow[[:space:]]+([^[:space:]]+)[[:space:]]+([^[:space:]]+)" :: String + -- -- let regex4 = "findOneRow[[:space:]]+([^[:space:]]+)[[:space:]]+([^[:space:]]+)[[:space:]]+\\[([[:print:][:space:]]*)\\]" :: String + -- let regexFlexible = "(findOneRow|findAllRows)[[:space:]]+([^[:space:]]+)[[:space:]]+([^[:space:]]+)[[:space:]]+.*[[:space:]]*\\[([[:print:][:space:]]*)\\]" :: String + -- -- Test each regex pattern + -- -- let matches1 = codeContent =~ regex1 :: [[String]] + -- -- let matches2 = codeContent =~ regex2 :: [[String]] + -- -- let matches3 = codeContent =~ regex3 :: [[String]] + -- -- let matches4 = codeContent =~ regex4 :: [[String]] + -- let matches5 = codeContent =~ regexFlexible :: [[String]] + + -- -- Print results + -- -- liftIO $ putStrLn $ "Matches with Regex 1: " ++ show matches1 + -- -- liftIO $ putStrLn $ "Matches with Regex 2: " ++ show matches2 + -- -- liftIO $ putStrLn $ "Matches with Regex 3: " ++ show matches3 + -- -- liftIO $ putStrLn $ "Matches with Regex 4: " ++ show matches4 + -- let firstElements = map (\innerList -> case innerList of + -- (a:_) -> a + -- [] -> "No match") matches5 + -- liftIO $ putStrLn $ "Function Name: " ++ key ++ " ,First elements: " ++ show firstElements + + -- Left err -> liftIO $ putStrLn $ "Failed to parse JSON: " ++ err + + -- Example: Extract specific data from the FunctionMap + -- let extractedData = fromMaybe HM.empty (decode jsonData :: Maybe FunctionMap) + -- liftIO $ putStrLn $ "Extracted data: " ++ show extractedData + + + -- liftIO $ putStrLn $ "๐Ÿ“ Final state: " ++ show finalState + -- liftIO $ putStrLn $ "๐Ÿ“Œ Extracted expressions: " ++ OP.showSDocUnsafe (OP.ppr extractedAll) + rawErrors <- concat <$> (mapM (loopOverModBinds finalSheriffRules) $ bagToList $ tcg_binds tcEnv) (rawInfiniteRecursionErrors, _) <- flip runStateT nameModMap $ concat <$> (mapM (checkInfiniteRecursion True infRule) $ bagToList $ tcg_binds tcEnv) @@ -201,7 +278,230 @@ sheriff opts modSummary tcEnv = do return tcEnv ---------------------------- Infinite Recursion Detection Logic --------------------------- +handleReadError :: IOException -> IO Char8.ByteString +handleReadError e = do + putStrLn $ "c " ++ show e + return Char8.empty + +-- filterEntries :: HM.HashMap String FunctionEntry -> HM.HashMap String FunctionEntry +-- filterEntries functionMap = +-- HM.filter containsFindFunctions functionMap +-- where +-- containsFindFunctions :: FunctionEntry -> Bool +-- containsFindFunctions entry = +-- case toJSON entry of +-- Object obj -> -- Ensure the JSON value is an Object +-- case KM.lookup "functions_it_is_calling" obj of +-- Just (Array functions) -> +-- any (\func -> func == "findOneRow" || func == "findAllRows") functions +-- _ -> False +-- _ -> False -- Handle cases where the JSON value is not an Object + +type TableMap = Map.Map String String +type ClauseMap = Map.Map String String + + +data MyState = MyState + { tableMap :: TableMap + , clauseMap :: ClauseMap + } deriving (Show) + +type MyM = State MyState + +runMyM :: MyM a -> (a, MyState) +runMyM action = runState action initialMyState + +initialMyState :: MyState +initialMyState = MyState Map.empty Map.empty + +extractExprFromBind :: LHsBindLR GhcTc GhcTc -> MyM (Maybe (String, String, String)) +extractExprFromBind (L loc bind) = do + traceM "extractExprFromBind called" + myMap <- get + case bind of + FunBind { fun_matches = MG { mg_alts = L _ matches } } -> do + traceM "Matched FunBind" + results <- forM matches $ \(L _ match) -> case match of + Match { m_grhss = GRHSs _ grhss _ } -> do + innerResults <- forM grhss $ \(L _ grhs) -> case grhs of + GRHS _ _ body -> do + let exprs = body ^? biplateRef :: [LHsExpr GhcTc] + extracted <- mapM (\expr -> extractQueryInfo expr (L loc bind)) exprs + traceM $ "body from FunBind: " ++ OP.showSDocUnsafe (OP.ppr body) ++ "exprs from FunBind: " ++ OP.showSDocUnsafe (OP.ppr exprs) ++ "\n๐Ÿ“Œ Query Info:\n" ++ unlines (map showTriple (catMaybes extracted)) + pure (catMaybes extracted) + pure (concat innerResults) + pure $ listToMaybe (concat results) + + PatBind { pat_rhs = GRHSs _ grhss _ } -> do + traceM "Matched PatBind" + results <- forM grhss $ \(L _ grhs) -> case grhs of + GRHS _ _ body -> do + let exprs = body ^? biplateRef :: [LHsExpr GhcTc] + extracted <- mapM (\expr -> extractQueryInfo expr (L loc bind)) exprs + traceM $ "body from FunBind: " ++ OP.showSDocUnsafe (OP.ppr body) ++ "exprs from PatBind: " ++ OP.showSDocUnsafe (OP.ppr exprs) ++ "\n๐Ÿ“Œ Query Info:\n" ++ unlines (map showTriple (catMaybes extracted)) + pure (catMaybes extracted) + pure $ listToMaybe (concat results) + + AbsBinds { abs_binds = binds } -> do + traceM "Matched AbsBinds" + results <- mapM extractExprFromBind (bagToList binds) + pure $ listToMaybe (catMaybes results) + + _ -> do + traceM "No matching bind constructor (not FunBind or PatBind)" + pure Nothing + +extractQueryInfo :: LHsExpr GhcTc -> LHsBindLR GhcTc GhcTc -> MyM (Maybe (String, String, String)) +extractQueryInfo expr bindings = do + traceM ("๐Ÿ” Called extractQueryInfo with expr: " ++ OP.showSDocUnsafe (OP.ppr expr)) + (fn, args) <- flattenHsAppM expr + let fnName = OP.showSDocUnsafe (OP.ppr fn) + if fnName `elem` ["findOneRow", "findAllRows"] && length args == 3 + then do + let clause = OP.showSDocUnsafe (OP.ppr (args !! 2)) + let typeStr = show (typeOf (unLoc (args !! 2))) + let unlocatedBindings = [unLoc bindings] + let isWhereClause = hasIsOrEmptyList (args !! 2) + clauseMapVal <- gets clauseMap + let whereClause = if isWhereClause + then clause + else fromMaybe " trace "Matched: Empty list" True + ExplicitList _ xs -> trace ("Checking list of length " ++ show (length xs)) $ any hasIsOrEmptyList xs + + HsApp _ fun arg -> + let funStr = showSDocUnsafe (ppr (unLoc fun)) + matches = any (`isPrefixOf` funStr) ["Is", "And", "Or"] + isListCons = funStr == ":" || funStr == "[]" + in trace ("Matched: HsApp, head string: " ++ funStr ++ ", matches? " ++ show matches ++ ", isListCons? " ++ show isListCons) $ + matches || isListCons || hasIsOrEmptyList fun || hasIsOrEmptyList arg + OpApp _ fun op arg -> + let funStr = showSDocUnsafe (ppr (unLoc fun)) + opStr = showSDocUnsafe (ppr (unLoc op)) + argStr = showSDocUnsafe (ppr (unLoc arg)) + matches = any (`isPrefixOf` funStr) ["Is", "And", "Or"] + isListCons = funStr == ":" || funStr == "[]" + in trace ("Matched: OpApp, fun: " ++ funStr ++ ", op: " ++ opStr ++ ", arg: " ++ argStr ++ + ", matches? " ++ show matches ++ ", isListCons? " ++ show isListCons) $ + matches || isListCons || hasIsOrEmptyList fun || hasIsOrEmptyList op || hasIsOrEmptyList arg + + HsPar _ inner -> trace "Matched: HsPar" $ hasIsOrEmptyList inner + + XExpr (WrapExpr innerExpr) -> + case innerExpr of + HsWrap _ exprInner -> + trace ("Matched: XExpr WrapExpr with location " ++ showSDocUnsafe (ppr exprInner)) $ + hasIsOrEmptyList (noLocA exprInner) + + XExpr _ -> trace "XExpr case: cannot handle yet" False + + other -> trace ("๐Ÿ›‘ we came to other case: " ++ showSDocUnsafe (ppr other)) False +-- unwrapExpr :: LHsExpr GhcTc -> LHsExpr GhcTc +-- unwrapExpr lexpr = +-- case unLoc lexpr of +-- HsWrap _ inner -> noLocA inner -- inner :: HsExpr GhcTc +-- other -> lexpr + +-- unwrapExprFromHsExpr :: HsExpr GhcTc -> HsExpr GhcTc +-- unwrapExprFromHsExpr (HsWrap _ inner) = unwrapExprFromHsExpr inner +-- unwrapExprFromHsExpr expr = expr + + + +flattenHsAppM :: LHsExpr GhcTc -> MyM (HsExpr GhcTc, [LHsExpr GhcTc]) +flattenHsAppM expr = do + traceM $ "๐Ÿงพ Received expr: " ++ OP.showSDocUnsafe (OP.ppr expr) + let doStmts = case expr of + L _ (HsDo _ _ (L _ stmts)) -> stmts + _ -> [] + + forM_ doStmts $ \stmt -> case stmt of + L _ (BindStmt _ (L _ (VarPat _ (L _ varName))) rhsExpr) -> do -- To get dbConfig (Table Name) + traceM $ "๐Ÿ“ฆ RHS Expr: " ++ OP.showSDocUnsafe (OP.ppr rhsExpr) + let normalizedExpr = stripExpr rhsExpr + case normalizedExpr of + L _ (HsAppType _ (L _ (HsVar _ (L _ fnName))) (HsWC _ innerType)) + | occNameString (occName fnName) `elem` ["getEulerDbConf", "getEulerPsqlDbConf"] -> do + let lhsVarStr = OP.showSDocUnsafe (OP.ppr varName) + typeStr = case innerType of + L _ (HsTyVar _ _ (L _ name)) -> occNameString (occName name) + _ -> "unknown_type" + traceM $ "๐Ÿ“ฅ Inserting into map: " ++ lhsVarStr ++ " -> " ++ typeStr + modify $ \s -> s { tableMap = Map.insert lhsVarStr typeStr (tableMap s) } -- Map Insertion (if dbcongig1 <- getEulerDbConf @orderReferenceT) , ["dbcongig1", "orderReferenceT"] get inserted + _ -> pure () + -- To get where clause => + -- 1. Written in let statement => let whereclause = [Is..] + -- 2. Directly Passed to select functions => findonerow tablename [Is..] + -- 3. Passed as a parameter to the function => fun1 whereClause = do + -- findonerow tableame whereClause + -- 4. A local function or local value defined using a where clause + L _ (LetStmt _ localBindsL) -> + (case localBindsL of + HsValBinds _ valBinds -> + case valBinds of + XValBindsLR (NValBinds bindList _) -> do + traceM "๐Ÿ” Processing LetStmt -> HsValBinds -> XValBindsLR" + forM_ bindList $ \(recFlag, bindBag) -> do + traceM $ "๐ŸŒ€ Processing bind list with RecFlag: " ++ showSDocUnsafe (ppr recFlag) ++ " ,bindBag: " ++ showSDocUnsafe (ppr bindBag) + let binds = bagToList bindBag + traceM $ "๐Ÿ” Number of binds in bindBag: " ++ show (length binds) + if null binds + then traceM "โš ๏ธ bindBag is empty โ€“ nothing to process" + else forM_ (bagToList bindBag) $ \(L _ bind) -> do + traceM $ "๐Ÿ” bind constructor: " ++ showConstr (toConstr bind) + case bind of + FunBind { fun_id = L _ varName + , fun_matches = MG _ (L _ [L _ (Match _ _ _ (GRHSs _ [L _ (GRHS _ [] body)] _) )]) _ + } -> do + traceM $ "๐Ÿ“ฆ Found FunBind with var: " ++ showSDocUnsafe (ppr varName) + traceM $ "๐Ÿ“ฆ Let RHS: " ++ showSDocUnsafe (ppr body) + let normalizedExpr = stripExpr body + varStr = showSDocUnsafe (ppr varName) + traceM $ "๐Ÿงน Normalized RHS: " ++ showSDocUnsafe (ppr normalizedExpr) + let result = hasIsOrEmptyList normalizedExpr + traceM $ "๐Ÿงช hasIsOrEmptyList result: " ++ show result + when result $ do + traceM $ "โœ… Match: RHS has 'is' or '[]', recording clause for: " ++ varStr + modify $ \s -> s { clauseMap = Map.insert varStr (showSDocUnsafe (ppr normalizedExpr)) (clauseMap s) } + _ -> traceM "โ›” Skipping non-PatBind or unhandled bind pattern" + _ -> traceM "โš ๏ธ valBinds is not XValBindsLR -> skipping" + _ -> pure ()) + + _ -> pure () + + go expr [] + where + go :: LHsExpr GhcTc -> [LHsExpr GhcTc] -> MyM (HsExpr GhcTc, [LHsExpr GhcTc]) + go (L _ (HsApp _ f x)) args = go f (x : args) + go (L _ f) args = pure (f, args) + + +stripExpr :: LHsExpr GhcTc -> LHsExpr GhcTc +stripExpr (L l (HsPar _ e)) = stripExpr e +stripExpr (L l (HsAppType x e t)) = L l (HsAppType x (stripExpr e) t) +stripExpr (L l (XExpr (WrapExpr (HsWrap _ e)))) = stripExpr (L l e) +stripExpr other = other + +showTriple :: (String, String, String) -> String +showTriple (a, b, c) = "(" ++ a ++ ", " ++ b ++ ", " ++ c ++ ")" + {- 1. Check if bind is AbsBind, add a mapping from mono to poly Var and recurse for binds @@ -471,8 +771,13 @@ checkAndApplyRule ruleT ap = case ruleT of DBRuleT rule@(DBRule {table_name = ruleTableName}) -> case ap of (L _ (PatExplicitList (TyConApp ty [_, tblName]) exprs)) -> do - case (showS ty == "Clause" && showS tblName == (ruleTableName <> "T")) of - True -> validateDBRule rule (showS tblName) exprs ap + case (showS ty == "Clause") of + True -> do + simplifiedExprs <- trfWhereToSOP exprs + checkWhereClauseRule <- mapM (validateWhereClauseRule (showS tblName)) simplifiedExprs + liftIO $ putStrLn $ "Checking where clause rule for table: " <> (showS tblName) <> " ,clauses: " <> showS exprs <> " ,checkWhereClauseRule: " <> show checkWhereClauseRule + if (showS tblName == (ruleTableName <> "T")) then validateDBRule simplifiedExprs rule (showS tblName) exprs ap + else pure [] False -> pure [] _ -> pure [] FunctionRuleT rule@(FunctionRule {fn_name = ruleFnNames, arg_no}) -> do @@ -648,16 +953,16 @@ Part-2 Validation -} -- Function to check if given DB rules is violated or not -- TODO: Fix this, keep two separate options for - 1. Match All Fields in AND 2. Use 1st column matching or all columns matching for composite key -validateDBRule :: (HasPluginOpts PluginOpts) => DBRule -> String -> [LHsExpr GhcTc] -> LHsExpr GhcTc -> TcM ([(LHsExpr GhcTc, Violation)]) -validateDBRule rule@(DBRule {db_rule_name = ruleName, table_name = ruleTableName, indexed_cols_names = ruleColNames}) tableName clauses expr = do - simplifiedExprs <- trfWhereToSOP clauses +validateDBRule :: (HasPluginOpts PluginOpts) => [SimplifiedIsClause] -> DBRule -> String -> [LHsExpr GhcTc] -> LHsExpr GhcTc -> TcM ([(LHsExpr GhcTc, Violation)]) +validateDBRule simplifiedExprs rule@(DBRule {db_rule_name = ruleName, table_name = ruleTableName, indexed_cols_names = ruleColNames}) tableName clauses expr = do + let checkDBViolation = case (matchAllInsideAnd . pluginOpts $ ?pluginOpts) of True -> checkDBViolationMatchAll False -> checkDBViolationWithoutMatchAll violations <- catMaybes <$> mapM checkDBViolation simplifiedExprs pure violations where - -- Since we need all columns to be indexed, we need to check for the columns in the order of composite key + -- Since we need all columns to be indexed, we need to check for the columns in the order of composite ke checkDBViolationMatchAll :: [SimplifiedIsClause] -> TcM (Maybe (LHsExpr GhcTc, Violation)) checkDBViolationMatchAll sop = do let isDbViolation (cls, colName, tableName) = (ruleTableName == tableName) && not (doesMatchColNameInDbRuleWithComposite colName ruleColNames (map (\(_, col, _) -> col) sop)) @@ -675,6 +980,29 @@ validateDBRule rule@(DBRule {db_rule_name = ruleName, table_name = ruleTableName [] -> pure Nothing ((clause, colName, tableName) : _) -> pure $ Just (clause, NonIndexedDBColumn colName tableName rule) +validateWhereClauseRule :: String -> [SimplifiedIsClause] -> TcM Bool +validateWhereClauseRule tableName simplifiedExprs = do + -- Read JSON file + jsonData <- liftIO $ readFile "/Users/sailaja.b/euler-db/tables_and_fields_with_types.json" + case decode jsonData :: Maybe (HM.HashMap String (HM.HashMap String String)) of + Just jsonMap -> do + let tableKey = tableName + case HM.lookup tableKey jsonMap of + Just fieldsMap -> do + -- Extract field names from the nested structure + let fieldNames = HM.keys fieldsMap + let matchingFields = filter (\field -> "disable" `isInfixOf` field || "enable" `isInfixOf` field) fieldNames + liftIO $ putStrLn $ "Matching fields for " <> tableKey <> ": " <> show matchingFields + -- Check if any matching field exists in simplifiedExprs + let fieldExists = any (\(_, colName, _) -> any (\field -> field == colName) matchingFields) simplifiedExprs + pure fieldExists + Nothing -> do + liftIO $ putStrLn $ "No fields found for table: " <> tableKey + pure False + Nothing -> do + liftIO $ putStrLn "Failed to parse JSON file." + pure False + -- Check only for the ordering of the columns of the composite key doesMatchColNameInDbRuleWithComposite :: String -> [YamlTableKeys] -> [String] -> Bool doesMatchColNameInDbRuleWithComposite _ [] _ = False From 4fb84a4baf7c41b19af503fd62140feae7242bf4 Mon Sep 17 00:00:00 2001 From: Sailaja Date: Tue, 24 Jun 2025 16:00:34 +0530 Subject: [PATCH 02/15] local --- sheriff/.juspay/sheriffRules.yaml | 17 +- sheriff/src/Sheriff/Plugin.hs | 564 +++++++++++------------------- sheriff/src/Sheriff/Types.hs | 48 ++- sheriff/src/Sheriff/TypesUtils.hs | 4 + 4 files changed, 269 insertions(+), 364 deletions(-) diff --git a/sheriff/.juspay/sheriffRules.yaml b/sheriff/.juspay/sheriffRules.yaml index e1ebeb6..4138872 100644 --- a/sheriff/.juspay/sheriffRules.yaml +++ b/sheriff/.juspay/sheriffRules.yaml @@ -185,4 +185,19 @@ rules: - customerEmail db_rule_fixes: - "You might want to include an indexed column in the `where` clause of the query." - db_rule_exceptions: [] \ No newline at end of file + db_rule_exceptions: [] + + - where_clause_rule_name: "DefaultWhereClauseRule" + where_clause_rule_fixes: + - "You might want to include a mandatory column in the `where` clause of the query." + where_clause_rule_ignore_modules: [] + where_clause_rule_check_modules: + - "*" + query_functions_to_check: + - findOneRow + - findOne + - findAll + - findAllRows + - findAllRowsWithLimit + - findAllWithLimit + - findAllWithLimitAndOffset \ No newline at end of file diff --git a/sheriff/src/Sheriff/Plugin.hs b/sheriff/src/Sheriff/Plugin.hs index 76bf3f6..9f8b3eb 100644 --- a/sheriff/src/Sheriff/Plugin.hs +++ b/sheriff/src/Sheriff/Plugin.hs @@ -5,51 +5,36 @@ {-# LANGUAGE NamedFieldPuns #-} {-# LANGUAGE TypeFamilies #-} {-# LANGUAGE TypeSynonymInstances #-} -{-# LANGUAGE FlexibleContexts #-} module Sheriff.Plugin (plugin) where -- Sheriff imports -import Control.Exception (IOException, catch) import Sheriff.CommonTypes import Sheriff.Patterns import Sheriff.Rules import Sheriff.Types import Sheriff.TypesUtils import Sheriff.Utils -import Text.Regex.TDFA ((=~)) -- GHC imports import Control.Applicative ((<|>)) import Control.Monad (foldM, when) import Control.Monad.IO.Class (MonadIO (..)) import Control.Monad.State -import Control.Reference (biplateRef, (^?)) -import Data.Aeson as A -import qualified Data.Aeson.KeyMap as KM -import Data.Function (on, id) +import Data.Aeson as A import Data.Aeson.Encode.Pretty (encodePretty) import Data.Bool (bool) -import Data.ByteString.Lazy (writeFile, appendFile, readFile) +import Data.ByteString.Lazy (writeFile, appendFile) import qualified Data.ByteString.Lazy.Char8 as Char8 import Data.Data import Data.Function (on) -import Data.Functor.Identity (runIdentity) import qualified Data.HashMap.Strict as HM import Data.List (nub, sortBy, groupBy, find, isInfixOf, isSuffixOf, isPrefixOf) import Data.List.Extra (splitOn) -import qualified Data.Map as Map -import Data.Maybe (catMaybes, fromMaybe, listToMaybe) -import Data.Typeable (typeOf) -import Data.Yaml hiding (decode) -import qualified Data.Text as T +import Data.Maybe (catMaybes, fromMaybe, mapMaybe) +import Data.Yaml import Debug.Trace (traceShowId, trace, traceM) import GHC hiding (exprType) -import GHC.Types.TypeEnv (typeEnvElts) -import Prelude hiding (id, writeFile, appendFile, readFile) -import Text.Show.Pretty (ppShow) -import Data.Data (Data, toConstr, gmapQ) -import Data.Generics (everything, mkQ) -import Language.Haskell.Exts (parseFile, prettyPrint, ParseResult(..)) +import Prelude hiding (id, writeFile, appendFile) import System.Directory (createDirectoryIfMissing, getHomeDirectory) #if __GLASGOW_HASKELL__ >= 900 @@ -59,7 +44,6 @@ import GHC.Core.InstEnv import GHC.Core.TyCo.Rep import GHC.Data.Bag import GHC.HsToCore.Monad -import GHC.Hs.Expr (HsExpr(..)) import GHC.HsToCore.Expr import GHC.Plugins hiding ((<>), getHscEnv, purePlugin) import GHC.Tc.Types @@ -84,8 +68,6 @@ import TcType import TyCoRep #endif --- type VarBindingMap = Map.Map String (LHsExpr GhcTc) - plugin :: Plugin plugin = defaultPlugin { typeCheckResultAction = sheriff @@ -95,10 +77,6 @@ plugin = defaultPlugin { purePlugin :: [CommandLineOption] -> IO PluginRecompile purePlugin _ = return NoForceRecompile - - - - --------------------------- Core Logic --------------------------- {- @@ -142,7 +120,6 @@ type SheriffTcM = StateT (HM.HashMap NameModuleValue NameModuleValue) TcM sheriff :: [CommandLineOption] -> ModSummary -> TcGblEnv -> TcM TcGblEnv sheriff opts modSummary tcEnv = do - -- STAGE-1 let moduleName' = moduleNameString $ moduleName $ ms_mod modSummary pluginOpts@PluginOpts{..} = decodeAndUpdateOpts opts defaultPluginOpts @@ -164,8 +141,6 @@ sheriff opts modSummary tcEnv = do when failOnFileNotFound $ addErr (mkInvalidYamlFileErr (show err)) pure [] Right (YamlTables tables) -> pure $ (map yamlToDbRule tables) - - -- Check the parsed rules yaml file. If failed, throw file error if configured. configuredRules <- case parsedRulesYaml of Left err -> do @@ -192,6 +167,7 @@ sheriff opts modSummary tcEnv = do infRule = case find isInfiniteRecursionRule globalRules of Just (InfiniteRecursionRuleT r) -> r _ -> defaultInfiniteRecursionRuleT + liftIO $ putStrLn $ "finalSheriffRules: " <> show finalSheriffRules when logDebugInfo $ liftIO $ print globalRules when logDebugInfo $ liftIO $ print globalExceptionRules @@ -201,61 +177,7 @@ sheriff opts modSummary tcEnv = do insts <- tcg_insts . env_gbl <$> getEnv let namesModTuple = concatMap (\inst -> let clsName = className (is_cls inst) in (is_dfun_name inst, clsName) : fmap (\clsMethod -> (varName clsMethod, clsName)) (classMethods $ is_cls inst)) insts nameModMap = foldr (\(name, clsName) r -> HM.insert (NMV_Name name) (NMV_ClassModule clsName (getModuleName clsName)) r) HM.empty namesModTuple - -- let binds = bagToList $ tcg_binds tcEnv - -- liftIO $ putStrLn ("AST of binds:\n" ++ ppShow binds) - -- liftIO $ putStrLn ("Extracted bind names: " ++ OP.showSDocUnsafe (OP.ppr binds)) - - -- let mymAction = do - -- results <- forM binds extractExprFromBind - -- return (results) - -- let (extractedAll, finalState) = runMyM mymAction - - -- let jsonFilePath = "/Users/sailaja.b/spider/tree.json" - - -- Read the JSON file - -- jsonData <- liftIO (Char8.readFile jsonFilePath `catch` handleReadError) - -- let parsedData = eitherDecode jsonData :: Either String FunctionMap - - -- case parsedData of - -- Right functionMap -> do - -- let filteredEntries = filterEntries functionMap - -- forM_ (HM.toList filteredEntries) $ \(key, functionEntry) -> do - -- let codeContent = code functionEntry - -- -- let codeContent = "res <- findOneRow\n dbConf meshConfig\n $ [Is DB.filterGroupId (Eq groupId),\n Is DB.dimensionValue (Eq dimension)]\n toDomainAll res parseTenantConfigFilter\n ! #function_name \"getTenantConfigFilterByGroupIdAndDimensionValue\"\n ! #parser_name \"parseTenantConfigFilter\"" :: String - - -- -- let regex1 = "findOneRow" :: String - -- -- let regex2 = "findOneRow[[:space:]]+([^[:space:]]+)" :: String - -- -- let regex3 = "findOneRow[[:space:]]+([^[:space:]]+)[[:space:]]+([^[:space:]]+)" :: String - -- -- let regex4 = "findOneRow[[:space:]]+([^[:space:]]+)[[:space:]]+([^[:space:]]+)[[:space:]]+\\[([[:print:][:space:]]*)\\]" :: String - -- let regexFlexible = "(findOneRow|findAllRows)[[:space:]]+([^[:space:]]+)[[:space:]]+([^[:space:]]+)[[:space:]]+.*[[:space:]]*\\[([[:print:][:space:]]*)\\]" :: String - -- -- Test each regex pattern - -- -- let matches1 = codeContent =~ regex1 :: [[String]] - -- -- let matches2 = codeContent =~ regex2 :: [[String]] - -- -- let matches3 = codeContent =~ regex3 :: [[String]] - -- -- let matches4 = codeContent =~ regex4 :: [[String]] - -- let matches5 = codeContent =~ regexFlexible :: [[String]] - - -- -- Print results - -- -- liftIO $ putStrLn $ "Matches with Regex 1: " ++ show matches1 - -- -- liftIO $ putStrLn $ "Matches with Regex 2: " ++ show matches2 - -- -- liftIO $ putStrLn $ "Matches with Regex 3: " ++ show matches3 - -- -- liftIO $ putStrLn $ "Matches with Regex 4: " ++ show matches4 - -- let firstElements = map (\innerList -> case innerList of - -- (a:_) -> a - -- [] -> "No match") matches5 - -- liftIO $ putStrLn $ "Function Name: " ++ key ++ " ,First elements: " ++ show firstElements - - -- Left err -> liftIO $ putStrLn $ "Failed to parse JSON: " ++ err - - -- Example: Extract specific data from the FunctionMap - -- let extractedData = fromMaybe HM.empty (decode jsonData :: Maybe FunctionMap) - -- liftIO $ putStrLn $ "Extracted data: " ++ show extractedData - - - -- liftIO $ putStrLn $ "๐Ÿ“ Final state: " ++ show finalState - -- liftIO $ putStrLn $ "๐Ÿ“Œ Extracted expressions: " ++ OP.showSDocUnsafe (OP.ppr extractedAll) - rawErrors <- concat <$> (mapM (loopOverModBinds finalSheriffRules) $ bagToList $ tcg_binds tcEnv) (rawInfiniteRecursionErrors, _) <- flip runStateT nameModMap $ concat <$> (mapM (checkInfiniteRecursion True infRule) $ bagToList $ tcg_binds tcEnv) @@ -278,230 +200,8 @@ sheriff opts modSummary tcEnv = do return tcEnv -handleReadError :: IOException -> IO Char8.ByteString -handleReadError e = do - putStrLn $ "c " ++ show e - return Char8.empty - --- filterEntries :: HM.HashMap String FunctionEntry -> HM.HashMap String FunctionEntry --- filterEntries functionMap = --- HM.filter containsFindFunctions functionMap --- where --- containsFindFunctions :: FunctionEntry -> Bool --- containsFindFunctions entry = --- case toJSON entry of --- Object obj -> -- Ensure the JSON value is an Object --- case KM.lookup "functions_it_is_calling" obj of --- Just (Array functions) -> --- any (\func -> func == "findOneRow" || func == "findAllRows") functions --- _ -> False --- _ -> False -- Handle cases where the JSON value is not an Object - -type TableMap = Map.Map String String -type ClauseMap = Map.Map String String - - -data MyState = MyState - { tableMap :: TableMap - , clauseMap :: ClauseMap - } deriving (Show) - -type MyM = State MyState - -runMyM :: MyM a -> (a, MyState) -runMyM action = runState action initialMyState - -initialMyState :: MyState -initialMyState = MyState Map.empty Map.empty - -extractExprFromBind :: LHsBindLR GhcTc GhcTc -> MyM (Maybe (String, String, String)) -extractExprFromBind (L loc bind) = do - traceM "extractExprFromBind called" - myMap <- get - case bind of - FunBind { fun_matches = MG { mg_alts = L _ matches } } -> do - traceM "Matched FunBind" - results <- forM matches $ \(L _ match) -> case match of - Match { m_grhss = GRHSs _ grhss _ } -> do - innerResults <- forM grhss $ \(L _ grhs) -> case grhs of - GRHS _ _ body -> do - let exprs = body ^? biplateRef :: [LHsExpr GhcTc] - extracted <- mapM (\expr -> extractQueryInfo expr (L loc bind)) exprs - traceM $ "body from FunBind: " ++ OP.showSDocUnsafe (OP.ppr body) ++ "exprs from FunBind: " ++ OP.showSDocUnsafe (OP.ppr exprs) ++ "\n๐Ÿ“Œ Query Info:\n" ++ unlines (map showTriple (catMaybes extracted)) - pure (catMaybes extracted) - pure (concat innerResults) - pure $ listToMaybe (concat results) - - PatBind { pat_rhs = GRHSs _ grhss _ } -> do - traceM "Matched PatBind" - results <- forM grhss $ \(L _ grhs) -> case grhs of - GRHS _ _ body -> do - let exprs = body ^? biplateRef :: [LHsExpr GhcTc] - extracted <- mapM (\expr -> extractQueryInfo expr (L loc bind)) exprs - traceM $ "body from FunBind: " ++ OP.showSDocUnsafe (OP.ppr body) ++ "exprs from PatBind: " ++ OP.showSDocUnsafe (OP.ppr exprs) ++ "\n๐Ÿ“Œ Query Info:\n" ++ unlines (map showTriple (catMaybes extracted)) - pure (catMaybes extracted) - pure $ listToMaybe (concat results) - - AbsBinds { abs_binds = binds } -> do - traceM "Matched AbsBinds" - results <- mapM extractExprFromBind (bagToList binds) - pure $ listToMaybe (catMaybes results) - - _ -> do - traceM "No matching bind constructor (not FunBind or PatBind)" - pure Nothing - -extractQueryInfo :: LHsExpr GhcTc -> LHsBindLR GhcTc GhcTc -> MyM (Maybe (String, String, String)) -extractQueryInfo expr bindings = do - traceM ("๐Ÿ” Called extractQueryInfo with expr: " ++ OP.showSDocUnsafe (OP.ppr expr)) - (fn, args) <- flattenHsAppM expr - let fnName = OP.showSDocUnsafe (OP.ppr fn) - if fnName `elem` ["findOneRow", "findAllRows"] && length args == 3 - then do - let clause = OP.showSDocUnsafe (OP.ppr (args !! 2)) - let typeStr = show (typeOf (unLoc (args !! 2))) - let unlocatedBindings = [unLoc bindings] - let isWhereClause = hasIsOrEmptyList (args !! 2) - clauseMapVal <- gets clauseMap - let whereClause = if isWhereClause - then clause - else fromMaybe " trace "Matched: Empty list" True - ExplicitList _ xs -> trace ("Checking list of length " ++ show (length xs)) $ any hasIsOrEmptyList xs - - HsApp _ fun arg -> - let funStr = showSDocUnsafe (ppr (unLoc fun)) - matches = any (`isPrefixOf` funStr) ["Is", "And", "Or"] - isListCons = funStr == ":" || funStr == "[]" - in trace ("Matched: HsApp, head string: " ++ funStr ++ ", matches? " ++ show matches ++ ", isListCons? " ++ show isListCons) $ - matches || isListCons || hasIsOrEmptyList fun || hasIsOrEmptyList arg - OpApp _ fun op arg -> - let funStr = showSDocUnsafe (ppr (unLoc fun)) - opStr = showSDocUnsafe (ppr (unLoc op)) - argStr = showSDocUnsafe (ppr (unLoc arg)) - matches = any (`isPrefixOf` funStr) ["Is", "And", "Or"] - isListCons = funStr == ":" || funStr == "[]" - in trace ("Matched: OpApp, fun: " ++ funStr ++ ", op: " ++ opStr ++ ", arg: " ++ argStr ++ - ", matches? " ++ show matches ++ ", isListCons? " ++ show isListCons) $ - matches || isListCons || hasIsOrEmptyList fun || hasIsOrEmptyList op || hasIsOrEmptyList arg - - HsPar _ inner -> trace "Matched: HsPar" $ hasIsOrEmptyList inner - - XExpr (WrapExpr innerExpr) -> - case innerExpr of - HsWrap _ exprInner -> - trace ("Matched: XExpr WrapExpr with location " ++ showSDocUnsafe (ppr exprInner)) $ - hasIsOrEmptyList (noLocA exprInner) - - XExpr _ -> trace "XExpr case: cannot handle yet" False - - other -> trace ("๐Ÿ›‘ we came to other case: " ++ showSDocUnsafe (ppr other)) False --- unwrapExpr :: LHsExpr GhcTc -> LHsExpr GhcTc --- unwrapExpr lexpr = --- case unLoc lexpr of --- HsWrap _ inner -> noLocA inner -- inner :: HsExpr GhcTc --- other -> lexpr - --- unwrapExprFromHsExpr :: HsExpr GhcTc -> HsExpr GhcTc --- unwrapExprFromHsExpr (HsWrap _ inner) = unwrapExprFromHsExpr inner --- unwrapExprFromHsExpr expr = expr - - - -flattenHsAppM :: LHsExpr GhcTc -> MyM (HsExpr GhcTc, [LHsExpr GhcTc]) -flattenHsAppM expr = do - traceM $ "๐Ÿงพ Received expr: " ++ OP.showSDocUnsafe (OP.ppr expr) - let doStmts = case expr of - L _ (HsDo _ _ (L _ stmts)) -> stmts - _ -> [] - - forM_ doStmts $ \stmt -> case stmt of - L _ (BindStmt _ (L _ (VarPat _ (L _ varName))) rhsExpr) -> do -- To get dbConfig (Table Name) - traceM $ "๐Ÿ“ฆ RHS Expr: " ++ OP.showSDocUnsafe (OP.ppr rhsExpr) - let normalizedExpr = stripExpr rhsExpr - case normalizedExpr of - L _ (HsAppType _ (L _ (HsVar _ (L _ fnName))) (HsWC _ innerType)) - | occNameString (occName fnName) `elem` ["getEulerDbConf", "getEulerPsqlDbConf"] -> do - let lhsVarStr = OP.showSDocUnsafe (OP.ppr varName) - typeStr = case innerType of - L _ (HsTyVar _ _ (L _ name)) -> occNameString (occName name) - _ -> "unknown_type" - traceM $ "๐Ÿ“ฅ Inserting into map: " ++ lhsVarStr ++ " -> " ++ typeStr - modify $ \s -> s { tableMap = Map.insert lhsVarStr typeStr (tableMap s) } -- Map Insertion (if dbcongig1 <- getEulerDbConf @orderReferenceT) , ["dbcongig1", "orderReferenceT"] get inserted - _ -> pure () - -- To get where clause => - -- 1. Written in let statement => let whereclause = [Is..] - -- 2. Directly Passed to select functions => findonerow tablename [Is..] - -- 3. Passed as a parameter to the function => fun1 whereClause = do - -- findonerow tableame whereClause - -- 4. A local function or local value defined using a where clause - L _ (LetStmt _ localBindsL) -> - (case localBindsL of - HsValBinds _ valBinds -> - case valBinds of - XValBindsLR (NValBinds bindList _) -> do - traceM "๐Ÿ” Processing LetStmt -> HsValBinds -> XValBindsLR" - forM_ bindList $ \(recFlag, bindBag) -> do - traceM $ "๐ŸŒ€ Processing bind list with RecFlag: " ++ showSDocUnsafe (ppr recFlag) ++ " ,bindBag: " ++ showSDocUnsafe (ppr bindBag) - let binds = bagToList bindBag - traceM $ "๐Ÿ” Number of binds in bindBag: " ++ show (length binds) - if null binds - then traceM "โš ๏ธ bindBag is empty โ€“ nothing to process" - else forM_ (bagToList bindBag) $ \(L _ bind) -> do - traceM $ "๐Ÿ” bind constructor: " ++ showConstr (toConstr bind) - case bind of - FunBind { fun_id = L _ varName - , fun_matches = MG _ (L _ [L _ (Match _ _ _ (GRHSs _ [L _ (GRHS _ [] body)] _) )]) _ - } -> do - traceM $ "๐Ÿ“ฆ Found FunBind with var: " ++ showSDocUnsafe (ppr varName) - traceM $ "๐Ÿ“ฆ Let RHS: " ++ showSDocUnsafe (ppr body) - let normalizedExpr = stripExpr body - varStr = showSDocUnsafe (ppr varName) - traceM $ "๐Ÿงน Normalized RHS: " ++ showSDocUnsafe (ppr normalizedExpr) - let result = hasIsOrEmptyList normalizedExpr - traceM $ "๐Ÿงช hasIsOrEmptyList result: " ++ show result - when result $ do - traceM $ "โœ… Match: RHS has 'is' or '[]', recording clause for: " ++ varStr - modify $ \s -> s { clauseMap = Map.insert varStr (showSDocUnsafe (ppr normalizedExpr)) (clauseMap s) } - _ -> traceM "โ›” Skipping non-PatBind or unhandled bind pattern" - _ -> traceM "โš ๏ธ valBinds is not XValBindsLR -> skipping" - _ -> pure ()) - - _ -> pure () - - go expr [] - where - go :: LHsExpr GhcTc -> [LHsExpr GhcTc] -> MyM (HsExpr GhcTc, [LHsExpr GhcTc]) - go (L _ (HsApp _ f x)) args = go f (x : args) - go (L _ f) args = pure (f, args) - - -stripExpr :: LHsExpr GhcTc -> LHsExpr GhcTc -stripExpr (L l (HsPar _ e)) = stripExpr e -stripExpr (L l (HsAppType x e t)) = L l (HsAppType x (stripExpr e) t) -stripExpr (L l (XExpr (WrapExpr (HsWrap _ e)))) = stripExpr (L l e) -stripExpr other = other - -showTriple :: (String, String, String) -> String -showTriple (a, b, c) = "(" ++ a ++ ", " ++ b ++ ", " ++ c ++ ")" +--------------------------- Infinite Recursion Detection Logic --------------------------- {- 1. Check if bind is AbsBind, add a mapping from mono to poly Var and recurse for binds @@ -723,6 +423,7 @@ getBadFnCalls rules (FunBind{fun_matches = matches}) = do -- use childrenBi and then repeated children usage as per use case -- (exprs :: [LHsExpr GhcTc]) = traverseConditionalUni (noWhereClauseExpansion) (childrenBi match :: [LHsExpr GhcTc]) (exprs :: [LHsExpr GhcTc]) = traverseAstConditionally match noWhereClauseExpansion + -- liftIO $ putStrLn ("exprs: " ++ showS exprs) concat <$> mapM (isBadExpr rules) exprs getBadFnCalls _ _ = pure [] @@ -769,15 +470,19 @@ isBadExprHelper rules ap = concat <$> mapM (\rule -> checkAndApplyRule rule ap) checkAndApplyRule :: (HasPluginOpts PluginOpts) => Rule -> LHsExpr GhcTc -> TcM ([(LHsExpr GhcTc, Violation)]) checkAndApplyRule ruleT ap = case ruleT of DBRuleT rule@(DBRule {table_name = ruleTableName}) -> + case ap of + (L _ (PatExplicitList (TyConApp ty [_, tblName]) exprs)) -> do + case (showS ty == "Clause" && showS tblName == (ruleTableName <> "T")) of + True -> validateDBRule rule (showS tblName) exprs ap + False -> pure [] + _ -> pure [] + WhereClauseRuleT rule -> case ap of (L _ (PatExplicitList (TyConApp ty [_, tblName]) exprs)) -> do case (showS ty == "Clause") of - True -> do - simplifiedExprs <- trfWhereToSOP exprs - checkWhereClauseRule <- mapM (validateWhereClauseRule (showS tblName)) simplifiedExprs - liftIO $ putStrLn $ "Checking where clause rule for table: " <> (showS tblName) <> " ,clauses: " <> showS exprs <> " ,checkWhereClauseRule: " <> show checkWhereClauseRule - if (showS tblName == (ruleTableName <> "T")) then validateDBRule simplifiedExprs rule (showS tblName) exprs ap - else pure [] + True -> do + let (fnName, _) = fromMaybe (error "No function name found", []) $ getFnNameWithAllArgs ap + validateWhereClauseRule rule (showS tblName) exprs (showSDocUnsafe $ ppr fnName) False -> pure [] _ -> pure [] FunctionRuleT rule@(FunctionRule {fn_name = ruleFnNames, arg_no}) -> do @@ -953,9 +658,9 @@ Part-2 Validation -} -- Function to check if given DB rules is violated or not -- TODO: Fix this, keep two separate options for - 1. Match All Fields in AND 2. Use 1st column matching or all columns matching for composite key -validateDBRule :: (HasPluginOpts PluginOpts) => [SimplifiedIsClause] -> DBRule -> String -> [LHsExpr GhcTc] -> LHsExpr GhcTc -> TcM ([(LHsExpr GhcTc, Violation)]) -validateDBRule simplifiedExprs rule@(DBRule {db_rule_name = ruleName, table_name = ruleTableName, indexed_cols_names = ruleColNames}) tableName clauses expr = do - +validateDBRule :: (HasPluginOpts PluginOpts) => DBRule -> String -> [LHsExpr GhcTc] -> LHsExpr GhcTc -> TcM ([(LHsExpr GhcTc, Violation)]) +validateDBRule rule@(DBRule {db_rule_name = ruleName, table_name = ruleTableName, indexed_cols_names = ruleColNames}) tableName clauses expr = do + simplifiedExprs <- trfWhereToSOP clauses let checkDBViolation = case (matchAllInsideAnd . pluginOpts $ ?pluginOpts) of True -> checkDBViolationMatchAll False -> checkDBViolationWithoutMatchAll @@ -980,28 +685,48 @@ validateDBRule simplifiedExprs rule@(DBRule {db_rule_name = ruleName, table_name [] -> pure Nothing ((clause, colName, tableName) : _) -> pure $ Just (clause, NonIndexedDBColumn colName tableName rule) -validateWhereClauseRule :: String -> [SimplifiedIsClause] -> TcM Bool -validateWhereClauseRule tableName simplifiedExprs = do - -- Read JSON file - jsonData <- liftIO $ readFile "/Users/sailaja.b/euler-db/tables_and_fields_with_types.json" - case decode jsonData :: Maybe (HM.HashMap String (HM.HashMap String String)) of +validateWhereClauseRule :: (HasPluginOpts PluginOpts) => WhereClauseRule -> String -> [LHsExpr GhcTc] -> String -> TcM ([(LHsExpr GhcTc, Violation)]) +validateWhereClauseRule rule tableName clauses fnName = do + liftIO $ putStrLn $ "Validating WhereClauseRule for table: " <> tableName <> " with clauses: " <> showS clauses + simplifiedExprs <- trfWhereToSOP clauses + let matched = concatMap (map (\(_, colName, _) -> colName)) simplifiedExprs -- Extract column names from all clauses + liftIO $ putStrLn $ "tableKey: " <> tableName <> ", Matched columns: " <> show matched <> ", fnName: " <> fnName + jsonData <- liftIO $ Char8.readFile "tables_and_fields_with_types.json" + case A.decode jsonData :: Maybe (HM.HashMap String (HM.HashMap String String)) of Just jsonMap -> do let tableKey = tableName case HM.lookup tableKey jsonMap of Just fieldsMap -> do - -- Extract field names from the nested structure let fieldNames = HM.keys fieldsMap - let matchingFields = filter (\field -> "disable" `isInfixOf` field || "enable" `isInfixOf` field) fieldNames + let isBoolField field = case HM.lookup field fieldsMap of + Just value -> "Bool" `isInfixOf` value + _ -> False + let matchingFields = filter (\field -> ("disable" `isInfixOf` field || "enable" `isInfixOf` field) && isBoolField field) fieldNames + liftIO $ putStrLn $ "Matching fields for " <> tableKey <> ": " <> show matchingFields - -- Check if any matching field exists in simplifiedExprs - let fieldExists = any (\(_, colName, _) -> any (\field -> field == colName) matchingFields) simplifiedExprs - pure fieldExists + violation <- checkWhereClauseViolation tableKey matchingFields simplifiedExprs -- Pass all simplifiedExprs at once + pure $ maybe [] (\v -> [v]) violation Nothing -> do liftIO $ putStrLn $ "No fields found for table: " <> tableKey - pure False + pure [] Nothing -> do liftIO $ putStrLn "Failed to parse JSON file." - pure False + pure [] + where + checkWhereClauseViolation :: String -> [String] -> [[SimplifiedIsClause]] -> TcM (Maybe (LHsExpr GhcTc, Violation)) + checkWhereClauseViolation tableKey matchingFields sopList = do + let matched = concatMap (map (\(_, colName, _) -> colName)) sopList + let isWhereClauseViolation = not (null matchingFields) && all (\field -> field `notElem` matched) matchingFields + liftIO $ putStrLn $ "tableKey: " <> tableKey <> ", Matched columns: " <> show matched <> ", Is violation: " <> show isWhereClauseViolation + if null matchingFields + then pure Nothing + else case concat sopList of + [] -> pure Nothing + ((clause, _, _) : _) -> + if isWhereClauseViolation + -- then pure $ Just (clause, WhereClauseViolationDetected tableName matchingFields rule) + then pure Nothing + else pure Nothing -- Check only for the ordering of the columns of the composite key doesMatchColNameInDbRuleWithComposite :: String -> [YamlTableKeys] -> [String] -> Bool @@ -1035,22 +760,37 @@ trfWhereToSOP [] = pure [[]] trfWhereToSOP (clause : ls) = do let res = getWhereClauseFnNameWithAllArgs clause (fnName, args) = fromMaybe ("NA", []) res + liftIO $ putStrLn $ "Processing clause: " <> showS clause <> ", Function name: " <> fnName <> ", Args: " <> showS args case (fnName, args) of ("And", [(L _ (PatExplicitList _ arg))]) -> do + liftIO $ putStrLn "Detected 'And' clause." curr <- trfWhereToSOP arg rem <- trfWhereToSOP ls pure [x <> y | x <- curr, y <- rem] ("Or", [(L _ (PatExplicitList _ arg))]) -> do + liftIO $ putStrLn "Detected 'Or' clause." curr <- foldM (\r cls -> fmap (<> r) $ trfWhereToSOP [cls]) [] arg rem <- trfWhereToSOP ls pure [x <> y | x <- curr, y <- rem] ("$WIs", [arg1, arg2]) -> do + liftIO $ putStrLn "Detected 'Is' clause." curr <- getIsClauseData arg1 arg2 clause rem <- trfWhereToSOP ls case curr of - Nothing -> pure rem - Just (tblName, colName) -> pure $ fmap (\lst -> (clause, tblName, colName) : lst) rem - (fn, _) -> when ((logWarnInfo . pluginOpts $ ?pluginOpts)) (liftIO $ print $ "Invalid/unknown clause in `where` clause : " <> fn <> " at " <> (showS . getLoc2 $ clause)) >> trfWhereToSOP ls + Nothing -> do + liftIO $ putStrLn "Failed to extract 'Is' clause data." + pure rem + Just (tblName, colName) -> do + liftIO $ putStrLn $ "Extracted 'Is' clause data: Table = " <> tblName <> ", Column = " <> colName + pure $ fmap (\lst -> (clause, tblName, colName) : lst) rem + ("getField", a ) -> do + -- Handle `getField @"columnName"` syntax + liftIO $ putStrLn $ "Detected 'getField' clause for column: " <> showS a + -- rem <- trfWhereToSOP [] + pure [] + (fn, _) -> do + liftIO $ putStrLn $ "Invalid/unknown clause in `where` clause: " <> fn <> " at " <> (showS . getLoc2 $ clause) + trfWhereToSOP ls -- Get table field name and table name for the `Se.Is` clause -- Patterns to match 'getField`, `recordDot`, `overloadedRecordDot` (ghc > 9), selector (duplicate record fields), rec fields (ghc 9), lens @@ -1069,12 +809,19 @@ getIsClauseData fieldArg _comp _clause = do ("$sel" : colName : tableName : []) -> pure $ Just (colName, tableName) _ -> when ((logWarnInfo . pluginOpts $ ?pluginOpts)) (liftIO $ print "Invalid pattern for Selector way") >> pure Nothing RecordDot -> do + liftIO $ putStrLn $ "Debugging RecordDot: AST structure of fieldArg = " <> showS fieldArg + let allNodes = traverseAst fieldArg :: [HsExpr GhcTc] + liftIO $ putStrLn $ "All nodes in AST: " <> showS allNodes let tyApps = filter (\x -> case x of (HsApp _ (L _ (HsAppType _ _ fldName)) tableVar) -> True (PatHsWrap (WpCompose (WpEvApp (EvExpr _hasFld)) (WpCompose (WpTyApp _fldType) (WpTyApp tableVar))) (HsAppType _ _ fldName)) -> True + (HsApp _ (L _ fldName) tableVar) -> True -- Added fallback for simpler AST structure + (HsVar _ fldName) -> True -- Handle HsVar for field names + (HsOverLit _ (OverLit {ol_val = HsIsString _ colName})) -> True -- Handle string literals _ -> False - ) $ (traverseAst fieldArg :: [HsExpr GhcTc]) - if length tyApps > 0 + ) allNodes + liftIO $ putStrLn $ "Filtered tyApps: " <> showS tyApps + if not (null tyApps) then case head tyApps of (HsApp _ (L _ (HsAppType _ _ fldName)) tableVar) -> do @@ -1090,6 +837,76 @@ getIsClauseData fieldArg _comp _clause = do TyConApp ty1 _ -> showS ty1 ty -> showS ty in pure $ Just (getStrFromHsWildCardBndrs fldName, take (length tblName' - 1) tblName') + (HsApp _ (L _ fldName) tableVar) -> do -- Handle simpler AST structure + typ <- getHsExprType (logTypeDebugging . pluginOpts $ ?pluginOpts) tableVar + let tblName' = case typ of + AppTy ty1 _ -> showS ty1 + TyConApp ty1 _ -> showS ty1 + ty -> showS ty + pure $ Just (showS fldName, take (length tblName' - 1) tblName') + (HsVar _ fldName) -> do -- Handle HsVar for field names + let fldNameStr = occNameString (nameOccName (idName (unLoc fldName))) -- Extract the name as a String + liftIO $ putStrLn $ "Detected HsVar: " <> fldNameStr + case fldNameStr of + "getField" -> do + let colName = case allNodes of + (HsOverLit _ (OverLit {ol_val = HsIsString _ colName})) : _ -> + let extractedColName = unpackFS colName + in trace ("Matched HsOverLit (HsIsString): Extracted column name = " <> extractedColName) extractedColName + + (HsVar _ directFldName) : _ -> + let directFldNameStr = occNameString (nameOccName (idName (unLoc directFldName))) + in trace ("Matched HsVar directly: Field name = " <> directFldNameStr) directFldNameStr + + (HsPar _ expr) : _ -> + -- Helper function to recursively unwrap expressions + let recUnwrap :: LHsExpr GhcTc -> LHsExpr GhcTc + recUnwrap expr = case unLoc expr of + HsPar _ innerExpr -> recUnwrap innerExpr + other -> expr + + exprStr = showS expr + debugExpr = showS (unLoc expr) -- Log the structure of unLoc expr + innerColName = (case unLoc (recUnwrap expr) of + -- Match HsOverLit directly + HsOverLit _ (OverLit {ol_val = HsIsString _ colName}) -> + let extractedColName = unpackFS colName + in trace ("Inner Matched HsOverLit (HsIsString): Extracted column name = " <> extractedColName) extractedColName + + -- Match HsLit directly + HsLit _ (HsString _ colName) -> + let extractedColName = unpackFS colName + in trace ("Inner Matched HsLit (HsString): Extracted column name = " <> extractedColName) extractedColName + + -- Match HsVar directly + HsVar _ name -> + let extractedColName = occNameString (nameOccName (idName (unLoc name))) + in trace ("Inner Matched HsVar: Extracted column name = " <> extractedColName) extractedColName + + -- Match HsApp (function application) + HsApp _ func arg -> + let funcStr = showS func + argStr = showS arg + in case unLoc (recUnwrap func) of + HsVar _ name | occNameString (nameOccName (idName (unLoc name))) == "getField" -> + case unLoc (recUnwrap arg) of + HsLit _ (HsString _ colName) -> + let extractedColName = unpackFS colName + in trace ("Inner Matched HsApp (getField): Extracted column name = " <> extractedColName) extractedColName + _ -> trace ("Inner Argument of getField is not a string literal. Defaulting to UnknownColumn") "UnknownColumn" + _ -> trace ("Inner Function is not getField. Defaulting to UnknownColumn") "UnknownColumn" + + -- Default case for unmatched patterns + _ -> trace ("Inner No matching case found for unLoc expr: " <> debugExpr <> ". Defaulting to UnknownColumn") "UnknownColumn" + ) + in trace ("Matched HsPar: Expression = " <> exprStr <> ", Inner column name = " <> innerColName) innerColName + _ -> trace ("No matching case found. Nodes in AST: " <> showS allNodes <> ". Defaulting to UnknownColumn") "UnknownColumn7" + liftIO $ putStrLn $ "Extracted column name: " <> colName + pure $ Just (colName, "AuthenticationAccountT") -- Replace with actual table name if available + _ -> pure $ Just (fldNameStr, "UnknownTable") + (HsOverLit _ (OverLit {ol_val = HsIsString _ colName})) -> do -- Handle string literals + liftIO $ putStrLn $ "Detected string literal: " <> unpackFS colName + pure $ Just (unpackFS colName, "UnknownTable") _ -> when ((logWarnInfo . pluginOpts $ ?pluginOpts)) (liftIO $ putStrLn "HsAppType not present. Should never be the case as we already filtered.") >> pure Nothing else when ((logWarnInfo . pluginOpts $ ?pluginOpts)) (liftIO $ putStrLn "HsAppType not present after filtering. Should never reach as already deduced RecordDot.") >> pure Nothing Lens -> do @@ -1206,31 +1023,68 @@ getFnNameAndTypeableExprWithAllArgs _ = Nothing -- TODO: Verify the correctness of this function before moving it to utils -- Get function name with all it's arguments + getFnNameWithAllArgs :: LHsExpr GhcTc -> Maybe (Located Var, [LHsExpr GhcTc]) -getFnNameWithAllArgs (L loc (HsVar _ v)) = Just (getLocated v loc, []) -getFnNameWithAllArgs (L _ (HsConLikeOut _ cl)) = (\clId -> (noExprLoc clId, [])) <$> conLikeWrapId cl -getFnNameWithAllArgs (L _ (HsAppType _ expr _)) = getFnNameWithAllArgs expr -getFnNameWithAllArgs (L _ (HsApp _ (L loc (HsVar _ v)) funr)) = Just (getLocated v loc, [funr]) -getFnNameWithAllArgs (L _ (HsPar _ expr)) = getFnNameWithAllArgs expr -getFnNameWithAllArgs (L _ (HsApp _ funl funr)) = do - let res = getFnNameWithAllArgs funl - case res of - Nothing -> Nothing - Just (fnName, ls) -> Just (fnName, ls ++ [funr]) -getFnNameWithAllArgs (L loc (OpApp _ funl op funr)) = do - case showS op of - "($)" -> getFnNameWithAllArgs $ (L loc (HsApp noExtFieldOrAnn funl funr)) - _ -> Nothing -getFnNameWithAllArgs (L loc ap@(PatHsWrap _ expr)) = getFnNameWithAllArgs (L loc expr) +getFnNameWithAllArgs expr = case expr of + L loc (HsVar _ v) -> + trace ("getFnNameWithAllArgs: Detected HsVar with name = " <> showS v) $ + Just (getLocated v loc, []) + + L _ (HsConLikeOut _ cl) -> + trace "getFnNameWithAllArgs: Detected HsConLikeOut" $ + (\clId -> (noExprLoc clId, [])) <$> conLikeWrapId cl + + L _ (HsAppType _ expr _) -> + trace "getFnNameWithAllArgs: Detected HsAppType" $ + getFnNameWithAllArgs expr + + L _ (HsApp _ (L loc (HsVar _ v)) funr) -> + trace ("getFnNameWithAllArgs: Detected HsApp with HsVar function name = " <> showS v) $ + Just (getLocated v loc, [funr]) + + L _ (HsPar _ expr) -> + trace "getFnNameWithAllArgs: Detected HsPar" $ + getFnNameWithAllArgs expr + + L _ (HsApp _ funl funr) -> + trace "getFnNameWithAllArgs: Detected HsApp with nested function application" $ + case getFnNameWithAllArgs funl of + Nothing -> trace "getFnNameWithAllArgs: No function name found in nested HsApp" Nothing + Just (fnName, ls) -> + trace ("getFnNameWithAllArgs: Found function name = " <> showS fnName <> " with arguments = " <> showS ls) $ + Just (fnName, ls ++ [funr]) + + L loc (OpApp _ funl op funr) -> + trace ("getFnNameWithAllArgs: Detected OpApp with operator = " <> showS op) $ + case showS op of + "($)" -> + trace "getFnNameWithAllArgs: Detected ($) operator, treating as HsApp" $ + getFnNameWithAllArgs (L loc (HsApp noExtFieldOrAnn funl funr)) + _ -> trace "getFnNameWithAllArgs: Unsupported operator, returning Nothing" Nothing + + L loc ap@(PatHsWrap _ expr) -> + trace "getFnNameWithAllArgs: Detected PatHsWrap" $ + getFnNameWithAllArgs (L loc expr) + #if __GLASGOW_HASKELL__ >= 900 -getFnNameWithAllArgs (L loc ap@(PatHsExpansion orig expanded)) = - case (orig, expanded) of - ((OpApp _ _ op _), (HsApp _ (L _ (HsApp _ op' funl)) funr)) -> case showS op of - "($)" -> getFnNameWithAllArgs (L loc (HsApp noExtFieldOrAnn funl funr)) - _ -> getFnNameWithAllArgs (L loc expanded) - _ -> getFnNameWithAllArgs (L loc expanded) + L loc ap@(PatHsExpansion orig expanded) -> + trace "getFnNameWithAllArgs: Detected PatHsExpansion" $ + case (orig, expanded) of + ((OpApp _ _ op _), (HsApp _ (L _ (HsApp _ op' funl)) funr)) -> + case showS op of + "($)" -> + trace "getFnNameWithAllArgs: Detected ($) operator in PatHsExpansion, treating as HsApp" $ + getFnNameWithAllArgs (L loc (HsApp noExtFieldOrAnn funl funr)) + _ -> + trace "getFnNameWithAllArgs: Unsupported operator in PatHsExpansion, processing expanded expression" $ + getFnNameWithAllArgs (L loc expanded) + _ -> + trace "getFnNameWithAllArgs: Processing expanded expression in PatHsExpansion" $ + getFnNameWithAllArgs (L loc expanded) #endif -getFnNameWithAllArgs _ = Nothing + + _ -> + trace "getFnNameWithAllArgs: No matching case found, returning Nothing" Nothing --------------------------- Sheriff Plugin Utils --------------------------- -- Transform the FnBlockedInArg Violation with correct expression diff --git a/sheriff/src/Sheriff/Types.hs b/sheriff/src/Sheriff/Types.hs index 24fd04c..1615490 100644 --- a/sheriff/src/Sheriff/Types.hs +++ b/sheriff/src/Sheriff/Types.hs @@ -191,6 +191,35 @@ instance FromJSON FunctionRule where fn_rule_ignore_functions <- o .:? "fn_rule_ignore_functions" .!= (fn_rule_ignore_functions defaultFunctionRule) return FunctionRule {..} +data WhereClauseRule = + WhereClauseRule + { + where_clause_rule_name :: String, + where_clause_rule_fixes :: Suggestions, + where_clause_rule_ignore_modules :: Modules, + where_clause_rule_check_modules :: Modules, + query_functions_to_check :: [String] + } + deriving (Show, Eq) + +defaultWhereClauseRule :: WhereClauseRule +defaultWhereClauseRule = WhereClauseRule { + where_clause_rule_name = "WhereClauseRule", + where_clause_rule_fixes = ["You Might want to include an mandatory column in the `where` clause of the query."], + where_clause_rule_ignore_modules = [], + where_clause_rule_check_modules = ["*"], + query_functions_to_check = ["find", "findOne", "findoneRow" , "findAllRows"] + } + +instance FromJSON WhereClauseRule where + parseJSON = withObject "WhereClauseRule" $ \o -> do + where_clause_rule_name <- o .: "where_clause_rule_name" + where_clause_rule_fixes <- o .: "where_clause_rule_fixes" + where_clause_rule_ignore_modules <- o .:? "where_clause_rule_ignore_modules" .!= (where_clause_rule_ignore_modules defaultWhereClauseRule) + where_clause_rule_check_modules <- o .:? "where_clause_rule_check_modules" .!= (where_clause_rule_check_modules defaultWhereClauseRule) + query_functions_to_check <- o .:? "query_functions_to_check" .!= (query_functions_to_check defaultWhereClauseRule) + return WhereClauseRule {..} + data InfiniteRecursionRule = InfiniteRecursionRule { @@ -316,11 +345,12 @@ data Rule = DBRuleT DBRule | FunctionRuleT FunctionRule | InfiniteRecursionRuleT InfiniteRecursionRule + | WhereClauseRuleT WhereClauseRule | GeneralRuleT GeneralRule deriving (Show, Eq) instance FromJSON Rule where - parseJSON str = (DBRuleT <$> parseJSON str) <|> (FunctionRuleT <$> parseJSON str) <|> (InfiniteRecursionRuleT <$> parseJSON str) <|> (GeneralRuleT <$> parseJSON str) <|> (fail $ "Invalid Rule: " <> show str) + parseJSON str = (DBRuleT <$> parseJSON str) <|> (FunctionRuleT <$> parseJSON str) <|> (InfiniteRecursionRuleT <$> parseJSON str) <|> (WhereClauseRuleT <$> parseJSON str) <|>(GeneralRuleT <$> parseJSON str) <|> (fail $ "Invalid Rule: " <> show str) data LocalVar = FnArg Var | FnWhere Var | FnLocal Var deriving (Eq) @@ -344,15 +374,17 @@ data Violation = | FnUseBlocked String FunctionRule | FnSigBlocked String String FunctionRule | InfiniteRecursionDetected InfiniteRecursionRule + | WhereClauseViolationDetected String [String] WhereClauseRule | NoViolation deriving (Eq) instance Show Violation where show violation = case violation of - (ArgTypeBlocked typ exprTy ruleFnName rule) -> "Use of '" <> ruleFnName <> "' on '" <> typ <> "' is not allowed in the overall expression type '" <> exprTy <> "'." - (FnBlockedInArg (fnName, typ) ruleFnName _ rule) -> "Use of '" <> fnName <> "' on type '" <> typ <> "' inside argument of '" <> ruleFnName <> "' is not allowed." - (FnUseBlocked ruleFnName rule) -> "Use of '" <> ruleFnName <> "' in the code is not allowed." - (FnSigBlocked ruleFnName ruleFnSig rule) -> "Use of '" <> ruleFnName <> "' with signature '" <> ruleFnSig <> "' is not allowed in the code." - (NonIndexedDBColumn colName tableName _) -> "Querying on non-indexed column '" <> colName <> "' of table '" <> (tableName) <> "' is not allowed." - (InfiniteRecursionDetected _) -> "Infinite recursion detected in expression" - NoViolation -> "NoViolation" \ No newline at end of file + (ArgTypeBlocked typ exprTy ruleFnName rule) -> "Use of '" <> ruleFnName <> "' on '" <> typ <> "' is not allowed in the overall expression type '" <> exprTy <> "'." + (FnBlockedInArg (fnName, typ) ruleFnName _ rule) -> "Use of '" <> fnName <> "' on type '" <> typ <> "' inside argument of '" <> ruleFnName <> "' is not allowed." + (FnUseBlocked ruleFnName rule) -> "Use of '" <> ruleFnName <> "' in the code is not allowed." + (FnSigBlocked ruleFnName ruleFnSig rule) -> "Use of '" <> ruleFnName <> "' with signature '" <> ruleFnSig <> "' is not allowed in the code." + (NonIndexedDBColumn colName tableName _) -> "Querying on non-indexed column '" <> colName <> "' of table '" <> (tableName) <> "' is not allowed." + (InfiniteRecursionDetected _) -> "Infinite recursion detected in expression" + (WhereClauseViolationDetected tableName colNames _) -> "Where clause rule violation: Missing mandatory querying fields on table '" <> tableName <> "' for column '" <> show colNames <> "'." + NoViolation -> "NoViolation" \ No newline at end of file diff --git a/sheriff/src/Sheriff/TypesUtils.hs b/sheriff/src/Sheriff/TypesUtils.hs index 50f5555..f7a55e5 100644 --- a/sheriff/src/Sheriff/TypesUtils.hs +++ b/sheriff/src/Sheriff/TypesUtils.hs @@ -38,6 +38,7 @@ getViolationSuggestions v = case v of FnSigBlocked _ _ r -> fn_rule_fixes r NonIndexedDBColumn _ _ r -> db_rule_fixes r InfiniteRecursionDetected r -> infinite_recursion_rule_fixes r + WhereClauseViolationDetected _ _ r -> where_clause_rule_fixes r NoViolation -> [] getViolationType :: Violation -> String @@ -48,6 +49,7 @@ getViolationType v = case v of FnSigBlocked _ _ _ -> "FnSigBlocked" NonIndexedDBColumn _ _ _ -> "NonIndexedDBColumn" InfiniteRecursionDetected _ -> "InfiniteRecursionDetected" + WhereClauseViolationDetected _ _ _-> "WhereClauseRule" NoViolation -> "NoViolation" getViolationRule :: Violation -> Rule @@ -58,6 +60,7 @@ getViolationRule v = case v of FnSigBlocked _ _ r -> FunctionRuleT r NonIndexedDBColumn _ _ r -> DBRuleT r InfiniteRecursionDetected r -> InfiniteRecursionRuleT r + WhereClauseViolationDetected _ _ r -> WhereClauseRuleT r NoViolation -> defaultRule getViolationRuleName :: Violation -> String @@ -68,6 +71,7 @@ getViolationRuleName v = case v of FnSigBlocked _ _ r -> fn_rule_name r NonIndexedDBColumn _ _ r -> db_rule_name r InfiniteRecursionDetected r -> infinite_recursion_rule_name r + WhereClauseViolationDetected _ _ r -> where_clause_rule_name r NoViolation -> "NA" getViolationRuleExceptions :: Violation -> Rules From 87cb239de8e7390fdb9e803d5409c97e6ecc593f Mon Sep 17 00:00:00 2001 From: Sailaja Date: Fri, 27 Jun 2025 15:36:32 +0530 Subject: [PATCH 03/15] local --- sheriff/src/Sheriff/Plugin.hs | 77 +++++++++++++++++++++++++++-------- 1 file changed, 61 insertions(+), 16 deletions(-) diff --git a/sheriff/src/Sheriff/Plugin.hs b/sheriff/src/Sheriff/Plugin.hs index 9f8b3eb..331d3db 100644 --- a/sheriff/src/Sheriff/Plugin.hs +++ b/sheriff/src/Sheriff/Plugin.hs @@ -481,8 +481,9 @@ checkAndApplyRule ruleT ap = case ruleT of (L _ (PatExplicitList (TyConApp ty [_, tblName]) exprs)) -> do case (showS ty == "Clause") of True -> do - let (fnName, _) = fromMaybe (error "No function name found", []) $ getFnNameWithAllArgs ap - validateWhereClauseRule rule (showS tblName) exprs (showSDocUnsafe $ ppr fnName) + let fnLocatedVar = fromMaybe (error "No function name found") $ getFnName ap + fnName = getLocatedVarNameWithModuleName fnLocatedVar + validateWhereClauseRule rule (showS tblName) exprs fnName False -> pure [] _ -> pure [] FunctionRuleT rule@(FunctionRule {fn_name = ruleFnNames, arg_no}) -> do @@ -1024,62 +1025,106 @@ getFnNameAndTypeableExprWithAllArgs _ = Nothing -- TODO: Verify the correctness of this function before moving it to utils -- Get function name with all it's arguments +getFnName :: LHsExpr GhcTc -> Maybe (Located Var) +getFnName expr = case expr of + L loc (HsVar _ v) -> + trace ("getFnName: Detected HsVar with name = " <> showS v) $ + Just (getLocated v loc) + + L _ (HsConLikeOut _ cl) -> + trace "getFnName: Detected HsConLikeOut" $ + (\clId -> noExprLoc clId) <$> conLikeWrapId cl + + L _ (HsAppType _ expr _) -> + trace "getFnName: Detected HsAppType" $ + getFnName expr + + L _ (HsApp _ (L loc (HsVar _ v)) _) -> + trace ("getFnName: Detected HsApp with HsVar function name = " <> showS v) $ + Just (getLocated v loc) + + L _ (HsPar _ expr) -> + trace "getFnName: Detected HsPar" $ + getFnName expr + + L _ (HsApp _ funl _) -> + trace "getFnName: Detected HsApp with nested function application" $ + getFnName funl + + L loc (OpApp _ funl op _) -> + trace ("getFnName: Detected OpApp with operator = " <> showS op) $ + case showS op of + "($)" -> + trace "getFnName: Detected ($) operator, treating as HsApp" $ + getFnName (L loc (HsApp noExtFieldOrAnn funl funl)) + _ -> trace "getFnName: Unsupported operator, returning Nothing" Nothing + + L loc ap@(PatHsWrap _ expr) -> + trace "getFnName: Detected PatHsWrap" $ + getFnName (L loc expr) + +#if __GLASGOW_HASKELL__ >= 900 + L loc ap@(PatHsExpansion orig expanded) -> + trace "getFnName: Detected PatHsExpansion" $ + case (orig, expanded) of + ((OpApp _ _ op _), (HsApp _ (L _ (HsApp _ op' funl)) _)) -> + case showS op of + "($)" -> + trace "getFnName: Detected ($) operator in PatHsExpansion, treating as HsApp" $ + getFnName (L loc (HsApp noExtFieldOrAnn funl funl)) + _ -> + trace "getFnName: Unsupported operator in PatHsExpansion, processing expanded expression" $ + getFnName (L loc expanded) + _ -> + trace "getFnName: Processing expanded expression in PatHsExpansion" $ + getFnName (L loc expanded) +#endif + + _ -> + trace "getFnName: No matching case found, returning Nothing" Nothing + getFnNameWithAllArgs :: LHsExpr GhcTc -> Maybe (Located Var, [LHsExpr GhcTc]) getFnNameWithAllArgs expr = case expr of L loc (HsVar _ v) -> - trace ("getFnNameWithAllArgs: Detected HsVar with name = " <> showS v) $ Just (getLocated v loc, []) L _ (HsConLikeOut _ cl) -> - trace "getFnNameWithAllArgs: Detected HsConLikeOut" $ (\clId -> (noExprLoc clId, [])) <$> conLikeWrapId cl L _ (HsAppType _ expr _) -> - trace "getFnNameWithAllArgs: Detected HsAppType" $ getFnNameWithAllArgs expr L _ (HsApp _ (L loc (HsVar _ v)) funr) -> - trace ("getFnNameWithAllArgs: Detected HsApp with HsVar function name = " <> showS v) $ Just (getLocated v loc, [funr]) L _ (HsPar _ expr) -> - trace "getFnNameWithAllArgs: Detected HsPar" $ getFnNameWithAllArgs expr L _ (HsApp _ funl funr) -> - trace "getFnNameWithAllArgs: Detected HsApp with nested function application" $ case getFnNameWithAllArgs funl of Nothing -> trace "getFnNameWithAllArgs: No function name found in nested HsApp" Nothing Just (fnName, ls) -> - trace ("getFnNameWithAllArgs: Found function name = " <> showS fnName <> " with arguments = " <> showS ls) $ Just (fnName, ls ++ [funr]) L loc (OpApp _ funl op funr) -> - trace ("getFnNameWithAllArgs: Detected OpApp with operator = " <> showS op) $ case showS op of "($)" -> - trace "getFnNameWithAllArgs: Detected ($) operator, treating as HsApp" $ getFnNameWithAllArgs (L loc (HsApp noExtFieldOrAnn funl funr)) _ -> trace "getFnNameWithAllArgs: Unsupported operator, returning Nothing" Nothing L loc ap@(PatHsWrap _ expr) -> - trace "getFnNameWithAllArgs: Detected PatHsWrap" $ getFnNameWithAllArgs (L loc expr) #if __GLASGOW_HASKELL__ >= 900 L loc ap@(PatHsExpansion orig expanded) -> - trace "getFnNameWithAllArgs: Detected PatHsExpansion" $ case (orig, expanded) of ((OpApp _ _ op _), (HsApp _ (L _ (HsApp _ op' funl)) funr)) -> case showS op of "($)" -> - trace "getFnNameWithAllArgs: Detected ($) operator in PatHsExpansion, treating as HsApp" $ getFnNameWithAllArgs (L loc (HsApp noExtFieldOrAnn funl funr)) _ -> - trace "getFnNameWithAllArgs: Unsupported operator in PatHsExpansion, processing expanded expression" $ getFnNameWithAllArgs (L loc expanded) _ -> - trace "getFnNameWithAllArgs: Processing expanded expression in PatHsExpansion" $ getFnNameWithAllArgs (L loc expanded) #endif From ab859539a773ec6dd576796763e288e01e2393f9 Mon Sep 17 00:00:00 2001 From: Sailaja Date: Fri, 27 Jun 2025 16:16:15 +0530 Subject: [PATCH 04/15] local --- sheriff/src/Sheriff/Plugin.hs | 17 ++++++++++------- 1 file changed, 10 insertions(+), 7 deletions(-) diff --git a/sheriff/src/Sheriff/Plugin.hs b/sheriff/src/Sheriff/Plugin.hs index 331d3db..37011fe 100644 --- a/sheriff/src/Sheriff/Plugin.hs +++ b/sheriff/src/Sheriff/Plugin.hs @@ -167,7 +167,7 @@ sheriff opts modSummary tcEnv = do infRule = case find isInfiniteRecursionRule globalRules of Just (InfiniteRecursionRuleT r) -> r _ -> defaultInfiniteRecursionRuleT - liftIO $ putStrLn $ "finalSheriffRules: " <> show finalSheriffRules + -- liftIO $ putStrLn $ "finalSheriffRules: " <> show finalSheriffRules when logDebugInfo $ liftIO $ print globalRules when logDebugInfo $ liftIO $ print globalExceptionRules @@ -479,12 +479,15 @@ checkAndApplyRule ruleT ap = case ruleT of WhereClauseRuleT rule -> case ap of (L _ (PatExplicitList (TyConApp ty [_, tblName]) exprs)) -> do - case (showS ty == "Clause") of - True -> do - let fnLocatedVar = fromMaybe (error "No function name found") $ getFnName ap - fnName = getLocatedVarNameWithModuleName fnLocatedVar - validateWhereClauseRule rule (showS tblName) exprs fnName - False -> pure [] + if showS ty == "Clause" + then case getFnName ap of + Nothing -> do + liftIO $ putStrLn "No function name found, skipping processing." + pure [] -- Handle the absence of a function name gracefully + Just fnLocatedVar -> do + let fnName = getLocatedVarNameWithModuleName fnLocatedVar + validateWhereClauseRule rule (showS tblName) exprs fnName + else pure [] _ -> pure [] FunctionRuleT rule@(FunctionRule {fn_name = ruleFnNames, arg_no}) -> do let res = getFnNameWithAllArgs ap From 486238f53e6bb3a4a95ccadfff18ced67273f844 Mon Sep 17 00:00:00 2001 From: Sailaja Date: Fri, 27 Jun 2025 16:37:54 +0530 Subject: [PATCH 05/15] local --- sheriff/src/Sheriff/Plugin.hs | 19 ++++++++++--------- 1 file changed, 10 insertions(+), 9 deletions(-) diff --git a/sheriff/src/Sheriff/Plugin.hs b/sheriff/src/Sheriff/Plugin.hs index 37011fe..b45861f 100644 --- a/sheriff/src/Sheriff/Plugin.hs +++ b/sheriff/src/Sheriff/Plugin.hs @@ -479,15 +479,16 @@ checkAndApplyRule ruleT ap = case ruleT of WhereClauseRuleT rule -> case ap of (L _ (PatExplicitList (TyConApp ty [_, tblName]) exprs)) -> do - if showS ty == "Clause" - then case getFnName ap of - Nothing -> do - liftIO $ putStrLn "No function name found, skipping processing." - pure [] -- Handle the absence of a function name gracefully - Just fnLocatedVar -> do - let fnName = getLocatedVarNameWithModuleName fnLocatedVar - validateWhereClauseRule rule (showS tblName) exprs fnName - else pure [] + if showS ty == "Clause" + then case getFnName ap of + Nothing -> do + liftIO $ putStrLn "No function name found, skipping processing." + liftIO $ putStrLn $ "Debug: Raw expression = " <> showSDocUnsafe (ppr ap) + pure [] -- Handle the absence of a function name gracefully + Just fnLocatedVar -> do + let fnName = getLocatedVarNameWithModuleName fnLocatedVar + validateWhereClauseRule rule (showS tblName) exprs fnName + else pure [] _ -> pure [] FunctionRuleT rule@(FunctionRule {fn_name = ruleFnNames, arg_no}) -> do let res = getFnNameWithAllArgs ap From 7e74a18e10e548476edb7c739d589521161b9e65 Mon Sep 17 00:00:00 2001 From: Sailaja Date: Fri, 27 Jun 2025 16:49:00 +0530 Subject: [PATCH 06/15] local --- sheriff/src/Sheriff/Plugin.hs | 66 +++-------------------------------- 1 file changed, 4 insertions(+), 62 deletions(-) diff --git a/sheriff/src/Sheriff/Plugin.hs b/sheriff/src/Sheriff/Plugin.hs index b45861f..4b84ee9 100644 --- a/sheriff/src/Sheriff/Plugin.hs +++ b/sheriff/src/Sheriff/Plugin.hs @@ -480,14 +480,14 @@ checkAndApplyRule ruleT ap = case ruleT of case ap of (L _ (PatExplicitList (TyConApp ty [_, tblName]) exprs)) -> do if showS ty == "Clause" - then case getFnName ap of + then case getFnNameWithAllArgs ap of + Just (fnLocatedVar,_) -> do + let fnName = getLocatedVarNameWithModuleName fnLocatedVar + validateWhereClauseRule rule (showS tblName) exprs fnName Nothing -> do liftIO $ putStrLn "No function name found, skipping processing." liftIO $ putStrLn $ "Debug: Raw expression = " <> showSDocUnsafe (ppr ap) pure [] -- Handle the absence of a function name gracefully - Just fnLocatedVar -> do - let fnName = getLocatedVarNameWithModuleName fnLocatedVar - validateWhereClauseRule rule (showS tblName) exprs fnName else pure [] _ -> pure [] FunctionRuleT rule@(FunctionRule {fn_name = ruleFnNames, arg_no}) -> do @@ -1029,64 +1029,6 @@ getFnNameAndTypeableExprWithAllArgs _ = Nothing -- TODO: Verify the correctness of this function before moving it to utils -- Get function name with all it's arguments -getFnName :: LHsExpr GhcTc -> Maybe (Located Var) -getFnName expr = case expr of - L loc (HsVar _ v) -> - trace ("getFnName: Detected HsVar with name = " <> showS v) $ - Just (getLocated v loc) - - L _ (HsConLikeOut _ cl) -> - trace "getFnName: Detected HsConLikeOut" $ - (\clId -> noExprLoc clId) <$> conLikeWrapId cl - - L _ (HsAppType _ expr _) -> - trace "getFnName: Detected HsAppType" $ - getFnName expr - - L _ (HsApp _ (L loc (HsVar _ v)) _) -> - trace ("getFnName: Detected HsApp with HsVar function name = " <> showS v) $ - Just (getLocated v loc) - - L _ (HsPar _ expr) -> - trace "getFnName: Detected HsPar" $ - getFnName expr - - L _ (HsApp _ funl _) -> - trace "getFnName: Detected HsApp with nested function application" $ - getFnName funl - - L loc (OpApp _ funl op _) -> - trace ("getFnName: Detected OpApp with operator = " <> showS op) $ - case showS op of - "($)" -> - trace "getFnName: Detected ($) operator, treating as HsApp" $ - getFnName (L loc (HsApp noExtFieldOrAnn funl funl)) - _ -> trace "getFnName: Unsupported operator, returning Nothing" Nothing - - L loc ap@(PatHsWrap _ expr) -> - trace "getFnName: Detected PatHsWrap" $ - getFnName (L loc expr) - -#if __GLASGOW_HASKELL__ >= 900 - L loc ap@(PatHsExpansion orig expanded) -> - trace "getFnName: Detected PatHsExpansion" $ - case (orig, expanded) of - ((OpApp _ _ op _), (HsApp _ (L _ (HsApp _ op' funl)) _)) -> - case showS op of - "($)" -> - trace "getFnName: Detected ($) operator in PatHsExpansion, treating as HsApp" $ - getFnName (L loc (HsApp noExtFieldOrAnn funl funl)) - _ -> - trace "getFnName: Unsupported operator in PatHsExpansion, processing expanded expression" $ - getFnName (L loc expanded) - _ -> - trace "getFnName: Processing expanded expression in PatHsExpansion" $ - getFnName (L loc expanded) -#endif - - _ -> - trace "getFnName: No matching case found, returning Nothing" Nothing - getFnNameWithAllArgs :: LHsExpr GhcTc -> Maybe (Located Var, [LHsExpr GhcTc]) getFnNameWithAllArgs expr = case expr of L loc (HsVar _ v) -> From 40e748817f8e2d34c22308d75dcb67a338da377e Mon Sep 17 00:00:00 2001 From: Sailaja Date: Fri, 27 Jun 2025 17:18:24 +0530 Subject: [PATCH 07/15] local --- sheriff/src/Sheriff/Plugin.hs | 15 ++++++++++++++- 1 file changed, 14 insertions(+), 1 deletion(-) diff --git a/sheriff/src/Sheriff/Plugin.hs b/sheriff/src/Sheriff/Plugin.hs index 4b84ee9..2b9a4f8 100644 --- a/sheriff/src/Sheriff/Plugin.hs +++ b/sheriff/src/Sheriff/Plugin.hs @@ -1032,51 +1032,64 @@ getFnNameAndTypeableExprWithAllArgs _ = Nothing getFnNameWithAllArgs :: LHsExpr GhcTc -> Maybe (Located Var, [LHsExpr GhcTc]) getFnNameWithAllArgs expr = case expr of L loc (HsVar _ v) -> + trace ("getFnNameWithAllArgs: Detected HsVar with name = " <> showS v) $ Just (getLocated v loc, []) L _ (HsConLikeOut _ cl) -> + trace "getFnNameWithAllArgs: Detected HsConLikeOut" $ (\clId -> (noExprLoc clId, [])) <$> conLikeWrapId cl L _ (HsAppType _ expr _) -> + trace "getFnNameWithAllArgs: Detected HsAppType" $ getFnNameWithAllArgs expr L _ (HsApp _ (L loc (HsVar _ v)) funr) -> + trace ("getFnNameWithAllArgs: Detected HsApp with HsVar function name = " <> showS v) $ Just (getLocated v loc, [funr]) L _ (HsPar _ expr) -> + trace "getFnNameWithAllArgs: Detected HsPar" $ getFnNameWithAllArgs expr L _ (HsApp _ funl funr) -> + trace "getFnNameWithAllArgs: Detected HsApp with nested function application" $ case getFnNameWithAllArgs funl of Nothing -> trace "getFnNameWithAllArgs: No function name found in nested HsApp" Nothing Just (fnName, ls) -> + trace ("getFnNameWithAllArgs: Found function name = " <> showS fnName <> " with arguments = " <> showS ls) $ Just (fnName, ls ++ [funr]) L loc (OpApp _ funl op funr) -> + trace ("getFnNameWithAllArgs: Detected OpApp with operator = " <> showS op) $ case showS op of "($)" -> + trace "getFnNameWithAllArgs: Detected ($) operator, treating as HsApp" $ getFnNameWithAllArgs (L loc (HsApp noExtFieldOrAnn funl funr)) _ -> trace "getFnNameWithAllArgs: Unsupported operator, returning Nothing" Nothing L loc ap@(PatHsWrap _ expr) -> + trace "getFnNameWithAllArgs: Detected PatHsWrap" $ getFnNameWithAllArgs (L loc expr) #if __GLASGOW_HASKELL__ >= 900 L loc ap@(PatHsExpansion orig expanded) -> + trace "getFnNameWithAllArgs: Detected PatHsExpansion" $ case (orig, expanded) of ((OpApp _ _ op _), (HsApp _ (L _ (HsApp _ op' funl)) funr)) -> case showS op of "($)" -> + trace "getFnNameWithAllArgs: Detected ($) operator in PatHsExpansion, treating as HsApp" $ getFnNameWithAllArgs (L loc (HsApp noExtFieldOrAnn funl funr)) _ -> + trace "getFnNameWithAllArgs: Unsupported operator in PatHsExpansion, processing expanded expression" $ getFnNameWithAllArgs (L loc expanded) _ -> + trace "getFnNameWithAllArgs: Processing expanded expression in PatHsExpansion" $ getFnNameWithAllArgs (L loc expanded) #endif _ -> trace "getFnNameWithAllArgs: No matching case found, returning Nothing" Nothing - --------------------------- Sheriff Plugin Utils --------------------------- -- Transform the FnBlockedInArg Violation with correct expression trfViolationErrorInfo :: (HasPluginOpts PluginOpts) => Violation -> LHsExpr GhcTc -> LHsExpr GhcTc -> TcM Violation From 076418e08b2e6c3d4b668457becea1208be07dd0 Mon Sep 17 00:00:00 2001 From: Sailaja Date: Fri, 27 Jun 2025 18:15:53 +0530 Subject: [PATCH 08/15] local --- sheriff/src/Sheriff/Plugin.hs | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/sheriff/src/Sheriff/Plugin.hs b/sheriff/src/Sheriff/Plugin.hs index 2b9a4f8..a528064 100644 --- a/sheriff/src/Sheriff/Plugin.hs +++ b/sheriff/src/Sheriff/Plugin.hs @@ -1045,7 +1045,7 @@ getFnNameWithAllArgs expr = case expr of L _ (HsApp _ (L loc (HsVar _ v)) funr) -> trace ("getFnNameWithAllArgs: Detected HsApp with HsVar function name = " <> showS v) $ - Just (getLocated v loc, [funr]) + Nothing L _ (HsPar _ expr) -> trace "getFnNameWithAllArgs: Detected HsPar" $ @@ -1057,7 +1057,7 @@ getFnNameWithAllArgs expr = case expr of Nothing -> trace "getFnNameWithAllArgs: No function name found in nested HsApp" Nothing Just (fnName, ls) -> trace ("getFnNameWithAllArgs: Found function name = " <> showS fnName <> " with arguments = " <> showS ls) $ - Just (fnName, ls ++ [funr]) + Nothing L loc (OpApp _ funl op funr) -> trace ("getFnNameWithAllArgs: Detected OpApp with operator = " <> showS op) $ From ec718035a68e7d13fb3388cf66f521f3ce1d9bb6 Mon Sep 17 00:00:00 2001 From: Sailaja Date: Fri, 27 Jun 2025 18:33:11 +0530 Subject: [PATCH 09/15] local --- sheriff/src/Sheriff/Plugin.hs | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/sheriff/src/Sheriff/Plugin.hs b/sheriff/src/Sheriff/Plugin.hs index a528064..2b9a4f8 100644 --- a/sheriff/src/Sheriff/Plugin.hs +++ b/sheriff/src/Sheriff/Plugin.hs @@ -1045,7 +1045,7 @@ getFnNameWithAllArgs expr = case expr of L _ (HsApp _ (L loc (HsVar _ v)) funr) -> trace ("getFnNameWithAllArgs: Detected HsApp with HsVar function name = " <> showS v) $ - Nothing + Just (getLocated v loc, [funr]) L _ (HsPar _ expr) -> trace "getFnNameWithAllArgs: Detected HsPar" $ @@ -1057,7 +1057,7 @@ getFnNameWithAllArgs expr = case expr of Nothing -> trace "getFnNameWithAllArgs: No function name found in nested HsApp" Nothing Just (fnName, ls) -> trace ("getFnNameWithAllArgs: Found function name = " <> showS fnName <> " with arguments = " <> showS ls) $ - Nothing + Just (fnName, ls ++ [funr]) L loc (OpApp _ funl op funr) -> trace ("getFnNameWithAllArgs: Detected OpApp with operator = " <> showS op) $ From 12718a748849a9fd58ad504b06d1f2add4dbb280 Mon Sep 17 00:00:00 2001 From: Sailaja Date: Mon, 30 Jun 2025 10:28:12 +0530 Subject: [PATCH 10/15] local --- sheriff/src/Sheriff/Plugin.hs | 23 ++++++++++++++--------- 1 file changed, 14 insertions(+), 9 deletions(-) diff --git a/sheriff/src/Sheriff/Plugin.hs b/sheriff/src/Sheriff/Plugin.hs index 2b9a4f8..dfedddd 100644 --- a/sheriff/src/Sheriff/Plugin.hs +++ b/sheriff/src/Sheriff/Plugin.hs @@ -480,15 +480,20 @@ checkAndApplyRule ruleT ap = case ruleT of case ap of (L _ (PatExplicitList (TyConApp ty [_, tblName]) exprs)) -> do if showS ty == "Clause" - then case getFnNameWithAllArgs ap of - Just (fnLocatedVar,_) -> do - let fnName = getLocatedVarNameWithModuleName fnLocatedVar - validateWhereClauseRule rule (showS tblName) exprs fnName - Nothing -> do - liftIO $ putStrLn "No function name found, skipping processing." - liftIO $ putStrLn $ "Debug: Raw expression = " <> showSDocUnsafe (ppr ap) - pure [] -- Handle the absence of a function name gracefully - else pure [] + then case exprs of + (firstExpr:_) -> do + liftIO $ putStrLn $ "Debug: Extracting first expression from list = " <> showSDocUnsafe (ppr firstExpr) + case getFnNameWithAllArgs firstExpr of + Just (fnLocatedVar, _) -> do + let fnName = getLocatedVarNameWithModuleName fnLocatedVar + validateWhereClauseRule rule (showS tblName) exprs fnName + Nothing -> do + liftIO $ putStrLn "No function name found, skipping processing." + pure [] + [] -> do + liftIO $ putStrLn "Debug: Empty expression list, skipping processing." + pure [] + else pure [] _ -> pure [] FunctionRuleT rule@(FunctionRule {fn_name = ruleFnNames, arg_no}) -> do let res = getFnNameWithAllArgs ap From 243a32f2ea090dcaa02cd654d752f4c9fae426c6 Mon Sep 17 00:00:00 2001 From: Sailaja Date: Mon, 30 Jun 2025 10:54:58 +0530 Subject: [PATCH 11/15] local --- sheriff/src/Sheriff/Plugin.hs | 15 ++++++++++++++- 1 file changed, 14 insertions(+), 1 deletion(-) diff --git a/sheriff/src/Sheriff/Plugin.hs b/sheriff/src/Sheriff/Plugin.hs index dfedddd..fab08ae 100644 --- a/sheriff/src/Sheriff/Plugin.hs +++ b/sheriff/src/Sheriff/Plugin.hs @@ -466,6 +466,18 @@ isBadExpr rules ap = pure [] isBadExprHelper :: (HasPluginOpts PluginOpts) => Rules -> LHsExpr GhcTc -> TcM [(LHsExpr GhcTc, Violation)] isBadExprHelper rules ap = concat <$> mapM (\rule -> checkAndApplyRule rule ap) rules +-- Collect all function names in the given expression +collectFnNames :: LHsExpr GhcTc -> [String] +collectFnNames expr = case expr of + L _ (HsVar _ v) -> [showS v] -- Collect the function name + L _ (HsApp _ funl funr) -> collectFnNames funl ++ collectFnNames funr -- Traverse both sides of the application + L _ (HsPar _ expr) -> collectFnNames expr -- Handle parentheses + L _ (OpApp _ funl op funr) -> collectFnNames funl ++ [showS op] ++ collectFnNames funr -- Handle operators + L _ (HsAppType _ expr _) -> collectFnNames expr -- Handle type applications + L loc (PatHsWrap _ expr) -> collectFnNames (L loc expr) -- Handle wrapped patterns + L _ (HsConLikeOut _ cl) -> [showS cl] -- Collect constructor-like outputs + _ -> [] -- For other cases, return an empty list + -- Check if a particular rule applies to given expr checkAndApplyRule :: (HasPluginOpts PluginOpts) => Rule -> LHsExpr GhcTc -> TcM ([(LHsExpr GhcTc, Violation)]) checkAndApplyRule ruleT ap = case ruleT of @@ -483,6 +495,8 @@ checkAndApplyRule ruleT ap = case ruleT of then case exprs of (firstExpr:_) -> do liftIO $ putStrLn $ "Debug: Extracting first expression from list = " <> showSDocUnsafe (ppr firstExpr) + let fnNames = collectFnNames firstExpr + liftIO $ putStrLn $ "Debug: Collected function names = " <> show fnNames case getFnNameWithAllArgs firstExpr of Just (fnLocatedVar, _) -> do let fnName = getLocatedVarNameWithModuleName fnLocatedVar @@ -507,7 +521,6 @@ checkAndApplyRule ruleT ap = case ruleT of Nothing -> pure [] InfiniteRecursionRuleT rule -> pure [] --TODO: Add handling of infinite recursion rule GeneralRuleT rule -> pure [] --TODO: Add handling of general rule - --------------------------- Function Rule Validation Logic --------------------------- {- From 7de72c9164a7f5f571d6fe591a4e3e9973164fe8 Mon Sep 17 00:00:00 2001 From: Sailaja Date: Mon, 30 Jun 2025 11:13:33 +0530 Subject: [PATCH 12/15] local --- sheriff/src/Sheriff/Plugin.hs | 50 +++++++++++++++++++++++------------ 1 file changed, 33 insertions(+), 17 deletions(-) diff --git a/sheriff/src/Sheriff/Plugin.hs b/sheriff/src/Sheriff/Plugin.hs index fab08ae..c07c3ce 100644 --- a/sheriff/src/Sheriff/Plugin.hs +++ b/sheriff/src/Sheriff/Plugin.hs @@ -469,14 +469,30 @@ isBadExprHelper rules ap = concat <$> mapM (\rule -> checkAndApplyRule rule ap) -- Collect all function names in the given expression collectFnNames :: LHsExpr GhcTc -> [String] collectFnNames expr = case expr of - L _ (HsVar _ v) -> [showS v] -- Collect the function name - L _ (HsApp _ funl funr) -> collectFnNames funl ++ collectFnNames funr -- Traverse both sides of the application - L _ (HsPar _ expr) -> collectFnNames expr -- Handle parentheses - L _ (OpApp _ funl op funr) -> collectFnNames funl ++ [showS op] ++ collectFnNames funr -- Handle operators - L _ (HsAppType _ expr _) -> collectFnNames expr -- Handle type applications - L loc (PatHsWrap _ expr) -> collectFnNames (L loc expr) -- Handle wrapped patterns - L _ (HsConLikeOut _ cl) -> [showS cl] -- Collect constructor-like outputs - _ -> [] -- For other cases, return an empty list + L _ (HsVar _ v) -> + let fnName = showS v + in trace ("collectFnNames: Detected HsVar with name = " <> fnName) [fnName] -- Log detected function name + L _ (HsApp _ funl funr) -> + trace "collectFnNames: Detected HsApp, traversing both sides" $ + collectFnNames funl ++ collectFnNames funr -- Traverse both sides of the application + L _ (HsPar _ expr) -> + trace "collectFnNames: Detected HsPar, traversing inner expression" $ + collectFnNames expr -- Handle parentheses + L _ (OpApp _ funl op funr) -> + let opName = showS op + in trace ("collectFnNames: Detected OpApp with operator = " <> opName) $ + collectFnNames funl ++ [opName] ++ collectFnNames funr -- Handle operators + L _ (HsAppType _ expr _) -> + trace "collectFnNames: Detected HsAppType, traversing inner expression" $ + collectFnNames expr -- Handle type applications + L loc (PatHsWrap _ innerExpr) -> + trace "collectFnNames: Detected PatHsWrap, traversing wrapped expression" $ + collectFnNames (L loc innerExpr) -- Handle wrapped patterns + L _ (HsConLikeOut _ cl) -> + let conName = showS cl + in trace ("collectFnNames: Detected HsConLikeOut with name = " <> conName) [conName] -- Collect constructor-like outputs + _ -> + trace "collectFnNames: No matching case found, returning empty list" [] -- For other cases, return an empty list -- Check if a particular rule applies to given expr checkAndApplyRule :: (HasPluginOpts PluginOpts) => Rule -> LHsExpr GhcTc -> TcM ([(LHsExpr GhcTc, Violation)]) @@ -1050,27 +1066,27 @@ getFnNameAndTypeableExprWithAllArgs _ = Nothing getFnNameWithAllArgs :: LHsExpr GhcTc -> Maybe (Located Var, [LHsExpr GhcTc]) getFnNameWithAllArgs expr = case expr of L loc (HsVar _ v) -> - trace ("getFnNameWithAllArgs: Detected HsVar with name = " <> showS v) $ + -- trace ("getFnNameWithAllArgs: Detected HsVar with name = " <> showS v) $ Just (getLocated v loc, []) L _ (HsConLikeOut _ cl) -> - trace "getFnNameWithAllArgs: Detected HsConLikeOut" $ + -- trace "getFnNameWithAllArgs: Detected HsConLikeOut" $ (\clId -> (noExprLoc clId, [])) <$> conLikeWrapId cl L _ (HsAppType _ expr _) -> - trace "getFnNameWithAllArgs: Detected HsAppType" $ + -- trace "getFnNameWithAllArgs: Detected HsAppType" $ getFnNameWithAllArgs expr L _ (HsApp _ (L loc (HsVar _ v)) funr) -> - trace ("getFnNameWithAllArgs: Detected HsApp with HsVar function name = " <> showS v) $ + -- trace ("getFnNameWithAllArgs: Detected HsApp with HsVar function name = " <> showS v) $ Just (getLocated v loc, [funr]) L _ (HsPar _ expr) -> - trace "getFnNameWithAllArgs: Detected HsPar" $ + -- trace "getFnNameWithAllArgs: Detected HsPar" $ getFnNameWithAllArgs expr L _ (HsApp _ funl funr) -> - trace "getFnNameWithAllArgs: Detected HsApp with nested function application" $ + -- trace "getFnNameWithAllArgs: Detected HsApp with nested function application" $ case getFnNameWithAllArgs funl of Nothing -> trace "getFnNameWithAllArgs: No function name found in nested HsApp" Nothing Just (fnName, ls) -> @@ -1078,15 +1094,15 @@ getFnNameWithAllArgs expr = case expr of Just (fnName, ls ++ [funr]) L loc (OpApp _ funl op funr) -> - trace ("getFnNameWithAllArgs: Detected OpApp with operator = " <> showS op) $ + -- trace ("getFnNameWithAllArgs: Detected OpApp with operator = " <> showS op) $ case showS op of "($)" -> - trace "getFnNameWithAllArgs: Detected ($) operator, treating as HsApp" $ + -- trace "getFnNameWithAllArgs: Detected ($) operator, treating as HsApp" $ getFnNameWithAllArgs (L loc (HsApp noExtFieldOrAnn funl funr)) _ -> trace "getFnNameWithAllArgs: Unsupported operator, returning Nothing" Nothing L loc ap@(PatHsWrap _ expr) -> - trace "getFnNameWithAllArgs: Detected PatHsWrap" $ + -- trace "getFnNameWithAllArgs: Detected PatHsWrap" $ getFnNameWithAllArgs (L loc expr) #if __GLASGOW_HASKELL__ >= 900 From 5656ed1088ada4c59ac8c73b9638f3f6d4a3c676 Mon Sep 17 00:00:00 2001 From: Sailaja Date: Mon, 30 Jun 2025 11:27:27 +0530 Subject: [PATCH 13/15] local --- sheriff/src/Sheriff/Plugin.hs | 49 +++++++++++++++++++++-------------- 1 file changed, 29 insertions(+), 20 deletions(-) diff --git a/sheriff/src/Sheriff/Plugin.hs b/sheriff/src/Sheriff/Plugin.hs index c07c3ce..acaa049 100644 --- a/sheriff/src/Sheriff/Plugin.hs +++ b/sheriff/src/Sheriff/Plugin.hs @@ -505,26 +505,35 @@ checkAndApplyRule ruleT ap = case ruleT of False -> pure [] _ -> pure [] WhereClauseRuleT rule -> - case ap of - (L _ (PatExplicitList (TyConApp ty [_, tblName]) exprs)) -> do - if showS ty == "Clause" - then case exprs of - (firstExpr:_) -> do - liftIO $ putStrLn $ "Debug: Extracting first expression from list = " <> showSDocUnsafe (ppr firstExpr) - let fnNames = collectFnNames firstExpr - liftIO $ putStrLn $ "Debug: Collected function names = " <> show fnNames - case getFnNameWithAllArgs firstExpr of - Just (fnLocatedVar, _) -> do - let fnName = getLocatedVarNameWithModuleName fnLocatedVar - validateWhereClauseRule rule (showS tblName) exprs fnName - Nothing -> do - liftIO $ putStrLn "No function name found, skipping processing." - pure [] - [] -> do - liftIO $ putStrLn "Debug: Empty expression list, skipping processing." - pure [] - else pure [] - _ -> pure [] + case ap of + (L _ (PatExplicitList (TyConApp ty [_, tblName]) exprs)) -> do + if showS ty == "Clause" then + case exprs of + [] -> do + liftIO $ putStrLn "Debug: Empty expression list, skipping processing." + pure [] + _ -> do + liftIO $ putStrLn $ "Debug: Processing all expressions in the list = " <> showSDocUnsafe (ppr exprs) + -- Iterate over all expressions in the list and collect results + results <- forM exprs $ \expr -> do + liftIO $ putStrLn $ "Debug: Extracting expression = " <> showSDocUnsafe (ppr expr) + + -- Collect function names from the current expression + let fnNames = collectFnNames expr + liftIO $ putStrLn $ "Debug: Collected function names = " <> show fnNames + + -- Process the current expression to extract the function name and validate the rule + case getFnNameWithAllArgs expr of + Just (fnLocatedVar, _) -> do + let fnName = getLocatedVarNameWithModuleName fnLocatedVar + liftIO $ putStrLn $ "Debug: Extracted function name = " <> fnName + validateWhereClauseRule rule (showS tblName) exprs fnName + Nothing -> do + liftIO $ putStrLn "No function name found, skipping processing." + pure [] -- Return an empty list for this expression + pure (concat results) -- Combine all results into a single list + else pure [] + _ -> pure [] FunctionRuleT rule@(FunctionRule {fn_name = ruleFnNames, arg_no}) -> do let res = getFnNameWithAllArgs ap case res of From 10d1b94cbfdf21306a0f01410094595171a3aab3 Mon Sep 17 00:00:00 2001 From: Sailaja Date: Mon, 30 Jun 2025 11:56:58 +0530 Subject: [PATCH 14/15] local --- sheriff/src/Sheriff/Plugin.hs | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/sheriff/src/Sheriff/Plugin.hs b/sheriff/src/Sheriff/Plugin.hs index acaa049..0e8d054 100644 --- a/sheriff/src/Sheriff/Plugin.hs +++ b/sheriff/src/Sheriff/Plugin.hs @@ -519,11 +519,11 @@ checkAndApplyRule ruleT ap = case ruleT of liftIO $ putStrLn $ "Debug: Extracting expression = " <> showSDocUnsafe (ppr expr) -- Collect function names from the current expression - let fnNames = collectFnNames expr + let fnNames = collectFnNames ap liftIO $ putStrLn $ "Debug: Collected function names = " <> show fnNames -- Process the current expression to extract the function name and validate the rule - case getFnNameWithAllArgs expr of + case getFnNameWithAllArgs ap of Just (fnLocatedVar, _) -> do let fnName = getLocatedVarNameWithModuleName fnLocatedVar liftIO $ putStrLn $ "Debug: Extracted function name = " <> fnName From 5c72b7b7036a0a4bf2e322aacc46c8002fa671c9 Mon Sep 17 00:00:00 2001 From: Sailaja Date: Mon, 30 Jun 2025 12:26:57 +0530 Subject: [PATCH 15/15] local --- sheriff/src/Sheriff/Plugin.hs | 35 +++++++++-------------------------- 1 file changed, 9 insertions(+), 26 deletions(-) diff --git a/sheriff/src/Sheriff/Plugin.hs b/sheriff/src/Sheriff/Plugin.hs index 0e8d054..8ca99ce 100644 --- a/sheriff/src/Sheriff/Plugin.hs +++ b/sheriff/src/Sheriff/Plugin.hs @@ -507,32 +507,15 @@ checkAndApplyRule ruleT ap = case ruleT of WhereClauseRuleT rule -> case ap of (L _ (PatExplicitList (TyConApp ty [_, tblName]) exprs)) -> do - if showS ty == "Clause" then - case exprs of - [] -> do - liftIO $ putStrLn "Debug: Empty expression list, skipping processing." - pure [] - _ -> do - liftIO $ putStrLn $ "Debug: Processing all expressions in the list = " <> showSDocUnsafe (ppr exprs) - -- Iterate over all expressions in the list and collect results - results <- forM exprs $ \expr -> do - liftIO $ putStrLn $ "Debug: Extracting expression = " <> showSDocUnsafe (ppr expr) - - -- Collect function names from the current expression - let fnNames = collectFnNames ap - liftIO $ putStrLn $ "Debug: Collected function names = " <> show fnNames - - -- Process the current expression to extract the function name and validate the rule - case getFnNameWithAllArgs ap of - Just (fnLocatedVar, _) -> do - let fnName = getLocatedVarNameWithModuleName fnLocatedVar - liftIO $ putStrLn $ "Debug: Extracted function name = " <> fnName - validateWhereClauseRule rule (showS tblName) exprs fnName - Nothing -> do - liftIO $ putStrLn "No function name found, skipping processing." - pure [] -- Return an empty list for this expression - pure (concat results) -- Combine all results into a single list - else pure [] + if showS ty == "Clause" then do + let res = getFnNameWithAllArgs ap + case res of + Nothing -> pure [] + Just (fnLocatedVar, args) -> do + let fnName = getLocatedVarNameWithModuleName fnLocatedVar + fnLHsExpr = mkLHsVar fnLocatedVar + validateWhereClauseRule rule (showS tblName) exprs fnName + else pure [] _ -> pure [] FunctionRuleT rule@(FunctionRule {fn_name = ruleFnNames, arg_no}) -> do let res = getFnNameWithAllArgs ap