Skip to content

Refactor type checker #38

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 3 commits into from
Jun 23, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 3 additions & 2 deletions cabal.project.freeze
Original file line number Diff line number Diff line change
Expand Up @@ -34,7 +34,8 @@ constraints: any.Cabal ==3.8.1.0,
distributive +semigroups +tagged,
any.dlist ==1.0,
dlist -werror,
any.exceptions ==0.10.5,
any.exceptions ==0.10.8,
exceptions +transformers-0-4,
any.extra ==1.7.16,
any.file-embed ==0.0.16.0,
any.filepath ==1.4.2.2,
Expand All @@ -54,7 +55,7 @@ constraints: any.Cabal ==3.8.1.0,
any.megaparsec ==9.6.1,
megaparsec -dev,
any.mmorph ==1.2.0,
any.mtl ==2.2.2,
any.mtl ==2.2.2 || ==2.3.1,
any.optparse-applicative ==0.18.1.0,
optparse-applicative +process,
any.os-string ==2.0.3,
Expand Down
181 changes: 87 additions & 94 deletions lib/TypeChecker/HindleyMilner.hs
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,25 @@
{-# LANGUAGE StandaloneDeriving #-}
{-# OPTIONS_GHC -Wno-orphans #-}

module TypeChecker.HindleyMilner where
module TypeChecker.HindleyMilner
( Infer,
TypeError (..),
UType,
Polytype,
applyBindings,
generalize,
toPolytype,
toUType,
withBinding,
fresh,
Poly (..),
UTerm (UTVar, UTUnit, UTBool, UTInt, UTFun),
(=:=),
lookup,
TypeF (..),
mkVarName,
)
where

import Control.Monad.Except
import Control.Monad.Reader
Expand All @@ -27,86 +45,78 @@ import Data.Set (Set, (\\))
import qualified Data.Set as S
import Data.Text (pack)
import GHC.Generics (Generic1)
import Trees.Common (Identifier, Type (..))
import qualified Trees.Common as L -- Lang
import Prelude hiding (lookup)

data HType a
= TyVarF Identifier
| TyUnitF
| TyBoolF
| TyIntF
| TyFunF a a
deriving (Show, Eq, Functor, Foldable, Traversable, Generic1, Unifiable)

type TypeF = Fix HType

type UType = UTerm HType IntVar
-- * Type

data Poly t = Forall [Identifier] t
deriving (Eq, Show, Functor)
type Type = Fix TypeF

type Polytype = Poly TypeF
data TypeF a
= TVarF L.Identifier
| TUnitF
| TBoolF
| TIntF
| TFunF a a
deriving (Show, Eq, Functor, Foldable, Traversable, Generic1, Unifiable)

type UPolytype = Poly UType
-- * UType

-- TypeF
type UType = UTerm TypeF IntVar

pattern TyVar :: Identifier -> TypeF
pattern TyVar v = Fix (TyVarF v)
pattern UTVar :: L.Identifier -> UType
pattern UTVar var = UTerm (TVarF var)

pattern TyUnit :: TypeF
pattern TyUnit = Fix TyUnitF
pattern UTUnit :: UType
pattern UTUnit = UTerm TUnitF

pattern TyBool :: TypeF
pattern TyBool = Fix TyBoolF
pattern UTBool :: UType
pattern UTBool = UTerm TBoolF

pattern TyInt :: TypeF
pattern TyInt = Fix TyIntF
pattern UTInt :: UType
pattern UTInt = UTerm TIntF

pattern TyFun :: TypeF -> TypeF -> TypeF
pattern TyFun t1 t2 = Fix (TyFunF t1 t2)
pattern UTFun :: UType -> UType -> UType
pattern UTFun funT argT = UTerm (TFunF funT argT)

-- UType
-- * Polytype

pattern UTyVar :: Identifier -> UType
pattern UTyVar v = UTerm (TyVarF v)
data Poly t = Forall [L.Identifier] t
deriving (Eq, Show, Functor)

pattern UTyUnit :: UType
pattern UTyUnit = UTerm TyUnitF
type Polytype = Poly Type

pattern UTyBool :: UType
pattern UTyBool = UTerm TyBoolF
type UPolytype = Poly UType

pattern UTyInt :: UType
pattern UTyInt = UTerm TyIntF
-- * Converters

pattern UTyFun :: UType -> UType -> UType
pattern UTyFun t1 t2 = UTerm (TyFunF t1 t2)
toUType :: L.Type -> UType
toUType = \case
L.TUnit -> UTUnit
L.TBool -> UTBool
L.TInt -> UTInt
L.TFun funT argT -> UTFun (toUType funT) (toUType argT)

toTypeF :: Type -> TypeF
toTypeF = \case
TUnit -> TyUnit
TBool -> TyBool
TInt -> TyInt
TFun t1 t2 -> TyFun (toTypeF t1) (toTypeF t2)
toPolytype :: UPolytype -> Polytype
toPolytype = fmap (fromJust . freeze)

fromTypeToUType :: Type -> UType
fromTypeToUType = \case
TUnit -> UTyUnit
TBool -> UTyBool
TInt -> UTyInt
TFun t1 t2 -> UTyFun (fromTypeToUType t1) (fromTypeToUType t2)
-- * Infer

type Infer = ReaderT Ctx (ExceptT TypeError (IntBindingT HType Identity))
type Infer = ReaderT Ctx (ExceptT TypeError (IntBindingT TypeF Identity))

type Ctx = Map Identifier UPolytype
type Ctx = Map L.Identifier UPolytype

lookup :: LookUpType -> Infer UType
lookup (Var v) = do
ctx <- ask
maybe (throwError $ UnboundVar v) instantiate (M.lookup v ctx)
lookup :: L.Identifier -> Infer UType
lookup var = do
varUPT <- asks $ M.lookup var
maybe (throwError $ UnboundVar var) instantiate varUPT
where
instantiate :: UPolytype -> Infer UType
instantiate (Forall xs uty) = do
xs' <- mapM (const fresh) xs
return $ substU (M.fromList (zip (map Left xs) xs')) uty

withBinding :: (MonadReader Ctx m) => Identifier -> UPolytype -> m a -> m a
withBinding :: (MonadReader Ctx m) => L.Identifier -> UPolytype -> m a -> m a
withBinding x ty = local (M.insert x ty)

ucata :: (Functor t) => (v -> a) -> (t a -> a) -> UTerm t v -> a
Expand All @@ -115,16 +125,18 @@ ucata f g (UTerm t) = g (fmap (ucata f g) t)

deriving instance Ord IntVar

-- * FreeVars

class FreeVars a where
freeVars :: a -> Infer (Set (Either Identifier IntVar))
freeVars :: a -> Infer (Set (Either L.Identifier IntVar))

instance FreeVars UType where
freeVars ut = do
fuvs <- fmap (S.fromList . map Right) . lift . lift $ getFreeVars ut
let ftvs =
ucata
(const S.empty)
(\case TyVarF x -> S.singleton (Left x); f -> fold f)
(\case TVarF x -> S.singleton (Left x); f -> fold f)
ut
return $ fuvs `S.union` ftvs

Expand All @@ -134,67 +146,48 @@ instance FreeVars UPolytype where
instance FreeVars Ctx where
freeVars = fmap S.unions . mapM freeVars . M.elems

newtype LookUpType = Var Identifier
fresh :: Infer UType
fresh = UVar <$> lift (lift freeVar)

-- * Errors

data TypeError where
Unreachable :: TypeError
UnboundVar :: Identifier -> TypeError
UnboundVar :: L.Identifier -> TypeError
Infinite :: IntVar -> UType -> TypeError
ImpossibleBinOpApplication :: UType -> UType -> TypeError
ImpossibleUnOpApplication :: UType -> TypeError
Mismatch :: HType UType -> HType UType -> TypeError
Mismatch :: TypeF UType -> TypeF UType -> TypeError
deriving (Show)

instance Fallible HType IntVar TypeError where
instance Fallible TypeF IntVar TypeError where
occursFailure = Infinite
mismatchFailure = Mismatch

fresh :: Infer UType
fresh = UVar <$> lift (lift freeVar)

(=:=) :: UType -> UType -> Infer UType
s =:= t = lift $ s U.=:= t

applyBindings :: UType -> Infer UType
applyBindings = lift . U.applyBindings

instantiate :: UPolytype -> Infer UType
instantiate (Forall xs uty) = do
xs' <- mapM (const fresh) xs
return $ substU (M.fromList (zip (map Left xs) xs')) uty

substU :: Map (Either Identifier IntVar) UType -> UType -> UType
substU :: Map (Either L.Identifier IntVar) UType -> UType -> UType
substU m =
ucata
(\v -> fromMaybe (UVar v) (M.lookup (Right v) m))
( \case
TyVarF v -> fromMaybe (UTyVar v) (M.lookup (Left v) m)
TVarF v -> fromMaybe (UTVar v) (M.lookup (Left v) m)
f -> UTerm f
)

skolemize :: UPolytype -> Infer UType
skolemize (Forall xs uty) = do
xs' <- mapM (const fresh) xs
return $ substU (M.fromList (zip (map Left xs) (map toSkolem xs'))) uty
where
toSkolem (UVar v) = UTyVar (mkVarName "s" v)
toSkolem _ = undefined -- We can't reach another situation, because we previously give `fresh` variable

mkVarName :: String -> IntVar -> Identifier
mkVarName nm (IntVar v) = pack (nm ++ show (v + (maxBound :: Int) + 1))
mkVarName :: String -> IntVar -> L.Identifier
mkVarName nm (IntVar v) = pack (nm <> show (v + (maxBound :: Int) + 1))

generalize :: UType -> Infer UPolytype
generalize uty = do
uty' <- applyBindings uty
ctx <- ask
tmfvs <- freeVars uty'
ctxfvs <- freeVars ctx
let fvs = S.toList $ tmfvs \\ ctxfvs
tmFreeVars <- freeVars uty'
ctxFreeVars <- freeVars ctx
let fvs = S.toList $ tmFreeVars \\ ctxFreeVars
xs = map (either id (mkVarName "a")) fvs
return $ Forall xs (substU (M.fromList (zip fvs (map UTyVar xs))) uty')

toUPolytype :: Polytype -> UPolytype
toUPolytype = fmap unfreeze

fromUPolytype :: UPolytype -> Polytype
fromUPolytype = fmap (fromJust . freeze)
return $ Forall xs (substU (M.fromList (zip fvs (map UTVar xs))) uty')
12 changes: 6 additions & 6 deletions lib/TypeChecker/PrettyPrinter.hs
Original file line number Diff line number Diff line change
Expand Up @@ -23,12 +23,12 @@ type Prec = Int
instance (Pretty (t (Fix t))) => Pretty (Fix t) where
prettyPrec p = prettyPrec p . unFix

instance (Pretty t) => Pretty (HType t) where
prettyPrec _ (TyVarF x) = unpack x
prettyPrec _ TyUnitF = "unit"
prettyPrec _ TyBoolF = "bool"
prettyPrec _ TyIntF = "int"
prettyPrec p (TyFunF ty1 ty2) =
instance (Pretty t) => Pretty (TypeF t) where
prettyPrec _ (TVarF x) = unpack x
prettyPrec _ TUnitF = "unit"
prettyPrec _ TBoolF = "bool"
prettyPrec _ TIntF = "int"
prettyPrec p (TFunF ty1 ty2) =
mparens (p > 0) $ prettyPrec 1 ty1 ++ " -> " ++ prettyPrec 0 ty2

instance (Pretty (t (UTerm t v)), Pretty v) => Pretty (UTerm t v) where
Expand Down
Loading