diff --git a/cabal.project.freeze b/cabal.project.freeze index ba7ba67..edc8d7f 100644 --- a/cabal.project.freeze +++ b/cabal.project.freeze @@ -34,7 +34,8 @@ constraints: any.Cabal ==3.8.1.0, distributive +semigroups +tagged, any.dlist ==1.0, dlist -werror, - any.exceptions ==0.10.5, + any.exceptions ==0.10.8, + exceptions +transformers-0-4, any.extra ==1.7.16, any.file-embed ==0.0.16.0, any.filepath ==1.4.2.2, @@ -54,7 +55,7 @@ constraints: any.Cabal ==3.8.1.0, any.megaparsec ==9.6.1, megaparsec -dev, any.mmorph ==1.2.0, - any.mtl ==2.2.2, + any.mtl ==2.2.2 || ==2.3.1, any.optparse-applicative ==0.18.1.0, optparse-applicative +process, any.os-string ==2.0.3, diff --git a/lib/TypeChecker/HindleyMilner.hs b/lib/TypeChecker/HindleyMilner.hs index 6bb6218..980b081 100644 --- a/lib/TypeChecker/HindleyMilner.hs +++ b/lib/TypeChecker/HindleyMilner.hs @@ -10,7 +10,25 @@ {-# LANGUAGE StandaloneDeriving #-} {-# OPTIONS_GHC -Wno-orphans #-} -module TypeChecker.HindleyMilner where +module TypeChecker.HindleyMilner + ( Infer, + TypeError (..), + UType, + Polytype, + applyBindings, + generalize, + toPolytype, + toUType, + withBinding, + fresh, + Poly (..), + UTerm (UTVar, UTUnit, UTBool, UTInt, UTFun), + (=:=), + lookup, + TypeF (..), + mkVarName, + ) +where import Control.Monad.Except import Control.Monad.Reader @@ -27,86 +45,78 @@ import Data.Set (Set, (\\)) import qualified Data.Set as S import Data.Text (pack) import GHC.Generics (Generic1) -import Trees.Common (Identifier, Type (..)) +import qualified Trees.Common as L -- Lang import Prelude hiding (lookup) -data HType a - = TyVarF Identifier - | TyUnitF - | TyBoolF - | TyIntF - | TyFunF a a - deriving (Show, Eq, Functor, Foldable, Traversable, Generic1, Unifiable) - -type TypeF = Fix HType - -type UType = UTerm HType IntVar +-- * Type -data Poly t = Forall [Identifier] t - deriving (Eq, Show, Functor) +type Type = Fix TypeF -type Polytype = Poly TypeF +data TypeF a + = TVarF L.Identifier + | TUnitF + | TBoolF + | TIntF + | TFunF a a + deriving (Show, Eq, Functor, Foldable, Traversable, Generic1, Unifiable) -type UPolytype = Poly UType +-- * UType --- TypeF +type UType = UTerm TypeF IntVar -pattern TyVar :: Identifier -> TypeF -pattern TyVar v = Fix (TyVarF v) +pattern UTVar :: L.Identifier -> UType +pattern UTVar var = UTerm (TVarF var) -pattern TyUnit :: TypeF -pattern TyUnit = Fix TyUnitF +pattern UTUnit :: UType +pattern UTUnit = UTerm TUnitF -pattern TyBool :: TypeF -pattern TyBool = Fix TyBoolF +pattern UTBool :: UType +pattern UTBool = UTerm TBoolF -pattern TyInt :: TypeF -pattern TyInt = Fix TyIntF +pattern UTInt :: UType +pattern UTInt = UTerm TIntF -pattern TyFun :: TypeF -> TypeF -> TypeF -pattern TyFun t1 t2 = Fix (TyFunF t1 t2) +pattern UTFun :: UType -> UType -> UType +pattern UTFun funT argT = UTerm (TFunF funT argT) --- UType +-- * Polytype -pattern UTyVar :: Identifier -> UType -pattern UTyVar v = UTerm (TyVarF v) +data Poly t = Forall [L.Identifier] t + deriving (Eq, Show, Functor) -pattern UTyUnit :: UType -pattern UTyUnit = UTerm TyUnitF +type Polytype = Poly Type -pattern UTyBool :: UType -pattern UTyBool = UTerm TyBoolF +type UPolytype = Poly UType -pattern UTyInt :: UType -pattern UTyInt = UTerm TyIntF +-- * Converters -pattern UTyFun :: UType -> UType -> UType -pattern UTyFun t1 t2 = UTerm (TyFunF t1 t2) +toUType :: L.Type -> UType +toUType = \case + L.TUnit -> UTUnit + L.TBool -> UTBool + L.TInt -> UTInt + L.TFun funT argT -> UTFun (toUType funT) (toUType argT) -toTypeF :: Type -> TypeF -toTypeF = \case - TUnit -> TyUnit - TBool -> TyBool - TInt -> TyInt - TFun t1 t2 -> TyFun (toTypeF t1) (toTypeF t2) +toPolytype :: UPolytype -> Polytype +toPolytype = fmap (fromJust . freeze) -fromTypeToUType :: Type -> UType -fromTypeToUType = \case - TUnit -> UTyUnit - TBool -> UTyBool - TInt -> UTyInt - TFun t1 t2 -> UTyFun (fromTypeToUType t1) (fromTypeToUType t2) +-- * Infer -type Infer = ReaderT Ctx (ExceptT TypeError (IntBindingT HType Identity)) +type Infer = ReaderT Ctx (ExceptT TypeError (IntBindingT TypeF Identity)) -type Ctx = Map Identifier UPolytype +type Ctx = Map L.Identifier UPolytype -lookup :: LookUpType -> Infer UType -lookup (Var v) = do - ctx <- ask - maybe (throwError $ UnboundVar v) instantiate (M.lookup v ctx) +lookup :: L.Identifier -> Infer UType +lookup var = do + varUPT <- asks $ M.lookup var + maybe (throwError $ UnboundVar var) instantiate varUPT + where + instantiate :: UPolytype -> Infer UType + instantiate (Forall xs uty) = do + xs' <- mapM (const fresh) xs + return $ substU (M.fromList (zip (map Left xs) xs')) uty -withBinding :: (MonadReader Ctx m) => Identifier -> UPolytype -> m a -> m a +withBinding :: (MonadReader Ctx m) => L.Identifier -> UPolytype -> m a -> m a withBinding x ty = local (M.insert x ty) ucata :: (Functor t) => (v -> a) -> (t a -> a) -> UTerm t v -> a @@ -115,8 +125,10 @@ ucata f g (UTerm t) = g (fmap (ucata f g) t) deriving instance Ord IntVar +-- * FreeVars + class FreeVars a where - freeVars :: a -> Infer (Set (Either Identifier IntVar)) + freeVars :: a -> Infer (Set (Either L.Identifier IntVar)) instance FreeVars UType where freeVars ut = do @@ -124,7 +136,7 @@ instance FreeVars UType where let ftvs = ucata (const S.empty) - (\case TyVarF x -> S.singleton (Left x); f -> fold f) + (\case TVarF x -> S.singleton (Left x); f -> fold f) ut return $ fuvs `S.union` ftvs @@ -134,67 +146,48 @@ instance FreeVars UPolytype where instance FreeVars Ctx where freeVars = fmap S.unions . mapM freeVars . M.elems -newtype LookUpType = Var Identifier +fresh :: Infer UType +fresh = UVar <$> lift (lift freeVar) + +-- * Errors data TypeError where Unreachable :: TypeError - UnboundVar :: Identifier -> TypeError + UnboundVar :: L.Identifier -> TypeError Infinite :: IntVar -> UType -> TypeError ImpossibleBinOpApplication :: UType -> UType -> TypeError ImpossibleUnOpApplication :: UType -> TypeError - Mismatch :: HType UType -> HType UType -> TypeError + Mismatch :: TypeF UType -> TypeF UType -> TypeError deriving (Show) -instance Fallible HType IntVar TypeError where +instance Fallible TypeF IntVar TypeError where occursFailure = Infinite mismatchFailure = Mismatch -fresh :: Infer UType -fresh = UVar <$> lift (lift freeVar) - (=:=) :: UType -> UType -> Infer UType s =:= t = lift $ s U.=:= t applyBindings :: UType -> Infer UType applyBindings = lift . U.applyBindings -instantiate :: UPolytype -> Infer UType -instantiate (Forall xs uty) = do - xs' <- mapM (const fresh) xs - return $ substU (M.fromList (zip (map Left xs) xs')) uty - -substU :: Map (Either Identifier IntVar) UType -> UType -> UType +substU :: Map (Either L.Identifier IntVar) UType -> UType -> UType substU m = ucata (\v -> fromMaybe (UVar v) (M.lookup (Right v) m)) ( \case - TyVarF v -> fromMaybe (UTyVar v) (M.lookup (Left v) m) + TVarF v -> fromMaybe (UTVar v) (M.lookup (Left v) m) f -> UTerm f ) -skolemize :: UPolytype -> Infer UType -skolemize (Forall xs uty) = do - xs' <- mapM (const fresh) xs - return $ substU (M.fromList (zip (map Left xs) (map toSkolem xs'))) uty - where - toSkolem (UVar v) = UTyVar (mkVarName "s" v) - toSkolem _ = undefined -- We can't reach another situation, because we previously give `fresh` variable - -mkVarName :: String -> IntVar -> Identifier -mkVarName nm (IntVar v) = pack (nm ++ show (v + (maxBound :: Int) + 1)) +mkVarName :: String -> IntVar -> L.Identifier +mkVarName nm (IntVar v) = pack (nm <> show (v + (maxBound :: Int) + 1)) generalize :: UType -> Infer UPolytype generalize uty = do uty' <- applyBindings uty ctx <- ask - tmfvs <- freeVars uty' - ctxfvs <- freeVars ctx - let fvs = S.toList $ tmfvs \\ ctxfvs + tmFreeVars <- freeVars uty' + ctxFreeVars <- freeVars ctx + let fvs = S.toList $ tmFreeVars \\ ctxFreeVars xs = map (either id (mkVarName "a")) fvs - return $ Forall xs (substU (M.fromList (zip fvs (map UTyVar xs))) uty') - -toUPolytype :: Polytype -> UPolytype -toUPolytype = fmap unfreeze - -fromUPolytype :: UPolytype -> Polytype -fromUPolytype = fmap (fromJust . freeze) + return $ Forall xs (substU (M.fromList (zip fvs (map UTVar xs))) uty') diff --git a/lib/TypeChecker/PrettyPrinter.hs b/lib/TypeChecker/PrettyPrinter.hs index c88e355..5efccb9 100644 --- a/lib/TypeChecker/PrettyPrinter.hs +++ b/lib/TypeChecker/PrettyPrinter.hs @@ -23,12 +23,12 @@ type Prec = Int instance (Pretty (t (Fix t))) => Pretty (Fix t) where prettyPrec p = prettyPrec p . unFix -instance (Pretty t) => Pretty (HType t) where - prettyPrec _ (TyVarF x) = unpack x - prettyPrec _ TyUnitF = "unit" - prettyPrec _ TyBoolF = "bool" - prettyPrec _ TyIntF = "int" - prettyPrec p (TyFunF ty1 ty2) = +instance (Pretty t) => Pretty (TypeF t) where + prettyPrec _ (TVarF x) = unpack x + prettyPrec _ TUnitF = "unit" + prettyPrec _ TBoolF = "bool" + prettyPrec _ TIntF = "int" + prettyPrec p (TFunF ty1 ty2) = mparens (p > 0) $ prettyPrec 1 ty1 ++ " -> " ++ prettyPrec 0 ty2 instance (Pretty (t (UTerm t v)), Pretty v) => Pretty (UTerm t v) where diff --git a/lib/TypeChecker/TypeChecker.hs b/lib/TypeChecker/TypeChecker.hs index 5e5e24b..e8496c1 100644 --- a/lib/TypeChecker/TypeChecker.hs +++ b/lib/TypeChecker/TypeChecker.hs @@ -14,6 +14,7 @@ import Data.Maybe import Parser.Ast import qualified StdLib import Trees.Common +import qualified Trees.Common as L import TypeChecker.HindleyMilner import Prelude hiding (lookup) @@ -28,14 +29,14 @@ inferProgram (Program stmts) = runInfer $ withStdLib (inferStatements stmts) runInfer :: Infer UType -> Either TypeError Polytype runInfer = (>>= applyBindings) - >>> (>>= (generalize >>> fmap fromUPolytype)) + >>> (>>= (generalize >>> fmap toPolytype)) >>> flip runReaderT M.empty >>> runExceptT >>> evalIntBindingT >>> runIdentity withStdLib infer = do - let generalizeDecl (ident, t) = (ident,) <$> generalize (fromTypeToUType t) + let generalizeDecl (ident, t) = (ident,) <$> generalize (toUType t) generalizedDecls <- mapM generalizeDecl StdLib.typedDecls local (M.union (M.fromList generalizedDecls)) infer @@ -45,102 +46,95 @@ inferStatements :: [Statement] -> Infer UType inferStatements x = inferStatements' x (throwError Unreachable) inferStatements' :: [Statement] -> Infer UType -> Infer UType -inferStatements' [] pr = pr -inferStatements' ((StmtExpr e) : xs) _ = do +inferStatements' [] t = t +inferStatements' ((StmtExpr e) : stmts) _ = do res <- inferExpr e - inferStatements' xs (return res) -inferStatements' ((StmtDecl (DeclVar (ident, t) body)) : xs) _ = do - res <- inferExpr body - vType <- maybe (return res) ((=:=) res <$> fromTypeToUType) t - pvType <- generalize vType - withBinding ident pvType (inferStatements' xs $ return vType) -inferStatements' ((StmtDecl (DeclFun ident False fun)) : xs) _ = do - res <- inferFun fun - withBinding ident (Forall [] res) (inferStatements' xs $ return res) -inferStatements' ((StmtDecl (DeclFun ident True fun)) : xs) _ = do - preT <- fresh - next <- withBinding ident (Forall [] preT) $ inferFun fun - after <- withBinding ident (Forall [] next) $ inferFun fun - withBinding ident (Forall [] after) (inferStatements' xs $ return next) + inferStatements' stmts (return res) +inferStatements' ((StmtDecl (DeclVar (ident, t) val)) : stmts) _ = do + t' <- inferExpr val + t'' <- checkByAnnotation t' t + upt <- generalize t'' + withBinding ident upt $ inferStatements' stmts (return t'') +inferStatements' ((StmtDecl (DeclFun ident isRec fun)) : stmts) _ = do + funT <- + if isRec + then do + funT <- fresh + funT' <- withBinding ident (Forall [] funT) $ inferFun fun + withBinding ident (Forall [] funT') $ inferFun fun + else inferFun fun + funUT <- generalize funT + withBinding ident funUT $ inferStatements' stmts (return funT) inferExpr :: Expression -> Infer UType -inferExpr (ExprId x) = lookup (Var x) +inferExpr (ExprId ident) = lookup ident inferExpr (ExprPrimVal value) = case value of - PrimValUnit -> return UTyUnit - PrimValBool _ -> return UTyBool - PrimValInt _ -> return UTyInt + PrimValUnit -> return UTUnit + PrimValBool _ -> return UTBool + PrimValInt _ -> return UTInt inferExpr (ExprBinOp op lhs rhs) = do - utLhs <- inferExpr lhs - utRhs <- inferExpr rhs - withError (const $ ImpossibleBinOpApplication utLhs utRhs) $ do - ut <- utLhs =:= utRhs + lhsT <- inferExpr lhs + rhsT <- inferExpr rhs + withError (const $ ImpossibleBinOpApplication lhsT rhsT) $ do + valT <- lhsT =:= rhsT case op of - BoolOp _ -> ut =:= UTyBool - ArithOp _ -> ut =:= UTyInt - CompOp _ -> return UTyBool -inferExpr (ExprUnOp op x) = do - ut <- inferExpr x - withError (const $ ImpossibleUnOpApplication ut) $ case op of - UnMinusOp -> ut =:= UTyInt + BoolOp _ -> valT =:= UTBool + ArithOp _ -> valT =:= UTInt + CompOp _ -> return UTBool +inferExpr (ExprUnOp op val) = do + valT <- inferExpr val + withError (const $ ImpossibleUnOpApplication valT) $ case op of + UnMinusOp -> valT =:= UTInt inferExpr (ExprApp funExpr argExpr) = do - funUT <- inferExpr funExpr - argUT <- inferExpr argExpr - resUT <- fresh - _ <- funUT =:= UTyFun argUT resUT - return resUT + funT <- inferExpr funExpr + argT <- inferExpr argExpr + resT <- fresh + _ <- funT =:= UTFun argT resT + return resT inferExpr (ExprIte c t e) = do - _ <- check c UTyBool - t' <- inferExpr t - e' <- inferExpr e - t' =:= e' + _ <- check c UTBool + tT <- inferExpr t + eT <- inferExpr e + tT =:= eT inferExpr (ExprLetIn decl expr) = inferLetIn decl expr inferExpr (ExprFun fun) = inferFun fun inferLetIn :: Declaration -> Expression -> Infer UType -inferLetIn (DeclVar (x, Just pty) xdef) expr = do - let upty = toUPolytype (Forall [] $ toTypeF pty) - upty' <- skolemize upty - bl <- inferExpr xdef - _ <- bl =:= upty' - withBinding x upty $ inferExpr expr -inferLetIn (DeclVar (x, Nothing) xdef) expr = do - ty <- inferExpr xdef - pty <- generalize ty - withBinding x pty $ inferExpr expr -inferLetIn (DeclFun f False fun) expr = do - fdef <- inferFun fun - pfdef <- generalize fdef - withBinding f pfdef $ inferExpr expr -inferLetIn (DeclFun f True fun) expr = do - preT <- fresh - next <- withBinding f (Forall [] preT) $ inferFun fun - after <- withBinding f (Forall [] next) $ inferFun fun - inferredBlock <- withBinding f (Forall [] next) (inferExpr expr) - pfdef <- generalize after - withBinding f pfdef (return inferredBlock) +inferLetIn (DeclVar (ident, t) val) expr = do + t' <- inferExpr val + t'' <- checkByAnnotation t' t + upt <- generalize t'' + withBinding ident upt $ inferExpr expr +inferLetIn (DeclFun ident isRec fun) expr = do + funT <- + if isRec + then do + funT <- fresh + funT' <- withBinding ident (Forall [] funT) $ inferFun fun + withBinding ident (Forall [] funT') $ inferFun fun + else inferFun fun + funUT <- generalize funT + withBinding ident funUT $ inferExpr expr inferFun :: Fun -> Infer UType -inferFun (Fun args restype body) = inferFun' $ toList args +inferFun (Fun params resT body) = inferFun' $ toList params where - inferFun' args' = case args' of + inferFun' params' = case params' of [] -> do - inferredBody <- inferExpr body - case restype of - Just t -> fromTypeToUType t =:= inferredBody - Nothing -> return inferredBody - (ident, t) : ys -> do - t' <- maybe fresh (return . fromTypeToUType) t - withBinding ident (Forall [] t') $ UTyFun t' <$> inferFun' ys + bodyT <- inferExpr body + checkByAnnotation bodyT resT + (ident, t) : params'' -> do + t' <- maybe fresh (return . toUType) t + withBinding ident (Forall [] t') $ UTFun t' <$> inferFun' params'' --- Utils +-- ** Utils check :: Expression -> UType -> Infer UType -check e ty = do - ty' <- inferExpr e - ty =:= ty' - -withError :: (MonadError e m) => (e -> e) -> m a -> m a -withError f action = tryError action >>= either (throwError . f) pure - -tryError :: (MonadError e m) => m a -> m (Either e a) -tryError action = (Right <$> action) `catchError` (pure . Left) +check expr t = do + exprT <- inferExpr expr + t =:= exprT + +checkByAnnotation :: UType -> Maybe L.Type -> Infer UType +checkByAnnotation t ann = case ann of + Just annT -> toUType annT =:= t + Nothing -> return t diff --git a/miniml.cabal b/miniml.cabal index 4ed7fba..1d4a74d 100644 --- a/miniml.cabal +++ b/miniml.cabal @@ -62,7 +62,7 @@ library , file-embed , llvm-codegen , megaparsec >=9.2 - , mtl >=2.2.2 + , mtl >=2.3.0 , parser-combinators >=1.3.0 , process , recursion-schemes diff --git a/test/Unit/TypeInference/TypeInferenceTest.hs b/test/Unit/TypeInference/TypeInferenceTest.hs index 169c2ee..1ab0868 100644 --- a/test/Unit/TypeInference/TypeInferenceTest.hs +++ b/test/Unit/TypeInference/TypeInferenceTest.hs @@ -24,6 +24,7 @@ tests = test11, test12, test13, + test14, testredecalaration, testfib, testrec, @@ -175,6 +176,16 @@ test13 = expected @=? actual +test14 :: TestTree +test14 = + testCase + "[let f x = x;; print_int(f 4);; print_bool(f true);; f]" + $ do + let expected = "forall a7. a7 -> a7" + let actual = processTillTypeChecker "let f x = x;; print_int(f 4);; print_bool(f true);; f" + + expected @=? actual + testredecalaration :: TestTree testredecalaration = testCase