From cd8b208bb62076a08f11276fee45124b5537009c Mon Sep 17 00:00:00 2001 From: Dougal Date: Thu, 22 Jun 2023 14:26:07 -0400 Subject: [PATCH 1/4] Remove type/effect annotations from `Block`. Blocks are now just a bunch of decls followed by an atom. The type and effects information is carried by the context -- e.g. a pi type or the other arguments to a built-in HOF. The reason is that we're about to add decls to binders and this way we can treat binder-decls block-decls uniformly. --- src/lib/Algebra.hs | 2 +- src/lib/Builder.hs | 103 ++++++------------- src/lib/CheapReduction.hs | 50 ++++----- src/lib/CheckType.hs | 109 +++++++++----------- src/lib/Core.hs | 4 +- src/lib/Export.hs | 9 +- src/lib/Imp.hs | 51 ++++----- src/lib/ImpToLLVM.hs | 2 +- src/lib/Inference.hs | 58 +++++++---- src/lib/Inline.hs | 30 ++---- src/lib/Linearize.hs | 19 ++-- src/lib/Lower.hs | 46 +++++---- src/lib/OccAnalysis.hs | 34 +++--- src/lib/Optimize.hs | 85 ++++++++------- src/lib/PPrint.hs | 24 ++--- src/lib/QueryType.hs | 104 ++++++++++--------- src/lib/QueryTypePure.hs | 21 ++-- src/lib/Simplify.hs | 165 ++++++++++++++++-------------- src/lib/TopLevel.hs | 53 ++++------ src/lib/Transpose.hs | 14 ++- src/lib/Types/Core.hs | 89 +++++++--------- src/lib/Vectorize.hs | 19 ++-- tests/uexpr-tests.dx | 2 +- tests/unit/ConstantCastingSpec.hs | 3 +- tests/unit/JaxADTSpec.hs | 3 +- tests/unit/OccAnalysisSpec.hs | 8 +- 26 files changed, 517 insertions(+), 590 deletions(-) diff --git a/src/lib/Algebra.hs b/src/lib/Algebra.hs index b526eeace..65491714e 100644 --- a/src/lib/Algebra.hs +++ b/src/lib/Algebra.hs @@ -137,7 +137,7 @@ type BlockTraverserM i o a = SubstReaderT PolySubstVal (MaybeT1 (BuilderM SimpIR blockAsPoly :: (EnvExtender m, EnvReader m) => Block SimpIR n -> m n (Maybe (Polynomial n)) -blockAsPoly (Block _ decls result) = +blockAsPoly (Abs decls result) = liftBuilder $ runMaybeT1 $ runSubstReaderT idSubst $ blockAsPolyRec decls result blockAsPolyRec :: Nest (Decl SimpIR) i i' -> Atom SimpIR i' -> BlockTraverserM i o (Polynomial o) diff --git a/src/lib/Builder.hs b/src/lib/Builder.hs index dd4225b1e..b3012cc90 100644 --- a/src/lib/Builder.hs +++ b/src/lib/Builder.hs @@ -15,8 +15,9 @@ import Control.Monad.Reader import Control.Monad.Writer.Strict hiding (Alt) import Control.Monad.State.Strict (MonadState (..), StateT (..), runStateT) import qualified Data.Map.Strict as M +import Data.Foldable (fold) import Data.Graph (graphFromEdges, topSort) -import Data.Text.Prettyprint.Doc (Pretty (..), group, line, nest) +import Data.Text.Prettyprint.Doc (Pretty (..)) import Foreign.Ptr import qualified Unsafe.Coerce as TrulyUnsafe @@ -28,7 +29,6 @@ import IRVariants import MTL1 import Subst import Name -import PPrint (prettyBlock) import QueryType import Types.Core import Types.Imp @@ -88,11 +88,11 @@ emitUnOp op x = emitOp $ UnOp op x {-# INLINE emitUnOp #-} emitBlock :: (Builder r m, Emits n) => Block r n -> m n (Atom r n) -emitBlock (Block _ decls result) = emitDecls decls result +emitBlock = emitDecls emitDecls :: (Builder r m, Emits n, RenameE e, SinkableE e) - => Nest (Decl r) n l -> e l -> m n (e n) -emitDecls decls result = runSubstReaderT idSubst $ emitDecls' decls result + => WithDecls r e n -> m n (e n) +emitDecls (Abs decls result) = runSubstReaderT idSubst $ emitDecls' decls result emitDecls' :: (Builder r m, Emits o, RenameE e, SinkableE e) => Nest (Decl r) i i' -> e i' -> SubstReaderT Name m i o (e o) @@ -278,10 +278,9 @@ emitTopLet hint letAnn expr = do v <- emitBinding hint $ AtomNameBinding $ LetBound (DeclBinding letAnn expr) return $ AtomVar v ty -emitTopFunBinding :: (Mut n, TopBuilder m) => NameHint -> TopFunDef n -> LamExpr SimpIR n -> m n (TopFunName n) +emitTopFunBinding :: (Mut n, TopBuilder m) => NameHint -> TopFunDef n -> STopLam n -> m n (TopFunName n) emitTopFunBinding hint def f = do - ty <- return $ getLamExprType f - emitBinding hint $ TopFunBinding $ DexTopFun def ty f Waiting + emitBinding hint $ TopFunBinding $ DexTopFun def f Waiting emitSourceMap :: TopBuilder m => SourceMap n -> m n () emitSourceMap sm = emitLocalModuleEnv $ mempty {envSourceMap = sm} @@ -334,7 +333,7 @@ extendLinearizationCache s fs = queryObjCache :: EnvReader m => TopFunName n -> m n (Maybe (FunObjCodeName n)) queryObjCache v = lookupEnv v >>= \case - TopFunBinding (DexTopFun _ _ _ (Finished impl)) -> return $ Just $ topFunObjCode impl + TopFunBinding (DexTopFun _ _ (Finished impl)) -> return $ Just $ topFunObjCode impl _ -> return Nothing emitObjFile :: (Mut n, TopBuilder m) => CFunction n -> m n (FunObjCodeName n) @@ -468,7 +467,7 @@ liftEmitBuilder cont = do Distinct <- getDistinct let (result, decls, _) = runHardFail $ unsafeRunInplaceT (runBuilderT' cont) env emptyOutFrag Emits <- fabricateEmitsEvidenceM - emitDecls (unsafeCoerceB $ unRNest decls) result + emitDecls $ Abs (unsafeCoerceB $ unRNest decls) result instance (IRRep r, Fallible m) => ScopableBuilder r (BuilderT r m) where buildScoped cont = BuilderT do @@ -601,56 +600,14 @@ buildBlock :: ScopableBuilder r m => (forall l. (Emits l, DExt n l) => m l (Atom r l)) -> m n (Block r n) -buildBlock cont = buildScoped (cont >>= withType) >>= computeAbsEffects >>= absToBlock - -withType :: ((EnvReader m, IRRep r), HasType r e) => e l -> m l ((e `PairE` Type r) l) -withType e = do - ty <- {-# SCC blockTypeNormalization #-} cheapNormalize $ getType e - return $ e `PairE` ty -{-# INLINE withType #-} - -makeBlock :: IRRep r => Nest (Decl r) n l -> EffectRow r l -> Atom r l -> Type r l -> Block r n -makeBlock decls effs atom ty = Block (BlockAnn (EffTy effs' ty')) decls atom where - ty' = ignoreHoistFailure $ hoist decls ty - effs' = ignoreHoistFailure $ hoist decls effs -{-# INLINE makeBlock #-} - -absToBlockInferringTypes :: (EnvReader m, IRRep r) => Abs (Nest (Decl r)) (Atom r) n -> m n (Block r n) -absToBlockInferringTypes ab = liftEnvReaderM do - abWithEffs <- computeAbsEffects ab - refreshAbs abWithEffs \decls (effs `PairE` result) -> do - ty <- cheapNormalize $ getType result - return $ ignoreExcept $ - absToBlock $ Abs decls (effs `PairE` (result `PairE` ty)) -{-# INLINE absToBlockInferringTypes #-} - -absToBlock - :: (Fallible m, IRRep r) - => Abs (Nest (Decl r)) (EffectRow r `PairE` (Atom r `PairE` Type r)) n -> m (Block r n) -absToBlock (Abs decls (effs `PairE` (result `PairE` ty))) = do - let msg = "Block:" <> nest 1 (prettyBlock decls result) <> line - <> group ("Of type:" <> nest 2 (line <> pretty ty)) <> line - <> group ("With effects:" <> nest 2 (line <> pretty effs)) - ty' <- liftHoistExcept' (docAsStr msg) $ hoist decls ty - effs' <- liftHoistExcept' (docAsStr msg) $ hoist decls effs - return $ Block (BlockAnn (EffTy effs' ty')) decls result -{-# INLINE absToBlock #-} - -makeBlockFromDecls :: (EnvReader m, IRRep r) => Abs (Nest (Decl r)) (Atom r) n -> m n (Block r n) -makeBlockFromDecls (Abs Empty result) = return $ AtomicBlock result -makeBlockFromDecls ab = liftEnvReaderM $ refreshAbs ab \decls result -> do - ty <- return $ getType result - effs <- declNestEffects decls - PairE ty' effs' <- return $ ignoreHoistFailure $ hoist decls $ PairE ty effs - return $ Block (BlockAnn (EffTy effs' ty')) decls result -{-# INLINE makeBlockFromDecls #-} +buildBlock = buildScoped coreLamExpr :: EnvReader m => AppExplicitness -> Abs (Nest (WithExpl CBinder)) (PairE (EffectRow CoreIR) CBlock) n -> m n (CoreLamExpr n) coreLamExpr appExpl ab = liftEnvReaderM do refreshAbs ab \bs' (PairE effs' body') -> do - resultTy <- return $ getType body' + EffTy _ resultTy <- blockEffTy body' let bs'' = fmapNest withoutExpl bs' return $ CoreLamExpr (CorePiType appExpl bs' (EffTy effs' resultTy)) (LamExpr bs'' body') @@ -736,12 +693,13 @@ buildLamExpr (Abs bs UnitE) cont = case bs of buildLamExpr rest' \vs -> cont $ sink v : vs return $ LamExpr (Nest b' bs') body' -buildLamExprFromPi +buildTopLamFromPi :: ScopableBuilder r m => PiType r n -> (forall l. (Emits l, Distinct l, DExt n l) => [AtomVar r l] -> m l (Atom r l)) - -> m n (LamExpr r n) -buildLamExprFromPi (PiType bs _) cont = buildLamExpr (EmptyAbs bs) cont + -> m n (TopLam r n) +buildTopLamFromPi piTy@(PiType bs _) cont = + TopLam False piTy <$> buildLamExpr (EmptyAbs bs) cont buildAlt :: ScopableBuilder r m @@ -789,7 +747,7 @@ buildCase' scrut resultTy indexedAltBody = do (alts, effs) <- unzip <$> forM (enumerate altBinderTys) \(i, bTy) -> do (Abs b' (body `PairE` eff')) <- buildAbs noHint bTy \x -> do blk <- buildBlock $ indexedAltBody i $ Var $ sink x - eff <- return $ getEffects blk + EffTy eff _ <- blockEffTy blk return $ blk `PairE` eff return (Abs b' body, ignoreHoistFailure $ hoist b' eff') return $ Case scrut alts $ EffTy (mconcat effs) resultTy @@ -1157,9 +1115,17 @@ mkDictAtom d = do ty <- typeOfDictExpr d return $ DictCon ty d +mkCase :: (EnvReader m, IRRep r) => Atom r n -> Type r n -> [Alt r n] -> m n (Expr r n) +mkCase scrut resultTy alts = liftEnvReaderM do + eff' <- fold <$> forM alts \alt -> refreshAbs alt \b body -> do + EffTy eff _ <- blockEffTy body + return $ ignoreHoistFailure $ hoist b eff + return $ Case scrut alts (EffTy eff' resultTy) + mkCatchException :: EnvReader m => CBlock n -> m n (Hof CoreIR n) mkCatchException body = do - resultTy <- makePreludeMaybeTy $ getType body + EffTy _ bodyTy <- blockEffTy body + resultTy <- makePreludeMaybeTy bodyTy return $ CatchException resultTy body app :: (CBuilder m, Emits n) => CAtom n -> CAtom n -> m n (CAtom n) @@ -1177,7 +1143,7 @@ naryTopAppInlined :: (Builder SimpIR m, Emits n) => TopFunName n -> [SAtom n] -> naryTopAppInlined f xs = do TopFunBinding f' <- lookupEnv f case f' of - DexTopFun _ _ (LamExpr bs body) _ -> + DexTopFun _ (TopLam _ _ (LamExpr bs body)) _ -> applySubst (bs@@>(SubstVal<$>xs)) body >>= emitBlock _ -> naryTopApp f xs {-# INLINE naryTopAppInlined #-} @@ -1237,7 +1203,7 @@ applyIxMethod dict method args = case dict of IxDictSpecialized _ d params -> do SpecializedDict _ maybeFs <- lookupSpecDict d Just fs <- return maybeFs - LamExpr bs body <- return $ fs !! fromEnum method + TopLam _ _ (LamExpr bs body) <- return $ fs !! fromEnum method emitBlock =<< applySubst (bs @@> fmap SubstVal (params ++ args)) body unsafeFromOrdinal :: (SBuilder m, Emits n) => IxType SimpIR n -> Atom SimpIR n -> m n (Atom SimpIR n) @@ -1551,15 +1517,10 @@ type ExprVisitorNoEmits2 m r = forall i o. ExprVisitorNoEmits (m i o) r i o visitLamNoEmits :: (ExprVisitorNoEmits2 m r, IRRep r, AtomSubstReader v m, EnvExtender2 m) => LamExpr r i -> m i o (LamExpr r o) -visitLamNoEmits (LamExpr bs body) = - visitBinders bs \bs' -> LamExpr bs' <$> visitBlockNoEmits body - -visitBlockNoEmits - :: (ExprVisitorNoEmits2 m r, IRRep r, AtomSubstReader v m, EnvExtender2 m) - => Block r i -> m i o (Block r o) -visitBlockNoEmits (Block _ decls result) = - absToBlockInferringTypes =<< visitDeclsNoEmits decls \decls' -> do - Abs decls' <$> visitAtom result +visitLamNoEmits (LamExpr bs (Abs decls result)) = + visitBinders bs \bs' -> LamExpr bs' <$> + visitDeclsNoEmits decls \decls' -> Abs decls' <$> do + visitAtom result visitDeclsNoEmits :: (ExprVisitorNoEmits2 m r, IRRep r, AtomSubstReader v m, EnvExtender2 m) @@ -1602,7 +1563,7 @@ visitLamEmits (LamExpr bs body) = visitBinders bs \bs' -> LamExpr bs' <$> visitBlockEmits :: (ExprVisitorEmits2 m r, SubstReader AtomSubstVal m, EnvExtender2 m, IRRep r, Emits o) => Block r i -> m i o (Atom r o) -visitBlockEmits (Block _ decls result) = visitDeclsEmits decls $ visitAtom result +visitBlockEmits (Abs decls result) = visitDeclsEmits decls $ visitAtom result visitDeclsEmits :: (ExprVisitorEmits2 m r, SubstReader AtomSubstVal m, EnvExtender2 m, IRRep r, Emits o) diff --git a/src/lib/CheapReduction.hs b/src/lib/CheapReduction.hs index fde08b7be..b301fa6b7 100644 --- a/src/lib/CheapReduction.hs +++ b/src/lib/CheapReduction.hs @@ -15,7 +15,8 @@ module CheapReduction , unwrapLeadingNewtypesType, wrapNewtypesData, liftSimpAtom, liftSimpType , liftSimpFun, makeStructRepVal, NonAtomRenamer (..), Visitor (..), VisitGeneric (..) , visitAtomPartial, visitTypePartial, visitAtomDefault, visitTypeDefault, Visitor2 - , visitBinders, visitPiDefault, visitAlt, toAtomVar) + , visitBinders, visitPiDefault, visitAlt, toAtomVar, instantiatePiTy + , bindersToVars, bindersToAtoms) where import Control.Applicative @@ -261,9 +262,6 @@ instance CheaplyReducibleE CoreIR TyConParams TyConParams where instance (CheaplyReducibleE r e e', NiceE r e') => CheaplyReducibleE r (Abs (Nest (Decl r)) e) e' where cheapReduceE (Abs decls result) = cheapReduceWithDeclsB decls $ cheapReduceE result -instance (CheaplyReducibleE r (Atom r) e', NiceE r e') => CheaplyReducibleE r (Block r) e' where - cheapReduceE (Block _ decls result) = cheapReduceE $ Abs decls result - instance IRRep r => CheaplyReducibleE r (Expr r) (Atom r) where cheapReduceE expr = confuseGHC >>= \_ -> case expr of Atom atom -> cheapReduceE atom @@ -472,6 +470,10 @@ instantiateTyConDef (TyConDef _ bs conDefs) (TyConParams _ xs) = do applySubst (bs @@> (SubstVal <$> xs)) conDefs {-# INLINE instantiateTyConDef #-} +instantiatePiTy :: (EnvReader m, IRRep r) => PiType r n -> [Atom r n] -> m n (EffTy r n) +instantiatePiTy (PiType bs effTy) xs = do + applySubst (bs @@> (SubstVal <$> xs)) effTy + -- Returns a representation type (type of an TypeCon-typed Newtype payload) -- given a list of instantiated DataConDefs. dataDefRep :: DataConDefs n -> CType n @@ -517,10 +519,10 @@ instance VisitGeneric (Type r) r where visitGeneric = visitType instance VisitGeneric (LamExpr r) r where visitGeneric = visitLam instance VisitGeneric (PiType r) r where visitGeneric = visitPi -instance VisitGeneric (Block r) r where - visitGeneric b = visitGeneric (LamExpr Empty b) >>= \case - LamExpr Empty b' -> return b' - _ -> error "not a block" +visitBlock :: Visitor m r i o => Block r i -> m (Block r o) +visitBlock b = visitGeneric (LamExpr Empty b) >>= \case + LamExpr Empty b' -> return b' + _ -> error "not a block" visitAlt :: Visitor m r i o => Alt r i -> m (Alt r o) visitAlt (Abs b body) = do @@ -615,17 +617,11 @@ instance IRRep r => VisitGeneric (Expr r) r where -- TODO: should we reuse the original effects? Whether it's valid depends on -- the type-preservation requirements for a visitor. We should clarify what -- those are. - Case x alts (EffTy _ t) -> do + Case x alts effTy -> do x' <- visitGeneric x - t' <- visitGeneric t alts' <- mapM visitAlt alts - let effs' = foldMap altEffects alts' - return $ Case x' alts' $ EffTy effs' t' - where - altEffects :: Alt r n -> EffectRow r n - altEffects (Abs bs (Block ann _ _)) = case ann of - NoBlockAnn -> Pure - BlockAnn (EffTy effs _) -> ignoreHoistFailure $ hoist bs effs + effTy' <- visitGeneric effTy + return $ Case x' alts' effTy' Atom x -> Atom <$> visitGeneric x TabCon Nothing t xs -> TabCon Nothing <$> visitGeneric t <*> mapM visitGeneric xs TabCon (Just (WhenIRE d)) t xs -> TabCon <$> (Just . WhenIRE <$> visitGeneric d) <*> visitGeneric t <*> mapM visitGeneric xs @@ -654,10 +650,10 @@ instance IRRep r => VisitGeneric (Hof r) r where RunReader x body -> RunReader <$> visitGeneric x <*> visitGeneric body RunWriter dest bm body -> RunWriter <$> mapM visitGeneric dest <*> visitGeneric bm <*> visitGeneric body RunState dest s body -> RunState <$> mapM visitGeneric dest <*> visitGeneric s <*> visitGeneric body - While b -> While <$> visitGeneric b - RunIO b -> RunIO <$> visitGeneric b - RunInit b -> RunInit <$> visitGeneric b - CatchException t b -> CatchException <$> visitType t <*> visitGeneric b + While b -> While <$> visitBlock b + RunIO b -> RunIO <$> visitBlock b + RunInit b -> RunInit <$> visitBlock b + CatchException t b -> CatchException <$> visitType t <*> visitBlock b Linearize lam x -> Linearize <$> visitGeneric lam <*> visitGeneric x Transpose lam x -> Transpose <$> visitGeneric lam <*> visitGeneric x @@ -674,7 +670,7 @@ instance IRRep r => VisitGeneric (DAMOp r) r where instance VisitGeneric UserEffectOp CoreIR where visitGeneric = \case - Handle name xs body -> Handle <$> renameN name <*> mapM visitGeneric xs <*> visitGeneric body + Handle name xs body -> Handle <$> renameN name <*> mapM visitGeneric xs <*> visitBlock body Resume t x -> Resume <$> visitGeneric t <*> visitGeneric x Perform x i -> Perform <$> visitGeneric x <*> pure i @@ -798,6 +794,15 @@ toAtomVar v = do ty <- getType <$> lookupAtomName v return $ AtomVar v ty +bindersToVars :: (EnvReader m, IRRep r) => Nest (Binder r) n' n -> m n [AtomVar r n] +bindersToVars bs = do + withExtEvidence bs do + Distinct <- getDistinct + mapM toAtomVar $ nestToNames bs + +bindersToAtoms :: (EnvReader m, IRRep r) => Nest (Binder r) n' n -> m n [Atom r n] +bindersToAtoms bs = liftM (Var <$>) $ bindersToVars bs + newtype SubstVisitor i o a = SubstVisitor { runSubstVisitor :: Reader (Env o, Subst AtomSubstVal i o) a } deriving (Functor, Applicative, Monad, MonadReader (Env o, Subst AtomSubstVal i o)) @@ -889,7 +894,6 @@ instance IRRep r => SubstE AtomSubstVal (PrimOp r) instance IRRep r => SubstE AtomSubstVal (RefOp r) instance IRRep r => SubstE AtomSubstVal (EffTy r) instance IRRep r => SubstE AtomSubstVal (Expr r) -instance IRRep r => SubstE AtomSubstVal (Block r) instance IRRep r => SubstE AtomSubstVal (GenericOpRep const r) instance SubstE AtomSubstVal InstanceBody instance SubstE AtomSubstVal DictType diff --git a/src/lib/CheckType.hs b/src/lib/CheckType.hs index 4aa800f17..f16a8b5bd 100644 --- a/src/lib/CheckType.hs +++ b/src/lib/CheckType.hs @@ -9,7 +9,7 @@ module CheckType ( CheckableE (..), CheckableB (..), checkTypes, checkTypesM, checkHasType, - checkExtends, tryGetType, isData, asFFIFunType, checkDestLam + checkExtends, tryGetType, isData, asFFIFunType, checkBlock ) where import Prelude hiding (id) @@ -52,13 +52,6 @@ checkHasType :: (EnvReader m, HasType r e) => e n -> Type r n -> m n (Except ()) checkHasType e ty = liftTyperT $ e |: ty {-# INLINE checkHasType #-} -checkDestLam :: (EnvReader m, Fallible1 m) => LamExpr SimpIR n -> m n () -checkDestLam lam = do - let allowedEffs = OneEffect InitEffect - PiType bs (EffTy effs _) <- return $ getDestLamExprType lam - let effs' = ignoreHoistFailure $ hoist bs effs - checkExtends allowedEffs effs' - -- === the type checking/querying monad === -- TODO: not clear why we need the explicit `Monad2` here since it should @@ -177,10 +170,23 @@ instance (CheckableB r b, CheckableE r e) => CheckableE r (Abs b e) where checkE (Abs b e) = checkB b \_ -> checkE e instance (IRRep r) => CheckableE r (LamExpr r) where - checkE (LamExpr bs body) = checkB bs \_ -> void $ getTypeE body + checkE (LamExpr bs body) = checkB bs \_ -> void $ checkBlock body -- === type checking core === +instance IRRep r => CheckableE r (TopLam r) where + checkE (TopLam _ piTy lam) = do + -- TODO: check destination-passing flag + checkE piTy + piTy' <- renameM piTy + piTy'' <- checkLamExpr lam + alphaEq piTy' piTy'' >>= \case + True -> return () + False -> throw TypeErr $ pprint piTy' ++ " != " ++ pprint piTy'' + +instance IRRep r => CheckableE r (PiType r) where + checkE piTy = void $ getTypeE piTy + instance IRRep r => CheckableE r (Atom r) where checkE atom = void $ getTypeE atom @@ -193,9 +199,6 @@ instance IRRep r => HasType r (AtomName r) where getType <$> lookupAtomName name' {-# INLINE getTypeE #-} -instance IRRep r => CheckableE r (Block r) where - checkE block = void $ getTypeE block - instance IRRep r => HasType r (Atom r) where getTypeE atom = case atom of Var name -> do @@ -332,25 +335,6 @@ typeCheckExpr effs expr = addContext ("Checking expr:\n" ++ pprint expr) case ex HoistFailure _ -> forM_ xs checkE return ty' -instance IRRep r => HasType r (Block r) where - getTypeE = \case - Block NoBlockAnn Empty atom -> getTypeE atom - Block (BlockAnn (EffTy effs' reqTy)) decls result -> do - effs <- renameM effs' - reqTy' <- renameM reqTy - go effs reqTy' decls result - return reqTy' - Block _ _ _ -> error "impossible" - where - go :: Typer m r => EffectRow r o -> Type r o -> Nest (Decl r) i i' -> Atom r i' -> m i o () - go _ reqTy Empty result = result |: reqTy - go effs reqTy (Nest (Let b rhs@(DeclBinding _ expr)) decls) result = do - void $ typeCheckExpr effs expr - rhs' <- renameM rhs - withFreshBinder (getNameHint b) rhs' \(b':>_) -> do - extendRenamer (b@>binderName b') do - go (sink effs) (sink reqTy) decls result - instance CheckableE CoreIR TyConParams where checkE (TyConParams _ params) = mapM_ checkE params @@ -392,6 +376,13 @@ instance HasType CoreIR CorePiType where resultTy|:TyKind return TyKind +instance IRRep r => HasType r (PiType r) where + getTypeE (PiType bs (EffTy eff resultTy)) = do + checkB bs \_ -> do + void $ checkE eff + resultTy|:TyKind + return TyKind + instance IRRep r => CheckableE r (IxType r) where checkE (IxType t _) = checkE t @@ -490,7 +481,7 @@ typeCheckPrimOp effs op = case op of MiscOp x -> typeCheckMiscOp effs x MemOp x -> typeCheckMemOp effs x DAMOp op' -> typeCheckDAMOp effs op' - UserEffectOp op' -> typeCheckUserEffect op' + UserEffectOp _ -> error "not implemented" RefOp ref m -> do TC (RefType h s) <- getTypeE ref case m of @@ -613,24 +604,6 @@ typeCheckVectorOp = \case unless (sbt == sbt') $ throw TypeErr "Scalar type mismatch" return $ RefTy heap ty' -typeCheckUserEffect :: Typer m CoreIR => UserEffectOp i -> m i o (CType o) -typeCheckUserEffect = \case - -- TODO(alex): check the argument - Resume retTy _argTy -> do - checkTypeE TyKind retTy - -- TODO(alex): actually check something here? this is a QueryType copy/paste - Handle hndName [] body -> do - hndName' <- renameM hndName - r <- getTypeE body - instantiateHandlerType hndName' r [] - -- TODO(alex): implement - Handle _ _ _ -> error "not implemented" - Perform eff i -> do - Eff (OneEffect (UserEffect effName)) <- return eff - EffectDef _ ops <- renameM effName >>= lookupEffectDef - let (_, EffectOpType _pol lamTy) = ops !! i - return lamTy - typeCheckPrimHof :: forall r m i o. (Typer m r, IRRep r) => EffectRow r o -> Hof r i -> m i o (Type r o) typeCheckPrimHof effs hof = addContext ("Checking HOF:\n" ++ pprint hof) case hof of For _ ixTy f -> do @@ -639,7 +612,7 @@ typeCheckPrimHof effs hof = addContext ("Checking HOF:\n" ++ pprint hof) case ho checkTypesEq t argTy return $ TabTy d (b:>t) eltTy While body -> do - condTy <- getTypeE body + condTy <- checkBlockWithEffs effs body checkTypesEq (BaseTy $ Scalar Word8Type) condTy return UnitTy Linearize f x -> do @@ -720,8 +693,7 @@ typeCheckDAMOp effs op = addContext ("Checking DAMOp:\n" ++ pprint op) case op o checkLamExpr :: (Typer m r, IRRep r) => LamExpr r i -> m i o (PiType r o) checkLamExpr (LamExpr bsTop body) = case bsTop of Empty -> do - resultTy <- getTypeE body - effs <- renameM $ getEffects body + EffTy effs resultTy <- checkBlock body return $ PiType Empty $ EffTy effs resultTy Nest (b:>ty) bs -> do ty' <- checkTypeE TyKind ty @@ -737,12 +709,31 @@ checkLamExprWithEffs allowedEffs lam = do checkExtends allowedEffs effs' return piTy -checkBlockWithEffs :: (Typer m r, IRRep r) => EffectRow r o -> Block r i -> m i o (Type r o) -checkBlockWithEffs allowedEffs block = do - ty <- getTypeE block - effs <- renameM $ getEffects block - checkExtends allowedEffs effs - return ty +checkBlockWithEffs :: forall i o r m. (Typer m r, IRRep r) => EffectRow r o -> Block r i -> m i o (Type r o) +checkBlockWithEffs allowedEffs (Abs decls result) = do + checkDecls allowedEffs decls \decls' -> do + resultTy <- getTypeE result + liftHoistExcept $ hoist decls' resultTy + +checkDecls + :: (Typer m r, IRRep r) + => EffectRow r o -> Decls r i i' + -> (forall o'. DExt o o' => Decls r o o' -> m i' o' a) + -> m i o a +checkDecls _ Empty cont = getDistinct >>= \Distinct -> cont Empty +checkDecls effs (Nest (Let b rhs@(DeclBinding _ expr)) decls) cont = do + void $ typeCheckExpr effs expr + rhs' <- renameM rhs + withFreshBinder (getNameHint b) rhs' \(b':>_) -> do + extendRenamer (b@>binderName b') do + let decl' = Let b' rhs' + checkDecls (sink effs) decls \decls' -> cont $ Nest decl' decls' + +checkBlock :: (Typer m r, IRRep r) => Block r i -> m i o (EffTy r o) +checkBlock block = do + EffTy effs _ <- blockEffTy =<< renameM block + ty <- checkBlockWithEffs effs block + return $ EffTy effs ty checkRWSAction :: (Typer m r, IRRep r) => EffectRow r o -> RWS -> LamExpr r i -> m i o (Type r o, Type r o) checkRWSAction effs rws f = do diff --git a/src/lib/Core.hs b/src/lib/Core.hs index c63021fe3..436aae869 100644 --- a/src/lib/Core.hs +++ b/src/lib/Core.hs @@ -423,8 +423,8 @@ getInstanceDicts name = do liftLamExpr :: (IRRep r, EnvReader m) => (forall l m2. EnvReader m2 => Block r l -> m2 l (Block r l)) - -> LamExpr r n -> m n (LamExpr r n) -liftLamExpr f (LamExpr bs body) = liftEnvReaderM $ + -> TopLam r n -> m n (TopLam r n) +liftLamExpr f (TopLam d ty (LamExpr bs body)) = liftM (TopLam d ty) $ liftEnvReaderM $ refreshAbs (Abs bs body) \bs' body' -> LamExpr bs' <$> f body' fromNaryForExpr :: IRRep r => Int -> Expr r n -> Maybe (Int, LamExpr r n) diff --git a/src/lib/Export.hs b/src/lib/Export.hs index c164d7bc6..eba7a718f 100644 --- a/src/lib/Export.hs +++ b/src/lib/Export.hs @@ -52,8 +52,8 @@ prepareFunctionForExport cc f = do HoistFailure _ -> throw TypeErr $ "Types of exported functions have to be closed terms. Got: " ++ pprint naryPi HoistSuccess s -> return s - CoreLamExpr _ f' <- liftBuilder $ buildCoreLam naryPi \xs -> naryApp (sink f) (Var <$> xs) - fSimp <- simplifyTopFunction f' + f' <- liftBuilder $ buildCoreLam naryPi \xs -> naryApp (sink f) (Var <$> xs) + fSimp <- simplifyTopFunction $ coreLamToTopLam f' fImp <- compileTopLevelFun cc fSimp nativeFun <- toCFunction "userFunc" fImp >>= emitObjFile >>= loadObject return $ ExportNativeFunction nativeFun closedSig @@ -61,9 +61,8 @@ prepareFunctionForExport cc f = do {-# SCC prepareFunctionForExport #-} prepareSLamForExport :: (Mut n, Topper m) - => CallingConvention -> SLam n -> m n ExportNativeFunction -prepareSLamForExport cc f = do - let naryPi = getLamExprType f + => CallingConvention -> STopLam n -> m n ExportNativeFunction +prepareSLamForExport cc f@(TopLam _ naryPi _) = do sig <- liftExportSigM $ simpPiToExportSig cc naryPi closedSig <- case hoistToTop sig of HoistFailure _ -> diff --git a/src/lib/Imp.hs b/src/lib/Imp.hs index b6d1d3958..1a2ab54dd 100644 --- a/src/lib/Imp.hs +++ b/src/lib/Imp.hs @@ -46,16 +46,14 @@ import Types.Imp import Types.Primitives import Util (forMFilter, Tree (..), zipTrees, enumerate) --- XXX: The LamExpr should be in destination-passing style, with its last --- argument a reference to the result. toImpFunction :: EnvReader m - => CallingConvention -> LamExpr SimpIR n -> m n (ImpFunction n) -toImpFunction cc lam = do - (LamExpr bsAndRefB body) <- return lam + => CallingConvention -> STopLam n -> m n (ImpFunction n) +toImpFunction cc (TopLam True destTy lam) = do + LamExpr bsAndRefB body <- return lam PairB bs destB <- case popNest bsAndRefB of Just bsAndRefB' -> return bsAndRefB' Nothing -> error "expected a trailing reference binder" - ty <- return $ getDestLamExprType lam + let ty = piTypeWithoutDest destTy impArgTys <- getNaryLamImpArgTypesWithCC cc ty liftImpM $ buildImpFunction cc (zip (repeat noHint) impArgTys) \vs -> do case cc of @@ -75,6 +73,7 @@ toImpFunction cc lam = do extendSubst (destB @> SubstVal (destToAtom (sink resultDest))) do void $ translateBlock body return [] +toImpFunction _ (TopLam False _ _) = error "expected a lambda in destination-passing form" getNaryLamImpArgTypesWithCC :: EnvReader m => CallingConvention -> PiType SimpIR n -> m n [BaseType] @@ -270,7 +269,7 @@ liftImpM cont = do translateBlock :: forall i o. Emits o => SBlock i -> SubstImpM i o (SAtom o) -translateBlock (Block _ decls result) = translateDeclNest decls $ substM result +translateBlock (Abs decls result) = translateDeclNest decls $ substM result translateDeclNestSubst :: Emits o => Subst AtomSubstVal l o @@ -296,7 +295,7 @@ translateExpr expr = confuseGHC >>= \_ -> case expr of f <- substM f' xs <- mapM substM xs' lookupTopFun f >>= \case - DexTopFun _ piTy _ _ -> emitCall piTy f $ toList xs + DexTopFun _ (TopLam _ piTy _) _ -> emitCall piTy f $ toList xs FFITopFun _ _ -> do scalarArgs <- liftM toList $ mapM fromScalarAtom xs results <- impCall f scalarArgs @@ -1161,12 +1160,9 @@ hoistDecls , BindsNames b, BindsEnv b, RenameB b, SinkableB b) => b n l -> SBlock l -> m n (Abs b SBlock n) hoistDecls b block = do - Abs hoistedDecls rest <- liftEnvReaderM $ - refreshAbs (Abs b block) \b' (Block _ decls result) -> + emitDecls =<< liftEnvReaderM do + refreshAbs (Abs b block) \b' (Abs decls result) -> hoistDeclsRec b' Empty decls result - ab <- emitDecls hoistedDecls rest - refreshAbs ab \b'' blockAbs' -> - Abs b'' <$> absToBlockInferringTypes blockAbs' {-# INLINE hoistDecls #-} hoistDeclsRec @@ -1409,30 +1405,27 @@ ordinalImp :: Emits n => IxType SimpIR n -> SAtom n -> SubstImpM i n (IExpr n) ordinalImp (IxType _ dict) i = fromScalarAtom =<< case dict of IxDictRawFin _ -> return i IxDictSpecialized _ d params -> do - SpecializedDict _ (Just fs) <- lookupSpecDict d - appSpecializedIxMethod (fs !! fromEnum Ordinal) (params ++ [i]) + appSpecializedIxMethod d Ordinal (params ++ [i]) unsafeFromOrdinalImp :: Emits n => IxType SimpIR n -> IExpr n -> SubstImpM i n (SAtom n) unsafeFromOrdinalImp (IxType _ dict) i = do let i' = toScalarAtom i case dict of IxDictRawFin _ -> return i' - IxDictSpecialized _ d params -> do - SpecializedDict _ (Just fs) <- lookupSpecDict d - appSpecializedIxMethod (fs !! fromEnum UnsafeFromOrdinal) (params ++ [i']) + IxDictSpecialized _ d params -> + appSpecializedIxMethod d UnsafeFromOrdinal (params ++ [i']) indexSetSizeImp :: Emits n => IxType SimpIR n -> SubstImpM i n (IExpr n) indexSetSizeImp (IxType _ dict) = do - ans <- case dict of + fromScalarAtom =<< case dict of IxDictRawFin n -> return n - IxDictSpecialized _ d params -> do - SpecializedDict _ (Just fs) <- lookupSpecDict d - appSpecializedIxMethod (fs !! fromEnum Size) (params ++ []) - fromScalarAtom ans - -appSpecializedIxMethod :: Emits n => LamExpr SimpIR n -> [SAtom n] -> SubstImpM i n (SAtom n) -appSpecializedIxMethod simpLam args = do - LamExpr bs body <- return simpLam + IxDictSpecialized _ d params -> + appSpecializedIxMethod d Size (params ++ []) + +appSpecializedIxMethod :: Emits n => SpecDictName n -> IxMethod -> [SAtom n] -> SubstImpM i n (SAtom n) +appSpecializedIxMethod d method args = do + SpecializedDict _ (Just fs) <- lookupSpecDict d + TopLam _ _ (LamExpr bs body) <- return $ fs !! fromEnum method dropSubst $ extendSubst (bs @@> map SubstVal args) $ translateBlock body -- === Abstracting link-time objects === @@ -1444,7 +1437,7 @@ abstractLinktimeObjects f = do let allVars = freeVarsE f (funVars, funTys) <- unzip <$> forMFilter (nameSetToList @TopFunNameC allVars) \v -> lookupTopFun v >>= \case - DexTopFun _ piTy _ _ -> do + DexTopFun _ (TopLam _ piTy _) _ -> do ty' <- getImpFunType StandardCC piTy return $ Just (v, ty') FFITopFun _ _ -> return Nothing @@ -1529,7 +1522,7 @@ impInstrTypes instr = case instr of DebugPrint _ _ -> return [] IQueryParallelism _ _ -> return [IIdxRepTy, IIdxRepTy] ICall f _ -> lookupTopFun f >>= \case - DexTopFun _ piTy _ _ -> do + DexTopFun _ (TopLam _ piTy _) _ -> do IFunType _ _ resultTys <- getImpFunType StandardCC piTy return resultTys FFITopFun _ (IFunType _ _ resultTys) -> return resultTys diff --git a/src/lib/ImpToLLVM.hs b/src/lib/ImpToLLVM.hs index 15474cd33..19905424b 100644 --- a/src/lib/ImpToLLVM.hs +++ b/src/lib/ImpToLLVM.hs @@ -501,7 +501,7 @@ compileInstr instr = case instr of return [] RenameOperandSubstVal v -> do lookupTopFun v >>= \case - DexTopFun _ _ _ _ -> error "Imp functions should be abstracted at this point" + DexTopFun _ _ _ -> error "Imp functions should be abstracted at this point" FFITopFun fname ty@(IFunType cc _ impResultTys) -> do let resultTys = map scalarTy impResultTys case cc of diff --git a/src/lib/Inference.hs b/src/lib/Inference.hs index de0d43854..63ab284de 100644 --- a/src/lib/Inference.hs +++ b/src/lib/Inference.hs @@ -10,7 +10,7 @@ module Inference ( inferTopUDecl, checkTopUType, inferTopUExpr - , trySynthTerm, generalizeDict + , trySynthTerm, generalizeDict, asTopBlock , synthTopE, UDeclInferenceResult (..)) where import Prelude hiding ((.), id) @@ -25,7 +25,7 @@ import Data.Foldable (toList, asum) import Data.Functor ((<&>)) import Data.List (sortOn) import Data.Maybe (fromJust, fromMaybe, catMaybes) -import Data.Text.Prettyprint.Doc (Pretty (..), (<+>), vcat) +import Data.Text.Prettyprint.Doc (Pretty (..), (<+>), vcat, group, line, nest) import Data.Word import qualified Data.HashMap.Strict as HM import qualified Data.Map.Strict as M @@ -47,7 +47,8 @@ import QueryType import Types.Core import Types.Primitives import Types.Source -import Util +import Util hiding (group) +import PPrint (prettyBlock) -- === Top-level interface === @@ -55,15 +56,15 @@ checkTopUType :: (Fallible1 m, EnvReader m) => UType n -> m n (CType n) checkTopUType ty = liftInfererM $ solveLocal $ withApplyDefaults $ checkUType ty {-# SCC checkTopUType #-} -inferTopUExpr :: (Fallible1 m, EnvReader m) => UExpr n -> m n (CBlock n) -inferTopUExpr e = liftInfererM do +inferTopUExpr :: (Fallible1 m, EnvReader m) => UExpr n -> m n (TopBlock CoreIR n) +inferTopUExpr e = asTopBlock =<< liftInfererM do solveLocal $ buildBlockInf $ withApplyDefaults $ inferSigma noHint e {-# SCC inferTopUExpr #-} data UDeclInferenceResult e n = UDeclResultDone (e n) -- used for UDataDefDecl, UInterface and UInstance - | UDeclResultBindName LetAnn (CBlock n) (Abs (UBinder (AtomNameC CoreIR)) e n) - | UDeclResultBindPattern NameHint (CBlock n) (ReconAbs CoreIR e n) + | UDeclResultBindName LetAnn (TopBlock CoreIR n) (Abs (UBinder (AtomNameC CoreIR)) e n) + | UDeclResultBindPattern NameHint (TopBlock CoreIR n) (ReconAbs CoreIR e n) inferTopUDecl :: (Mut n, Fallible1 m, TopBuilder m, SinkableE e, HoistableE e, RenameE e) => UTopDecl n l -> e l -> m n (UDeclInferenceResult e n) @@ -129,7 +130,8 @@ inferTopUDecl (ULocalDecl (WithSrcB src decl)) result = addSrcContext src case d WithSrcB _ (UPatBinder b) -> do block <- liftInfererM $ solveLocal $ buildBlockInf do checkMaybeAnnExpr (getNameHint b) tyAnn rhs <* applyDefaults - return $ UDeclResultBindName letAnn block (Abs b result) + topBlock <- asTopBlock block + return $ UDeclResultBindName letAnn topBlock (Abs b result) _ -> do PairE block recon <- liftInfererM $ solveLocal $ buildBlockInfWithRecon do val <- checkMaybeAnnExpr (getNameHint p) tyAnn rhs @@ -137,11 +139,17 @@ inferTopUDecl (ULocalDecl (WithSrcB src decl)) result = addSrcContext src case d bindLetPat p v do applyDefaults renameM result - return $ UDeclResultBindPattern (getNameHint p) block recon + topBlock <- asTopBlock block + return $ UDeclResultBindPattern (getNameHint p) topBlock recon inferTopUDecl (UEffectDecl _ _ _) _ = error "not implemented" inferTopUDecl (UHandlerDecl _ _ _ _ _ _ _) _ = error "not implemented" {-# SCC inferTopUDecl #-} +asTopBlock :: EnvReader m => CBlock n -> m n (TopBlock CoreIR n) +asTopBlock block = do + effTy <- blockEffTy block + return $ TopLam False (PiType Empty effTy) (LamExpr Empty block) + getInstanceType :: EnvReader m => InstanceDef n -> m n (CorePiType n) getInstanceType (InstanceDef className bs params _) = liftEnvReaderM do refreshAbs (Abs bs (ListE params)) \bs' (ListE params') -> do @@ -1454,7 +1462,8 @@ checkNamedArgValidity (Abs bs _) offeredNames = do inferPrimArg :: EmitsBoth o => UExpr i -> InfererM i o (CAtom o) inferPrimArg x = do xBlock <- buildBlockInf $ inferRho noHint x - case getType xBlock of + EffTy _ ty <- blockEffTy xBlock + case ty of TyKind -> cheapReduce xBlock >>= \case Just reduced -> return reduced _ -> throw CompilerErr "Type args to primops must be reducible" @@ -1628,7 +1637,7 @@ instanceFun instanceName expl = do args <- mapM toAtomVar $ nestToNames bs' let bs'' = fmapNest (\(RolePiBinder _ b) -> b) bs' result <- mkDictAtom $ InstanceDict (sink instanceName) (Var <$> args) - return $ Abs bs'' (PairE Pure (AtomicBlock result)) + return $ Abs bs'' (PairE Pure (WithoutDecls result)) Lam <$> coreLamExpr expl ab checkMaybeAnnExpr :: EmitsBoth o @@ -2823,7 +2832,7 @@ synthTerm targetTy reqMethodAccess = confuseGHC >>= \_ -> case targetTy of ab' <- withGivenBinders ab \bs targetTy' -> do Abs bs <$> synthTerm (SynthDictType targetTy') reqMethodAccess Abs bs synthExpr <- return ab' - liftM Lam $ coreLamExpr ImplicitApp $ Abs bs $ PairE Pure (AtomicBlock synthExpr) + liftM Lam $ coreLamExpr ImplicitApp $ Abs bs $ PairE Pure (WithoutDecls synthExpr) SynthDictType dictTy -> case dictTy of DictType "Ix" _ [Type (NewtypeTyCon (Fin n))] -> return $ DictCon (DictTy dictTy) $ IxFin n DictType "Data" _ [Type t] -> do @@ -2991,6 +3000,9 @@ instance ExprVisitorNoEmits (DictSynthTraverserM i o) CoreIR i o where class DictSynthTraversable (e::E) where dsTraverse :: e i -> DictSynthTraverserM i o (e o) +instance DictSynthTraversable (TopLam CoreIR) where + dsTraverse (TopLam d ty lam) = TopLam d <$> visitPiDefault ty <*> visitLamNoEmits lam + instance DictSynthTraversable CAtom where dsTraverse atom = case atom of DictHole (AlwaysEqual ctx) ty access -> do @@ -2999,12 +3011,13 @@ instance DictSynthTraversable CAtom where case ans of Failure errs -> put (LiftE errs) >> renameM atom Success d -> return d - Lam (CoreLamExpr piTy@(CorePiType _ bsPi _) (LamExpr bsLam body)) -> do + Lam (CoreLamExpr piTy@(CorePiType _ bsPi _) (LamExpr bsLam (Abs decls result))) -> do Pi piTy' <- dsTraverse $ Pi piTy let (expls, _) = unzipExpls bsPi lam' <- dsTraverseExplBinders (zipExpls expls bsLam) \bsLamExpl' -> do - let (_, bsLam') = unzipExpls bsLamExpl' - LamExpr bsLam' <$> dsTraverse body + visitDeclsNoEmits decls \decls' -> do + let (_, bsLam') = unzipExpls bsLamExpl' + LamExpr bsLam' <$> Abs decls' <$> dsTraverse result return $ Lam $ CoreLamExpr piTy' lam' Var _ -> renameM atom SimpInCore _ -> renameM atom @@ -3021,8 +3034,6 @@ instance DictSynthTraversable CType where _ -> visitTypePartial ty instance DictSynthTraversable DataConDefs where dsTraverse = visitGeneric -instance DictSynthTraversable (Block CoreIR) where - dsTraverse = visitBlockNoEmits dsTraverseExplBinders :: Nest (WithExpl CBinder) i i' @@ -3057,7 +3068,15 @@ buildBlockInf :: EmitsInf n => (forall l. (EmitsBoth l, DExt n l) => InfererM i l (CAtom l)) -> InfererM i n (CBlock n) -buildBlockInf cont = buildDeclsInf (cont >>= withType) >>= computeAbsEffects >>= absToBlock +buildBlockInf cont = do + Abs decls (PairE result ty) <- buildDeclsInf do + ans <- cont + ty <- cheapNormalize $ getType ans + return $ PairE ans ty + let msg = "Block:" <> nest 1 (prettyBlock decls result) <> line + <> group ("Of type:" <> nest 2 (line <> pretty ty)) <> line + void $ liftHoistExcept' (docAsStr msg) $ hoist decls ty + return $ Abs decls result {-# INLINE buildBlockInf #-} buildBlockInfWithRecon @@ -3066,10 +3085,9 @@ buildBlockInfWithRecon -> InfererM i n (PairE CBlock (ReconAbs CoreIR e) n) buildBlockInfWithRecon cont = do ab <- buildDeclsInfUnzonked cont - (declsResult, recon) <- refreshAbs ab \decls result -> do + (block, recon) <- refreshAbs ab \decls result -> do (newResult, recon) <- telescopicCapture decls result return (Abs decls newResult, recon) - block <- makeBlockFromDecls declsResult return $ PairE block recon {-# INLINE buildBlockInfWithRecon #-} diff --git a/src/lib/Inline.hs b/src/lib/Inline.hs index d250a8d4d..bcf5bdc64 100644 --- a/src/lib/Inline.hs +++ b/src/lib/Inline.hs @@ -22,14 +22,11 @@ import Types.Primitives -- === External API === -inlineBindings :: (EnvReader m) => SLam n -> m n (SLam n) -inlineBindings = liftLamExpr inlineBindingsBlock +inlineBindings :: (EnvReader m) => STopLam n -> m n (STopLam n) +inlineBindings = liftLamExpr \(Abs decls ans) -> liftInlineM $ + buildScoped $ inlineDecls decls $ inline Stop ans {-# INLINE inlineBindings #-} -inlineBindingsBlock :: (EnvReader m) => SBlock n -> m n (SBlock n) -inlineBindingsBlock blk = liftInlineM $ buildScopedAssumeNoDecls $ inline Stop blk -{-# SCC inlineBindingsBlock #-} - -- === Data Structure === data InlineExpr (r::IR) (o::S) where @@ -220,7 +217,7 @@ inlineDeclsSubst = \case ixDepthExpr _ = 0 ixDepthBlock :: Block SimpIR n -> Int ixDepthBlock (exprBlock -> (Just expr)) = ixDepthExpr expr - ixDepthBlock (AtomicBlock result) = ixDepthExpr $ Atom result + ixDepthBlock (Abs Empty result) = ixDepthExpr $ Atom result ixDepthBlock _ = 0 -- Should we decide to inline this binding wherever it appears, before we even @@ -316,9 +313,10 @@ instance Inlinable SType where inline ctx ty = visitTypePartial ty >>= reconstruct ctx instance Inlinable SLam where - inline ctx (LamExpr bs body) = do + inline ctx (LamExpr bs (Abs decls ans)) = do reconstruct ctx =<< withBinders bs \bs' -> do - LamExpr bs' <$> (buildScopedAssumeNoDecls $ inline Stop body) + (LamExpr bs' <$>) $ buildScoped $ + inlineDecls decls $ inline Stop ans withBinders :: Nest SBinder i i' @@ -337,18 +335,8 @@ instance Inlinable (PiType SimpIR) where effTy' <- buildScopedAssumeNoDecls $ inline Stop effTy return $ PiType bs' effTy' -instance Inlinable SBlock where - inline ctx (Block ann decls ans) = case (ann, decls) of - (NoBlockAnn, Empty) -> - (Block NoBlockAnn Empty <$> inline Stop ans) >>= reconstruct ctx - (NoBlockAnn, _) -> error "should be unreachable" - (BlockAnn effTy, _) -> do - (Abs decls' ans') <- buildScoped $ inlineDecls decls $ inline Stop ans - effTy' <- inline Stop effTy - reconstruct ctx $ Block (BlockAnn effTy') decls' ans' - inlineBlockEmits :: Emits o => Context SExpr e2 o -> SBlock i -> InlineM i o (e2 o) -inlineBlockEmits ctx (Block _ decls ans) = do +inlineBlockEmits ctx (Abs decls ans) = do inlineDecls decls $ inlineAtom ctx ans -- Still using InlineM because we may call back into inlining, and we wish to @@ -369,7 +357,7 @@ reconstructTabApp ctx expr [] = do reconstruct ctx expr reconstructTabApp ctx expr ixs = case fromNaryForExpr (length ixs) expr of - Just (bsCount, LamExpr bs (Block _ decls result)) -> do + Just (bsCount, LamExpr bs (Abs decls result)) -> do let (ixsPref, ixsRest) = splitAt bsCount ixs -- Note: There's a decision here. Is it ok to inline the atoms in -- `ixsPref` into the body `decls`? If so, should we pre-process them and diff --git a/src/lib/Linearize.hs b/src/lib/Linearize.hs index 32944c5f7..3d6b91a3f 100644 --- a/src/lib/Linearize.hs +++ b/src/lib/Linearize.hs @@ -4,7 +4,7 @@ -- license that can be found in the LICENSE file or at -- https://developers.google.com/open-source/licenses/bsd -module Linearize (linearize, linearizeLam) where +module Linearize (linearize, linearizeTopLam) where import Control.Category ((>>>)) import Control.Monad.Reader @@ -269,10 +269,9 @@ linearizeBlockDefuncGeneral locals block = do WithTangent primalResult tangentFun <- linearizeBlock block lam <- tangentFunAsLambda tangentFun return $ PairE primalResult lam - (blockAbs, recon) <- refreshAbs (Abs decls result) \decls' (PairE primal lam) -> do + (block', recon) <- refreshAbs (Abs decls result) \decls' (PairE primal lam) -> do (primal', recon) <- capture (locals >>> toScopeFrag decls') primal lam return (Abs decls' primal', recon) - block' <- makeBlockFromDecls blockAbs return (block', recon) -- Inverse of tangentFunAsLambda. Should be used inside a returned tangent action. @@ -289,9 +288,9 @@ linearize :: Emits n => SLam n -> SAtom n -> DoubleBuilder SimpIR n (SAtom n, SL linearize f x = runPrimalMInit $ linearizeLambdaApp f x {-# SCC linearize #-} -linearizeLam :: SLam n -> [Active] -> DoubleBuilder SimpIR n (SLam n, SLam n) -linearizeLam (LamExpr bs body) actives = runPrimalMInit do - refreshBinders bs \bs' frag -> extendSubst frag do +linearizeTopLam :: STopLam n -> [Active] -> DoubleBuilder SimpIR n (STopLam n, STopLam n) +linearizeTopLam (TopLam False _ (LamExpr bs body)) actives = do + (primalFun, tangentFun) <- runPrimalMInit $ refreshBinders bs \bs' frag -> extendSubst frag do let allPrimals = nestToAtomVars bs' activeVs <- catMaybes <$> forM (zip actives allPrimals) \(active, v) -> case active of True -> return $ Just v @@ -312,6 +311,8 @@ linearizeLam (LamExpr bs body) actives = runPrimalMInit do emitBlock =<< applySubst substFrag tangentBody return $ LamExpr (bs' >>> BinaryNest bResidual bTangent) tangentBody' return (primalFun, tangentFun) + (,) <$> asTopLam primalFun <*> asTopLam tangentFun +linearizeTopLam (TopLam True _ _) _ = error "expected a non-destination-passing function" -- reify the tangent builder as a lambda linearizeLambdaApp :: Emits o => SLam i -> SAtom o -> PrimalM i o (SAtom o, SLam o) @@ -343,7 +344,7 @@ linearizeAtom atom = case atom of where emitZeroT = withZeroT $ renameM atom linearizeBlock :: Emits o => SBlock i -> LinM i o SAtom SAtom -linearizeBlock (Block _ decls result) = +linearizeBlock (Abs decls result) = linearizeDecls decls $ linearizeAtom result linearizeDecls :: Emits o => Nest SDecl i i' -> LinM i' o e1 e2 -> LinM i o e1 e2 @@ -624,7 +625,7 @@ linearizeHof hof = case hof of WithTangent sInit' sLin <- linearizeAtom sInit (lam', recon) <- linearizeEffectFun State lam (primalAux, sFinal) <- fromPair =<< emitHof (RunState Nothing sInit' lam') - referentTy <- return $ snd $ getTypeRWSAction lam' + referentTy <- snd <$> getTypeRWSAction lam' (primal, linLam) <- reconstruct primalAux recon return $ WithTangent (PairVal primal sFinal) do sLin' <- sLin @@ -639,7 +640,7 @@ linearizeHof hof = case hof of (lam', recon) <- linearizeEffectFun Writer lam (primalAux, wFinal) <- fromPair =<< emitHof (RunWriter Nothing bm' lam') (primal, linLam) <- reconstruct primalAux recon - referentTy <- return $ snd $ getTypeRWSAction lam' + referentTy <- snd <$> getTypeRWSAction lam' return $ WithTangent (PairVal primal wFinal) do bm'' <- sinkM bm' tt <- tangentType $ sink referentTy diff --git a/src/lib/Lower.hs b/src/lib/Lower.hs index 6c8f30df2..8eaf8cbeb 100644 --- a/src/lib/Lower.hs +++ b/src/lib/Lower.hs @@ -7,8 +7,7 @@ {-# LANGUAGE UndecidableInstances #-} module Lower - ( lowerFullySequential, lowerFullySequentialNoDest - , DestLamExpr, DestBlock + ( lowerFullySequential, DestBlock ) where import Prelude hiding ((.)) @@ -29,7 +28,7 @@ import Subst import QueryType import Types.Core import Types.Primitives -import Util (enumerate, foldMapM) +import Util (enumerate) -- === For loop resolution === @@ -60,30 +59,34 @@ import Util (enumerate, foldMapM) -- destination to a sub-block or sub-expression, hence "desintation -- passing style"). -type DestLamExpr = SLam type DestBlock = Abs (SBinder) SBlock -lowerFullySequential :: EnvReader m => SLam n -> m n (DestLamExpr n) -lowerFullySequential (LamExpr bs body) = liftEnvReaderM $ do - refreshAbs (Abs bs body) \bs' body' -> do - Abs b body'' <- lowerFullySequentialBlock body' - return $ LamExpr (bs' >>> UnaryNest b) body'' - -lowerFullySequentialBlock :: EnvReader m => SBlock n -> m n (DestBlock n) -lowerFullySequentialBlock b = liftAtomSubstBuilder do - resultDestTy <- RawRefTy <$> substM (getType b) +lowerFullySequential :: EnvReader m => Bool -> STopLam n -> m n (STopLam n) +lowerFullySequential wantDestStyle (TopLam False piTy (LamExpr bs body)) = liftEnvReaderM $ do + lam <- case wantDestStyle of + True -> do + refreshAbs (Abs bs body) \bs' body' -> do + xs <- bindersToAtoms bs' + EffTy _ resultTy <- instantiatePiTy (sink piTy) xs + Abs b body'' <- lowerFullySequentialBlock resultTy body' + return $ LamExpr (bs' >>> UnaryNest b) body'' + False -> do + refreshAbs (Abs bs body) \bs' body' -> do + body'' <- lowerFullySequentialBlockNoDest body' + return $ LamExpr bs' body'' + piTy' <- getLamExprType lam + return $ TopLam wantDestStyle piTy' lam +lowerFullySequential _ (TopLam True _ _) = error "already in destination style" + +lowerFullySequentialBlock :: EnvReader m => SType n -> SBlock n -> m n (DestBlock n) +lowerFullySequentialBlock resultTy b = liftAtomSubstBuilder do + let resultDestTy = RawRefTy resultTy withFreshBinder (getNameHint @String "ans") resultDestTy \destBinder -> do Abs destBinder <$> buildBlock do let dest = Var $ sink $ binderVar destBinder lowerBlockWithDest dest b $> UnitVal {-# SCC lowerFullySequentialBlock #-} -lowerFullySequentialNoDest :: EnvReader m => SLam n -> m n (SLam n) -lowerFullySequentialNoDest (LamExpr bs body) = liftEnvReaderM $ do - refreshAbs (Abs bs body) \bs' body' -> do - body'' <- lowerFullySequentialBlockNoDest body' - return $ LamExpr bs' body'' - lowerFullySequentialBlockNoDest :: EnvReader m => SBlock n -> m n (SBlock n) lowerFullySequentialBlockNoDest b = liftAtomSubstBuilder $ buildBlock $ lowerBlock b {-# SCC lowerFullySequentialBlockNoDest #-} @@ -190,8 +193,7 @@ lowerCase maybeDest scrut alts resultTy = do extendSubst (b @> Rename (atomVarName b')) $ buildBlock do lowerBlockWithDest (Var $ sink $ local_dest) body $> UnitVal - eff' <- foldMapM (pure . getEffects) alts' - void $ emitExpr $ Case (sink scrut') alts' (EffTy eff' UnitTy) + void $ mkCase (sink scrut') UnitTy alts' >>= emitExpr return UnitVal return $ PrimOp $ DAMOp $ Freeze dest' @@ -243,7 +245,7 @@ decomposeDest dest = \case _ -> return Nothing lowerBlockWithDest :: Emits o => Dest SimpIR o -> SBlock i -> LowerM i o (SAtom o) -lowerBlockWithDest dest (Block _ decls ans) = do +lowerBlockWithDest dest (Abs decls ans) = do decomposeDest dest ans >>= \case Nothing -> do ans' <- visitDeclsEmits decls $ visitAtom ans diff --git a/src/lib/OccAnalysis.hs b/src/lib/OccAnalysis.hs index 11e1be2c3..bff364f47 100644 --- a/src/lib/OccAnalysis.hs +++ b/src/lib/OccAnalysis.hs @@ -28,12 +28,12 @@ import QueryType -- annotation holding a summary of how that binding is used. It also eliminates -- unused pure bindings as it goes, since it has all the needed information. -analyzeOccurrences :: EnvReader m => SLam n -> m n (SLam n) +analyzeOccurrences :: EnvReader m => STopLam n -> m n (STopLam n) analyzeOccurrences = liftLamExpr analyzeOccurrencesBlock {-# INLINE analyzeOccurrences #-} analyzeOccurrencesBlock :: EnvReader m => SBlock n -> m n (SBlock n) -analyzeOccurrencesBlock = liftOCCM . occ accessOnce +analyzeOccurrencesBlock = liftOCCM . occNest accessOnce {-# SCC analyzeOccurrencesBlock #-} -- === Overview === @@ -254,7 +254,7 @@ occTy ty = occ accessOnce ty instance HasOCC SLam where occ a (LamExpr bs body) = do lam@(LamExpr bs' _) <- refreshAbs (Abs bs body) \bs' body' -> - LamExpr bs' <$> occ (sink a) body' + LamExpr bs' <$> occNest (sink a) body' countFreeVarsAsOccurrencesB bs' return lam @@ -269,15 +269,6 @@ instance HasOCC (PiType SimpIR) where countFreeVarsAsOccurrencesB bs' return piTy -instance HasOCC SBlock where - occ a (Block ann decls ans) = case (ann, decls) of - (NoBlockAnn , Empty) -> Block NoBlockAnn Empty <$> occ a ans - (NoBlockAnn , _ ) -> error "should be unreachable" - (BlockAnn effTy, _ ) -> do - Abs decls' ans' <- occNest a decls ans - effTy' <- occ a effTy - return $ Block (BlockAnn effTy') decls' ans' - instance HasOCC (EffTy SimpIR) where occ _ (EffTy effs ty) = do ty' <- occTy ty @@ -288,17 +279,17 @@ data ElimResult (n::S) where ElimSuccess :: Abs (Nest SDecl) SAtom n -> ElimResult n ElimFailure :: SDecl n l -> UsageInfo -> Abs (Nest SDecl) SAtom l -> ElimResult n -occNest :: Access n -> Nest SDecl n l -> SAtom l +occNest :: Access n -> Abs (Nest SDecl) SAtom n -> OCCM n (Abs (Nest SDecl) SAtom n) -occNest a decls ans = case decls of +occNest a (Abs decls ans) = case decls of Empty -> Abs Empty <$> occ a ans Nest d@(Let _ binding) ds -> do isPureDecl <- return $ isPure binding dceAttempt <- refreshAbs (Abs d (Abs ds ans)) - \d'@(Let b' (DeclBinding _ expr')) (Abs ds' ans') -> do + \d'@(Let b' (DeclBinding _ expr')) rest -> do exprIx <- summaryExpr $ sink expr' extend b' exprIx do - below <- occNest (sink a) ds' ans' + below <- occNest (sink a) rest checkAllFreeVariablesMentioned below accessInfo <- getAccessInfo $ binderName d' let usage = usageInfo accessInfo @@ -387,7 +378,7 @@ occAlt acc scrut alt = do -- case statement in that event. scrutIx <- unknown $ sink scrut extend nb scrutIx do - body' <- occ (sink acc) body + body' <- occNest (sink acc) body return $ Abs b body' ty' <- occTy ty return $ Abs (b':>ty') body' @@ -407,12 +398,12 @@ instance HasOCC (Hof SimpIR) where ixDict' <- inlinedLater ixDict occWithBinder (Abs b body) \b' body' -> do extend b' (Occ.Var $ binderName b') do - (body'', bodyFV) <- isolated (occ accessOnce body') + (body'', bodyFV) <- isolated (occNest accessOnce body') modify (<> abstractFor b' bodyFV) return $ For ann ixDict' (UnaryLamExpr b' body'') For _ _ _ -> error "For body should be a unary lambda expression" While body -> While <$> do - (body', bodyFV) <- isolated (occ accessOnce body) + (body', bodyFV) <- isolated $ occNest accessOnce body modify (<> useManyTimes bodyFV) return body' RunReader ini bd -> do @@ -451,14 +442,15 @@ instance HasOCC (Hof SimpIR) where return $ RunState Nothing ini' bd' RunState (Just _) _ _ -> error "Expecting to do occurrence analysis before destination passing." - RunIO bd -> RunIO <$> occ a bd + RunIO bd -> RunIO <$> occNest a bd RunInit _ -> -- Though this is probably not too hard to implement. Presumably -- the lambda is one-shot. error "Expecting to do occurrence analysis before lowering." oneShot :: Access n -> [IxExpr n] -> LamExpr SimpIR n -> OCCM n (LamExpr SimpIR n) -oneShot acc [] (LamExpr Empty body) = LamExpr Empty <$> occ acc body +oneShot acc [] (LamExpr Empty body) = + LamExpr Empty <$> occNest acc body oneShot acc (ix:ixs) (LamExpr (Nest b bs) body) = do occWithBinder (Abs b (LamExpr bs body)) \b' restLam -> extend b' (sink ix) do diff --git a/src/lib/Optimize.hs b/src/lib/Optimize.hs index 262ef3ae0..e4331e484 100644 --- a/src/lib/Optimize.hs +++ b/src/lib/Optimize.hs @@ -31,7 +31,7 @@ import QueryType import Util (iota) import Err -optimize :: EnvReader m => SLam n -> m n (SLam n) +optimize :: EnvReader m => STopLam n -> m n (STopLam n) optimize = dceTop -- Clean up user code >=> unrollLoops >=> dceTop -- Clean up peephole-optimized code after unrolling @@ -208,7 +208,7 @@ peepholeExpr expr = case expr of -- === Loop unrolling === -unrollLoops :: EnvReader m => SLam n -> m n (SLam n) +unrollLoops :: EnvReader m => STopLam n -> m n (STopLam n) unrollLoops = liftLamExpr unrollLoopsBlock unrollLoopsBlock :: EnvReader m => SBlock n -> m n (SBlock n) @@ -240,7 +240,7 @@ ulBlock :: SBlock i -> ULM i o (SBlock o) ulBlock b = buildBlock $ visitBlockEmits b emitSubstBlock :: Emits o => SBlock i -> ULM i o (SAtom o) -emitSubstBlock (Block _ decls ans) = visitDeclsEmits decls $ visitAtom ans +emitSubstBlock (Abs decls ans) = visitDeclsEmits decls $ visitAtom ans -- TODO: Refine the cost accounting so that operations that will become -- constant-foldable after inlining don't count towards it. @@ -257,7 +257,7 @@ ulExpr expr = case expr of vals <- dropSubst $ forM (iota n) \i -> do extendSubst (b' @> SubstVal (IdxRepVal i)) $ emitSubstBlock block' inc $ fromIntegral n -- To account for the TabCon we emit below - case getLamExprType body' of + getLamExprType body' >>= \case PiType (UnaryNest (tb:>_)) (EffTy _ valTy) -> do let tabTy = TabPi $ TabPiType (IxDictRawFin (IdxRepVal n)) (tb:>IdxRepTy) valTy emitExpr $ TabCon Nothing tabTy vals @@ -305,7 +305,7 @@ hoistLoopInvariantBlock :: EnvReader m => SBlock n -> m n (SBlock n) hoistLoopInvariantBlock body = liftLICMM $ buildBlock $ visitBlockEmits body {-# SCC hoistLoopInvariantBlock #-} -hoistLoopInvariant :: EnvReader m => SLam n -> m n (SLam n) +hoistLoopInvariant :: EnvReader m => STopLam n -> m n (STopLam n) hoistLoopInvariant = liftLamExpr hoistLoopInvariantBlock {-# INLINE hoistLoopInvariant #-} @@ -317,11 +317,11 @@ licmExpr = \case let numCarriesOriginal = length dests' Abs hdecls destsAndBody <- visitBinders (UnaryNest b) \(UnaryNest b') -> do -- First, traverse the block, to allow any Hofs inside it to hoist their own decls. - Block _ decls ans <- buildBlock $ visitBlockEmits body + Abs decls ans <- buildBlock $ visitBlockEmits body -- Now, we process the decls and decide which ones to hoist. liftEnvReaderM $ runSubstReaderT idSubst $ seqLICM REmpty mempty (asNameBinder b') REmpty decls ans - PairE (ListE extraDests) ab <- emitDecls hdecls destsAndBody + PairE (ListE extraDests) ab <- emitDecls $ Abs hdecls destsAndBody extraDests' <- mapM toAtomVar extraDests -- Append the destinations of hoisted Allocs as loop carried values. let dests'' = ProdVal $ dests' ++ (Var <$> extraDests') @@ -334,19 +334,19 @@ licmExpr = \case (oldCarries, newCarries) <- splitAt numCarriesOriginal <$> getUnpacked allCarries let oldLoopBinderVal = PairVal oldIx (ProdVal oldCarries) let s = extraDestBs @@> map SubstVal newCarries <.> lb @> SubstVal oldLoopBinderVal - block <- applySubst s bodyAbs >>= makeBlockFromDecls + block <- applySubst s bodyAbs return $ UnaryLamExpr lb' block emitSeq dir ix' dests'' body' PrimOp (Hof (TypedHof _ (For dir ix (LamExpr (UnaryNest b) body)))) -> do ix' <- substM ix Abs hdecls destsAndBody <- visitBinders (UnaryNest b) \(UnaryNest b') -> do - Block _ decls ans <- buildBlock $ visitBlockEmits body + Abs decls ans <- buildBlock $ visitBlockEmits body liftEnvReaderM $ runSubstReaderT idSubst $ seqLICM REmpty mempty (asNameBinder b') REmpty decls ans - PairE (ListE []) (Abs lnb bodyAbs) <- emitDecls hdecls destsAndBody + PairE (ListE []) (Abs lnb bodyAbs) <- emitDecls $ Abs hdecls destsAndBody ixTy <- substM $ binderType b body' <- withFreshBinder noHint ixTy \i -> do - block <- applyRename (lnb@>binderName i) bodyAbs >>= makeBlockFromDecls + block <- applyRename (lnb@>binderName i) bodyAbs return $ UnaryLamExpr i block emitHof $ For dir ix' body' expr -> visitGeneric expr >>= emitExpr @@ -400,12 +400,12 @@ newtype DCEM n a = DCEM { runDCEM :: StateT1 FV EnvReaderM n a } deriving ( Functor, Applicative, Monad, EnvReader, ScopeReader , MonadState (FV n), EnvExtender) -dceTop :: EnvReader m => SLam n -> m n (SLam n) +dceTop :: EnvReader m => STopLam n -> m n (STopLam n) dceTop = liftLamExpr dceBlock {-# INLINE dceTop #-} dceBlock :: EnvReader m => SBlock n -> m n (SBlock n) -dceBlock b = liftEnvReaderM $ evalStateT1 (runDCEM $ dce b) mempty +dceBlock b = liftEnvReaderM $ evalStateT1 (runDCEM $ dceBlock' b) mempty {-# SCC dceBlock #-} class HasDCE (e::E) where @@ -432,32 +432,36 @@ instance HasDCE SAtom where instance HasDCE SType where dce = visitTypePartial instance HasDCE (PiType SimpIR) where dce (PiType bs effTy) = do - Abs bs' effTy' <- dce (Abs bs effTy) - return $ PiType bs' effTy' + dceBinders bs effTy \bs' effTy' -> PiType bs' <$> dce effTy' instance HasDCE (LamExpr SimpIR) where - dce (LamExpr bs e) = do - Abs bs' e' <- dce (Abs bs e) - return $ LamExpr bs' e' - -instance HasDCE SBlock where - dce (Block ann decls ans) = case (ann, decls) of - (NoBlockAnn , Empty) -> Block NoBlockAnn Empty <$> dce ans - (NoBlockAnn , _ ) -> error "should be unreachable" - (BlockAnn effTy, _ ) -> do - -- The free vars accumulated in the state of DCEM should correspond to - -- the free vars of the Abs of the block answer, by the decls traversed - -- so far. dceNest takes care to uphold this invariant, but we temporarily - -- reset the state to an empty map, just so that names from the surrounding - -- block don't end up influencing elimination decisions here. Note that we - -- restore the state (and accumulate free vars of the DCE'd block into it) - -- right after dceNest. - old <- get - put mempty - Abs decls' ans' <- dceNest decls ans - modify (<> old) - effTy' <- dce effTy - return $ Block (BlockAnn effTy') decls' ans' + dce (LamExpr bs e) = dceBinders bs e \bs' e' -> LamExpr bs' <$> dceBlock' e' + +dceBinders + :: (HoistableB b, BindsEnv b, RenameB b, RenameE e) + => b n l -> e l + -> (forall l'. b n l' -> e l' -> DCEM l' a) + -> DCEM n a +dceBinders b e cont = do + ans <- refreshAbs (Abs b e) \b' e' -> cont b' e' + modify (<>FV (freeVarsB b)) + return ans +{-# INLINE dceBinders #-} + +dceBlock' :: SBlock n -> DCEM n (SBlock n) +dceBlock' (Abs decls ans) = do + -- The free vars accumulated in the state of DCEM should correspond to + -- the free vars of the Abs of the block answer, by the decls traversed + -- so far. dceNest takes care to uphold this invariant, but we temporarily + -- reset the state to an empty map, just so that names from the surrounding + -- block don't end up influencing elimination decisions here. Note that we + -- restore the state (and accumulate free vars of the DCE'd block into it) + -- right after dceNest. + old <- get + put mempty + block <- dceNest decls ans + modify (<> old) + return block data CachedFVs e n = UnsafeCachedFVs { _cachedFVs :: (NameSet n), fromCachedFVs :: (e n) } instance HoistableE (CachedFVs e) where @@ -512,13 +516,6 @@ dceNest decls ans = case decls of modify (<>FV (freeVarsB b')) return $ Abs (Nest (Let b' decl'') bs'') ans'' -instance (BindsEnv b, RenameB b, HoistableB b, RenameE e, HasDCE e) => HasDCE (Abs b e) where - dce a = do - a'@(Abs b' _) <- refreshAbs a \b e -> Abs b <$> dce e - modify (<>FV (freeVarsB b')) - return a' - {-# INLINE dce #-} - instance HasDCE (EffectRow SimpIR) instance HasDCE (DeclBinding SimpIR) instance HasDCE (EffTy SimpIR) diff --git a/src/lib/PPrint.hs b/src/lib/PPrint.hs index 2121eff5d..5979de66a 100644 --- a/src/lib/PPrint.hs +++ b/src/lib/PPrint.hs @@ -131,9 +131,9 @@ pArg :: PrettyPrec a => a -> Doc ann pArg a = prettyPrec a ArgPrec instance IRRep r => Pretty (Block r n) where - pretty (Block _ decls expr) = prettyBlock decls expr + pretty (Abs decls expr) = prettyBlock decls expr instance IRRep r => PrettyPrec (Block r n) where - prettyPrec (Block _ decls expr) = atPrec LowestPrec $ prettyBlock decls expr + prettyPrec (Abs decls expr) = atPrec LowestPrec $ prettyBlock decls expr prettyBlock :: (IRRep r, PrettyPrec (e l)) => Nest (Decl r) n l -> e l -> Doc ann prettyBlock Empty expr = group $ line <> pLowest expr @@ -357,18 +357,6 @@ prettyLam :: Pretty a => Doc ann -> a -> Doc ann prettyLam binders body = group $ group (nest 4 $ binders) <> group (nest 2 $ p body) -_inlineLastDeclBlock :: IRRep r => Block r n -> Abs (Nest (Decl r)) (Expr r) n -_inlineLastDeclBlock (Block _ decls expr) = inlineLastDecl decls expr - -inlineLastDecl :: IRRep r => Nest (Decl r) n l -> Atom r l -> Abs (Nest (Decl r)) (Expr r) n -inlineLastDecl Empty result = Abs Empty $ Atom result -inlineLastDecl (Nest (Let b (DeclBinding _ expr)) Empty) (Var (AtomVar v _)) - | v == binderName b = Abs Empty expr -inlineLastDecl (Nest decl rest) result = - case inlineLastDecl rest result of - Abs decls' result' -> - Abs (Nest decl decls') result' - instance IRRep r => Pretty (EffectRow r n) where pretty (EffectRow effs t) = braces $ hsep (punctuate "," (map p (eSetToList effs))) <> p t @@ -785,14 +773,16 @@ instance Pretty (TopFunDef n) where instance Pretty (TopFun n) where pretty = \case - DexTopFun def ty simp lowering -> + DexTopFun def lam lowering -> "Top-level Function" <> hardline <+> "definition:" <+> pretty def - <> hardline <+> "type:" <+> pretty ty - <> hardline <+> "simplified:" <+> pretty simp + <> hardline <+> "lambda:" <+> pretty lam <> hardline <+> "lowering:" <+> pretty lowering FFITopFun f _ -> p f +instance IRRep r => Pretty (TopLam r n) where + pretty (TopLam _ _ lam) = pretty lam + instance Pretty a => Pretty (EvalStatus a) where pretty = \case Waiting -> "" diff --git a/src/lib/QueryType.hs b/src/lib/QueryType.hs index cf6f98a26..a33c7b147 100644 --- a/src/lib/QueryType.hs +++ b/src/lib/QueryType.hs @@ -9,6 +9,7 @@ module QueryType (module QueryType, module QueryTypePure, toAtomVar) where import Control.Category ((>>>)) import Control.Monad import Data.List (elemIndex) +import Data.Functor ((<&>)) import Types.Primitives import Types.Core @@ -47,14 +48,30 @@ caseAltsBinderTys ty = case ty of extendEffect :: IRRep r => Effect r n -> EffectRow r n -> EffectRow r n extendEffect eff (EffectRow effs t) = EffectRow (effs <> eSetSingleton eff) t -getDestLamExprType :: LamExpr SimpIR n -> PiType SimpIR n -getDestLamExprType (LamExpr bsRefB body) = +blockEffTy :: (EnvReader m, IRRep r) => Block r n -> m n (EffTy r n) +blockEffTy block = liftEnvReaderM $ refreshAbs block \decls result -> do + effs <- declsEffects decls mempty + return $ ignoreHoistFailure $ hoist decls $ EffTy effs $ getType result + where + declsEffects :: IRRep r => Nest (Decl r) n l -> EffectRow r l -> EnvReaderM l (EffectRow r l) + declsEffects Empty !acc = return acc + declsEffects n@(Nest (Let _ (DeclBinding _ expr)) rest) !acc = withExtEvidence n do + expr' <- sinkM expr + declsEffects rest $ acc <> getEffects expr' + +blockTy :: (EnvReader m, IRRep r) => Block r n -> m n (Type r n) +blockTy b = blockEffTy b <&> \(EffTy _ t) -> t + +piTypeWithoutDest :: PiType SimpIR n -> PiType SimpIR n +piTypeWithoutDest (PiType bsRefB _) = case popNest bsRefB of - Just (PairB bs (bDest:>RawRefTy ansTy)) -> do - let resultEffs = ignoreHoistFailure $ hoist bDest $ getEffects body - PiType bs $ EffTy resultEffs ansTy + Just (PairB bs (_:>RawRefTy ansTy)) -> do + PiType bs $ EffTy Pure ansTy -- XXX: we ignore the effects here _ -> error "expected trailing dest binder" +blockEff :: (EnvReader m, IRRep r) => Block r n -> m n (EffectRow r n) +blockEff b = blockEffTy b <&> \(EffTy eff _) -> eff + typeOfApp :: (IRRep r, EnvReader m) => Type r n -> [Atom r n] -> m n (Type r n) typeOfApp (Pi (CorePiType _ bs (EffTy _ resultTy))) xs = do let subst = bs @@> fmap SubstVal xs @@ -137,40 +154,42 @@ effTyOfHof hof = EffTy <$> hofEffects hof <*> typeOfHof hof typeOfHof :: (EnvReader m, IRRep r) => Hof r n -> m n (Type r n) typeOfHof = \case - For _ ixTy f -> case getLamExprType f of + For _ ixTy f -> getLamExprType f >>= \case PiType (UnaryNest b) (EffTy _ eltTy) -> return $ TabTy (ixTypeDict ixTy) b eltTy _ -> error "expected a unary pi type" While _ -> return UnitTy - Linearize f _ -> case getLamExprType f of + Linearize f _ -> getLamExprType f >>= \case PiType (UnaryNest (binder:>a)) (EffTy Pure b) -> do let b' = ignoreHoistFailure $ hoist binder b let fLinTy = Pi $ nonDepPiType [a] Pure b' return $ PairTy b' fLinTy _ -> error "expected a unary pi type" - Transpose f _ -> case getLamExprType f of + Transpose f _ -> getLamExprType f >>= \case PiType (UnaryNest (_:>a)) _ -> return a _ -> error "expected a unary pi type" - RunReader _ f -> return resultTy - where (resultTy, _) = getTypeRWSAction f - RunWriter _ _ f -> return $ uncurry PairTy $ getTypeRWSAction f - RunState _ _ f -> return $ PairTy resultTy stateTy - where (resultTy, stateTy) = getTypeRWSAction f - RunIO f -> return $ getType f - RunInit f -> return $ getType f + RunReader _ f -> do + (resultTy, _) <- getTypeRWSAction f + return resultTy + RunWriter _ _ f -> uncurry PairTy <$> getTypeRWSAction f + RunState _ _ f -> do + (resultTy, stateTy) <- getTypeRWSAction f + return $ PairTy resultTy stateTy + RunIO f -> blockTy f + RunInit f -> blockTy f CatchException ty _ -> return ty hofEffects :: (EnvReader m, IRRep r) => Hof r n -> m n (EffectRow r n) hofEffects = \case For _ _ f -> functionEffs f - While body -> return $ getEffects body + While body -> blockEff body Linearize _ _ -> return Pure -- Body has to be a pure function Transpose _ _ -> return Pure -- Body has to be a pure function RunReader _ f -> rwsFunEffects Reader f RunWriter d _ f -> maybeInit d <$> rwsFunEffects Writer f RunState d _ f -> maybeInit d <$> rwsFunEffects State f - RunIO f -> return $ deleteEff IOEffect $ getEffects f - RunInit f -> return $ deleteEff InitEffect $ getEffects f - CatchException _ f -> return $ deleteEff ExceptionEffect $ getEffects f + RunIO f -> deleteEff IOEffect <$> blockEff f + RunInit f -> deleteEff InitEffect <$> blockEff f + CatchException _ f -> deleteEff ExceptionEffect <$> blockEff f where maybeInit :: IRRep r => Maybe (Atom r i) -> (EffectRow r o -> EffectRow r o) maybeInit d = case d of Just _ -> (<>OneEffect InitEffect); Nothing -> id @@ -316,53 +335,35 @@ makePreludeMaybeTy ty = do -- === computing effects === functionEffs :: (IRRep r, EnvReader m) => LamExpr r n -> m n (EffectRow r n) -functionEffs f = case getLamExprType f of +functionEffs f = getLamExprType f >>= \case PiType b (EffTy effs _) -> return $ ignoreHoistFailure $ hoist b effs rwsFunEffects :: (IRRep r, EnvReader m) => RWS -> LamExpr r n -> m n (EffectRow r n) -rwsFunEffects rws f = return case getLamExprType f of +rwsFunEffects rws f = getLamExprType f >>= \case PiType (BinaryNest h ref) et -> do let effs' = ignoreHoistFailure $ hoist ref (etEff et) let hVal = Var $ AtomVar (binderName h) (TC HeapType) let effs'' = deleteEff (RWSEffect rws hVal) effs' - ignoreHoistFailure $ hoist h effs'' + return $ ignoreHoistFailure $ hoist h effs'' _ -> error "Expected a binary function type" -getLamExprType :: IRRep r => LamExpr r n -> PiType r n -getLamExprType (LamExpr bs body) = PiType bs (EffTy (getEffects body) (getType body)) +getLamExprType :: (IRRep r, EnvReader m) => LamExpr r n -> m n (PiType r n) +getLamExprType (LamExpr bs body) = liftEnvReaderM $ + refreshAbs (Abs bs body) \bs' body' -> do + effTy <- blockEffTy body' + return $ PiType bs' effTy -getTypeRWSAction :: IRRep r => LamExpr r n -> (Type r n, Type r n) -getTypeRWSAction f = case getLamExprType f of +getTypeRWSAction :: (IRRep r, EnvReader m) => LamExpr r n -> m n (Type r n, Type r n) +getTypeRWSAction f = getLamExprType f >>= \case PiType (BinaryNest regionBinder refBinder) (EffTy _ resultTy) -> do case binderType refBinder of RefTy _ referentTy -> do let referentTy' = ignoreHoistFailure $ hoist regionBinder referentTy let resultTy' = ignoreHoistFailure $ hoist (PairB regionBinder refBinder) resultTy - (resultTy', referentTy') + return (resultTy', referentTy') _ -> error "expected a ref" _ -> error "expected a pi type" -computeAbsEffects :: (IRRep r, EnvExtender m, RenameE e) - => Abs (Nest (Decl r)) e n -> m n (Abs (Nest (Decl r)) (EffectRow r `PairE` e) n) -computeAbsEffects it = refreshAbs it \decls result -> do - effs <- declNestEffects decls - return $ Abs decls (effs `PairE` result) -{-# INLINE computeAbsEffects #-} - -declNestEffects :: (IRRep r, EnvReader m) => Nest (Decl r) n l -> m l (EffectRow r l) -declNestEffects decls = liftEnvReaderM $ declNestEffectsRec decls mempty -{-# INLINE declNestEffects #-} - -declNestEffectsRec :: IRRep r => Nest (Decl r) n l -> EffectRow r l -> EnvReaderM l (EffectRow r l) -declNestEffectsRec Empty !acc = return acc -declNestEffectsRec n@(Nest decl rest) !acc = withExtEvidence n do - expr <- sinkM $ declExpr decl - acc' <- sinkM $ acc <> (getEffects expr) - declNestEffectsRec rest acc' - where - declExpr :: Decl r n l -> Expr r n - declExpr (Let _ (DeclBinding _ expr)) = expr - instantiateHandlerType :: EnvReader m => HandlerName n -> CType n -> [CAtom n] -> m n (CType n) instantiateHandlerType hndName r args = do HandlerDef _ rb bs _effs retTy _ _ <- lookupHandlerDef hndName @@ -385,9 +386,14 @@ getSuperclassTys (DictType _ className params) = do getTypeTopFun :: EnvReader m => TopFunName n -> m n (PiType SimpIR n) getTypeTopFun f = lookupTopFun f >>= \case - DexTopFun _ piTy _ _ -> return piTy + DexTopFun _ (TopLam _ piTy _) _ -> return piTy FFITopFun _ iTy -> liftIFunType iTy +asTopLam :: (EnvReader m, IRRep r) => LamExpr r n -> m n (TopLam r n) +asTopLam lam = do + piTy <- getLamExprType lam + return $ TopLam False piTy lam + liftIFunType :: (IRRep r, EnvReader m) => IFunType -> m n (PiType r n) liftIFunType (IFunType _ argTys resultTys) = liftEnvReaderM $ go argTys where go :: IRRep r => [BaseType] -> EnvReaderM n (PiType r n) diff --git a/src/lib/QueryTypePure.hs b/src/lib/QueryTypePure.hs index e71769003..9f439bc46 100644 --- a/src/lib/QueryTypePure.hs +++ b/src/lib/QueryTypePure.hs @@ -236,6 +236,12 @@ nonDepTabPiType (IxType t d) resultTy = case toConstAbsPure resultTy of Abs b resultTy' -> TabPiType d (b:>t) resultTy' +corePiTypeToPiType :: CorePiType n -> PiType CoreIR n +corePiTypeToPiType (CorePiType _ bs effTy) = PiType (fmapNest withoutExpl bs) effTy + +coreLamToTopLam :: CoreLamExpr n -> TopLam CoreIR n +coreLamToTopLam (CoreLamExpr ty f) = TopLam False (corePiTypeToPiType ty) f + (==>) :: IRRep r => IxType r n -> Type r n -> Type r n a ==> b = TabPi $ nonDepTabPiType a b @@ -253,11 +259,6 @@ ixTyFromDict ixDict = flip IxType ixDict $ case ixDict of IxDictRawFin _ -> IdxRepTy IxDictSpecialized n _ _ -> n -instance IRRep r => HasType r (Block r) where - getType (Block NoBlockAnn Empty result) = getType result - getType (Block (BlockAnn (EffTy _ ty)) _ _) = ty - getType _ = error "impossible" - -- === querying effects implementation === instance IRRep r => HasEffects (Expr r) r where @@ -318,13 +319,3 @@ instance IRRep r => HasEffects (PrimOp r) r where Freeze _ -> Pure -- is this correct? Hof (TypedHof (EffTy eff _) _) -> eff {-# INLINE getEffects #-} - - -instance IRRep r => HasEffects (Block r) r where - getEffects (Block (BlockAnn (EffTy effs _)) _ _) = effs - getEffects (Block NoBlockAnn _ _) = Pure - {-# INLINE getEffects #-} - -instance IRRep r => HasEffects (Alt r) r where - getEffects (Abs bs body) = ignoreHoistFailure $ hoist bs (getEffects body) - {-# INLINE getEffects #-} diff --git a/src/lib/Simplify.hs b/src/lib/Simplify.hs index b3167ef0b..34da148e1 100644 --- a/src/lib/Simplify.hs +++ b/src/lib/Simplify.hs @@ -7,8 +7,8 @@ {-# LANGUAGE UndecidableInstances #-} module Simplify - ( simplifyTopBlock, simplifyTopFunction, SimplifiedBlock (..), ReconstructAtom (..), applyReconTop, - linearizeTopFun) where + ( simplifyTopBlock, simplifyTopFunction, ReconstructAtom (..), applyReconTop, + linearizeTopFun, SimplifiedTopLam (..)) where import Control.Applicative import Control.Category ((>>>)) @@ -19,7 +19,7 @@ import Data.Text.Prettyprint.Doc (Pretty (..), hardline) import Builder import CheapReduction -import CheckType (CheckableE (..), isData) +import CheckType (CheckableE (..), isData, checkBlock) import Core import Err import Generalize @@ -106,9 +106,8 @@ tryAsDataAtom atom = do forceTabLam :: Emits n => TabLamExpr n -> SimplifyM i n (SAtom n) forceTabLam (PairE ixTy (Abs b ab)) = buildFor (getNameHint b) Fwd ixTy \v -> do - Abs decls result <- applyRename (b@>(atomVarName v)) ab - result' <- emitDecls decls result - toDataAtomIgnoreRecon result' + result <- applyRename (b@>(atomVarName v)) ab >>= emitDecls + toDataAtomIgnoreRecon result type NaryTabLamExpr = Abs (Nest SBinder) (Abs (Nest SDecl) CAtom) @@ -246,18 +245,24 @@ instance ScopableBuilder SimpIR (SimplifyM i) where -- === Top-level API === +data SimplifiedTopLam n = SimplifiedTopLam (STopLam n) (ReconstructAtom n) data SimplifiedBlock n = SimplifiedBlock (SBlock n) (ReconstructAtom n) --- TODO: extend this to work on functions instead of blocks (with blocks still --- accessible as nullary functions) -simplifyTopBlock :: (TopBuilder m, Mut n) => Block CoreIR n -> m n (SimplifiedBlock n) -simplifyTopBlock block = liftSimplifyM $ buildSimplifiedBlock $ simplifyBlock block +simplifyTopBlock + :: (TopBuilder m, Mut n) => TopBlock CoreIR n -> m n (SimplifiedTopLam n) +simplifyTopBlock (TopLam _ _ (LamExpr Empty body)) = do + SimplifiedBlock block recon <- liftSimplifyM $ buildSimplifiedBlock $ simplifyBlock body + topLam <- asTopLam $ LamExpr Empty block + return $ SimplifiedTopLam topLam recon +simplifyTopBlock _ = error "not a block (nullary lambda)" {-# SCC simplifyTopBlock #-} -simplifyTopFunction :: (TopBuilder m, Mut n) => LamExpr CoreIR n -> m n (LamExpr SimpIR n) -simplifyTopFunction f = liftSimplifyM do - (lam, CoerceReconAbs) <- simplifyLam f - return lam +simplifyTopFunction :: (TopBuilder m, Mut n) => CTopLam n -> m n (STopLam n) +simplifyTopFunction (TopLam False _ f) = do + asTopLam =<< liftSimplifyM do + (lam, CoerceReconAbs) <- simplifyLam f + return lam +simplifyTopFunction _ = error "shouldn't be in destination-passing style already" {-# SCC simplifyTopFunction #-} applyReconTop :: (EnvReader m, Fallible1 m) => ReconstructAtom n -> SAtom n -> m n (CAtom n) @@ -276,12 +281,23 @@ instance HoistableE SimplifiedBlock instance CheckableE SimpIR SimplifiedBlock where checkE (SimplifiedBlock block _) = -- TODO: CheckableE instance for the recon too - checkE block + void $ checkBlock block instance Pretty (SimplifiedBlock n) where pretty (SimplifiedBlock block recon) = pretty block <> hardline <> pretty recon +instance SinkableE SimplifiedTopLam where + sinkingProofE = todoSinkableProof + +instance CheckableE SimpIR SimplifiedTopLam where + checkE (SimplifiedTopLam lam _) = do + -- TODO: CheckableE instance for the recon too + checkE lam + +instance Pretty (SimplifiedTopLam n) where + pretty (SimplifiedTopLam lam recon) = pretty lam <> hardline <> pretty recon + -- === All the bits of IR === simplifyDecls :: Emits o => Nest (Decl CoreIR) i i' -> SimplifyM i' o a -> SimplifyM i o a @@ -349,13 +365,6 @@ simplifyRefOp op ref = case op of ProjRef _ UnwrapNewtype -> return ref where emitRefOp op' = emitOp $ RefOp ref op' -caseComputingEffs - :: forall m n r. (MonadFail1 m, EnvReader m, IRRep r) - => Atom r n -> [Alt r n] -> Type r n -> m n (Expr r n) -caseComputingEffs scrut alts resultTy = do - return $ Case scrut alts (EffTy (foldMap getEffects alts) resultTy) -{-# INLINE caseComputingEffs #-} - defuncCaseCore :: Emits o => Atom CoreIR o -> Type CoreIR o -> (forall o'. (Emits o', DExt o o') => Int -> CAtom o' -> SimplifyM i o' (CAtom o')) @@ -396,17 +405,16 @@ defuncCase scrut resultTy cont = do alts' <- forM (enumerate altBinderTys) \(i, bTy) -> do buildAbs noHint bTy \x -> do buildBlock $ cont i (sink $ Var x) >>= toDataAtomIgnoreRecon - caseExpr <- caseComputingEffs scrut alts' resultTyData + caseExpr <- mkCase scrut resultTyData alts' emitExpr caseExpr >>= liftSimpAtom resultTy Nothing -> do split <- splitDataComponents resultTy - (alts', recons) <- unzip <$> forM (enumerate altBinderTys) \(i, bTy) -> do + (alts', closureTys, recons) <- unzip3 <$> forM (enumerate altBinderTys) \(i, bTy) -> do simplifyAlt split bTy $ cont i - closureTys <- mapM getAltNonDataTy alts' let closureSumTy = SumTy closureTys let newNonDataTy = nonDataTy split alts'' <- forM (enumerate alts') \(i, alt) -> injectAltResult closureTys i alt - caseExpr <- caseComputingEffs scrut alts'' (PairTy (dataTy split) closureSumTy) + caseExpr <- mkCase scrut (PairTy (dataTy split) closureSumTy) alts'' caseResult <- emitExpr $ caseExpr (dataVal, sumVal) <- fromPair caseResult reconAlts <- forM (zip closureTys recons) \(ty, recon) -> @@ -419,11 +427,11 @@ simplifyAlt :: SplitDataNonData n -> SType o -> (forall o'. (Emits o', DExt o o') => SAtom o' -> SimplifyM i o' (CAtom o')) - -> SimplifyM i o (Alt SimpIR o, ReconstructAtom o) -simplifyAlt split ty cont = fromPairE <$> do + -> SimplifyM i o (Alt SimpIR o, SType o, ReconstructAtom o) +simplifyAlt split ty cont = do withFreshBinder noHint ty \b -> do ab <- buildScoped $ cont $ sink $ Var $ binderVar b - (resultWithDecls, recon) <- refreshAbs ab \decls result -> do + (body, recon) <- refreshAbs ab \decls result -> do let locals = toScopeFrag b >>> toScopeFrag decls -- TODO: this might be too cautious. The type only needs to -- be hoistable above the decls. In principle it can still @@ -432,15 +440,9 @@ simplifyAlt split ty cont = fromPairE <$> do (resultData, resultNonData) <- toSplit split result (newResult, reconAbs) <- telescopicCapture locals resultNonData return (Abs decls (PairVal resultData newResult), LamRecon reconAbs) - block <- makeBlockFromDecls resultWithDecls - return $ PairE (Abs b block) recon - -getAltNonDataTy :: EnvReader m => Alt SimpIR n -> m n (SType n) -getAltNonDataTy (Abs bs body) = liftSubstEnvReaderM @AtomSubstVal do - substBinders bs \bs' -> do - ~(PairTy _ ty) <- substM $ getType body - -- Result types of simplified abs should be hoistable past binder - return $ ignoreHoistFailure $ hoist bs' ty + EffTy _ (PairTy _ nonDataType) <- blockEffTy body + let nonDataType' = ignoreHoistFailure $ hoist b nonDataType + return (Abs b body, nonDataType', recon) simplifyApp :: forall i o. Emits o => NameHint -> CType o -> CAtom i -> [CAtom o] -> SimplifyM i o (CAtom o) @@ -531,9 +533,9 @@ emitSpecialization s = do extendSpecializationCache s name return name -specializedFunCoreDefinition :: (Mut n, TopBuilder m) => SpecializationSpec n -> m n (LamExpr CoreIR n) +specializedFunCoreDefinition :: (Mut n, TopBuilder m) => SpecializationSpec n -> m n (TopLam CoreIR n) specializedFunCoreDefinition (AppSpecialization f (Abs bs staticArgs)) = do - liftBuilder $ buildLamExpr (EmptyAbs bs) \runtimeArgs -> do + (asTopLam =<<) $ liftBuilder $ buildLamExpr (EmptyAbs bs) \runtimeArgs -> do -- This avoids an infinite loop. Otherwise, in simplifyTopFunction, -- where we eta-expand and try to simplify `App f args`, we would see `f` as a -- "noinline" function and defer its simplification. @@ -547,12 +549,12 @@ simplifyTabApp f [] = return f simplifyTabApp f@(SimpInCore sic) xs = case sic of TabLam _ _ -> do case fromNaryTabLam (length xs) f of - Just (bsCount, Abs bs declsAtom) -> do + Just (bsCount, Abs bs block) -> do let (xsPref, xsRest) = splitAt bsCount xs xsPref' <- mapM toDataAtomIgnoreRecon xsPref - Abs decls atom <- applySubst (bs@@>(SubstVal <$> xsPref')) declsAtom - atom' <- emitDecls decls atom - simplifyTabApp atom' xsRest + block' <- applySubst (bs@@>(SubstVal <$> xsPref')) block + atom <- emitDecls block' + simplifyTabApp atom xsRest Nothing -> error "should never happen" ACase e alts ty -> dropSubst do resultTy <- typeOfTabApp ty xs @@ -604,10 +606,10 @@ requireIxDictCache dictAbs = do Nothing -> error "Couldn't hoist specialized dictionary" {-# INLINE requireIxDictCache #-} -simplifyDictMethod :: Mut n => AbsDict n -> IxMethod -> TopBuilderM n (SLam n) +simplifyDictMethod :: Mut n => AbsDict n -> IxMethod -> TopBuilderM n (TopLam SimpIR n) simplifyDictMethod absDict@(Abs bs dict) method = do ty <- liftEnvReaderM $ ixMethodType method absDict - lamExpr <- liftBuilder $ buildLamExprFromPi ty \allArgs -> do + lamExpr <- liftBuilder $ buildTopLamFromPi ty \allArgs -> do let (extraArgs, methodArgs) = splitAt (nestLength bs) allArgs dict' <- applyRename (bs @@> (atomVarName <$> extraArgs)) dict emitExpr =<< mkApplyMethod dict' (fromEnum method) (Var <$> methodArgs) @@ -716,15 +718,13 @@ buildSimplifiedBlock cont = do return $ RightE (dataResult `PairE` ansTy) case eitherResult of LeftE ans -> do - (declsResult, recon) <- refreshAbs (Abs decls ans) \decls' ans' -> do + (block, recon) <- refreshAbs (Abs decls ans) \decls' ans' -> do (newResult, reconAbs) <- telescopicCapture (toScopeFrag decls') ans' return (Abs decls' newResult, LamRecon reconAbs) - block <- makeBlockFromDecls declsResult return $ SimplifiedBlock block recon RightE (ans `PairE` ty) -> do - block <- makeBlockFromDecls $ Abs decls ans let ty' = ignoreHoistFailure $ hoist (toScopeFrag decls) ty - return $ SimplifiedBlock block (CoerceRecon ty') + return $ SimplifiedBlock (Abs decls ans) (CoerceRecon ty') simplifyOp :: Emits o => NameHint -> PrimOp CoreIR i -> SimplifyM i o (CAtom o) simplifyOp hint op = case op of @@ -879,7 +879,9 @@ simplifyHof _hint resultTy = \case liftSimpAtom resultTy result CatchException _ body-> do SimplifiedBlock body' recon <- buildSimplifiedBlock $ simplifyBlock body - block <- liftBuilder $ runSubstReaderT idSubst $ buildBlock $ exceptToMaybeBlock body' + simplifiedResultTy <- blockTy body' + block <- liftBuilder $ runSubstReaderT idSubst $ buildBlock $ + exceptToMaybeBlock (sink simplifiedResultTy) body' result <- emitBlock block case recon of CoerceRecon ty -> do @@ -925,7 +927,7 @@ preludeMaybeNewtypeCon ty = do return $ UserADTData sn tyConName params simplifyBlock :: Emits o => Block CoreIR i -> SimplifyM i o (CAtom o) -simplifyBlock (Block _ decls result) = simplifyDecls decls $ simplifyAtom result +simplifyBlock (Abs decls result) = simplifyDecls decls $ simplifyAtom result -- === simplifying custom linearizations === @@ -941,10 +943,10 @@ linearizeTopFun spec = do linearizeTopFunNoCache :: (Mut n, TopBuilder m) => LinearizationSpec n -> m n (TopFunName n, TopFunName n) linearizeTopFunNoCache spec@(LinearizationSpec f actives) = do - TopFunBinding ~(DexTopFun _ _ lam _) <- lookupEnv f + TopFunBinding ~(DexTopFun _ lam _) <- lookupEnv f PairE fPrimal fTangent <- liftSimplifyM $ tryGetCustomRule (sink f) >>= \case Just (absParams, rule) -> simplifyCustomLinearization (sink absParams) actives (sink rule) - Nothing -> liftM toPairE $ liftDoubleBuilderToSimplifyM $ linearizeLam (sink lam) actives + Nothing -> liftM toPairE $ liftDoubleBuilderToSimplifyM $ linearizeTopLam (sink lam) actives fTangentT <- transposeTopFun fTangent fPrimal' <- emitTopFunBinding "primal" (LinearizationPrimal spec) fPrimal fTangent' <- emitTopFunBinding "tangent" (LinearizationTangent spec) fTangent @@ -956,7 +958,7 @@ tryGetCustomRule :: EnvReader m => TopFunName n -> m n (Maybe (Abstracted CoreIR tryGetCustomRule f' = do ~(TopFunBinding f) <- lookupEnv f' case f of - DexTopFun def _ _ _ -> case def of + DexTopFun def _ _ -> case def of Specialization (AppSpecialization fCore absParams) -> fmap (absParams,) <$> lookupCustomRules (atomVarName fCore) _ -> return Nothing @@ -969,10 +971,10 @@ type Linearized = Abs (Nest SBinder) -- primal args simplifyCustomLinearization :: Abstracted CoreIR (ListE CAtom) n -> [Active] -> AtomRules n - -> SimplifyM i n (PairE SLam SLam n) + -> SimplifyM i n (PairE STopLam STopLam n) simplifyCustomLinearization (Abs runtimeBs staticArgs) actives rule = do CustomLinearize nImplicit nExplicit zeros fCustom <- return rule - defuncLinearized =<< withSimplifiedBinders runtimeBs \runtimeBs' runtimeArgs -> do + linearized <- withSimplifiedBinders runtimeBs \runtimeBs' runtimeArgs -> do Abs runtimeBs' <$> buildScoped do ListE staticArgs' <- applySubst (runtimeBs @@> (SubstVal . sink <$> runtimeArgs)) staticArgs fCustom' <- sinkM fCustom @@ -1003,7 +1005,11 @@ simplifyCustomLinearization (Abs runtimeBs staticArgs) actives rule = do tangentResult <- dropSubst $ simplifyApp noHint resultTyTangent fLin' tangentArgs' toDataAtomIgnoreRecon tangentResult return $ PairE primalResult' fLin' - where + PairE primalFun tangentFun <- defuncLinearized linearized + primalFun' <- asTopLam primalFun + tangentFun' <- asTopLam tangentFun + return $ PairE primalFun' tangentFun' + where buildTangentArgs :: Emits n => SymbolicZeros -> [(SType n, Active)] -> [SAtom n] -> SimplifyM i n [SAtom n] buildTangentArgs _ [] [] = return [] buildTangentArgs zeros ((t, False):tys) activeArgs = do @@ -1038,7 +1044,7 @@ defuncLinearized ab = liftBuilder $ refreshAbs ab \bs ab' -> do return $ Abs (Nest rB tBs') UnitE residualsTangentsBs' <- return $ ignoreHoistFailure $ hoist decls residualsTangentsBs return (Abs decls (PairVal primalResult residuals), reconAbs, residualsTangentsBs') - primalFun <- LamExpr bs <$> makeBlockFromDecls declsAndResult + let primalFun = LamExpr bs declsAndResult LamExpr residualAndTangentBs tangentBody <- buildLamExpr residualsTangentsBs \(residuals:tangents) -> do LamExpr tangentBs' body <- applyReconAbs (sink reconAbs) (Var residuals) applyRename (tangentBs' @@> (atomVarName <$> tangents)) body >>= emitBlock @@ -1049,25 +1055,20 @@ defuncLinearized ab = liftBuilder $ refreshAbs ab \bs ab' -> do type HandlerM = SubstReaderT AtomSubstVal (BuilderM SimpIR) -exceptToMaybeBlock :: Emits o => SBlock i -> HandlerM i o (SAtom o) -exceptToMaybeBlock (Block (BlockAnn (EffTy _ ty)) decls result) = do - ty' <- substM ty - exceptToMaybeDecls ty' decls $ Atom result -exceptToMaybeBlock (Block NoBlockAnn Empty result) = exceptToMaybeExpr $ Atom result -exceptToMaybeBlock _ = error "impossible" - -exceptToMaybeDecls :: Emits o => SType o -> Nest SDecl i i' -> SExpr i' -> HandlerM i o (SAtom o) -exceptToMaybeDecls _ Empty result = exceptToMaybeExpr result -exceptToMaybeDecls resultTy (Nest (Let b (DeclBinding _ rhs)) decls) finalResult = do +exceptToMaybeBlock :: Emits o => SType o -> SBlock i -> HandlerM i o (SAtom o) +exceptToMaybeBlock ty (Abs Empty result) = do + result' <- substM result + return $ JustAtom ty result' +exceptToMaybeBlock resultTy (Abs (Nest (Let b (DeclBinding _ rhs)) decls) finalResult) = do maybeResult <- exceptToMaybeExpr rhs case maybeResult of -- This case is just an optimization (but an important one!) JustAtom _ x -> - extendSubst (b@> SubstVal x) $ exceptToMaybeDecls resultTy decls finalResult + extendSubst (b@> SubstVal x) $ exceptToMaybeBlock resultTy (Abs decls finalResult) _ -> emitMaybeCase maybeResult (MaybeTy resultTy) (return $ NothingAtom $ sink resultTy) (\v -> extendSubst (b@> SubstVal v) $ - exceptToMaybeDecls (sink resultTy) decls finalResult) + exceptToMaybeBlock (sink resultTy) (Abs decls finalResult)) exceptToMaybeExpr :: Emits o => SExpr i -> HandlerM i o (SAtom o) exceptToMaybeExpr expr = case expr of @@ -1076,15 +1077,19 @@ exceptToMaybeExpr expr = case expr of resultTy' <- substM $ MaybeTy resultTy buildCase e' resultTy' \i v -> do Abs b body <- return $ alts !! i - extendSubst (b @> SubstVal v) $ exceptToMaybeBlock body + extendSubst (b @> SubstVal v) do + blockResultTy <- blockTy =<< substM body -- TODO: avoid this by caching the type + exceptToMaybeBlock blockResultTy body Atom x -> do x' <- substM x let ty = getType x' return $ JustAtom ty x' PrimOp (Hof (TypedHof _ (For ann ixTy' (UnaryLamExpr b body)))) -> do ixTy <- substM ixTy' - maybes <- buildForAnn (getNameHint b) ann ixTy \i -> - extendSubst (b@>Rename (atomVarName i)) $ exceptToMaybeBlock body + maybes <- buildForAnn (getNameHint b) ann ixTy \i -> do + extendSubst (b@>Rename (atomVarName i)) do + blockResultTy <- blockTy =<< substM body -- TODO: avoid this by caching the type + exceptToMaybeBlock blockResultTy body catMaybesE maybes PrimOp (MiscOp (ThrowException _)) -> do ty <- substM $ getType expr @@ -1094,7 +1099,8 @@ exceptToMaybeExpr expr = case expr of BinaryLamExpr h ref body <- return lam result <- emitRunState noHint s' \h' ref' -> extendSubst (h @> Rename (atomVarName h') <.> ref @> Rename (atomVarName ref')) do - exceptToMaybeBlock body + blockResultTy <- blockTy =<< substM body -- TODO: avoid this by caching the type + exceptToMaybeBlock blockResultTy body (maybeAns, newState) <- fromPair result a <- substM $ getType expr emitMaybeCase maybeAns (MaybeTy a) @@ -1104,14 +1110,17 @@ exceptToMaybeExpr expr = case expr of monoid' <- substM monoid PairTy _ accumTy <- substM resultTy result <- emitRunWriter noHint accumTy monoid' \h' ref' -> - extendSubst (h @> Rename (atomVarName h') <.> ref @> Rename (atomVarName ref')) $ - exceptToMaybeBlock body + extendSubst (h @> Rename (atomVarName h') <.> ref @> Rename (atomVarName ref')) do + blockResultTy <- blockTy =<< substM body -- TODO: avoid this by caching the type + exceptToMaybeBlock blockResultTy body (maybeAns, accumResult) <- fromPair result a <- substM $ getType expr emitMaybeCase maybeAns (MaybeTy a) (return $ NothingAtom $ sink a) (\ans -> return $ JustAtom (sink a) $ PairVal ans (sink accumResult)) - PrimOp (Hof (TypedHof _ (While body))) -> runMaybeWhile $ exceptToMaybeBlock body + PrimOp (Hof (TypedHof _ (While body))) -> do + blockResultTy <- blockTy =<< substM body -- TODO: avoid this by caching the type + runMaybeWhile $ exceptToMaybeBlock (sink blockResultTy) body _ -> do expr' <- substM expr case hasExceptions expr' of diff --git a/src/lib/TopLevel.hs b/src/lib/TopLevel.hs index a07a9da47..9b08665ec 100644 --- a/src/lib/TopLevel.hs +++ b/src/lib/TopLevel.hs @@ -43,7 +43,7 @@ import qualified LLVM.AST import AbstractSyntax import Builder -import CheckType ( CheckableE (..), asFFIFunType, checkHasType, checkExtends, checkDestLam) +import CheckType ( CheckableE (..), asFFIFunType, checkHasType) #ifdef DEX_DEBUG import CheckType (checkTypesM) #endif @@ -536,37 +536,32 @@ whenOpt x act = getConfig <&> optLevel >>= \case NoOptimize -> return x Optimize -> act x -evalBlock :: (Topper m, Mut n) => CBlock n -> m n (CAtom n) +evalBlock :: (Topper m, Mut n) => TopBlock CoreIR n -> m n (CAtom n) evalBlock typed = do -- Be careful when adding new compilation passes here. If you do, be sure to -- also check compileTopLevelFun, below, and Export.prepareFunctionForExport. -- In most cases it should be easiest to add new passes to simpOptimizations or -- loweredOptimizations, below, because those are reused in all three places. - checkEffects Pure typed synthed <- checkPass SynthPass $ synthTopE typed - simplifiedBlock <- checkPass SimpPass $ simplifyTopBlock synthed - SimplifiedBlock simp recon <- return simplifiedBlock - checkEffects Pure simp - NullaryLamExpr opt <- simpOptimizations $ NullaryLamExpr simp - checkEffects Pure opt + SimplifiedTopLam simp recon <- checkPass SimpPass $ simplifyTopBlock synthed + opt <- simpOptimizations simp simpResult <- case opt of - AtomicBlock result -> return result + TopLam _ _ (LamExpr Empty (WithoutDecls result)) -> return result _ -> do - lowered <- checkPass LowerPass $ lowerFullySequential $ NullaryLamExpr opt - checkDestLam lowered - lOpt <- loweredOptimizations lowered - checkDestLam lOpt + lowered <- checkPass LowerPass $ lowerFullySequential True opt + lOpt <- checkPass OptPass $ loweredOptimizations lowered cc <- getEntryFunCC impOpt <- checkPass ImpPass $ toImpFunction cc lOpt llvmOpt <- packageLLVMCallable impOpt resultVals <- liftIO $ callEntryFun llvmOpt [] - PiType bs (EffTy _ resultTy') <- return $ getDestLamExprType lOpt + TopLam _ destTy _ <- return lOpt + PiType bs (EffTy _ resultTy') <- return $ piTypeWithoutDest destTy let resultTy = ignoreHoistFailure $ hoist bs resultTy' repValAtom =<< repValFromFlatList resultTy resultVals applyReconTop recon simpResult {-# SCC evalBlock #-} -simpOptimizations :: Topper m => SLam n -> m n (SLam n) +simpOptimizations :: Topper m => STopLam n -> m n (STopLam n) simpOptimizations simp = do analyzed <- whenOpt simp $ checkPass OccAnalysisPass . analyzeOccurrences inlined <- whenOpt analyzed $ checkPass InlinePass . inlineBindings @@ -574,7 +569,7 @@ simpOptimizations simp = do inlined2 <- whenOpt analyzed2 $ checkPass InlinePass . inlineBindings whenOpt inlined2 $ checkPass OptPass . optimize -loweredOptimizations :: Topper m => DestLamExpr n -> m n (DestLamExpr n) +loweredOptimizations :: Topper m => STopLam n -> m n (STopLam n) loweredOptimizations lowered = do lopt <- whenOpt lowered $ checkPass LowerOptPass . (dceTop >=> hoistLoopInvariant) @@ -584,7 +579,7 @@ loweredOptimizations lowered = do logFiltered l VectPass $ return [TextOut $ pprint errs] checkPass VectPass $ return vo -loweredOptimizationsNoDest :: Topper m => SLam n -> m n (SLam n) +loweredOptimizationsNoDest :: Topper m => STopLam n -> m n (STopLam n) loweredOptimizationsNoDest lowered = do lopt <- whenOpt lowered $ checkPass LowerOptPass . (dceTop >=> hoistLoopInvariant) @@ -594,7 +589,7 @@ loweredOptimizationsNoDest lowered = do evalSpecializations :: (Topper m, Mut n) => [TopFunName n] -> m n () evalSpecializations fs = do fSimps <- toposortAnnVars <$> catMaybes <$> forM fs \f -> lookupTopFun f >>= \case - DexTopFun _ _ simp Waiting -> return $ Just (f, simp) + DexTopFun _ simp Waiting -> return $ Just (f, simp) _ -> return Nothing forM_ fSimps \(f, simp) -> do -- Prevents infinite loop in case compiling `v` ends up requiring `v` @@ -608,14 +603,14 @@ evalSpecializations fs = do evalDictSpecializations :: (Topper m, Mut n) => [SpecDictName n] -> m n () evalDictSpecializations ds = do - -- TODO Do we have to do these in order, like evalSpecializations, or are they - -- independent enough not to need it? - -- TODO Do we need to gate the status of these, too? + -- -- TODO Do we have to do these in order, like evalSpecializations, or are they + -- -- independent enough not to need it? + -- -- TODO Do we need to gate the status of these, too? forM_ ds \dName -> do SpecializedDict _ (Just fs) <- lookupSpecDict dName fs' <- forM fs \lam -> do opt <- simpOptimizations lam - lowered <- checkPass LowerPass $ lowerFullySequentialNoDest opt + lowered <- checkPass LowerPass $ lowerFullySequential False opt loweredOptimizationsNoDest lowered updateTopEnv $ LowerDictSpecialization dName fs' return () @@ -647,10 +642,10 @@ execUDecl mname decl = do {-# SCC execUDecl #-} compileTopLevelFun :: (Topper m, Mut n) - => CallingConvention -> SLam n -> m n (ImpFunction n) + => CallingConvention -> STopLam n -> m n (ImpFunction n) compileTopLevelFun cc fSimp = do fOpt <- simpOptimizations fSimp - fLower <- checkPass LowerPass $ lowerFullySequential fOpt + fLower <- checkPass LowerPass $ lowerFullySequential True fOpt flOpt <- loweredOptimizations fLower checkPass ImpPass $ toImpFunction cc flOpt {-# SCC compileTopLevelFun #-} @@ -659,7 +654,8 @@ printCodegen :: (Topper m, Mut n) => CAtom n -> m n String printCodegen x = do block <- liftBuilder $ buildBlock do emitExpr $ PrimOp $ MiscOp $ ShowAny $ sink x - getDexString =<< evalBlock block + topBlock <- asTopBlock block + getDexString =<< evalBlock topBlock loadObject :: (Topper m, Mut n) => FunObjCodeName n -> m n NativeFunction loadObject fname = @@ -733,7 +729,7 @@ funNameToObj :: (EnvReader m, Fallible1 m) => ImpFunName n -> m n (FunObjCodeName n) funNameToObj v = do lookupEnv v >>= \case - TopFunBinding (DexTopFun _ _ _ (Finished impl)) -> return $ topFunObjCode impl + TopFunBinding (DexTopFun _ _ (Finished impl)) -> return $ topFunObjCode impl b -> error $ "couldn't find object cache entry for " ++ pprint v ++ "\ngot:\n" ++ pprint b withCompileTime :: MonadIO m => m Result -> m Result @@ -756,11 +752,6 @@ checkPass name cont = do #endif return result -checkEffects :: (Topper m, HasEffects e r, IRRep r) => EffectRow r n -> e n -> m n () -checkEffects allowedEffs e = do - let actualEffs = getEffects e - checkExtends allowedEffs actualEffs - addResultCtx :: SourceBlock -> Result -> Result addResultCtx block (Result outs errs) = Result outs (addSrcTextContext (sbOffset block) (sbText block) errs) diff --git a/src/lib/Transpose.hs b/src/lib/Transpose.hs index 29032ff49..904e608d1 100644 --- a/src/lib/Transpose.hs +++ b/src/lib/Transpose.hs @@ -42,8 +42,8 @@ runTransposeM cont = runReaderT1 (ListE []) $ runSubstReaderT idSubst $ cont transposeTopFun :: (MonadFail1 m, EnvReader m) - => LamExpr SimpIR n -> m n (LamExpr SimpIR n) -transposeTopFun lam = liftBuilder $ runTransposeM do + => STopLam n -> m n (STopLam n) +transposeTopFun (TopLam False _ lam) = liftBuilder $ runTransposeM do (Abs bsNonlin (Abs bLin body), Abs bsNonlin'' outTy) <- unpackLinearLamExpr lam refreshBinders bsNonlin \bsNonlin' substFrag -> extendRenamer substFrag do outTy' <- applyRename (bsNonlin''@@> nestToNames bsNonlin') outTy @@ -54,7 +54,11 @@ transposeTopFun lam = liftBuilder $ runTransposeM do withAccumulator inTy \refSubstVal -> extendSubst (bLin @> refSubstVal) $ transposeBlock body (sink ct) - return $ LamExpr (bsNonlin' >>> UnaryNest bCT) body' + EffTy _ bodyTy <- blockEffTy body' + let piTy = PiType (bsNonlin' >>> UnaryNest bCT) (EffTy Pure bodyTy) + let lamT = LamExpr (bsNonlin' >>> UnaryNest bCT) body' + return $ TopLam False piTy lamT +transposeTopFun (TopLam True _ _) = error "shouldn't be transposing in destination passing style" unpackLinearLamExpr :: (MonadFail1 m, EnvReader m) => LamExpr SimpIR n @@ -63,7 +67,7 @@ unpackLinearLamExpr unpackLinearLamExpr lam@(LamExpr bs body) = do let numNonlin = nestLength bs - 1 PairB bsNonlin (UnaryNest bLin) <- return $ splitNestAt numNonlin bs - PiType bsTy (EffTy _ resultTy) <- return $ getLamExprType lam + PiType bsTy (EffTy _ resultTy) <- getLamExprType lam PairB bsNonlinTy (UnaryNest bLinTy) <- return $ splitNestAt numNonlin bsTy let resultTy' = ignoreHoistFailure $ hoist bLinTy resultTy return ( Abs bsNonlin $ Abs bLin body @@ -154,7 +158,7 @@ extendLinRegions v cont = local (\(ListE vs) -> ListE (v:vs)) cont -- === actual pass === transposeBlock :: Emits o => SBlock i -> SAtom o -> TransposeM i o () -transposeBlock (Block _ decls result) ct = transposeWithDecls decls result ct +transposeBlock (Abs decls result) ct = transposeWithDecls decls result ct transposeWithDecls :: Emits o => Nest SDecl i i' -> SAtom i' -> SAtom o -> TransposeM i o () transposeWithDecls Empty atom ct = transposeAtom atom ct diff --git a/src/lib/Types/Core.hs b/src/lib/Types/Core.hs index 8953b1944..394919a82 100644 --- a/src/lib/Types/Core.hs +++ b/src/lib/Types/Core.hs @@ -175,19 +175,13 @@ data ParamRole = TypeParam | DictParam | DataParam deriving (Show, Generic, Eq) data TyConParams n = TyConParams [Explicitness] [Atom CoreIR n] deriving (Show, Generic) --- The Type is the type of the result expression (and thus the type of the --- block). It's given by querying the result expression's type, and checking --- that it doesn't have any free names bound by the decls in the block. We store --- it separately as an optimization, to avoid having to traverse the block. --- If the decls are empty we can skip the type annotation, because then we can --- cheaply query the result, and, more importantly, there's no risk of having a --- type that mentions local variables. -data Block (r::IR) (n::S) where - Block :: BlockAnn r n l -> Nest (Decl r) n l -> Atom r l -> Block r n - -data BlockAnn r n l where - BlockAnn :: EffTy r n -> BlockAnn r n l - NoBlockAnn :: BlockAnn r n n +type WithDecls (r::IR) = Abs (Decls r) :: E -> E +type Block (r::IR) = WithDecls r (Atom r) :: E + +type TopBlock = TopLam -- used for nullary lambda +type IsDestLam = Bool +data TopLam (r::IR) (n::S) = TopLam IsDestLam (PiType r n) (LamExpr r n) + deriving (Show, Generic) data LamExpr (r::IR) (n::S) where LamExpr :: Nest (Binder r) n l -> Block r l -> LamExpr r n @@ -416,6 +410,7 @@ type CDecl = Decl CoreIR type CDecls = Decls CoreIR type CAtomName = AtomName CoreIR type CAtomVar = AtomVar CoreIR +type CTopLam = TopLam CoreIR type SAtom = Atom SimpIR type SType = Type SimpIR @@ -429,6 +424,7 @@ type SAtomVar = AtomVar SimpIR type SBinder = Binder SimpIR type SRepVal = RepVal SimpIR type SLam = LamExpr SimpIR +type STopLam = TopLam SimpIR -- === newtypes === @@ -591,8 +587,8 @@ data TopEnvUpdate n = | AddCustomRule (CAtomName n) (AtomRules n) | UpdateLoadedModules ModuleSourceName (ModuleName n) | UpdateLoadedObjects (FunObjCodeName n) NativeFunction - | FinishDictSpecialization (SpecDictName n) [LamExpr SimpIR n] - | LowerDictSpecialization (SpecDictName n) [LamExpr SimpIR n] + | FinishDictSpecialization (SpecDictName n) [TopLam SimpIR n] + | LowerDictSpecialization (SpecDictName n) [TopLam SimpIR n] | UpdateTopFunEvalStatus (TopFunName n) (TopFunEvalStatus n) | UpdateInstanceDef (InstanceName n) (InstanceDef n) | UpdateTyConDef (TyConName n) (TyConDef n) @@ -756,7 +752,7 @@ deriving instance Show (EffectOpType n) deriving via WrapE EffectOpType n instance Generic (EffectOpType n) instance GenericE SpecializedDictDef where - type RepE SpecializedDictDef = AbsDict `PairE` MaybeE (ListE (LamExpr SimpIR)) + type RepE SpecializedDictDef = AbsDict `PairE` MaybeE (ListE (TopLam SimpIR)) fromE (SpecializedDict ab methods) = ab `PairE` methods' where methods' = case methods of Just xs -> LeftE (ListE xs) Nothing -> RightE UnitE @@ -777,7 +773,7 @@ data EvalStatus a = Waiting | Running | Finished a type TopFunEvalStatus n = EvalStatus (TopFunLowerings n) data TopFun (n::S) = - DexTopFun (TopFunDef n) (PiType SimpIR n) (LamExpr SimpIR n) (TopFunEvalStatus n) + DexTopFun (TopFunDef n) (TopLam SimpIR n) (TopFunEvalStatus n) | FFITopFun String IFunType deriving (Show, Generic) @@ -929,7 +925,7 @@ data SpecializedDictDef n = -- Methods (thunked if nullary), if they're available. -- We create specialized dict names during simplification, but we don't -- actually simplify/lower them until we return to TopLevel - (Maybe [LamExpr SimpIR n]) + (Maybe [TopLam SimpIR n]) deriving (Show, Generic) -- TODO: extend with AD-oriented specializations, backend-specific specializations etc. @@ -1125,12 +1121,11 @@ pattern UnaryLamExpr b body = LamExpr (UnaryNest b) body pattern BinaryLamExpr :: Binder r n l1 -> Binder r l1 l2 -> Block r l2 -> LamExpr r n pattern BinaryLamExpr b1 b2 body = LamExpr (BinaryNest b1 b2) body -pattern AtomicBlock :: Atom r n -> Block r n -pattern AtomicBlock atom <- Block _ Empty atom - where AtomicBlock atom = Block NoBlockAnn Empty atom +pattern WithoutDecls :: e n -> WithDecls r e n +pattern WithoutDecls x = Abs Empty x exprBlock :: IRRep r => Block r n -> Maybe (Expr r n) -exprBlock (Block _ (Nest (Let b (DeclBinding _ expr)) Empty) (Var (AtomVar n _))) +exprBlock (Abs (Nest (Let b (DeclBinding _ expr)) Empty) (Var (AtomVar n _))) | n == binderName b = Just expr exprBlock _ = Nothing {-# INLINE exprBlock #-} @@ -1890,27 +1885,6 @@ instance IRRep r => AlphaEqE (TC r) instance IRRep r => AlphaHashableE (TC r) instance IRRep r => RenameE (TC r) -instance IRRep r => GenericE (Block r) where - type RepE (Block r) = PairE (MaybeE (EffTy r)) (Abs (Nest (Decl r)) (Atom r)) - fromE (Block (BlockAnn effTy) decls result) = PairE (JustE effTy) (Abs decls result) - fromE (Block NoBlockAnn Empty result) = PairE NothingE (Abs Empty result) - fromE _ = error "impossible" - {-# INLINE fromE #-} - toE (PairE (JustE effTy) (Abs decls result)) = Block (BlockAnn effTy) decls result - toE (PairE NothingE (Abs Empty result)) = Block NoBlockAnn Empty result - toE _ = error "impossible" - {-# INLINE toE #-} - -deriving instance IRRep r => Show (BlockAnn r n l) - -instance IRRep r => SinkableE (Block r) -instance IRRep r => HoistableE (Block r) -instance IRRep r => AlphaEqE (Block r) -instance IRRep r => AlphaHashableE (Block r) -instance IRRep r => RenameE (Block r) -deriving instance IRRep r => Show (Block r n) -deriving via WrapE (Block r) n instance IRRep r => Generic (Block r n) - instance IRRep r => GenericB (NonDepNest r ann) where type RepB (NonDepNest r ann) = (LiftB (ListE ann)) `PairB` Nest (AtomNameBinder r) fromB (NonDepNest bs anns) = LiftB (ListE anns) `PairB` bs @@ -2291,16 +2265,29 @@ instance RenameE TopFunDef instance AlphaEqE TopFunDef instance AlphaHashableE TopFunDef +instance IRRep r => GenericE (TopLam r) where + type RepE (TopLam r) = LiftE Bool `PairE` PiType r `PairE` LamExpr r + fromE (TopLam d x y) = LiftE d `PairE` x `PairE` y + {-# INLINE fromE #-} + toE (LiftE d `PairE` x `PairE` y) = TopLam d x y + {-# INLINE toE #-} + +instance IRRep r => SinkableE (TopLam r) +instance IRRep r => HoistableE (TopLam r) +instance IRRep r => RenameE (TopLam r) +instance IRRep r => AlphaEqE (TopLam r) +instance IRRep r => AlphaHashableE (TopLam r) + instance GenericE TopFun where type RepE TopFun = EitherE - (TopFunDef `PairE` PiType SimpIR `PairE` LamExpr SimpIR `PairE` ComposeE EvalStatus TopFunLowerings) + (TopFunDef `PairE` TopLam SimpIR `PairE` ComposeE EvalStatus TopFunLowerings) (LiftE (String, IFunType)) fromE = \case - DexTopFun def ty simp status -> LeftE (def `PairE` ty `PairE` simp `PairE` ComposeE status) + DexTopFun def lam status -> LeftE (def `PairE` lam `PairE` ComposeE status) FFITopFun name ty -> RightE (LiftE (name, ty)) {-# INLINE fromE #-} toE = \case - LeftE (def `PairE` ty `PairE` simp `PairE` ComposeE status) -> DexTopFun def ty simp status + LeftE (def `PairE` lam `PairE` ComposeE status) -> DexTopFun def lam status RightE (LiftE (name, ty)) -> FFITopFun name ty {-# INLINE toE #-} @@ -2574,8 +2561,8 @@ instance GenericE TopEnvUpdate where {- UpdateLoadedModules -} (LiftE ModuleSourceName `PairE` ModuleName) {- UpdateLoadedObjects -} (FunObjCodeName `PairE` LiftE NativeFunction) ) ( EitherE6 - {- FinishDictSpecialization -} (SpecDictName `PairE` ListE (LamExpr SimpIR)) - {- LowerDictSpecialization -} (SpecDictName `PairE` ListE (LamExpr SimpIR)) + {- FinishDictSpecialization -} (SpecDictName `PairE` ListE (TopLam SimpIR)) + {- LowerDictSpecialization -} (SpecDictName `PairE` ListE (TopLam SimpIR)) {- UpdateTopFunEvalStatus -} (TopFunName `PairE` ComposeE EvalStatus TopFunLowerings) {- UpdateInstanceDef -} (InstanceName `PairE` InstanceDef) {- UpdateTyConDef -} (TyConName `PairE` TyConDef) @@ -2685,8 +2672,8 @@ applyUpdate e = \case updateEnv dName newBinding e UpdateTopFunEvalStatus f s -> do case lookupEnvPure e f of - TopFunBinding (DexTopFun def ty simp _) -> - updateEnv f (TopFunBinding $ DexTopFun def ty simp s) e + TopFunBinding (DexTopFun def lam _) -> + updateEnv f (TopFunBinding $ DexTopFun def lam s) e _ -> error "can't update ffi function impl" UpdateInstanceDef name def -> do case lookupEnvPure e name of @@ -2842,7 +2829,7 @@ instance Store (TyConParams n) instance Store (DataConDefs n) instance Store (TyConDef n) instance Store (DataConDef n) -instance IRRep r => Store (Block r n) +instance IRRep r => Store (TopLam r n) instance IRRep r => Store (LamExpr r n) instance IRRep r => Store (IxType r n) instance Store (CorePiType n) diff --git a/src/lib/Vectorize.hs b/src/lib/Vectorize.hs index 69b7404e5..112060821 100644 --- a/src/lib/Vectorize.hs +++ b/src/lib/Vectorize.hs @@ -19,7 +19,7 @@ import Core import Err import CheapReduction import IRVariants -import Lower (DestBlock, DestLamExpr) +import Lower (DestBlock) import MTL1 import Name import Subst @@ -87,13 +87,13 @@ newtype TopVectorizeM (i::S) (o::S) (a:: *) = TopVectorizeM , EnvExtender, Builder SimpIR, ScopableBuilder SimpIR, Catchable , SubstReader Name) -vectorizeLoops :: EnvReader m => Word32 -> DestLamExpr n -> m n (DestLamExpr n, Errs) -vectorizeLoops width (LamExpr bsDestB body) = liftEnvReaderM do +vectorizeLoops :: EnvReader m => Word32 -> STopLam n -> m n (STopLam n, Errs) +vectorizeLoops width (TopLam d ty (LamExpr bsDestB body)) = liftEnvReaderM do case popNest bsDestB of Just (PairB bs b) -> refreshAbs (Abs bs (Abs b body)) \bs' body' -> do (Abs b'' body'', errs) <- liftTopVectorizeM width $ vectorizeLoopsDestBlock body' - return $ (LamExpr (bs' >>> UnaryNest b'') body'', errs) + return $ (TopLam d ty (LamExpr (bs' >>> UnaryNest b'') body''), errs) Nothing -> error "expected a trailing dest binder" liftTopVectorizeM :: (EnvReader m) @@ -139,7 +139,7 @@ vectorizeLoopsDestBlock (Abs (destb:>destTy) body) = do vectorizeLoopsBlock :: (Emits o) => Block SimpIR i -> TopVectorizeM i o (SAtom o) -vectorizeLoopsBlock (Block _ decls ans) = +vectorizeLoopsBlock (Abs decls ans) = vectorizeLoopsDecls decls $ renameM ans vectorizeLoopsDecls :: (Emits o) @@ -360,7 +360,7 @@ vectorizeLamExpr (LamExpr bs body) argStabilities = case (bs, argStabilities) of _ -> error "Zip error" vectorizeBlock :: Emits o => SBlock i -> VectorizeM i o (VAtom o) -vectorizeBlock block@(Block _ decls (ans :: SAtom i')) = +vectorizeBlock block@(Abs decls (ans :: SAtom i')) = addVectErrCtx "vectorizeBlock" ("Block:\n" ++ pprint block) $ go decls where @@ -497,11 +497,10 @@ vectorizePrimOp op = case op of -- complain about FFI calls and the like. Hof (TypedHof _ (RunIO body)) -> do -- TODO: buildBlockAux? - Abs decls (LiftE vy `PairE` yWithTy) <- buildScoped do + Abs decls (LiftE vy `PairE` y) <- buildScoped do VVal vy y <- vectorizeBlock body - PairE (LiftE vy) <$> withType y - body' <- absToBlock =<< computeAbsEffects (Abs decls yWithTy) - VVal vy <$> emitHof (RunIO body') + return $ PairE (LiftE vy) y + VVal vy <$> emitHof (RunIO $ Abs decls y) _ -> throwVectErr $ "Can't vectorize op: " ++ pprint op vectorizeType :: SType i -> VectorizeM i o (SType o) diff --git a/tests/uexpr-tests.dx b/tests/uexpr-tests.dx index c32378219..6583ce5bd 100644 --- a/tests/uexpr-tests.dx +++ b/tests/uexpr-tests.dx @@ -321,7 +321,7 @@ def bug(n|Data) -> () = > v#1:((RangeFrom n v#0) => ()) = for i:(RangeFrom n v#0). () > v#1 > Of type: ((RangeFrom n v#0) => ()) -> With effects: {} +> > > for w':n. > ^^^^^^^^^^ diff --git a/tests/unit/ConstantCastingSpec.hs b/tests/unit/ConstantCastingSpec.hs index 499946ef8..fe9abab12 100644 --- a/tests/unit/ConstantCastingSpec.hs +++ b/tests/unit/ConstantCastingSpec.hs @@ -24,6 +24,7 @@ import Types.Core import Types.Imp import Types.Primitives import Types.Source +import QueryType castOp :: ScalarBaseType -> (SAtom n) -> PrimOp SimpIR n castOp ty x = MiscOp $ CastOp (BaseTy (Scalar ty)) x @@ -43,7 +44,7 @@ exprToBlock expr = do compile :: (Topper m, Mut n) => ScalarBaseType -> ScalarBaseType -> m n LLVMCallable compile fromTy toTy = do - sLam <- liftEnvReaderM $ castLam fromTy toTy + sLam <- liftEnvReaderM (castLam fromTy toTy) >>= asTopLam compileTopLevelFun (EntryFunCC CUDANotRequired) sLam >>= packageLLVMCallable arbLitVal :: ScalarBaseType -> Gen LitVal diff --git a/tests/unit/JaxADTSpec.hs b/tests/unit/JaxADTSpec.hs index ee2725788..40c0ff785 100644 --- a/tests/unit/JaxADTSpec.hs +++ b/tests/unit/JaxADTSpec.hs @@ -18,6 +18,7 @@ import TopLevel import Types.Imp import Types.Primitives hiding (Sin) import Types.Source hiding (SourceName) +import QueryType x_nm, y_nm :: JSourceName x_nm = JSourceName 0 0 "x" @@ -48,7 +49,7 @@ compile jaxpr = do -- the jaxpr instead of just coercing it. Distinct <- getDistinct jRename <- liftRenameM $ renameJaxpr (unsafeCoerceE jaxpr) - jSimp <- liftJaxSimpM $ simplifyJaxpr jRename + jSimp <- liftJaxSimpM (simplifyJaxpr jRename) >>= asTopLam compileTopLevelFun (EntryFunCC CUDANotRequired) jSimp >>= packageLLVMCallable spec :: Spec diff --git a/tests/unit/OccAnalysisSpec.hs b/tests/unit/OccAnalysisSpec.hs index 06007fb1c..20cd6dc88 100644 --- a/tests/unit/OccAnalysisSpec.hs +++ b/tests/unit/OccAnalysisSpec.hs @@ -25,6 +25,7 @@ import Types.Imp (Backend (..)) import Types.Primitives import Types.Source import TopLevel +import QueryType sourceTextToBlocks :: (Topper m, Mut n) => Text -> m n [SBlock n] sourceTextToBlocks source = do @@ -44,11 +45,11 @@ uExprToBlock expr = do renamed <- renameSourceNamesUExpr expr typed <- inferTopUExpr renamed synthed <- synthTopE typed - (SimplifiedBlock block (CoerceRecon _)) <- simplifyTopBlock synthed + SimplifiedTopLam (TopLam _ _ (LamExpr Empty block)) (CoerceRecon _) <- simplifyTopBlock synthed return block findRunIOAnnotation :: SBlock n -> LetAnn -findRunIOAnnotation (Block _ decls _) = go decls where +findRunIOAnnotation (Abs decls _) = go decls where go :: Nest SDecl n l -> LetAnn go (Nest (Let _ (DeclBinding ann (PrimOp (Hof (TypedHof _ (RunIO _)))))) _) = ann go (Nest _ rest) = go rest @@ -57,7 +58,8 @@ findRunIOAnnotation (Block _ decls _) = go decls where analyze :: EvalConfig -> TopStateEx -> [Text] -> IO LetAnn analyze cfg env code = fst <$> runTopperM cfg env do [block] <- sourceTextToBlocks $ unlines code - NullaryLamExpr block' <- analyzeOccurrences $ NullaryLamExpr block + lam <- asTopLam $ LamExpr Empty block + TopLam _ _ (LamExpr Empty block') <- analyzeOccurrences lam -- The RunIO is generated by simplifying `unreachable()` in the examples -- below. If we want compound examples that have more than one RunIO block, -- we will need better pattern-matching. From c7fef435e917075a31c9c1d68acd342fb7c3d90c Mon Sep 17 00:00:00 2001 From: Dougal Date: Fri, 23 Jun 2023 23:58:30 -0400 Subject: [PATCH 2/4] Remove some user-defined effect stuff. Just to make the decls-in-binders update easier. --- src/lib/AbstractSyntax.hs | 1 - src/lib/CheapReduction.hs | 9 ---- src/lib/CheckType.hs | 2 - src/lib/Core.hs | 12 ----- src/lib/Inference.hs | 1 - src/lib/PPrint.hs | 12 ----- src/lib/QueryType.hs | 5 --- src/lib/QueryTypePure.hs | 8 ---- src/lib/Simplify.hs | 1 - src/lib/SourceRename.hs | 1 - src/lib/Types/Core.hs | 94 +++------------------------------------ src/lib/Types/Source.hs | 1 - 12 files changed, 7 insertions(+), 140 deletions(-) diff --git a/src/lib/AbstractSyntax.hs b/src/lib/AbstractSyntax.hs index bf65fda87..5a5805b0c 100644 --- a/src/lib/AbstractSyntax.hs +++ b/src/lib/AbstractSyntax.hs @@ -338,7 +338,6 @@ effect (Binary JuxtaposeWithSpace (Identifier "State") (Identifier h)) = return $ URWSEffect State $ fromString h effect (Identifier "Except") = return UExceptionEffect effect (Identifier "IO") = return UIOEffect -effect (Identifier effName) = return $ UUserEffect (fromString effName) effect _ = throw SyntaxErr "Unexpected effect form; expected one of `Read h`, `Accum h`, `State h`, `Except`, `IO`, or the name of a user-defined effect." aMethod :: CSDecl -> SyntaxM (Maybe (UMethodDef VoidS)) diff --git a/src/lib/CheapReduction.hs b/src/lib/CheapReduction.hs index b301fa6b7..bce6dcbed 100644 --- a/src/lib/CheapReduction.hs +++ b/src/lib/CheapReduction.hs @@ -638,7 +638,6 @@ instance IRRep r => VisitGeneric (PrimOp r) r where MiscOp op -> MiscOp <$> visitGeneric op Hof op -> Hof <$> visitGeneric op DAMOp op -> DAMOp <$> visitGeneric op - UserEffectOp op -> UserEffectOp <$> visitGeneric op RefOp r op -> RefOp <$> visitGeneric r <*> traverseOp op visitGeneric visitGeneric visitGeneric instance IRRep r => VisitGeneric (TypedHof r) r where @@ -668,18 +667,11 @@ instance IRRep r => VisitGeneric (DAMOp r) r where Place x y -> Place <$> visitGeneric x <*> visitGeneric y Freeze x -> Freeze <$> visitGeneric x -instance VisitGeneric UserEffectOp CoreIR where - visitGeneric = \case - Handle name xs body -> Handle <$> renameN name <*> mapM visitGeneric xs <*> visitBlock body - Resume t x -> Resume <$> visitGeneric t <*> visitGeneric x - Perform x i -> Perform <$> visitGeneric x <*> pure i - instance IRRep r => VisitGeneric (Effect r) r where visitGeneric = \case RWSEffect rws h -> RWSEffect rws <$> visitGeneric h ExceptionEffect -> pure ExceptionEffect IOEffect -> pure IOEffect - UserEffect name -> UserEffect <$> renameN name InitEffect -> pure InitEffect instance IRRep r => VisitGeneric (EffectRow r) r where @@ -881,7 +873,6 @@ instance IRRep r => SubstE AtomSubstVal (RepVal r) instance SubstE AtomSubstVal TyConParams instance SubstE AtomSubstVal DataConDef instance IRRep r => SubstE AtomSubstVal (BaseMonoid r) -instance SubstE AtomSubstVal UserEffectOp instance IRRep r => SubstE AtomSubstVal (DAMOp r) instance IRRep r => SubstE AtomSubstVal (TypedHof r) instance IRRep r => SubstE AtomSubstVal (Hof r) diff --git a/src/lib/CheckType.hs b/src/lib/CheckType.hs index f16a8b5bd..9e088be81 100644 --- a/src/lib/CheckType.hs +++ b/src/lib/CheckType.hs @@ -481,7 +481,6 @@ typeCheckPrimOp effs op = case op of MiscOp x -> typeCheckMiscOp effs x MemOp x -> typeCheckMemOp effs x DAMOp op' -> typeCheckDAMOp effs op' - UserEffectOp _ -> error "not implemented" RefOp ref m -> do TC (RefType h s) <- getTypeE ref case m of @@ -958,7 +957,6 @@ instance IRRep r => CheckableE r (EffectRow r) where RWSEffect _ v -> v |: TC HeapType ExceptionEffect -> return () IOEffect -> return () - UserEffect _ -> return () InitEffect -> return () case effTail of NoTail -> return () diff --git a/src/lib/Core.hs b/src/lib/Core.hs index 436aae869..88725c1ca 100644 --- a/src/lib/Core.hs +++ b/src/lib/Core.hs @@ -347,18 +347,6 @@ lookupInstanceTy :: EnvReader m => InstanceName n -> m n (CorePiType n) lookupInstanceTy name = lookupEnv name >>= \case InstanceBinding _ ty -> return ty {-# INLINE lookupInstanceTy #-} -lookupEffectDef :: EnvReader m => EffectName n -> m n (EffectDef n) -lookupEffectDef name = lookupEnv name >>= \case EffectBinding x -> return x -{-# INLINE lookupEffectDef #-} - -lookupEffectOpDef :: EnvReader m => EffectOpName n -> m n (EffectOpDef n) -lookupEffectOpDef name = lookupEnv name >>= \case EffectOpBinding x -> return x -{-# INLINE lookupEffectOpDef #-} - -lookupHandlerDef :: EnvReader m => HandlerName n -> m n (HandlerDef n) -lookupHandlerDef name = lookupEnv name >>= \case HandlerBinding x -> return x -{-# INLINE lookupHandlerDef #-} - lookupSourceMapPure :: SourceMap n -> SourceName -> [SourceNameDef n] lookupSourceMapPure (SourceMap m) v = M.findWithDefault [] v m {-# INLINE lookupSourceMapPure #-} diff --git a/src/lib/Inference.hs b/src/lib/Inference.hs index 63ab284de..6f07dddf9 100644 --- a/src/lib/Inference.hs +++ b/src/lib/Inference.hs @@ -1975,7 +1975,6 @@ checkUEff eff = case eff of return $ RWSEffect rws (Var region') UExceptionEffect -> return ExceptionEffect UIOEffect -> return IOEffect - UUserEffect ~(SIInternalName _ name _ _) -> UserEffect <$> renameM name constrainVarTy :: EmitsInf o => CAtomVar o -> CType o -> InfererM i o () constrainVarTy v tyReq = do diff --git a/src/lib/PPrint.hs b/src/lib/PPrint.hs index 5979de66a..c68ea0390 100644 --- a/src/lib/PPrint.hs +++ b/src/lib/PPrint.hs @@ -170,11 +170,6 @@ instance IRRep r => PrettyPrec (Expr r n) where prettyPrec (PrimOp op) = prettyPrec op prettyPrec (ApplyMethod _ d i xs) = atPrec AppPrec $ "applyMethod" <+> p d <+> p i <+> p xs -instance Pretty (UserEffectOp n) where pretty = prettyFromPrettyPrec -instance PrettyPrec (UserEffectOp n) where - prettyPrec (Handle v args body) = atPrec LowestPrec $ p v <+> p args <+> prettyLam "\\_." body - prettyPrec _ = error "not implemented" - prettyPrecCase :: IRRep r => Doc ann -> Atom r n -> [Alt r n] -> EffectRow r n -> DocPrec ann prettyPrecCase name e alts effs = atPrec LowestPrec $ name <+> pApp e <+> "of" <> @@ -371,7 +366,6 @@ instance IRRep r => Pretty (Effect r n) where RWSEffect rws h -> p rws <+> p h ExceptionEffect -> "Except" IOEffect -> "IO" - UserEffect name -> p name InitEffect -> "Init" instance Pretty (UEffect n) where @@ -379,7 +373,6 @@ instance Pretty (UEffect n) where URWSEffect rws h -> p rws <+> p h UExceptionEffect -> "Except" UIOEffect -> "IO" - UUserEffect name -> p name instance PrettyPrec (Name s n) where prettyPrec = atPrec ArgPrec . pretty @@ -423,10 +416,6 @@ instance Pretty (Binding c n) where FunObjCodeBinding _ -> "" ModuleBinding _ -> "" PtrBinding _ _ -> "" - -- TODO(alex): do something actually useful here - EffectBinding _ -> "" - HandlerBinding _ -> "" - EffectOpBinding _ -> "" SpecializedDictBinding _ -> "" ImpNameBinding ty -> "Imp name of type: " <+> p ty @@ -917,7 +906,6 @@ instance IRRep r => PrettyPrec (PrimOp r n) where MemOp op -> prettyPrec op VectorOp op -> prettyPrec op DAMOp op -> prettyPrec op - UserEffectOp op -> prettyPrec op Hof (TypedHof _ hof) -> prettyPrec hof RefOp ref eff -> atPrec LowestPrec case eff of MAsk -> "ask" <+> pApp ref diff --git a/src/lib/QueryType.hs b/src/lib/QueryType.hs index a33c7b147..eaad95cba 100644 --- a/src/lib/QueryType.hs +++ b/src/lib/QueryType.hs @@ -364,11 +364,6 @@ getTypeRWSAction f = getLamExprType f >>= \case _ -> error "expected a ref" _ -> error "expected a pi type" -instantiateHandlerType :: EnvReader m => HandlerName n -> CType n -> [CAtom n] -> m n (CType n) -instantiateHandlerType hndName r args = do - HandlerDef _ rb bs _effs retTy _ _ <- lookupHandlerDef hndName - applySubst (rb @> (SubstVal (Type r)) <.> bs @@> (map SubstVal args)) retTy - getSuperclassDicts :: EnvReader m => CAtom n -> m n ([CAtom n]) getSuperclassDicts dict = do case getType dict of diff --git a/src/lib/QueryTypePure.hs b/src/lib/QueryTypePure.hs index 9f439bc46..3bda0d8fd 100644 --- a/src/lib/QueryTypePure.hs +++ b/src/lib/QueryTypePure.hs @@ -144,12 +144,6 @@ instance IRRep r => HasType r (DAMOp r) where Seq _ _ _ cinit _ -> getType cinit RememberDest _ d _ -> getType d -instance HasType CoreIR UserEffectOp where - getType = \case - Handle _ _ _ -> undefined - Perform _ _ -> undefined - Resume retTy _ -> retTy - instance IRRep r => HasType r (PrimOp r) where getType primOp = case primOp of BinOp op x _ -> TC $ BaseType $ typeBinOp op $ getTypeBaseType x @@ -159,7 +153,6 @@ instance IRRep r => HasType r (PrimOp r) where MiscOp op -> getType op VectorOp op -> getType op DAMOp op -> getType op - UserEffectOp op -> getType op RefOp ref m -> case getType ref of TC (RefType _ s) -> case m of MGet -> s @@ -310,7 +303,6 @@ instance IRRep r => HasEffects (PrimOp r) r where IndexRef _ _ -> Pure ProjRef _ _ -> Pure _ -> error "not a ref" - UserEffectOp _ -> undefined DAMOp op -> case op of Place _ _ -> OneEffect InitEffect Seq eff _ _ _ _ -> eff diff --git a/src/lib/Simplify.hs b/src/lib/Simplify.hs index 34da148e1..2e820ee71 100644 --- a/src/lib/Simplify.hs +++ b/src/lib/Simplify.hs @@ -728,7 +728,6 @@ buildSimplifiedBlock cont = do simplifyOp :: Emits o => NameHint -> PrimOp CoreIR i -> SimplifyM i o (CAtom o) simplifyOp hint op = case op of - UserEffectOp _ -> error "not implemented" Hof (TypedHof (EffTy _ ty) hof) -> do ty' <- substM ty simplifyHof hint ty' hof diff --git a/src/lib/SourceRename.hs b/src/lib/SourceRename.hs index 9ea111ae3..45fafc01e 100644 --- a/src/lib/SourceRename.hs +++ b/src/lib/SourceRename.hs @@ -224,7 +224,6 @@ instance SourceRenamableE UEffect where sourceRenameE (URWSEffect rws name) = URWSEffect rws <$> sourceRenameE name sourceRenameE UExceptionEffect = return UExceptionEffect sourceRenameE UIOEffect = return UIOEffect - sourceRenameE (UUserEffect name) = UUserEffect <$> sourceRenameE name instance SourceRenamableE a => SourceRenamableE (WithSrcE a) where sourceRenameE (WithSrcE pos e) = addSrcContext pos $ diff --git a/src/lib/Types/Core.hs b/src/lib/Types/Core.hs index 394919a82..9f22018b8 100644 --- a/src/lib/Types/Core.hs +++ b/src/lib/Types/Core.hs @@ -131,8 +131,6 @@ type ClassName = Name ClassNameC type TyConName = Name TyConNameC type DataConName = Name DataConNameC type EffectName = Name EffectNameC -type EffectOpName = Name EffectOpNameC -type HandlerName = Name HandlerNameC type InstanceName = Name InstanceNameC type MethodName = Name MethodNameC type ModuleName = Name ModuleNameC @@ -312,8 +310,7 @@ data PrimOp (r::IR) (n::S) where MiscOp :: MiscOp r n -> PrimOp r n Hof :: TypedHof r n -> PrimOp r n RefOp :: Atom r n -> RefOp r n -> PrimOp r n - DAMOp :: DAMOp SimpIR n -> PrimOp SimpIR n - UserEffectOp :: UserEffectOp n -> PrimOp CoreIR n + DAMOp :: DAMOp SimpIR n -> PrimOp SimpIR n deriving instance IRRep r => Show (PrimOp r n) deriving via WrapE (PrimOp r) n instance IRRep r => Generic (PrimOp r n) @@ -393,12 +390,6 @@ data RefOp r n = | ProjRef (Type r n) Projection deriving (Show, Generic) -data UserEffectOp n = - Handle (HandlerName n) [CAtom n] (CBlock n) - | Resume (CType n) (CAtom n) -- Resume from effect handler (type, arg) - | Perform (CAtom n) Int -- Call an effect operation (effect name) (op #) - deriving (Show, Generic) - -- === IR variants === type CAtom = Atom CoreIR @@ -656,9 +647,6 @@ data Binding (c::C) (n::S) where ClassBinding :: ClassDef n -> Binding ClassNameC n InstanceBinding :: InstanceDef n -> CorePiType n -> Binding InstanceNameC n MethodBinding :: ClassName n -> Int -> Binding MethodNameC n - EffectBinding :: EffectDef n -> Binding EffectNameC n - HandlerBinding :: HandlerDef n -> Binding HandlerNameC n - EffectOpBinding :: EffectOpDef n -> Binding EffectOpNameC n TopFunBinding :: TopFun n -> Binding TopFunNameC n FunObjCodeBinding :: CFunction n -> Binding FunObjCodeNameC n ModuleBinding :: Module n -> Binding ModuleNameC n @@ -707,33 +695,6 @@ instance RenameE EffectDef deriving instance Show (EffectDef n) deriving via WrapE EffectDef n instance Generic (EffectDef n) -data HandlerDef (n::S) where - HandlerDef :: EffectName n - -> CBinder n r -- body type arg - -> RolePiBinders r l - -> EffectRow CoreIR l - -> CType l -- return type - -> [Block CoreIR l] -- effect operations - -> Block CoreIR l -- return body - -> HandlerDef n - -instance GenericE HandlerDef where - type RepE HandlerDef = - EffectName `PairE` Abs (CBinder `PairB` RolePiBinders) - (EffectRow CoreIR `PairE` CType `PairE` ListE (Block CoreIR) `PairE` Block CoreIR) - fromE (HandlerDef name bodyTyArg bs effs ty ops ret) = - name `PairE` Abs (bodyTyArg `PairB` bs) (effs `PairE` ty `PairE` ListE ops `PairE` ret) - toE (name `PairE` Abs (bodyTyArg `PairB` bs) (effs `PairE` ty `PairE` ListE ops `PairE` ret)) = - HandlerDef name bodyTyArg bs effs ty ops ret - -instance SinkableE HandlerDef -instance HoistableE HandlerDef -instance AlphaEqE HandlerDef -instance AlphaHashableE HandlerDef -instance RenameE HandlerDef -deriving instance Show (HandlerDef n) -deriving via WrapE HandlerDef n instance Generic (HandlerDef n) - data EffectOpType (n::S) where EffectOpType :: UResumePolicy -> CType n -> EffectOpType n @@ -862,7 +823,6 @@ data Effect (r::IR) (n::S) = RWSEffect RWS (Atom r n) | ExceptionEffect | IOEffect - | UserEffect (Name EffectNameC n) | InitEffect -- Internal effect modeling writing to a destination. deriving (Generic, Show) @@ -1319,29 +1279,6 @@ instance IRRep r => RenameE (BaseMonoid r) instance IRRep r => AlphaEqE (BaseMonoid r) instance IRRep r => AlphaHashableE (BaseMonoid r) -instance GenericE UserEffectOp where - type RepE UserEffectOp = EitherE3 - {- Handle -} (HandlerName `PairE` ListE CAtom `PairE` CBlock) - {- Resume -} (CType `PairE` CAtom) - {- Perform -} (CAtom `PairE` LiftE Int) - fromE = \case - Handle name args body -> Case0 $ name `PairE` ListE args `PairE` body - Resume x y -> Case1 $ x `PairE` y - Perform x i -> Case2 $ x `PairE` LiftE i - {-# INLINE fromE #-} - toE = \case - Case0 (name `PairE` ListE args `PairE` body) -> Handle name args body - Case1 (x `PairE` y) -> Resume x y - Case2 (x `PairE` LiftE i) -> Perform x i - _ -> error "impossible" - {-# INLINE toE #-} - -instance SinkableE UserEffectOp -instance HoistableE UserEffectOp -instance RenameE UserEffectOp -instance AlphaEqE UserEffectOp -instance AlphaHashableE UserEffectOp - instance IRRep r => GenericE (DAMOp r) where type RepE (DAMOp r) = EitherE5 {- Seq -} (EffectRow r `PairE` LiftE Direction `PairE` IxType r `PairE` Atom r `PairE` LamExpr r) @@ -1685,11 +1622,10 @@ instance IRRep r => GenericE (PrimOp r) where {- MemOp -} (MemOp r) {- VectorOp -} (VectorOp r) {- MiscOp -} (MiscOp r) - ) (EitherE4 + ) (EitherE3 {- Hof -} (TypedHof r) {- RefOp -} (Atom r `PairE` RefOp r) {- DAMOp -} (WhenSimp r (DAMOp SimpIR)) - {- UserEffectOp -} (WhenCore r UserEffectOp) ) fromE = \case UnOp op x -> Case0 $ Case0 $ LiftE op `PairE` x @@ -1700,7 +1636,6 @@ instance IRRep r => GenericE (PrimOp r) where Hof op -> Case1 $ Case0 op RefOp r op -> Case1 $ Case1 $ r `PairE` op DAMOp op -> Case1 $ Case2 $ WhenIRE op - UserEffectOp op -> Case1 $ Case3 $ WhenIRE op {-# INLINE fromE #-} toE = \case @@ -1715,7 +1650,6 @@ instance IRRep r => GenericE (PrimOp r) where Case0 op -> Hof op Case1 (r `PairE` op) -> RefOp r op Case2 (WhenIRE op) -> DAMOp op - Case3 (WhenIRE op) -> UserEffectOp op _ -> error "impossible" _ -> error "impossible" {-# INLINE toE #-} @@ -2358,14 +2292,11 @@ instance GenericE (Binding c) where (WhenC ClassNameC c (ClassDef)) (WhenC InstanceNameC c (InstanceDef `PairE` CorePiType)) (WhenC MethodNameC c (ClassName `PairE` LiftE Int))) - (EitherE7 + (EitherE4 (WhenC TopFunNameC c (TopFun)) (WhenC FunObjCodeNameC c (CFunction)) (WhenC ModuleNameC c (Module)) - (WhenC PtrNameC c (LiftE (PtrType, PtrLitVal))) - (WhenC EffectNameC c (EffectDef)) - (WhenC HandlerNameC c (HandlerDef)) - (WhenC EffectOpNameC c (EffectOpDef))) + (WhenC PtrNameC c (LiftE (PtrType, PtrLitVal)))) (EitherE2 (WhenC SpecializedDictNameC c (SpecializedDictDef)) (WhenC ImpNameC c (LiftE BaseType))) @@ -2381,9 +2312,6 @@ instance GenericE (Binding c) where FunObjCodeBinding cFun -> Case1 $ Case1 $ WhenC $ cFun ModuleBinding m -> Case1 $ Case2 $ WhenC $ m PtrBinding ty p -> Case1 $ Case3 $ WhenC $ LiftE (ty,p) - EffectBinding effDef -> Case1 $ Case4 $ WhenC $ effDef - HandlerBinding hDef -> Case1 $ Case5 $ WhenC $ hDef - EffectOpBinding opDef -> Case1 $ Case6 $ WhenC $ opDef SpecializedDictBinding def -> Case2 $ Case0 $ WhenC $ def ImpNameBinding ty -> Case2 $ Case1 $ WhenC $ LiftE ty {-# INLINE fromE #-} @@ -2399,9 +2327,6 @@ instance GenericE (Binding c) where Case1 (Case1 (WhenC (f))) -> FunObjCodeBinding f Case1 (Case2 (WhenC (m))) -> ModuleBinding m Case1 (Case3 (WhenC ((LiftE (ty,p))))) -> PtrBinding ty p - Case1 (Case4 (WhenC (effDef))) -> EffectBinding effDef - Case1 (Case5 (WhenC (hDef))) -> HandlerBinding hDef - Case1 (Case6 (WhenC (opDef))) -> EffectOpBinding opDef Case2 (Case0 (WhenC (def))) -> SpecializedDictBinding def Case2 (Case1 (WhenC ((LiftE ty)))) -> ImpNameBinding ty _ -> error "impossible" @@ -2458,23 +2383,20 @@ instance IRRep r => BindsNames (Decl r) instance IRRep r => GenericE (Effect r) where type RepE (Effect r) = - EitherE4 (PairE (LiftE RWS) (Atom r)) + EitherE3 (PairE (LiftE RWS) (Atom r)) (LiftE (Either () ())) - (Name EffectNameC) UnitE fromE = \case RWSEffect rws h -> Case0 (PairE (LiftE rws) h) ExceptionEffect -> Case1 (LiftE (Left ())) IOEffect -> Case1 (LiftE (Right ())) - UserEffect name -> Case2 name - InitEffect -> Case3 UnitE + InitEffect -> Case2 UnitE {-# INLINE fromE #-} toE = \case Case0 (PairE (LiftE rws) h) -> RWSEffect rws h Case1 (LiftE (Left ())) -> ExceptionEffect Case1 (LiftE (Right ())) -> IOEffect - Case2 name -> UserEffect name - Case3 UnitE -> InitEffect + Case2 UnitE -> InitEffect _ -> error "unreachable" {-# INLINE toE #-} @@ -2845,7 +2767,6 @@ instance Store (DictExpr n) instance Store (EffectDef n) instance Store (EffectOpDef n) instance Store (RolePiBinder n l) -instance Store (HandlerDef n) instance Store (EffectOpType n) instance Store (EffectOpIdx) instance Store (SynthCandidates n) @@ -2869,7 +2790,6 @@ instance IRRep r => Store (RefOp r n) instance IRRep r => Store (BaseMonoid r n) instance IRRep r => Store (DAMOp r n) instance IRRep r => Store (IxDict r n) -instance Store (UserEffectOp n) instance Store (NewtypeCon n) instance Store (NewtypeTyCon n) instance Store (DotMethods n) diff --git a/src/lib/Types/Source.hs b/src/lib/Types/Source.hs index 78bac3ddc..a27b02268 100644 --- a/src/lib/Types/Source.hs +++ b/src/lib/Types/Source.hs @@ -201,7 +201,6 @@ data UEffect (n::S) = URWSEffect RWS (SourceOrInternalName (AtomNameC CoreIR) n) | UExceptionEffect | UIOEffect - | UUserEffect (SourceOrInternalName EffectNameC n) deriving (Generic) data UEffectRow (n::S) = From 75eacbf9387824f7d79a72834da2d4087b39919f Mon Sep 17 00:00:00 2001 From: Dougal Date: Mon, 26 Jun 2023 21:38:22 -0400 Subject: [PATCH 3/4] Unbundle binders from their role/expl attributes. Fancy B-kinded things are a pain and they're about to get worse when we add decls to binders. An earlier attempt at adding decls without doing this forced me to create lots of complicated type classes to handle all the `WithExpl` and `RolePiBinder` variants. --- src/lib/AbstractSyntax.hs | 49 ++-- src/lib/Builder.hs | 15 +- src/lib/CheapReduction.hs | 14 +- src/lib/CheckType.hs | 31 +- src/lib/Core.hs | 12 +- src/lib/Export.hs | 10 +- src/lib/Generalize.hs | 21 +- src/lib/Inference.hs | 561 +++++++++++++++++++----------------- src/lib/JAX/ToSimp.hs | 4 +- src/lib/Name.hs | 45 ++- src/lib/PPrint.hs | 42 +-- src/lib/QueryType.hs | 91 +++--- src/lib/QueryTypePure.hs | 6 +- src/lib/RuntimePrint.hs | 5 +- src/lib/Simplify.hs | 10 +- src/lib/SourceRename.hs | 27 +- src/lib/Subst.hs | 13 +- src/lib/TopLevel.hs | 6 +- src/lib/Types/Core.hs | 82 ++---- src/lib/Types/Primitives.hs | 57 ---- src/lib/Types/Source.hs | 17 +- 21 files changed, 543 insertions(+), 575 deletions(-) diff --git a/src/lib/AbstractSyntax.hs b/src/lib/AbstractSyntax.hs index 5a5805b0c..e2143fd9a 100644 --- a/src/lib/AbstractSyntax.hs +++ b/src/lib/AbstractSyntax.hs @@ -110,23 +110,23 @@ topDecl = dropSrc topDecl' where topDecl' (CSDecl ann d) = ULocalDecl <$> decl ann (WithSrc emptySrcPosCtx d) topDecl' (CData name tyConParams givens constructors) = do tyConParams' <- aExplicitParams tyConParams - givens' <- toNest <$> fromMaybeM givens [] aGivens + givens' <- aOptGivens givens constructors' <- forM constructors \(v, ps) -> do ps' <- toNest <$> mapM tyOptBinder ps return (v, ps') return $ UDataDefDecl - (UDataDef name (givens' >>> tyConParams') $ + (UDataDef name (catUOptAnnExplBinders givens' tyConParams') $ map (\(name', cons) -> (name', UDataDefTrail cons)) constructors') (fromString name) (toNest $ map (fromString . fst) constructors') topDecl' (CStruct name params givens fields defs) = do params' <- aExplicitParams params - givens' <- toNest <$> fromMaybeM givens [] aGivens + givens' <- aOptGivens givens fields' <- forM fields \(v, ty) -> (v,) <$> expr ty methods <- forM defs \(ann, d) -> do (methodName, lam) <- aDef d return (ann, methodName, Abs (UBindSource emptySrcPosCtx "self") lam) - return $ UStructDecl (fromString name) (UStructDef name (givens' >>> params') fields' methods) + return $ UStructDecl (fromString name) (UStructDef name (catUOptAnnExplBinders givens' params') fields' methods) topDecl' (CInterface name params methods) = do params' <- aExplicitParams params (methodNames, methodTys) <- unzip <$> forM methods \(methodName, ty) -> do @@ -153,7 +153,7 @@ aInstanceDef :: CInstanceDef -> SyntaxM (UTopDecl VoidS VoidS) aInstanceDef (CInstanceDef clName args givens methods instNameAndParams) = do let clName' = fromString clName args' <- mapM expr args - givens' <- toNest <$> fromMaybeM givens [] aGivens + givens' <- aOptGivens givens methods' <- catMaybes <$> mapM aMethod methods case instNameAndParams of Nothing -> return $ UInstance clName' givens' args' methods' NothingB ImplicitApp @@ -162,7 +162,7 @@ aInstanceDef (CInstanceDef clName args givens methods instNameAndParams) = do case optParams of Just params -> do params' <- aExplicitParams params - return $ UInstance clName' (givens' >>> params') args' methods' instName' ExplicitApp + return $ UInstance clName' (catUOptAnnExplBinders givens' params') args' methods' instName' ExplicitApp Nothing -> return $ UInstance clName' givens' args' methods' instName' ImplicitApp aDef :: CDef -> SyntaxM (SourceName, ULamExpr VoidS) @@ -173,19 +173,27 @@ aDef (CDef name params optRhs optGivens body) = do effs <- fromMaybeM optEffs UPure aEffects resultTy' <- expr resultTy return (expl, Just effs, Just resultTy') - implicitParams <- toNest <$> fromMaybeM optGivens [] aGivens - let allParams = implicitParams >>> explicitParams + implicitParams <- aOptGivens optGivens + let allParams = catUOptAnnExplBinders implicitParams explicitParams body' <- block body return (name, ULamExpr allParams expl effs resultTy body') +catUOptAnnExplBinders :: UOptAnnExplBinders n l -> UOptAnnExplBinders l l' -> UOptAnnExplBinders n l' +catUOptAnnExplBinders (expls, bs) (expls', bs') = (expls <> expls', bs >>> bs') + stripParens :: Group -> Group stripParens (WithSrc _ (CParens [g])) = stripParens g stripParens g = g -aExplicitParams :: ExplicitParams -> SyntaxM (Nest (WithExpl UOptAnnBinder) VoidS VoidS) +aExplicitParams :: ExplicitParams -> SyntaxM ([Explicitness], Nest UOptAnnBinder VoidS VoidS) aExplicitParams gs = generalBinders DataParam Explicit gs -aGivens :: GivenClause -> SyntaxM [WithExpl UOptAnnBinder VoidS VoidS] +aOptGivens :: Maybe GivenClause -> SyntaxM (UOptAnnExplBinders VoidS VoidS) +aOptGivens optGivens = do + (expls, implicitParams) <- unzip <$> fromMaybeM optGivens [] aGivens + return (expls, toNest implicitParams) + +aGivens :: GivenClause -> SyntaxM [(Explicitness, UOptAnnBinder VoidS VoidS)] aGivens (implicits, optConstraints) = do implicits' <- mapM (generalBinder DataParam (Inferred Nothing Unify)) implicits constraints <- fromMaybeM optConstraints [] \gs -> do @@ -194,23 +202,24 @@ aGivens (implicits, optConstraints) = do generalBinders :: ParamStyle -> Explicitness -> [Group] - -> SyntaxM (Nest (WithExpl UOptAnnBinder) VoidS VoidS) -generalBinders paramStyle expl params = toNest . concat <$> - forM params \case + -> SyntaxM ([Explicitness], Nest UOptAnnBinder VoidS VoidS) +generalBinders paramStyle expl params = do + (expls, bs) <- unzip . concat <$> forM params \case WithSrc _ (CGivens gs) -> aGivens gs p -> (:[]) <$> generalBinder paramStyle expl p + return (expls, toNest bs) generalBinder :: ParamStyle -> Explicitness -> Group - -> SyntaxM (WithExpl UOptAnnBinder VoidS VoidS) + -> SyntaxM (Explicitness, UOptAnnBinder VoidS VoidS) generalBinder paramStyle expl g = case expl of - Inferred _ (Synth _) -> WithExpl expl <$> tyOptBinder g + Inferred _ (Synth _) -> (expl,) <$> tyOptBinder g Inferred _ Unify -> do b <- binderOptTy g expl' <- return case b of UAnnBinder (UBindSource _ s) _ _ -> Inferred (Just s) Unify _ -> expl - return $ WithExpl expl' b - Explicit -> WithExpl expl <$> case paramStyle of + return (expl', b) + Explicit -> (expl,) <$> case paramStyle of TypeParam -> tyOptBinder g DataParam -> binderOptTy g @@ -347,7 +356,7 @@ aMethod (WithSrc src d) = Just . WithSrcE src <$> addSrcContext src case d of (name, lam) <- aDef def return $ UMethodDef (fromString name) lam CLet (WithSrc _ (CIdentifier name)) rhs -> do - rhs' <- ULamExpr Empty ImplicitApp Nothing Nothing <$> block rhs + rhs' <- ULamExpr ([], Empty) ImplicitApp Nothing Nothing <$> block rhs return $ UMethodDef (fromString name) rhs' _ -> throw SyntaxErr "Unexpected method definition. Expected `def` or `x = ...`." @@ -368,10 +377,10 @@ blockDecls [WithSrc src d] = addSrcContext src case d of CExpr g -> (Empty,) <$> expr g _ -> throw SyntaxErr "Block must end in expression" blockDecls (WithSrc pos (CBind b rhs):ds) = do - WithExpl _ b' <- generalBinder DataParam Explicit b + (_, b') <- generalBinder DataParam Explicit b rhs' <- asExpr <$> block rhs body <- block $ IndentedBlock ds - let lam = ULam $ ULamExpr (UnaryNest (WithExpl Explicit b')) ExplicitApp Nothing Nothing body + let lam = ULam $ ULamExpr ([Explicit], UnaryNest b') ExplicitApp Nothing Nothing body return (Empty, WithSrcE pos $ extendAppRight rhs' (ns lam)) blockDecls (d:ds) = do d' <- decl PlainLet d diff --git a/src/lib/Builder.hs b/src/lib/Builder.hs index b3012cc90..8e7563511 100644 --- a/src/lib/Builder.hs +++ b/src/lib/Builder.hs @@ -602,22 +602,13 @@ buildBlock -> m n (Block r n) buildBlock = buildScoped -coreLamExpr :: EnvReader m => AppExplicitness - -> Abs (Nest (WithExpl CBinder)) (PairE (EffectRow CoreIR) CBlock) n - -> m n (CoreLamExpr n) -coreLamExpr appExpl ab = liftEnvReaderM do - refreshAbs ab \bs' (PairE effs' body') -> do - EffTy _ resultTy <- blockEffTy body' - let bs'' = fmapNest withoutExpl bs' - return $ CoreLamExpr (CorePiType appExpl bs' (EffTy effs' resultTy)) (LamExpr bs'' body') - buildCoreLam :: ScopableBuilder CoreIR m => CorePiType n -> (forall l. (Emits l, DExt n l) => [CAtomVar l] -> m l (CAtom l)) -> m n (CoreLamExpr n) -buildCoreLam piTy@(CorePiType _ bs _) cont = do - lam <- buildLamExpr (EmptyAbs $ fmapNest withoutExpl bs) cont +buildCoreLam piTy@(CorePiType _ _ bs _) cont = do + lam <- buildLamExpr (EmptyAbs bs) cont return $ CoreLamExpr piTy lam buildAbs @@ -1083,7 +1074,7 @@ projectStructRef i x = do getStructProjections :: EnvReader m => Int -> CType n -> m n [Projection] getStructProjections i (NewtypeTyCon (UserADTType _ tyConName _)) = do - TyConDef _ _ ~(StructFields fields) <- lookupTyCon tyConName + TyConDef _ _ _ ~(StructFields fields) <- lookupTyCon tyConName return case fields of [_] | i == 0 -> [UnwrapNewtype] | otherwise -> error "bad index" diff --git a/src/lib/CheapReduction.hs b/src/lib/CheapReduction.hs index bce6dcbed..9a35c2ed5 100644 --- a/src/lib/CheapReduction.hs +++ b/src/lib/CheapReduction.hs @@ -240,7 +240,7 @@ cheapReduceDictExpr resultTy d = case d of cheapReduceE child >>= \case DictCon _ (InstanceDict instanceName args) -> dropSubst do args' <- mapM cheapReduceE args - InstanceDef _ bs _ body <- lookupInstanceDef instanceName + InstanceDef _ _ bs _ body <- lookupInstanceDef instanceName let InstanceBody superclasses _ = body applySubst (bs@@>(SubstVal <$> args')) (superclasses !! superclassIx) child' -> return $ DictCon resultTy $ SuperclassProj child' superclassIx @@ -285,7 +285,7 @@ instance IRRep r => CheaplyReducibleE r (Expr r) (Atom r) where cheapReduceE dict >>= \case DictCon _ (InstanceDict instanceName args) -> dropSubst do args' <- mapM cheapReduceE args - InstanceDef _ bs _ (InstanceBody _ methods) <- lookupInstanceDef instanceName + InstanceDef _ _ bs _ (InstanceBody _ methods) <- lookupInstanceDef instanceName let method = methods !! i extendSubst (bs@@>(SubstVal <$> args')) do method' <- cheapReduceE method @@ -466,7 +466,7 @@ wrapNewtypesData [] x = x wrapNewtypesData (c:cs) x = NewtypeCon c $ wrapNewtypesData cs x instantiateTyConDef :: EnvReader m => TyConDef n -> TyConParams n -> m n (DataConDefs n) -instantiateTyConDef (TyConDef _ bs conDefs) (TyConParams _ xs) = do +instantiateTyConDef (TyConDef _ _ bs conDefs) (TyConParams _ xs) = do applySubst (bs @@> (SubstVal <$> xs)) conDefs {-# INLINE instantiateTyConDef #-} @@ -487,7 +487,7 @@ dataDefRep (StructFields fields) = case map snd fields of makeStructRepVal :: (Fallible1 m, EnvReader m) => TyConName n -> [CAtom n] -> m n (CAtom n) makeStructRepVal tyConName args = do - TyConDef _ _ (StructFields fields) <- lookupTyCon tyConName + TyConDef _ _ _ (StructFields fields) <- lookupTyCon tyConName case fields of [_] -> case args of [arg] -> return arg @@ -725,11 +725,9 @@ instance VisitGeneric CoreLamExpr CoreIR where visitGeneric (CoreLamExpr t lam) = CoreLamExpr <$> visitGeneric t <*> visitGeneric lam instance VisitGeneric CorePiType CoreIR where - visitGeneric (CorePiType app bsExpl effty) = do - let (expls, bs) = unzipExpls bsExpl + visitGeneric (CorePiType app expl bs effty) = do PiType bs' effty' <- visitGeneric $ PiType bs effty - let bsExpl' = zipExpls expls bs' - return $ CorePiType app bsExpl' effty' + return $ CorePiType app expl bs' effty' instance IRRep r => VisitGeneric (TabPiType r) r where visitGeneric (TabPiType d b eltTy) = do diff --git a/src/lib/CheckType.hs b/src/lib/CheckType.hs index 9e088be81..671067d04 100644 --- a/src/lib/CheckType.hs +++ b/src/lib/CheckType.hs @@ -256,7 +256,7 @@ instance IRRep r => HasType r (Type r) where TC tyCon -> typeCheckPrimTC tyCon DepPairTy ty -> getTypeE ty DictTy (DictType _ className params) -> do - ClassDef _ _ _ paramBs _ _ <- renameM className >>= lookupClassDef + ClassDef _ _ _ _ paramBs _ _ <- renameM className >>= lookupClassDef params' <- mapM renameM params checkArgTys paramBs params' return TyKind @@ -293,9 +293,6 @@ instance (ToBinding ann c, Color c, CheckableE r ann) => CheckableB r (BinderP c extendRenamer (b@>binderName b') $ cont b' -instance (BindsNames b, CheckableB r b) => CheckableB r (WithExpl b) where - checkB (WithExpl expl b) cont = checkB b \b' -> cont (WithExpl expl b') - typeCheckExpr :: (Typer m r, IRRep r) => EffectRow r o -> Expr r i -> m i o (Type r o) typeCheckExpr effs expr = addContext ("Checking expr:\n" ++ pprint expr) case expr of App (EffTy _ reqTy) f xs -> do @@ -318,7 +315,7 @@ typeCheckExpr effs expr = addContext ("Checking expr:\n" ++ pprint expr) case ex return resultTy' ApplyMethod (EffTy _ reqTy) dict i args -> do DictTy (DictType _ className params) <- getTypeE dict - ClassDef _ _ _ paramBs classBs methodTys <- lookupClassDef className + ClassDef _ _ _ _ paramBs classBs methodTys <- lookupClassDef className let methodTy = methodTys !! i superclassDicts <- getSuperclassDicts =<< renameM dict let subst = ( paramBs @@> map SubstVal params @@ -342,8 +339,8 @@ dictExprType :: Typer m CoreIR => DictExpr i -> m i o (CType o) dictExprType e = case e of InstanceDict instanceName args -> do instanceName' <- renameM instanceName - InstanceDef className bs params _ <- lookupInstanceDef instanceName' - ClassDef sourceName _ _ _ _ _ <- lookupClassDef className + InstanceDef className _ bs params _ <- lookupInstanceDef instanceName' + ClassDef sourceName _ _ _ _ _ _ <- lookupClassDef className args' <- mapM renameM args checkArgTys bs args' ListE params' <- applySubst (bs@@>(SubstVal<$>args')) (ListE params) @@ -353,7 +350,7 @@ dictExprType e = case e of checkApp Pure givenTy (toList args) SuperclassProj d i -> do DictTy (DictType _ className params) <- getTypeE d - ClassDef _ _ _ bs superclasses _ <- lookupClassDef className + ClassDef _ _ _ _ bs superclasses _ <- lookupClassDef className let scType = getSuperclassType REmpty superclasses i checkedApplyNaryAbs (Abs bs scType) params IxFin n -> do @@ -370,7 +367,7 @@ instance IRRep r => HasType r (DepPairType r) where return TyKind instance HasType CoreIR CorePiType where - getTypeE (CorePiType _ bs (EffTy eff resultTy)) = do + getTypeE (CorePiType _ _ bs (EffTy eff resultTy)) = do checkB bs \_ -> do void $ checkE eff resultTy|:TyKind @@ -407,14 +404,14 @@ checkAgainstGiven givenTy computedTy = do return givenTy' checkCoreLam :: Typer m CoreIR => CorePiType o -> LamExpr CoreIR i -> m i o () -checkCoreLam (CorePiType _ Empty (EffTy effs resultTy)) (LamExpr Empty body) = do +checkCoreLam (CorePiType _ _ Empty (EffTy effs resultTy)) (LamExpr Empty body) = do resultTy' <- checkBlockWithEffs effs body checkTypesEq resultTy resultTy' -checkCoreLam (CorePiType expl (Nest piB piBs) effTy) (LamExpr (Nest lamB lamBs) body) = do +checkCoreLam (CorePiType expl (_:expls) (Nest piB piBs) effTy) (LamExpr (Nest lamB lamBs) body) = do argTy <- renameM $ binderType lamB checkTypesEq (binderType piB) argTy withFreshBinder (getNameHint lamB) argTy \b -> do - piTy <- applyRename (piB@>binderName b) (CorePiType expl piBs effTy) + piTy <- applyRename (piB@>binderName b) (CorePiType expl expls piBs effTy) extendRenamer (lamB@>binderName b) do checkCoreLam piTy (LamExpr lamBs body) checkCoreLam _ _ = throw TypeErr "zip error" @@ -446,7 +443,7 @@ typeCheckNewtypeCon con x = case con of FinCon n -> n|:NatTy >> x|:NatTy >> renameM (Fin n) UserADTData _ d params -> do d' <- renameM d - def@(TyConDef sn _ _) <- lookupTyCon d' + def@(TyConDef sn _ _ _) <- lookupTyCon d' params' <- renameM params void $ checkedInstantiateTyConDef def params' return $ UserADTType sn d' params' @@ -773,7 +770,7 @@ checkAlt resultTyReq bTyReq effs (Abs b body) = do checkApp :: (Typer m r, IRRep r) => EffectRow r o -> Type r o -> [Atom r i] -> m i o (Type r o) checkApp allowedEffs fTy xs = case fTy of - Pi (CorePiType _ bs effTy) -> do + Pi (CorePiType _ _ bs effTy) -> do xs' <- mapM renameM xs checkArgTys bs xs' let subst = bs @@> fmap SubstVal xs' @@ -929,7 +926,7 @@ checkUnOp op x = do checkedInstantiateTyConDef :: (EnvReader m, Fallible1 m) => TyConDef n -> TyConParams n -> m n (DataConDefs n) -checkedInstantiateTyConDef (TyConDef _ bs cons) (TyConParams _ xs) = do +checkedInstantiateTyConDef (TyConDef _ _ bs cons) (TyConParams _ xs) = do checkedApplyNaryAbs (Abs bs cons) xs checkedApplyNaryAbs @@ -995,7 +992,7 @@ asFFIFunType ty = return do return (impTy, piTy) checkFFIFunTypeM :: Fallible m => CorePiType n -> m IFunType -checkFFIFunTypeM (CorePiType appExpl (Nest b bs) effTy) = do +checkFFIFunTypeM (CorePiType appExpl (_:expls) (Nest b bs) effTy) = do argTy <- checkScalar $ binderType b case bs of Empty -> do @@ -1006,7 +1003,7 @@ checkFFIFunTypeM (CorePiType appExpl (Nest b bs) effTy) = do _ -> FFIMultiResultCC return $ IFunType cc [argTy] resultTys Nest b' rest -> do - let naryPiRest = CorePiType appExpl (Nest b' rest) effTy + let naryPiRest = CorePiType appExpl expls (Nest b' rest) effTy IFunType cc argTys resultTys <- checkFFIFunTypeM naryPiRest return $ IFunType cc (argTy:argTys) resultTys checkFFIFunTypeM _ = error "expected at least one argument" diff --git a/src/lib/Core.hs b/src/lib/Core.hs index 88725c1ca..f6fb57452 100644 --- a/src/lib/Core.hs +++ b/src/lib/Core.hs @@ -218,17 +218,13 @@ instance BindsEnv EnvFrag where toEnvFrag frag = frag {-# INLINE toEnvFrag #-} -instance BindsEnv b => BindsEnv (WithExpl b) where - toEnvFrag (WithExpl _ b) = toEnvFrag b - {-# INLINE toEnvFrag #-} - -instance BindsEnv RolePiBinder where - toEnvFrag (RolePiBinder _ b) = toEnvFrag b - {-# INLINE toEnvFrag #-} - instance BindsEnv (RecSubstFrag Binding) where toEnvFrag frag = EnvFrag frag +instance BindsEnv b => BindsEnv (WithAttrB a b) where + toEnvFrag (WithAttrB _ b) = toEnvFrag b + {-# INLINE toEnvFrag #-} + instance (BindsEnv b1, BindsEnv b2) => (BindsEnv (PairB b1 b2)) where toEnvFrag (PairB b1 b2) = do diff --git a/src/lib/Export.hs b/src/lib/Export.hs index eba7a718f..42dcc7ba1 100644 --- a/src/lib/Export.hs +++ b/src/lib/Export.hs @@ -100,11 +100,11 @@ liftExportSigM cont = do corePiToExportSig :: CallingConvention -> CorePiType i -> ExportSigM CoreIR i o (ExportedSignature o) -corePiToExportSig cc (CorePiType _ tbs (EffTy effs resultTy)) = do +corePiToExportSig cc (CorePiType _ expls tbs (EffTy effs resultTy)) = do case effs of Pure -> return () _ -> throw TypeErr "Only pure functions can be exported" - goArgs cc Empty [] tbs resultTy + goArgs cc Empty [] (zipAttrs expls tbs) resultTy simpPiToExportSig :: CallingConvention -> PiType SimpIR i -> ExportSigM SimpIR i o (ExportedSignature o) @@ -112,14 +112,14 @@ simpPiToExportSig cc (PiType bs (EffTy effs resultTy)) = do case effs of Pure -> return () _ -> throw TypeErr "Only pure functions can be exported" - bs' <- return $ fmapNest (\b -> WithExpl Explicit b) bs + bs' <- return $ fmapNest (\b -> WithAttrB Explicit b) bs goArgs cc Empty [] bs' resultTy goArgs :: (IRRep r) => CallingConvention -> Nest ExportArg o o' -> [CAtomName o'] - -> Nest (WithExpl (Binder r)) i i' + -> Nest (WithAttrB Explicitness (Binder r)) i i' -> Type r i' -> ExportSigM r i o' (ExportedSignature o) goArgs cc argSig argVs piBs piRes = case piBs of @@ -128,7 +128,7 @@ goArgs cc argSig argVs piBs piRes = case piBs of StandardCC -> (fromListE $ sink $ ListE argVs) ++ nestToList (sink . binderName) resSig XLACC -> [] _ -> error $ "calling convention not supported: " ++ show cc - Nest (WithExpl expl (b:>ty)) bs -> do + Nest (WithAttrB expl (b:>ty)) bs -> do ety <- toExportType ty withFreshBinder (getNameHint b) ety \(v:>_) -> extendSubst (b @> Rename (binderName v)) $ do diff --git a/src/lib/Generalize.hs b/src/lib/Generalize.hs index 78037c742..dacb584fb 100644 --- a/src/lib/Generalize.hs +++ b/src/lib/Generalize.hs @@ -20,6 +20,9 @@ import Subst import MTL1 import Types.Primitives +type RolePiBinder = WithAttrB RoleExpl CBinder +type RolePiBinders = Nest RolePiBinder + generalizeIxDict :: EnvReader m => Atom CoreIR n -> m n (Generalized CoreIR CAtom n) generalizeIxDict dict = liftGeneralizerM do dict' <- sinkM dict @@ -31,12 +34,12 @@ generalizeIxDict dict = liftGeneralizerM do generalizeArgs ::EnvReader m => CorePiType n -> [Atom CoreIR n] -> m n (Generalized CoreIR (ListE CAtom) n) generalizeArgs fTy argsTop = liftGeneralizerM $ runSubstReaderT idSubst do - PairE (CorePiType _ bs _) (ListE argsTop') <- sinkM $ PairE fTy (ListE argsTop) - ListE <$> go bs argsTop' + PairE (CorePiType _ expls bs _) (ListE argsTop') <- sinkM $ PairE fTy (ListE argsTop) + ListE <$> go (zipAttrs expls bs) argsTop' where - go :: Nest (WithExpl CBinder) i i' -> [Atom CoreIR n] + go :: Nest (WithAttrB Explicitness CBinder) i i' -> [Atom CoreIR n] -> SubstReaderT AtomSubstVal GeneralizerM i n [Atom CoreIR n] - go (Nest (WithExpl expl b) bs) (arg:args) = do + go (Nest (WithAttrB expl b) bs) (arg:args) = do ty' <- substM $ binderType b arg' <- case (ty', expl) of (TyKind, _) -> liftSubstReaderT case arg of @@ -172,7 +175,7 @@ traverseRoleBinders f allBinders allParams = go :: forall i i'. RolePiBinders i i' -> [Atom CoreIR n] -> SubstReaderT AtomSubstVal m i n [Atom CoreIR n] go Empty [] = return [] - go (Nest (RolePiBinder role b) bs) (param:params) = do + go (Nest (WithAttrB (role, _) b) bs) (param:params) = do ty' <- substM $ binderType b Distinct <- getDistinct param' <- liftSubstReaderT $ f role ty' param @@ -183,14 +186,14 @@ traverseRoleBinders f allBinders allParams = getDataDefRoleBinders :: EnvReader m => TyConName n -> m n (Abs RolePiBinders UnitE n) getDataDefRoleBinders def = do - TyConDef _ bs _ <- lookupTyCon def - return $ Abs bs UnitE + TyConDef _ attrs bs _ <- lookupTyCon def + return $ Abs (zipAttrs attrs bs) UnitE {-# INLINE getDataDefRoleBinders #-} getClassRoleBinders :: EnvReader m => ClassName n -> m n (Abs RolePiBinders UnitE n) getClassRoleBinders def = do - ClassDef _ _ _ bs _ _ <- lookupClassDef def - return $ Abs bs UnitE + ClassDef _ _ _ roleExpls bs _ _ <- lookupClassDef def + return $ Abs (zipAttrs roleExpls bs) UnitE {-# INLINE getClassRoleBinders #-} -- === instances === diff --git a/src/lib/Inference.hs b/src/lib/Inference.hs index 6f07dddf9..907f375a9 100644 --- a/src/lib/Inference.hs +++ b/src/lib/Inference.hs @@ -74,7 +74,7 @@ inferTopUDecl (UStructDecl tc def) result = do extendRenamer (tc@>sink tc') $ inferStructDef def def'' <- synthTyConDef def' updateTopEnv $ UpdateTyConDef tc' def'' - UStructDef _ paramBs _ methods <- return def + UStructDef _ (_, paramBs) _ methods <- return def forM_ methods \(letAnn, methodName, methodDef) -> do method <- liftInfererM $ solveLocal $ extendRenamer (tc@>sink tc') $ @@ -85,7 +85,7 @@ inferTopUDecl (UStructDecl tc def) result = do UDeclResultDone <$> applyRename (tc @> tc') result inferTopUDecl (UDataDefDecl def tc dcs) result = do tcDef <- liftInfererM $ solveLocal $ inferTyConDef def - tcDef'@(TyConDef _ _ (ADTCons dataCons)) <- synthTyConDef tcDef + tcDef'@(TyConDef _ _ _ (ADTCons dataCons)) <- synthTyConDef tcDef tc' <- emitBinding (getNameHint tcDef') $ TyConBinding (Just tcDef') (DotMethods mempty) dcs' <- forM (enumerate dataCons) \(i, dcDef) -> emitBinding (getNameHint dcDef) $ DataConBinding tc' i @@ -104,14 +104,16 @@ inferTopUDecl (UInterface paramBs methodTys className methodNames) result = do inferTopUDecl (UInstance className instanceBs params methods maybeName expl) result = do let (InternalName _ _ className') = className ab <- liftInfererM $ solveLocal do - withRoleUBinders instanceBs \_ -> do - ClassDef _ _ _ paramBinders _ _ <- lookupClassDef (sink className') - params' <- checkInstanceParams paramBinders params + withRoleUBinders instanceBs do + ClassDef _ _ _ roleExpls paramBinders _ _ <- lookupClassDef (sink className') + let expls = snd <$> roleExpls + params' <- checkInstanceParams expls paramBinders params className'' <- sinkM className' body <- checkInstanceBody className'' params' methods return (ListE params' `PairE` body) Abs bs' (ListE params' `PairE` body) <- return ab - let def = InstanceDef className' bs' params' body + let (roleExpls, bs'') = unzipAttrs bs' + let def = InstanceDef className' roleExpls bs'' params' body UDeclResultDone <$> case maybeName of RightB UnitB -> do void $ synthInstanceDefAndAddSynthCandidate def @@ -151,13 +153,12 @@ asTopBlock block = do return $ TopLam False (PiType Empty effTy) (LamExpr Empty block) getInstanceType :: EnvReader m => InstanceDef n -> m n (CorePiType n) -getInstanceType (InstanceDef className bs params _) = liftEnvReaderM do +getInstanceType (InstanceDef className roleExpls bs params _) = liftEnvReaderM do refreshAbs (Abs bs (ListE params)) \bs' (ListE params') -> do className' <- sinkM className - ClassDef classSourceName _ _ _ _ _ <- lookupClassDef className' + ClassDef classSourceName _ _ _ _ _ _ <- lookupClassDef className' let dTy = DictTy $ DictType classSourceName className' params' - let bs'' = fmapNest (\(RolePiBinder _ b) -> b) bs' - return $ CorePiType ImplicitApp bs'' $ EffTy Pure dTy + return $ CorePiType ImplicitApp (snd <$> roleExpls) bs' $ EffTy Pure dTy -- === Inferer interface === @@ -178,19 +179,40 @@ class ( MonadFail1 m, Fallible1 m, Catchable1 m, CtxReader1 m, Builder CoreIR m => EmitsInf n => NameHint -> Explicitness -> CType n -> (forall l. (EmitsInf l, DExt n l) => CAtomVar l -> m l (e l)) - -> m n (Abs (WithExpl CBinder) e n) + -> m n (Abs CBinder e n) + +buildAbsInfWithExpl + :: (InfBuilder m, SinkableE e, HoistableE e, RenameE e, SubstE AtomSubstVal e) + => EmitsInf n + => NameHint -> Explicitness -> CType n + -> (forall l. (EmitsInf l, DExt n l) => CAtomVar l -> m l (e l)) + -> m n (Abs (WithExpl CBinder) e n) +buildAbsInfWithExpl hint expl ty cont = do + Abs b e <- buildAbsInf hint expl ty cont + return $ Abs (WithAttrB expl b) e + +buildNaryAbsInfWithExpl + :: (Inferer m, SinkableE e, HoistableE e, RenameE e, SubstE AtomSubstVal e, Inferer m) + => EmitsInf n + => [Explicitness] -> EmptyAbs (Nest CBinder) n + -> (forall l. (EmitsInf l, DExt n l) => [CAtomVar l] -> m i l (e l)) + -> m i n (Abs (Nest (WithExpl CBinder)) e n) +buildNaryAbsInfWithExpl expls bs cont = do + Abs bs' e <- buildNaryAbsInf expls bs cont + return $ Abs (zipAttrs expls bs') e buildNaryAbsInf :: (SinkableE e, HoistableE e, RenameE e, SubstE AtomSubstVal e, Inferer m) => EmitsInf n - => EmptyAbs (Nest (WithExpl CBinder)) n + => [Explicitness] -> EmptyAbs (Nest CBinder) n -> (forall l. (EmitsInf l, DExt n l) => [CAtomVar l] -> m i l (e l)) - -> m i n (Abs (Nest (WithExpl CBinder)) e n) -buildNaryAbsInf (Abs Empty UnitE) cont = getDistinct >>= \Distinct -> Abs Empty <$> cont [] -buildNaryAbsInf (Abs (Nest (WithExpl expl (b:>ty)) bs) UnitE) cont = + -> m i n (Abs (Nest CBinder) e n) +buildNaryAbsInf [] (Abs Empty UnitE) cont = getDistinct >>= \Distinct -> Abs Empty <$> cont [] +buildNaryAbsInf (expl:expls) (Abs (Nest (b:>ty) bs) UnitE) cont = prependAbs <$> buildAbsInf (getNameHint b) expl ty \v -> do bs' <- applyRename (b@>atomVarName v) (Abs bs UnitE) - buildNaryAbsInf bs' \vs -> cont (sink v:vs) + buildNaryAbsInf expls bs' \vs -> cont (sink v:vs) +buildNaryAbsInf _ _ _ = error "zip error" buildDeclsInf :: (SubstE AtomSubstVal e, RenameE e, Solver m, InfBuilder m) @@ -522,7 +544,7 @@ instance InfBuilder (InfererM i) where ++ "\n" ++ pprint infFrag Abs b e <- return ab ty' <- zonk ty - return $ Abs (WithExpl expl (b:>ty')) e + return $ Abs (b:>ty') e dceInfFrag :: (EnvReader m, EnvExtender m, Fallible1 m, RenameE e, HoistableE e) @@ -831,11 +853,12 @@ extendSynthCandidates (Inferred _ (Synth _)) v (Env topEnv (ModuleEnv a b scs)) extendSynthCandidates _ _ env = env {-# INLINE extendSynthCandidates #-} -extendSynthCandidatess :: Distinct n => RolePiBinders n' n -> Env n -> Env n -extendSynthCandidatess (Nest (RolePiBinder _ (WithExpl expl b)) rest) env = - extendSynthCandidatess rest env' - where env' = extendSynthCandidates expl (withExtEvidence rest $ sink $ binderName b) env -extendSynthCandidatess Empty env = env +extendSynthCandidatess :: Distinct n => [Explicitness] -> Nest CBinder n' n -> Env n -> Env n +extendSynthCandidatess (expl:expls) (Nest b bs) env = + extendSynthCandidatess expls bs env' + where env' = extendSynthCandidates expl (withExtEvidence bs $ sink $ binderName b) env +extendSynthCandidatess [] Empty env = env +extendSynthCandidatess _ _ _ = error "zip error" {-# INLINE extendSynthCandidatess #-} -- === actual inference pass === @@ -848,8 +871,8 @@ data RequiredTy (e::E) (n::S) = checkSigma :: EmitsBoth o => NameHint -> UExpr i -> CType o -> InfererM i o (CAtom o) checkSigma hint expr sTy = confuseGHC >>= \_ -> case sTy of - Pi piTy@(CorePiType _ bs _) -> do - if all (== Explicit) (nestToList getExpl bs) + Pi piTy@(CorePiType _ expls _ _) -> do + if all (== Explicit) expls then fallback else case expr of WithSrcE src (ULam lam) -> addSrcContext src $ Lam <$> checkULam lam piTy @@ -949,7 +972,8 @@ checkOrInferRho hint uExprWithSrc@(WithSrcE pos expr) reqTy = do -- TODO: check explicitness constraints ab <- withUBinders bs \_ -> EffTy <$> checkUEffRow effs <*> checkUType ty Abs bs' effTy' <- return ab - matchRequirement $ Type $ Pi $ CorePiType appExpl bs' effTy' + let (expls, bs'') = unzipAttrs bs' + matchRequirement $ Type $ Pi $ CorePiType appExpl expls bs'' effTy' UTabPi (UTabPiExpr (UAnnBinder b ann cs) ty) -> do unless (null cs) $ throw TypeErr "`=>` shouldn't have constraints" ann' <- asIxType =<< checkAnn (getSourceName b) ann @@ -1157,11 +1181,11 @@ getFieldDefs ty = case ty of instantiateSigma :: forall i o. EmitsBoth o => SigmaAtom o -> InfererM i o (CAtom o) instantiateSigma sigmaAtom = case getType sigmaAtom of - Pi piTy@(CorePiType ExplicitApp _ _) -> do + Pi piTy@(CorePiType ExplicitApp _ _ _) -> do Lam <$> etaExpandExplicits fDesc piTy \args -> applySigmaAtom (sink sigmaAtom) args - Pi (CorePiType ImplicitApp bs (EffTy _ resultTy)) -> do - args <- inferMixedArgs @UExpr fDesc (Abs bs resultTy) [] [] + Pi (CorePiType ImplicitApp expls bs (EffTy _ resultTy)) -> do + args <- inferMixedArgs @UExpr fDesc expls (Abs bs resultTy) [] [] applySigmaAtom sigmaAtom args DepPairTy (DepPairType ImplicitDepPair _ _) -> -- TODO: we should probably call instantiateSigma again here in case @@ -1198,53 +1222,55 @@ etaExpandExplicits :: EmitsInf o => SourceName -> CorePiType o -> (forall o'. (EmitsBoth o', DExt o o') => [CAtom o'] -> InfererM i o' (CAtom o')) -> InfererM i o (CoreLamExpr o) -etaExpandExplicits fSourceName (CorePiType _ bsTop (EffTy effs _)) contTop = do - ab <- go bsTop \xs -> do +etaExpandExplicits fSourceName (CorePiType _ explsTop bsTop (EffTy effs _)) contTop = do + Abs bs body <- go explsTop bsTop \xs -> do effs' <- applySubst (bsTop@@>(SubstVal<$>xs)) effs withAllowedEffects effs' do body <- buildBlockInf $ contTop $ sinkList xs return $ PairE effs' body - coreLamExpr ExplicitApp ab + let (expls, bs') = unzipAttrs bs + coreLamExpr ExplicitApp expls $ Abs bs' body where go :: (EmitsInf o, SinkableE e, RenameE e, SubstE AtomSubstVal e, HoistableE e ) - => Nest (WithExpl CBinder) o any + => [Explicitness] -> Nest CBinder o any -> (forall o'. (EmitsInf o', DExt o o') => [CAtom o'] -> InfererM i o' (e o')) -> InfererM i o (Abs (Nest (WithExpl CBinder)) e o) - go Empty cont = getDistinct >>= \Distinct -> Abs Empty <$> cont [] - go (Nest (WithExpl expl (b:>ty)) rest) cont = case expl of + go [] Empty cont = getDistinct >>= \Distinct -> Abs Empty <$> cont [] + go (expl:expls) (Nest (b:>ty) rest) cont = case expl of Explicit -> do - prependAbs <$> buildAbsInf (getNameHint b) expl ty \v -> do + prependAbs <$> buildAbsInfWithExpl (getNameHint b) expl ty \v -> do Abs rest' UnitE <- applyRename (b@>atomVarName v) $ Abs rest UnitE - go rest' \args -> cont (sink (Var v) : args) + go expls rest' \args -> cont (sink (Var v) : args) Inferred argSourceName infMech -> do arg <- getImplicitArg (fSourceName, fromMaybe "_" argSourceName) infMech ty Abs rest' UnitE <- applySubst (b@>SubstVal arg) $ Abs rest UnitE - go rest' \args -> cont (sink arg : args) + go expls rest' \args -> cont (sink arg : args) + go _ _ _ = error "zip error" buildLamInf :: EmitsInf o => CorePiType o -> (forall o' . (EmitsBoth o', DExt o o') => [(Explicitness, CAtom o')] -> CType o' -> InfererM i o' (CAtom o')) -> InfererM i o (CoreLamExpr o) -buildLamInf (CorePiType appExpl bsTop effTy) contTop = do - ab <- go bsTop \xs -> do +buildLamInf (CorePiType appExpl explsTop bsTop effTy) contTop = do + ab <- go explsTop bsTop \xs -> do let (expls, xs') = unzip xs EffTy effs' resultTy' <- applySubst (bsTop@@>(SubstVal<$>xs')) effTy withAllowedEffects effs' do body <- buildBlockInf $ contTop (zip expls $ sinkList xs') (sink resultTy') return $ PairE effs' body - coreLamExpr appExpl ab + coreLamExpr appExpl explsTop ab where go :: (EmitsInf o, HoistableE e, SinkableE e, SubstE AtomSubstVal e, RenameE e) - => Nest (WithExpl CBinder) o any - -> (forall o'. (EmitsInf o', DExt o o') - => [(Explicitness, CAtom o')] -> InfererM i o' (e o')) - -> InfererM i o (Abs (Nest (WithExpl CBinder)) e o) - go Empty cont = getDistinct >>= \Distinct -> Abs Empty <$> cont [] - go (Nest (WithExpl expl b) rest) cont = do + => [Explicitness] -> Nest CBinder o any + -> (forall o'. (EmitsInf o', DExt o o') => [(Explicitness, CAtom o')] -> InfererM i o' (e o')) + -> InfererM i o (Abs (Nest CBinder) e o) + go [] Empty cont = getDistinct >>= \Distinct -> Abs Empty <$> cont [] + go (expl:expls) (Nest b rest) cont = do prependAbs <$> buildAbsInf (getNameHint b) expl (binderType b) \v -> do Abs rest' UnitE <- applyRename (b@>atomVarName v) $ Abs rest UnitE - go rest' \args -> cont ((expl, sink (Var v)) : args) + go expls rest' \args -> cont $ (expl, sink $ Var v) : args + go _ _ _ = error "zip error" class ExplicitArg (e::E) where checkExplicitArg :: EmitsBoth o => IsDependent -> e i -> CType o -> InfererM i o (CAtom o) @@ -1274,14 +1300,14 @@ checkOrInferApp checkOrInferApp f' posArgs namedArgs reqTy = do f <- maybeInterpretPunsAsTyCons reqTy f' case getType f of - Pi (CorePiType appExpl bs effTy) -> case appExpl of + Pi (CorePiType appExpl expls bs effTy) -> case appExpl of ExplicitApp -> do - checkArity bs posArgs - args' <- inferMixedArgs fDesc (Abs bs effTy) posArgs namedArgs + checkArity expls posArgs + args' <- inferMixedArgs fDesc expls (Abs bs effTy) posArgs namedArgs applySigmaAtom f args' >>= matchRequirement ImplicitApp -> do -- TODO: should this already have been done by the time we get `f`? - implicitArgs <- inferMixedArgs @UExpr fDesc (Abs bs effTy) [] [] + implicitArgs <- inferMixedArgs @UExpr fDesc expls (Abs bs effTy) [] [] f'' <- SigmaAtom (Just fDesc) <$> applySigmaAtom f implicitArgs checkOrInferApp f'' posArgs namedArgs Infer >>= matchRequirement -- TODO: special-case error for when `fTy` can't possibly be a function @@ -1328,24 +1354,24 @@ applySigmaAtom (SigmaUVar _ _ f) args = case f of f'' <- toAtomVar f' emitExprWithEffects =<< mkApp (Var f'') args UTyConVar f' -> do - TyConDef sn bs _ <- lookupTyCon f' - let expls = nestToList (\(RolePiBinder _ (WithExpl expl _)) -> expl) bs + TyConDef sn roleExpls _ _ <- lookupTyCon f' + let expls = snd <$> roleExpls return $ Type $ NewtypeTyCon $ UserADTType sn f' (TyConParams expls args) UDataConVar v -> do (tyCon, i) <- lookupDataCon v applyDataCon tyCon i args UPunVar tc -> do - TyConDef sn _ _ <- lookupTyCon tc + TyConDef sn _ _ _ <- lookupTyCon tc -- interpret as a data constructor by default (params, dataArgs) <- splitParamPrefix tc args repVal <- makeStructRepVal tc dataArgs return $ NewtypeCon (UserADTData sn tc params) repVal UClassVar f' -> do - ClassDef sourceName _ _ _ _ _ <- lookupClassDef f' + ClassDef sourceName _ _ _ _ _ _ <- lookupClassDef f' return $ Type $ DictTy $ DictType sourceName f' args UMethodVar f' -> do MethodBinding className methodIdx <- lookupEnv f' - ClassDef _ _ _ paramBs _ _ <- lookupClassDef className + ClassDef _ _ _ _ paramBs _ _ <- lookupClassDef className let numParams = nestLength paramBs -- params aren't needed because they're already implied by the dict argument let (dictArg:args') = drop numParams args @@ -1357,14 +1383,14 @@ applySigmaAtom (SigmaPartialApp _ f prevArgs) args = splitParamPrefix :: EnvReader m => TyConName n -> [CAtom n] -> m n (TyConParams n, [CAtom n]) splitParamPrefix tc args = do - TyConDef _ paramBs _ <- lookupTyCon tc + TyConDef _ _ paramBs _ <- lookupTyCon tc let (paramArgs, dataArgs) = splitAt (nestLength paramBs) args params <- makeTyConParams tc paramArgs return (params, dataArgs) applyDataCon :: Emits o => TyConName o -> Int -> [CAtom o] -> InfererM i o (CAtom o) applyDataCon tc conIx topArgs = do - tyDef@(TyConDef sn _ _) <- lookupTyCon tc + tyDef@(TyConDef sn _ _ _) <- lookupTyCon tc (params, dataArgs) <- splitParamPrefix tc topArgs ADTCons conDefs <- instantiateTyConDef tyDef params DataConDef _ _ repTy _ <- return $ conDefs !! conIx @@ -1398,9 +1424,9 @@ emitExprWithEffects expr = do addEffects $ getEffects expr emitExpr expr -checkArity :: BindsNames b => Nest (WithExpl b) n l -> [a] -> InfererM i o () -checkArity bs args = do - let arity = length [() | Explicit <- nestToList (\(WithExpl expl _) -> expl) bs] +checkArity :: [Explicitness] -> [a] -> InfererM i o () +checkArity expls args = do + let arity = length [() | Explicit <- expls] let numArgs = length args when (numArgs /= arity) do throw TypeErr $ "Wrong number of positional arguments provided. Expected " ++ @@ -1410,24 +1436,25 @@ checkArity bs args = do inferMixedArgs :: forall arg i o e . (ExplicitArg arg, EmitsBoth o, SubstE (SubstVal Atom) e, SinkableE e, HoistableE e) - => SourceName - -> Abs (Nest (WithExpl CBinder)) e o -> [arg i] -> [(SourceName, arg i)] + => SourceName -> [Explicitness] + -> Abs (Nest CBinder) e o -> [arg i] -> [(SourceName, arg i)] -> InfererM i o [CAtom o] -inferMixedArgs fSourceName bsAbs posArgs namedArgs = do - checkNamedArgValidity bsAbs (map fst namedArgs) - liftM fst $ runStreamReaderT1 posArgs $ go bsAbs +inferMixedArgs fSourceName explsTop bsAbs posArgs namedArgs = do + checkNamedArgValidity explsTop (map fst namedArgs) + liftM fst $ runStreamReaderT1 posArgs $ go explsTop bsAbs where go :: (EmitsBoth o, SubstE (SubstVal Atom) e, SinkableE e, HoistableE e) - => Abs (Nest (WithExpl CBinder)) e o + => [Explicitness] -> Abs (Nest CBinder) e o -> StreamReaderT1 (arg i) (InfererM i) o [CAtom o] - go (Abs Empty _) = return [] - go (Abs (Nest (WithExpl expl b) bs) result) = do + go [] (Abs Empty _) = return [] + go (expl:expls) (Abs (Nest b bs) result) = do let rest = Abs bs result let isDependent = binderName b `isFreeIn` rest arg <- inferMixedArg isDependent (binderType b) expl arg' <- lift11 $ zonk arg rest' <- applySubst (b @> SubstVal arg') rest - (arg:) <$> go rest' + (arg:) <$> go expls rest' + go _ _ = error "zip error" inferMixedArg :: EmitsBoth o => IsDependent -> CType o -> Explicitness -> StreamReaderT1 (arg i) (InfererM i) o (CAtom o) @@ -1445,12 +1472,12 @@ inferMixedArgs fSourceName bsAbs posArgs namedArgs = do lookupNamedArg Nothing = Nothing lookupNamedArg (Just v) = lookup v namedArgs -checkNamedArgValidity :: (BindsNames b, Fallible m) => Abs (Nest (WithExpl b)) e any -> [SourceName] -> m () -checkNamedArgValidity (Abs bs _) offeredNames = do +checkNamedArgValidity :: Fallible m => [Explicitness] -> [SourceName] -> m () +checkNamedArgValidity expls offeredNames = do let explToMaybeName = \case Explicit -> Nothing Inferred v _ -> v - let acceptedNames = catMaybes $ nestToList (explToMaybeName . getExpl) bs + let acceptedNames = catMaybes $ map explToMaybeName expls let duplicates = repeated offeredNames when (not $ null duplicates) do throw TypeErr $ "Repeated names offered" ++ pprint duplicates @@ -1618,7 +1645,7 @@ buildSortedCase scrut alts resultTy = do scrutTy <- return $ getType scrut case scrutTy of TypeCon _ defName _ -> do - TyConDef _ _ (ADTCons cons) <- lookupTyCon defName + TyConDef _ _ _ (ADTCons cons) <- lookupTyCon defName case cons of [] -> error "case of void?" -- Single constructor ADTs are not sum types, so elide the case. @@ -1631,14 +1658,13 @@ buildSortedCase scrut alts resultTy = do -- TODO: cache this with the instance def (requires a recursive binding) instanceFun :: EnvReader m => InstanceName n -> AppExplicitness -> m n (CAtom n) -instanceFun instanceName expl = do - InstanceDef _ bs _ _ <- lookupInstanceDef instanceName +instanceFun instanceName appExpl = do + InstanceDef _ expls bs _ _ <- lookupInstanceDef instanceName ab <- liftEnvReaderM $ refreshAbs (Abs bs UnitE) \bs' UnitE -> do args <- mapM toAtomVar $ nestToNames bs' - let bs'' = fmapNest (\(RolePiBinder _ b) -> b) bs' result <- mkDictAtom $ InstanceDict (sink instanceName) (Var <$> args) - return $ Abs bs'' (PairE Pure (WithoutDecls result)) - Lam <$> coreLamExpr expl ab + return $ Abs bs' (PairE Pure (WithoutDecls result)) + Lam <$> coreLamExpr appExpl (snd<$>expls) ab checkMaybeAnnExpr :: EmitsBoth o => NameHint -> Maybe (UType i) -> UExpr i -> InfererM i o (CAtom o) @@ -1664,53 +1690,55 @@ inferRole ty = \case inferTyConDef :: EmitsInf o => UDataDef i -> InfererM i o (TyConDef o) inferTyConDef (UDataDef tyConName paramBs dataCons) = do Abs paramBs' dataCons' <- - withRoleUBinders paramBs \_ -> do + withRoleUBinders paramBs do ADTCons <$> mapM inferDataCon dataCons - return (TyConDef tyConName paramBs' dataCons') + let (roleExpls, paramBs'') = unzipAttrs paramBs' + return (TyConDef tyConName roleExpls paramBs'' dataCons') inferStructDef :: EmitsInf o => UStructDef i -> InfererM i o (TyConDef o) inferStructDef (UStructDef tyConName paramBs fields _) = do let (fieldNames, fieldTys) = unzip fields - Abs paramBs' dataConDefs <- withRoleUBinders paramBs \_ -> do + Abs paramBs' dataConDefs <- withRoleUBinders paramBs do tys <- mapM checkUType fieldTys return $ StructFields $ zip fieldNames tys - return $ TyConDef tyConName paramBs' dataConDefs + let (roleExpls, paramBs'') = unzipAttrs paramBs' + return $ TyConDef tyConName roleExpls paramBs'' dataConDefs inferDotMethod :: EmitsInf o => TyConName o - -> Abs (Nest (WithExpl UOptAnnBinder)) (Abs UAtomBinder ULamExpr) i + -> Abs (Nest UOptAnnBinder) (Abs UAtomBinder ULamExpr) i -> InfererM i o (CoreLamExpr o) inferDotMethod tc (Abs uparamBs (Abs selfB lam)) = do - TyConDef sn paramBs _ <- lookupTyCon tc - let paramBs' = fmapNest (\(RolePiBinder _ b) -> b) paramBs - ab <- buildNaryAbsInf (Abs paramBs' UnitE) \paramVs -> do - let expls = nestToList (\(WithExpl expl _) -> expl) paramBs' + TyConDef sn roleExpls paramBs _ <- lookupTyCon tc + let expls = snd <$> roleExpls + ab <- buildNaryAbsInfWithExpl expls (Abs paramBs UnitE) \paramVs -> do let paramVs' = catMaybes $ zip expls paramVs <&> \(expl, v) -> case expl of Inferred _ (Synth _) -> Nothing _ -> Just v extendRenamer (uparamBs @@> (atomVarName <$> paramVs')) do let selfTy = NewtypeTyCon $ UserADTType sn (sink tc) (TyConParams expls (Var <$> paramVs)) - buildAbsInf "self" Explicit selfTy \vSelf -> + buildAbsInfWithExpl "self" Explicit selfTy \vSelf -> extendRenamer (selfB @> atomVarName vSelf) $ inferULam lam Abs paramBs'' (Abs selfB' lam') <- return ab return $ prependCoreLamExpr (paramBs'' >>> UnaryNest selfB') lam' prependCoreLamExpr :: Nest (WithExpl CBinder) n l -> CoreLamExpr l -> CoreLamExpr n prependCoreLamExpr bs e = case e of - CoreLamExpr (CorePiType appExpl piBs effTy) (LamExpr lamBs body) -> do - let piType = CorePiType appExpl (bs >>> piBs) effTy - let lamExpr = LamExpr (fmapNest withoutExpl bs >>> lamBs) body + CoreLamExpr (CorePiType appExpl piExpls piBs effTy) (LamExpr lamBs body) -> do + let (expls, bs') = unzipAttrs bs + let piType = CorePiType appExpl (expls <> piExpls) (bs' >>> piBs) effTy + let lamExpr = LamExpr (fmapNest withoutAttr bs >>> lamBs) body CoreLamExpr piType lamExpr inferDataCon :: EmitsInf o => (SourceName, UDataDefTrail i) -> InfererM i o (DataConDef o) inferDataCon (sourceName, UDataDefTrail argBs) = do - let argBsExpls = addExpls Explicit argBs - Abs argBs' UnitE <- withUBinders argBsExpls \_ -> return UnitE - let argBs'' = Abs (fmapNest withoutExpl argBs') UnitE + let expls = nestToList (const Explicit) argBs + Abs argBs' UnitE <- withUBinders (expls, argBs) \_ -> return UnitE + let argBs'' = Abs (fmapNest withoutAttr argBs') UnitE let (repTy, projIdxs) = dataConRepTy argBs'' return $ DataConDef sourceName argBs'' repTy projIdxs -dataConRepTy :: EmptyAbs (Nest (Binder CoreIR)) n -> (CType n, [[Projection]]) +dataConRepTy :: EmptyAbs (Nest CBinder) n -> (CType n, [[Projection]]) dataConRepTy (Abs topBs UnitE) = case topBs of Empty -> (UnitTy, []) _ -> go [] [UnwrapNewtype] topBs @@ -1738,47 +1766,49 @@ dataConRepTy (Abs topBs UnitE) = case topBs of inferClassDef :: EmitsInf o => SourceName -> [SourceName] - -> Nest (WithExpl UOptAnnBinder) i i' + -> UOptAnnExplBinders i i' -> [UType i'] -> InfererM i o (ClassDef o) -inferClassDef className methodNames paramBs methods = do +inferClassDef className methodNames paramBs@(expls, paramBs') methods = do + let paramBsWithAttrBs = zipWithNest paramBs' expls \b expl -> WithAttrB expl b let paramNames = catMaybes $ nestToList - (\(WithExpl expl (UAnnBinder b _ _)) -> case expl of + (\(WithAttrB expl (UAnnBinder b _ _)) -> case expl of Inferred _ (Synth _) -> Nothing - _ -> Just $ Just $ getSourceName b) paramBs - ab <- withRoleUBinders paramBs \_ -> do + _ -> Just $ Just $ getSourceName b) paramBsWithAttrBs + ab <- withRoleUBinders paramBs do ListE <$> forM methods \m -> do checkUType m >>= \case Pi t -> return t - t -> return $ CorePiType ImplicitApp Empty (EffTy Pure t) + t -> return $ CorePiType ImplicitApp [] Empty (EffTy Pure t) Abs (PairB bs scs) (ListE mtys) <- identifySuperclasses ab - return $ ClassDef className methodNames paramNames bs scs mtys + let (roleExpls, bs') = unzipAttrs bs + return $ ClassDef className methodNames paramNames roleExpls bs' scs mtys --- TODO: this is just partitioning the binders. We could write a more general function like this: --- partitionBinders :: Nest b n l -> (forall n l. b i i' -> EitherB b1 b2 i i') --- -> Except (PairB (Nest b1) (Nest b2)) n l identifySuperclasses - :: RenameE e => Abs RolePiBinders e n - -> InfererM i n (Abs (PairB RolePiBinders (Nest CBinder)) e n) -identifySuperclasses ab = refreshAbs ab \bs e -> do - bs' <- partitionBinders bs \b@(RolePiBinder _ (WithExpl expl b')) -> case expl of - Explicit -> return $ LeftB b - Inferred _ Unify -> throw TypeErr "Interfaces can't have implicit parameters" - Inferred _ (Synth _) -> return $ RightB b' - return $ Abs bs' e + :: RenameE e => Abs (Nest (WithRoleExpl CBinder)) e n + -> InfererM i n (Abs (PairB (Nest (WithRoleExpl CBinder)) (Nest CBinder)) e n) +identifySuperclasses ab = do + refreshAbs ab \bs e -> do + bs' <- partitionBinders bs \b@(WithAttrB (_, expl) b') -> case expl of + Explicit -> return $ LeftB b + Inferred _ Unify -> throw TypeErr "Interfaces can't have implicit parameters" + Inferred _ (Synth _) -> return $ RightB b' + return $ Abs bs' e withUBinders :: (EmitsInf o, HasNamesE e, SubstE AtomSubstVal e, SinkableE e) - => Nest (WithExpl (UAnnBinder req)) i i' + => UAnnExplBinders req i i' -> (forall o'. (EmitsInf o', DExt o o') => [CAtomVar o'] -> InfererM i' o' (e o')) -> InfererM i o (Abs (Nest (WithExpl CBinder)) e o) withUBinders bs cont = case bs of - Empty -> getDistinct >>= \Distinct -> Abs Empty <$> cont [] - Nest (WithExpl expl (UAnnBinder b ann cs)) rest -> do + ([], Empty) -> getDistinct >>= \Distinct -> Abs Empty <$> cont [] + (expl:expls, Nest (UAnnBinder b ann cs) rest) -> do ann' <- checkAnn (getSourceName b) ann - prependAbs <$> buildAbsInf (getNameHint b) expl ann' \v -> + prependAbs <$> buildAbsInfWithExpl (getNameHint b) expl ann' \v -> concatAbs <$> withConstraintBinders cs v do - extendSubst (b@>sink (atomVarName v)) $ withUBinders rest \vs -> cont (sink v : vs) + extendSubst (b@>sink (atomVarName v)) $ withUBinders (expls, rest) \vs -> + cont (sink v : vs) + _ -> error "zip error" withConstraintBinders :: (EmitsInf o, HasNamesE e, SubstE AtomSubstVal e, RenameE e, SinkableE e) @@ -1791,24 +1821,26 @@ withConstraintBinders (c:cs) v cont = do Type dictTy <- withReducibleEmissions "Can't reduce interface constraint" do c' <- inferWithoutInstantiation c >>= zonk dropSubst $ checkOrInferApp c' [Var $ sink v] [] (Check TyKind) - prependAbs <$> buildAbsInf "d" (Inferred Nothing (Synth Full)) dictTy \_ -> + prependAbs <$> buildAbsInfWithExpl "d" (Inferred Nothing (Synth Full)) dictTy \_ -> withConstraintBinders cs (sink v) cont withRoleUBinders :: forall i i' o e req. (EmitsInf o, HasNamesE e, SubstE AtomSubstVal e, SinkableE e) - => Nest (WithExpl (UAnnBinder req)) i i' - -> (forall o'. (EmitsInf o', DExt o o') => [CAtomVar o'] -> InfererM i' o' (e o')) - -> InfererM i o (Abs RolePiBinders e o) -withRoleUBinders bs cont = case bs of - Empty -> getDistinct >>= \Distinct -> Abs Empty <$> cont [] - Nest (WithExpl expl (UAnnBinder b ann cs)) rest -> do + => UAnnExplBinders req i i' + -> (forall o'. (EmitsInf o', DExt o o') => InfererM i' o' (e o')) + -> InfererM i o (Abs (Nest (WithRoleExpl CBinder)) e o) +withRoleUBinders roleBs cont = case roleBs of + ([], Empty) -> getDistinct >>= \Distinct -> Abs Empty <$> cont + (expl:expls, Nest (UAnnBinder b ann cs) rest) -> do ann' <- checkAnn (getSourceName b) ann Abs b' (Abs bs' e) <- buildAbsInf (getNameHint b) expl ann' \v -> do Abs ds (Abs bs' e) <- withConstraintBinders cs v $ - extendSubst (b@>sink (atomVarName v)) $ withRoleUBinders rest \vs -> cont (sink v : vs) - return $ Abs (fmapNest (RolePiBinder DictParam) ds >>> bs') e + extendSubst (b@>sink (atomVarName v)) $ withRoleUBinders (expls, rest) cont + let ds' = fmapNest (\(WithAttrB expl' b') -> WithAttrB (DictParam, expl') b') ds + return $ Abs (ds' >>> bs') e role <- inferRole (binderType b') expl - return $ Abs (Nest (RolePiBinder role b') bs') e + return $ Abs (Nest (WithAttrB (role,expl) b') bs') e + _ -> error "zip error" inferULam :: EmitsInf o => ULamExpr i -> InfererM i o (CoreLamExpr o) inferULam (ULamExpr bs appExpl effs resultTy body) = do @@ -1823,12 +1855,13 @@ inferULam (ULamExpr bs appExpl effs resultTy body) = do checkSigma noHint result (sink resultTy'') return (PairE effs' body') Abs bs' (PairE effs' body') <- return ab + let (expls, bs'') = unzipAttrs bs' case appExpl of - ImplicitApp -> checkImplicitLamRestrictions bs' effs' + ImplicitApp -> checkImplicitLamRestrictions bs'' effs' ExplicitApp -> return () - coreLamExpr appExpl $ Abs bs' $ PairE effs' body' + coreLamExpr appExpl expls $ Abs bs'' $ PairE effs' body' -checkImplicitLamRestrictions :: Nest (WithExpl CBinder) o o' -> EffectRow CoreIR o' -> InfererM i o () +checkImplicitLamRestrictions :: Nest CBinder o o' -> EffectRow CoreIR o' -> InfererM i o () checkImplicitLamRestrictions _ _ = return () -- TODO checkUForExpr :: EmitsBoth o => UForExpr i -> TabPiType CoreIR o -> InfererM i o (LamExpr CoreIR o) @@ -1845,7 +1878,7 @@ checkUForExpr (UForExpr (UAnnBinder bFor ann cs) body) tabPi@(TabPiType _ bPi _) buildBlockInf do withBlockDecls body \result -> checkSigma noHint result $ sink resultTy' - return $ LamExpr (UnaryNest $ withoutExpl b) body' + return $ LamExpr (UnaryNest b) body' inferUForExpr :: EmitsBoth o => UForExpr i -> InfererM i o (LamExpr CoreIR o) inferUForExpr (UForExpr (UAnnBinder bFor ann cs) body) = do @@ -1855,15 +1888,15 @@ inferUForExpr (UForExpr (UAnnBinder bFor ann cs) body) = do extendRenamer (bFor@>atomVarName i) $ buildBlockInf $ withBlockDecls body \result -> checkOrInferRho noHint result Infer - return $ LamExpr (UnaryNest $ withoutExpl b) body' + return $ LamExpr (UnaryNest b) body' checkULam :: EmitsInf o => ULamExpr i -> CorePiType o -> InfererM i o (CoreLamExpr o) -checkULam (ULamExpr lamBs lamAppExpl lamEffs lamResultTy body) - (CorePiType piAppExpl piBs effTy) = do - checkArity piBs (nestToList (const ()) lamBs) +checkULam (ULamExpr (_, lamBs) lamAppExpl lamEffs lamResultTy body) + (CorePiType piAppExpl expls piBs effTy) = do + checkArity expls (nestToList (const ()) lamBs) when (piAppExpl /= lamAppExpl) $ throw TypeErr $ "Wrong arrow. Expected " ++ pprint piAppExpl ++ " got " ++ pprint lamAppExpl - ab <- checkLamBinders piBs lamBs \vs -> do + Abs explBs body' <- checkLamBinders expls piBs lamBs \vs -> do EffTy piEffs' piResultTy' <- applyRename (piBs@@>map atomVarName vs) effTy case lamResultTy of Nothing -> return () @@ -1877,47 +1910,44 @@ checkULam (ULamExpr lamBs lamAppExpl lamEffs lamResultTy body) withBlockDecls body \result -> checkSigma noHint result piResultTy'' return $ PairE piEffs' body' - coreLamExpr piAppExpl ab + let (expls', bs') = unzipAttrs explBs + coreLamExpr piAppExpl expls' $ Abs bs' body' checkLamBinders :: (EmitsInf o, SinkableE e, HoistableE e, SubstE AtomSubstVal e, RenameE e) - => Nest (WithExpl CBinder) o any - -> Nest (WithExpl UOptAnnBinder) i i' + => [Explicitness] -> Nest CBinder o any + -> Nest UOptAnnBinder i i' -> (forall o'. (EmitsInf o', DExt o o') => [CAtomVar o'] -> InfererM i' o' (e o')) -> InfererM i o (Abs (Nest (WithExpl CBinder)) e o) -checkLamBinders Empty Empty cont = getDistinct >>= \Distinct -> Abs Empty <$> cont [] -checkLamBinders (Nest (WithExpl piExpl (piB:>piAnn)) piBs) lamBs cont = do +checkLamBinders [] Empty Empty cont = getDistinct >>= \Distinct -> Abs Empty <$> cont [] +checkLamBinders (piExpl:piExpls) (Nest (piB:>piAnn) piBs) lamBs cont = do prependAbs <$> case piExpl of Inferred _ _ -> - buildAbsInf (getNameHint piB) piExpl piAnn \v -> do + buildAbsInfWithExpl (getNameHint piB) piExpl piAnn \v -> do Abs piBs' UnitE <- applyRename (piB@>atomVarName v) $ Abs piBs UnitE - checkLamBinders piBs' lamBs \vs -> + checkLamBinders piExpls piBs' lamBs \vs -> cont (sink v:vs) Explicit -> case lamBs of - Nest (WithExpl Explicit (UAnnBinder lamB ann cs)) lamBsRest -> do + Nest (UAnnBinder lamB ann cs) lamBsRest -> do case ann of UAnn lamAnn -> checkUType lamAnn >>= constrainTypesEq piAnn UNoAnn -> return () - buildAbsInf (getNameHint lamB) Explicit piAnn \v -> do + buildAbsInfWithExpl (getNameHint lamB) Explicit piAnn \v -> do concatAbs <$> withConstraintBinders cs v do Abs piBs' UnitE <- applyRename (piB@>sink (atomVarName v)) $ Abs piBs UnitE - extendRenamer (lamB@>sink (atomVarName v)) $ checkLamBinders piBs' lamBsRest \vs -> + extendRenamer (lamB@>sink (atomVarName v)) $ checkLamBinders piExpls piBs' lamBsRest \vs -> cont (sink v:vs) - Nest (WithExpl (Inferred _ _) _) _ -> - -- TODO(dougalm): I don't think this case is reachable, but if it is - -- then we can check for it in `checkULam` and fall back to `inferULam`. - error "shouldn't be able to check lambda terms with implicit binders" Empty -> error "zip error" -checkLamBinders _ _ _ = error "zip error" +checkLamBinders _ _ _ _ = error "zip error" -checkInstanceParams :: EmitsInf o => RolePiBinders o any -> [UExpr i] -> InfererM i o [CAtom o] -checkInstanceParams bsTop paramsTop = do - checkArity (fmapNest (\(RolePiBinder _ b) -> b) bsTop) paramsTop +checkInstanceParams :: EmitsInf o => [Explicitness] -> Nest CBinder o any -> [UExpr i] -> InfererM i o [CAtom o] +checkInstanceParams expls bsTop paramsTop = do + checkArity expls paramsTop go bsTop paramsTop where - go :: EmitsInf o => Nest RolePiBinder o any -> [UExpr i] -> InfererM i o [CAtom o] + go :: EmitsInf o => Nest CBinder o any -> [UExpr i] -> InfererM i o [CAtom o] go Empty [] = return [] - go (Nest (RolePiBinder _ (WithExpl _ (b:>ty))) bs) (x:xs) = do + go (Nest (b:>ty) bs) (x:xs) = do x' <- checkUParam ty x Abs bs' UnitE <- applySubst (b@>SubstVal x') $ Abs bs UnitE (x':) <$> go bs' xs @@ -1927,7 +1957,7 @@ checkInstanceBody :: EmitsInf o => ClassName o -> [CAtom o] -> [UMethodDef i] -> InfererM i o (InstanceBody o) checkInstanceBody className params methods = do - ClassDef _ methodNames _ paramBs scBs methodTys <- lookupClassDef className + ClassDef _ methodNames _ _ paramBs scBs methodTys <- lookupClassDef className Abs scBs' methodTys' <- applySubst (paramBs @@> (SubstVal <$> params)) $ Abs scBs $ ListE methodTys superclassTys <- superclassDictTys scBs' superclassDicts <- mapM (flip trySynthTerm Full) superclassTys @@ -1952,7 +1982,7 @@ checkMethodDef className methodTys (WithSrcE src m) = addSrcContext src do UMethodDef ~(InternalName _ sourceName v) rhs <- return m MethodBinding className' i <- renameM v >>= lookupEnv when (className /= className') do - ClassBinding (ClassDef classSourceName _ _ _ _ _) <- lookupEnv className + ClassBinding (ClassDef classSourceName _ _ _ _ _ _) <- lookupEnv className throw TypeErr $ pprint sourceName ++ " is not a method of " ++ pprint classSourceName (i,) <$> Lam <$> checkULam rhs (methodTys !! i) @@ -2008,12 +2038,12 @@ checkCasePat :: EmitsBoth o checkCasePat (WithSrcB pos pat) scrutineeTy cont = addSrcContext pos $ case pat of UPatCon ~(InternalName _ _ conName) ps -> do (dataDefName, con) <- renameM conName >>= lookupDataCon - TyConDef sourceName paramBs (ADTCons cons) <- lookupTyCon dataDefName + TyConDef sourceName roleExpls paramBs (ADTCons cons) <- lookupTyCon dataDefName DataConDef _ _ repTy idxs <- return $ cons !! con when (length idxs /= nestLength ps) $ throw TypeErr $ "Unexpected number of pattern binders. Expected " ++ show (length idxs) ++ " got " ++ show (nestLength ps) - (params, repTy') <- inferParams sourceName (Abs paramBs repTy) + (params, repTy') <- inferParams sourceName roleExpls (Abs paramBs repTy) constrainTypesEq scrutineeTy $ TypeCon sourceName dataDefName params buildAltInf repTy' \arg -> do args <- forM idxs \projs -> do @@ -2023,22 +2053,23 @@ checkCasePat (WithSrcB pos pat) scrutineeTy cont = addSrcContext pos $ case pat _ -> throw TypeErr $ "Case patterns must start with a data constructor or variant pattern" inferParams :: (EmitsBoth o, HasNamesE e, SinkableE e, SubstE AtomSubstVal e) - => SourceName -> Abs RolePiBinders e o -> InfererM i o (TyConParams o, e o) -inferParams sourceName (Abs paramBs bodyTop) = do - (params, e') <- go (Abs paramBs bodyTop) - let expls = nestToList (\(RolePiBinder _ (WithExpl expl _)) -> expl) paramBs + => SourceName -> [RoleExpl] -> Abs (Nest CBinder) e o -> InfererM i o (TyConParams o, e o) +inferParams sourceName roleExpls (Abs paramBs bodyTop) = do + let expls = snd <$> roleExpls + (params, e') <- go expls (Abs paramBs bodyTop) return (TyConParams expls params, e') where go :: (EmitsBoth o, HasNamesE e, SinkableE e, SubstE AtomSubstVal e) - => Abs (Nest RolePiBinder) e o -> InfererM i o ([CAtom o], e o) - go (Abs Empty body) = return ([], body) - go (Abs (Nest (RolePiBinder _ (WithExpl expl (b:>ty))) bs) body) = do + => [Explicitness] -> Abs (Nest CBinder) e o -> InfererM i o ([CAtom o], e o) + go [] (Abs Empty body) = return ([], body) + go (expl:expls) (Abs (Nest (b:>ty) bs) body) = do x <- case expl of Explicit -> Var <$> freshInferenceName (TypeInstantiationInfVar sourceName) ty Inferred argName infMech -> getImplicitArg (sourceName, fromMaybe "_" argName) infMech ty rest <- applySubst (b@>SubstVal x) $ Abs bs body - (params, body') <- go rest + (params, body') <- go expls rest return (x:params, body') + go _ _ = error "zip error" bindLetPats :: EmitsBoth o => Nest UPat i i' -> [CAtomVar o] -> InfererM i' o a -> InfererM i o a @@ -2069,13 +2100,13 @@ bindLetPat (WithSrcB pos pat) v cont = addSrcContext pos $ case pat of cont UPatCon ~(InternalName _ _ conName) ps -> do (dataDefName, _) <- lookupDataCon =<< renameM conName - TyConDef sourceName paramBs cons <- lookupTyCon dataDefName + TyConDef sourceName roleExpls paramBs cons <- lookupTyCon dataDefName case cons of ADTCons [DataConDef _ _ _ idxss] -> do when (length idxss /= nestLength ps) $ throw TypeErr $ "Unexpected number of pattern binders. Expected " ++ show (length idxss) ++ " got " ++ show (nestLength ps) - (params, UnitE) <- inferParams sourceName (Abs paramBs UnitE) + (params, UnitE) <- inferParams sourceName roleExpls (Abs paramBs UnitE) constrainVarTy v $ TypeCon sourceName dataDefName params x <- cheapNormalize =<< zonk (Var v) xs <- forM idxss \idxs -> normalizeNaryProj idxs x >>= emit . Atom @@ -2140,7 +2171,7 @@ inferTabCon hint xs reqTy = do withFreshBinder noHint finTy \b' -> do elemTy' <- applyRename (b@>binderName b') elemTy dTy <- DictTy <$> dataDictType elemTy' - return $ Pi $ CorePiType ImplicitApp (UnaryNest (WithExpl (Inferred Nothing Unify) b')) (EffTy Pure dTy) + return $ Pi $ CorePiType ImplicitApp [Inferred Nothing Unify] (UnaryNest b') (EffTy Pure dTy) liftM Var $ emitHinted hint $ TabCon (dataDictHole dTy) tabTy xs' -- Bool flag is just to tweak the reported error message @@ -2494,19 +2525,19 @@ unifyEq e1 e2 = guard =<< alphaEq e1 e2 {-# INLINE unifyEq #-} instance Unifiable CorePiType where - unifyZonked (CorePiType appExpl1 bsTop1 effTy1) - (CorePiType appExpl2 bsTop2 effTy2) = do + unifyZonked (CorePiType appExpl1 expls1 bsTop1 effTy1) + (CorePiType appExpl2 expls2 bsTop2 effTy2) = do unless (appExpl1 == appExpl2) empty + unless (expls1 == expls2) empty go (Abs bsTop1 effTy1) (Abs bsTop2 effTy2) where go :: EmitsInf n - => Abs (Nest (WithExpl CBinder)) (EffTy CoreIR) n - -> Abs (Nest (WithExpl CBinder)) (EffTy CoreIR) n + => Abs (Nest CBinder) (EffTy CoreIR) n + -> Abs (Nest CBinder) (EffTy CoreIR) n -> SolverM n () go (Abs Empty (EffTy e1 t1)) (Abs Empty (EffTy e2 t2)) = unify t1 t2 >> unify e1 e2 - go (Abs (Nest (WithExpl expl1 (b1:>t1)) bs1) rest1) - (Abs (Nest (WithExpl expl2 (b2:>t2)) bs2) rest2) = do - unless (expl1 == expl2) empty + go (Abs (Nest (b1:>t1) bs1) rest1) + (Abs (Nest (b2:>t2) bs2) rest2) = do unify t1 t2 v <- freshSkolemName t1 ab1 <- zonk =<< applySubst (b1@>SubstVal (Var v)) (Abs bs1 rest1) @@ -2585,13 +2616,9 @@ synthTopE block = do {-# SCC synthTopE #-} synthTyConDef :: (EnvReader m, Fallible1 m) => TyConDef n -> m n (TyConDef n) -synthTyConDef (TyConDef sn rbs body) = (liftExcept =<<) $ liftDictSynthTraverserM do - let bs = fmapNest (\(RolePiBinder _ b) -> b) rbs - let roles = nestToList (\(RolePiBinder role _) -> role) rbs - dsTraverseExplBinders bs \bs' -> do - body' <- dsTraverse body - let rbs' = zipWithNest bs' roles \b role -> RolePiBinder role b - return $ TyConDef sn rbs' body' +synthTyConDef (TyConDef sn roleExpls bs body) = (liftExcept =<<) $ liftDictSynthTraverserM do + dsTraverseExplBinders (snd <$> roleExpls) bs \bs' -> + TyConDef sn roleExpls bs' <$> dsTraverse body {-# SCC synthTyConDef #-} -- Given a simplified dict (an Atom of type `DictTy _` in the @@ -2624,8 +2651,8 @@ generalizeDictRec dict = do DictCon _ dict' <- cheapNormalize dict mkDictAtom =<< case dict' of InstanceDict instanceName args -> do - InstanceDef _ bs _ _ <- lookupInstanceDef instanceName - args' <- generalizeInstanceArgs bs args + InstanceDef _ roleExpls bs _ _ <- lookupInstanceDef instanceName + args' <- generalizeInstanceArgs roleExpls bs args return $ InstanceDict instanceName args' IxFin _ -> IxFin <$> Var <$> freshInferenceName MiscInfVar NatTy InstantiatedGiven _ _ -> notSimplifiedDict @@ -2633,9 +2660,9 @@ generalizeDictRec dict = do DataData ty -> DataData <$> TyVar <$> freshInferenceName MiscInfVar ty where notSimplifiedDict = error $ "Not a simplified dict: " ++ pprint dict -generalizeInstanceArgs :: EmitsInf n => RolePiBinders n l -> [CAtom n] -> SolverM n [CAtom n] -generalizeInstanceArgs Empty [] = return [] -generalizeInstanceArgs (Nest (RolePiBinder role (WithExpl _ (b:>ty))) bs) (arg:args) = do +generalizeInstanceArgs :: EmitsInf n => [RoleExpl] -> Nest CBinder n l -> [CAtom n] -> SolverM n [CAtom n] +generalizeInstanceArgs [] Empty [] = return [] +generalizeInstanceArgs ((role,_):expls) (Nest (b:>ty) bs) (arg:args) = do arg' <- case role of -- XXX: for `TypeParam` we can just emit a fresh inference name rather than -- traversing the whole type like we do in `Generalize.hs`. The reason is @@ -2646,21 +2673,21 @@ generalizeInstanceArgs (Nest (RolePiBinder role (WithExpl _ (b:>ty))) bs) (arg:a DictParam -> generalizeDictAndUnify ty arg DataParam -> Var <$> freshInferenceName MiscInfVar ty Abs bs' UnitE <- applySubst (b@>SubstVal arg') (Abs bs UnitE) - args' <- generalizeInstanceArgs bs' args + args' <- generalizeInstanceArgs expls bs' args return $ arg':args' -generalizeInstanceArgs _ _ = error "zip error" +generalizeInstanceArgs _ _ _ = error "zip error" synthInstanceDefAndAddSynthCandidate :: (Mut n, TopBuilder m, EnvReader m, Fallible1 m) => InstanceDef n -> m n (InstanceName n) -synthInstanceDefAndAddSynthCandidate def@(InstanceDef className bs params (InstanceBody superclasses _)) = do - let emptyDef = InstanceDef className bs params $ InstanceBody superclasses [] +synthInstanceDefAndAddSynthCandidate def@(InstanceDef className expls bs params (InstanceBody superclasses _)) = do + let emptyDef = InstanceDef className expls bs params $ InstanceBody superclasses [] instanceName <- emitInstanceDef emptyDef addInstanceSynthCandidate className instanceName synthInstanceDefRec instanceName def return instanceName emitInstanceDef :: (Mut n, TopBuilder m) => InstanceDef n -> m n (Name InstanceNameC n) -emitInstanceDef instanceDef@(InstanceDef className _ _ _) = do +emitInstanceDef instanceDef@(InstanceDef className _ _ _ _) = do ty <- getInstanceType instanceDef emitBinding (getNameHint className) $ InstanceBinding instanceDef ty @@ -2672,46 +2699,47 @@ pattern InstanceDefAbsBody :: [CAtom n] -> [CAtom n] -> [CAtom n] -> [CAtom n] pattern InstanceDefAbsBody params superclasses doneMethods todoMethods = ListE params `PairE` (ListE superclasses) `PairE` (ListE doneMethods) `PairE` (ListE todoMethods) -type InstanceDefAbsT = Abs (Nest RolePiBinder) InstanceDefAbsBodyT +type InstanceDefAbsT n = ([RoleExpl], Abs (Nest CBinder) InstanceDefAbsBodyT n) -pattern InstanceDefAbs :: Nest RolePiBinder h n -> [CAtom n] -> [CAtom n] -> [CAtom n] -> [CAtom n] +pattern InstanceDefAbs :: [RoleExpl] -> Nest CBinder h n -> [CAtom n] -> [CAtom n] -> [CAtom n] -> [CAtom n] -> InstanceDefAbsT h -pattern InstanceDefAbs bs params superclasses doneMethods todoMethods = - Abs bs (InstanceDefAbsBody params superclasses doneMethods todoMethods) +pattern InstanceDefAbs expls bs params superclasses doneMethods todoMethods = + (expls, Abs bs (InstanceDefAbsBody params superclasses doneMethods todoMethods)) synthInstanceDefRec :: (Mut n, TopBuilder m, EnvReader m, Fallible1 m) => InstanceName n -> InstanceDef n -> m n () -synthInstanceDefRec instanceName (InstanceDef className bs params (InstanceBody superclasses methods)) = do - let ab = InstanceDefAbs bs params superclasses [] methods +synthInstanceDefRec instanceName def = do + InstanceDef className roleExplsTop bs params (InstanceBody superclasses methods) <- return def + let ab = InstanceDefAbs roleExplsTop bs params superclasses [] methods recur ab className instanceName where recur :: (Mut n, TopBuilder m, EnvReader m, Fallible1 m) => InstanceDefAbsT n -> ClassName n -> InstanceName n -> m n () - recur (InstanceDefAbs _ _ _ _ []) _ _ = return () - recur ab cname iname = do - (def, ab') <- liftExceptEnvReaderM $ refreshAbs ab + recur (InstanceDefAbs _ _ _ _ _ []) _ _ = return () + recur (roleExpls, ab) cname iname = do + (def', ab') <- liftExceptEnvReaderM $ refreshAbs ab \bs' (InstanceDefAbsBody ps scs doneMethods (m:ms)) -> do EnvReaderT $ ReaderT \(Distinct, env) -> do - let env' = extendSynthCandidatess bs' env + let env' = extendSynthCandidatess (snd<$>roleExpls) bs' env flip runReaderT (Distinct, env') $ runEnvReaderT' do m' <- synthTopE m let doneMethods' = doneMethods ++ [m'] - let ab' = InstanceDefAbs bs' ps scs doneMethods' ms - let def = InstanceDef cname bs' ps $ InstanceBody scs doneMethods' - return (def, ab') - updateTopEnv $ UpdateInstanceDef iname def + let ab' = InstanceDefAbs roleExpls bs' ps scs doneMethods' ms + let def' = InstanceDef cname roleExpls bs' ps $ InstanceBody scs doneMethods' + return (def', ab') + updateTopEnv $ UpdateInstanceDef iname def' recur ab' cname iname synthInstanceDef :: (EnvReader m, Fallible1 m) => InstanceDef n -> m n (InstanceDef n) -synthInstanceDef (InstanceDef className bs params body) = do +synthInstanceDef (InstanceDef className expls bs params body) = do liftExceptEnvReaderM $ refreshAbs (Abs bs (ListE params `PairE` body)) \bs' (ListE params' `PairE` InstanceBody superclasses methods) -> do EnvReaderT $ ReaderT \(Distinct, env) -> do - let env' = extendSynthCandidatess bs' env + let env' = extendSynthCandidatess (snd<$>expls) bs' env flip runReaderT (Distinct, env') $ runEnvReaderT' do methods' <- mapM synthTopE methods - return $ InstanceDef className bs' params' $ InstanceBody superclasses methods' + return $ InstanceDef className expls bs' params' $ InstanceBody superclasses methods' -- main entrypoint to dictionary synthesizer trySynthTerm :: (Fallible1 m, EnvReader m) => CType n -> RequiredMethodAccess -> m n (SynthAtom n) @@ -2728,7 +2756,7 @@ trySynthTerm ty reqMethodAccess = do {-# SCC trySynthTerm #-} type SynthAtom = CAtom -type SynthPiType = Abs (Nest (WithExpl CBinder)) DictType +type SynthPiType n = ([Explicitness], Abs (Nest CBinder) DictType n) data SynthType n = SynthDictType (DictType n) | SynthPiType (SynthPiType n) @@ -2781,7 +2809,7 @@ getSynthType x = ignoreExcept $ typeAsSynthType (getType x) typeAsSynthType :: CType n -> Except (SynthType n) typeAsSynthType = \case DictTy dictTy -> return $ SynthDictType dictTy - Pi (CorePiType ImplicitApp bs (EffTy Pure (DictTy d))) -> return $ SynthPiType (Abs bs d) + Pi (CorePiType ImplicitApp expls bs (EffTy Pure (DictTy d))) -> return $ SynthPiType (expls, Abs bs d) ty -> Failure $ Errs [Err TypeErr mempty $ "Can't synthesize terms of type: " ++ pprint ty] {-# SCC typeAsSynthType #-} @@ -2827,11 +2855,11 @@ getSuperclassClosurePure env givens newGivens = synthTerm :: SynthType n -> RequiredMethodAccess -> SyntherM n (SynthAtom n) synthTerm targetTy reqMethodAccess = confuseGHC >>= \_ -> case targetTy of - SynthPiType ab -> do - ab' <- withGivenBinders ab \bs targetTy' -> do + SynthPiType (expls, ab) -> do + ab' <- withGivenBinders expls ab \bs targetTy' -> do Abs bs <$> synthTerm (SynthDictType targetTy') reqMethodAccess Abs bs synthExpr <- return ab' - liftM Lam $ coreLamExpr ImplicitApp $ Abs bs $ PairE Pure (WithoutDecls synthExpr) + liftM Lam $ coreLamExpr ImplicitApp expls $ Abs bs $ PairE Pure (WithoutDecls synthExpr) SynthDictType dictTy -> case dictTy of DictType "Ix" _ [Type (NewtypeTyCon (Fin n))] -> return $ DictCon (DictTy dictTy) $ IxFin n DictType "Data" _ [Type t] -> do @@ -2848,21 +2876,29 @@ synthTerm targetTy reqMethodAccess = confuseGHC >>= \_ -> case targetTy of _ -> return dict {-# SCC synthTerm #-} +coreLamExpr :: EnvReader m => AppExplicitness + -> [Explicitness] -> Abs (Nest CBinder) (PairE (EffectRow CoreIR) CBlock) n + -> m n (CoreLamExpr n) +coreLamExpr appExpl expls ab = liftEnvReaderM do + refreshAbs ab \bs' (PairE effs' body') -> do + EffTy _ resultTy <- blockEffTy body' + return $ CoreLamExpr (CorePiType appExpl expls bs' (EffTy effs' resultTy)) (LamExpr bs' body') + withGivenBinders - :: (SinkableE e, RenameE e) => Abs (Nest (WithExpl CBinder)) e n - -> (forall l. DExt n l => Nest (WithExpl CBinder) n l -> e l -> SyntherM l a) + :: (SinkableE e, RenameE e) => [Explicitness] -> Abs (Nest CBinder) e n + -> (forall l. DExt n l => Nest CBinder n l -> e l -> SyntherM l a) -> SyntherM n a -withGivenBinders (Abs bsTop e) contTop = - runSubstReaderT idSubst $ go bsTop \bsTop' -> do +withGivenBinders explsTop (Abs bsTop e) contTop = + runSubstReaderT idSubst $ go explsTop bsTop \bsTop' -> do e' <- renameM e liftSubstReaderT $ contTop bsTop' e' where - go :: Nest (WithExpl CBinder) i i' - -> (forall o'. DExt o o' => Nest (WithExpl CBinder) o o' -> SubstReaderT Name SyntherM i' o' a) + go :: [Explicitness] -> Nest CBinder i i' + -> (forall o'. DExt o o' => Nest CBinder o o' -> SubstReaderT Name SyntherM i' o' a) -> SubstReaderT Name SyntherM i o a - go bs cont = case bs of - Empty -> getDistinct >>= \Distinct -> cont Empty - Nest (WithExpl expl b) rest -> do + go expls bs cont = case (expls, bs) of + ([], Empty) -> getDistinct >>= \Distinct -> cont Empty + (expl:explsRest, Nest b rest) -> do argTy <- renameM $ binderType b withFreshBinder (getNameHint b) argTy \b' -> do givens <- case expl of @@ -2871,13 +2907,14 @@ withGivenBinders (Abs bsTop e) contTop = s <- getSubst liftSubstReaderT $ extendGivens givens $ runSubstReaderT (s <>> b@>binderName b') $ - go rest \rest' -> cont (Nest (WithExpl expl b') rest') + go explsRest rest \rest' -> cont (Nest b' rest') + _ -> error "zip error" isMethodAccessAllowedBy :: EnvReader m => RequiredMethodAccess -> InstanceName n -> m n Bool isMethodAccessAllowedBy access instanceName = do - InstanceDef className _ _ (InstanceBody _ methods) <- lookupInstanceDef instanceName + InstanceDef className _ _ _ (InstanceBody _ methods) <- lookupInstanceDef instanceName let numInstanceMethods = length methods - ClassDef _ _ _ _ _ methodTys <- lookupClassDef className + ClassDef _ _ _ _ _ _ methodTys <- lookupClassDef className let numClassMethods = length methodTys case access of Full -> return $ numClassMethods == numInstanceMethods @@ -2899,34 +2936,35 @@ synthDictFromInstance :: DictType n -> SyntherM n (SynthAtom n) synthDictFromInstance targetTy@(DictType _ targetClass _) = do instances <- getInstanceDicts targetClass asum $ instances <&> \candidate -> do - CorePiType _ bs (EffTy _ (DictTy candidateTy)) <- lookupInstanceTy candidate - args <- instantiateSynthArgs targetTy $ Abs bs candidateTy + CorePiType _ expls bs (EffTy _ (DictTy candidateTy)) <- lookupInstanceTy candidate + args <- instantiateSynthArgs targetTy (expls, Abs bs candidateTy) return $ DictCon (DictTy targetTy) $ InstanceDict candidate args instantiateSynthArgs :: DictType n -> SynthPiType n -> SyntherM n [CAtom n] -instantiateSynthArgs targetTop (Abs bsTop resultTyTop) = do +instantiateSynthArgs targetTop (explsTop, Abs bsTop resultTyTop) = do ListE args <- (liftExceptAlt =<<) $ liftSolverM $ solveLocal do - args <- runSubstReaderT idSubst $ go (sink targetTop) (sink $ Abs bsTop resultTyTop) + args <- runSubstReaderT idSubst $ go (sink targetTop) explsTop (sink $ Abs bsTop resultTyTop) zonk $ ListE args forM args \case DictHole _ argTy req -> liftExceptAlt (typeAsSynthType argTy) >>= flip synthTerm req arg -> return arg where go :: EmitsInf o - => DictType o -> Abs (Nest (WithExpl CBinder)) DictType i + => DictType o -> [Explicitness] -> Abs (Nest CBinder) DictType i -> SubstReaderT AtomSubstVal SolverM i o [CAtom o] - go target (Abs bs proposed) = case bs of - Empty -> do + go target allExpls (Abs bs proposed) = case (allExpls, bs) of + ([], Empty) -> do proposed' <- substM proposed liftSubstReaderT $ unify target proposed' return [] - Nest (WithExpl expl b) rest -> do + (expl:expls, Nest b rest) -> do argTy <- substM $ binderType b arg <- liftSubstReaderT case expl of Explicit -> error "instances shouldn't have explicit args" Inferred _ Unify -> Var <$> freshInferenceName MiscInfVar argTy Inferred _ (Synth req) -> return $ DictHole (AlwaysEqual emptySrcPosCtx) argTy req - liftM (arg:) $ extendSubst (b@>SubstVal arg) $ go target (Abs rest proposed) + liftM (arg:) $ extendSubst (b@>SubstVal arg) $ go target expls (Abs rest proposed) + _ -> error "zip error" synthDictForData :: forall n. DictType n -> SyntherM n (SynthAtom n) synthDictForData dictTy@(DictType "Data" dName [Type ty]) = case ty of @@ -3010,12 +3048,10 @@ instance DictSynthTraversable CAtom where case ans of Failure errs -> put (LiftE errs) >> renameM atom Success d -> return d - Lam (CoreLamExpr piTy@(CorePiType _ bsPi _) (LamExpr bsLam (Abs decls result))) -> do + Lam (CoreLamExpr piTy@(CorePiType _ expls _ _) (LamExpr bsLam (Abs decls result))) -> do Pi piTy' <- dsTraverse $ Pi piTy - let (expls, _) = unzipExpls bsPi - lam' <- dsTraverseExplBinders (zipExpls expls bsLam) \bsLamExpl' -> do + lam' <- dsTraverseExplBinders expls bsLam \bsLam' -> do visitDeclsNoEmits decls \decls' -> do - let (_, bsLam') = unzipExpls bsLamExpl' LamExpr bsLam' <$> Abs decls' <$> dsTraverse result return $ Lam $ CoreLamExpr piTy' lam' Var _ -> renameM atom @@ -3025,9 +3061,9 @@ instance DictSynthTraversable CAtom where instance DictSynthTraversable CType where dsTraverse ty = case ty of - Pi (CorePiType appExpl bs (EffTy effs resultTy)) -> Pi <$> - dsTraverseExplBinders bs \bs' -> do - CorePiType appExpl bs' <$> (EffTy <$> renameM effs <*> dsTraverse resultTy) + Pi (CorePiType appExpl expls bs (EffTy effs resultTy)) -> Pi <$> + dsTraverseExplBinders expls bs \bs' -> do + CorePiType appExpl expls bs' <$> (EffTy <$> renameM effs <*> dsTraverse resultTy) TyVar _ -> renameM ty ProjectEltTy _ _ _ -> renameM ty _ -> visitTypePartial ty @@ -3035,16 +3071,17 @@ instance DictSynthTraversable CType where instance DictSynthTraversable DataConDefs where dsTraverse = visitGeneric dsTraverseExplBinders - :: Nest (WithExpl CBinder) i i' - -> (forall o'. DExt o o' => Nest (WithExpl CBinder) o o' -> DictSynthTraverserM i' o' a) + :: [Explicitness] -> Nest CBinder i i' + -> (forall o'. DExt o o' => Nest CBinder o o' -> DictSynthTraverserM i' o' a) -> DictSynthTraverserM i o a -dsTraverseExplBinders Empty cont = getDistinct >>= \Distinct -> cont Empty -dsTraverseExplBinders (Nest (WithExpl expl b) bs) cont = do +dsTraverseExplBinders [] Empty cont = getDistinct >>= \Distinct -> cont Empty +dsTraverseExplBinders (expl:expls) (Nest b bs) cont = do ty <- dsTraverse $ binderType b withFreshBinder (getNameHint b) ty \b' -> do let v = binderName b' extendSynthCandidatesDict expl v $ extendRenamer (b@>v) do - dsTraverseExplBinders bs \bs' -> cont $ Nest (WithExpl expl b') bs' + dsTraverseExplBinders expls bs \bs' -> cont $ Nest b' bs' +dsTraverseExplBinders _ _ _ = error "zip error" extendSynthCandidatesDict :: Explicitness -> CAtomName n -> DictSynthTraverserM i n a -> DictSynthTraverserM i n a extendSynthCandidatesDict c v cont = DictSynthTraverserM do @@ -3063,6 +3100,9 @@ extendSynthCandidatesDict c v cont = DictSynthTraverserM do -- the needs of inference, like adding `SubstE AtomSubstVal e` constraints in -- various places. +type WithExpl = WithAttrB Explicitness +type WithRoleExpl = WithAttrB RoleExpl + buildBlockInf :: EmitsInf n => (forall l. (EmitsBoth l, DExt n l) => InfererM i l (CAtom l)) @@ -3096,10 +3136,8 @@ buildTabPiInf -> (forall l. (EmitsInf l, Ext n l) => CAtomVar l -> InfererM i l (CType l)) -> InfererM i n (TabPiType CoreIR n) buildTabPiInf hint (IxType t d) body = do - Abs (WithExpl _ (b:>_)) resultTy <- - buildAbsInf hint Explicit t \v -> - withoutEffects $ body v - return $ TabPiType d (b:>t) resultTy + Abs b resultTy <- buildAbsInf hint Explicit t \v -> withoutEffects $ body v + return $ TabPiType d b resultTy buildDepPairTyInf :: EmitsInf n @@ -3108,7 +3146,7 @@ buildDepPairTyInf -> InfererM i n (DepPairType CoreIR n) buildDepPairTyInf hint expl ty body = do Abs b resultTy <- buildAbsInf hint Explicit ty body - return $ DepPairType expl (withoutExpl b) resultTy + return $ DepPairType expl b resultTy buildAltInf :: EmitsInf n @@ -3116,11 +3154,10 @@ buildAltInf -> (forall l. (EmitsBoth l, Ext n l) => CAtomVar l -> InfererM i l (CAtom l)) -> InfererM i n (Alt CoreIR n) buildAltInf ty body = do - Abs b body' <- buildAbsInf noHint Explicit ty \v -> + buildAbsInf noHint Explicit ty \v -> buildBlockInf do Distinct <- getDistinct body $ sink v - return $ Abs (withoutExpl b) body' -- === EmitsInf predicate === @@ -3190,11 +3227,11 @@ instance BindsEnv InfOutFrag where toEnvFrag (InfOutFrag frag _ _) = toEnvFrag frag instance GenericE SynthType where - type RepE SynthType = EitherE2 DictType (Abs (Nest (WithExpl CBinder)) DictType) + type RepE SynthType = EitherE2 DictType (PairE (LiftE [Explicitness]) (Abs (Nest CBinder) DictType)) fromE (SynthDictType d) = Case0 d - fromE (SynthPiType t) = Case1 t + fromE (SynthPiType (expl, t)) = Case1 (PairE (LiftE expl) t) toE (Case0 d) = SynthDictType d - toE (Case1 t) = SynthPiType t + toE (Case1 (PairE (LiftE expl) t)) = SynthPiType (expl, t) toE _ = error "impossible" instance AlphaEqE SynthType diff --git a/src/lib/JAX/ToSimp.hs b/src/lib/JAX/ToSimp.hs index a3b012ea6..e2e183955 100644 --- a/src/lib/JAX/ToSimp.hs +++ b/src/lib/JAX/ToSimp.hs @@ -30,8 +30,8 @@ liftJaxSimpM :: (EnvReader m) => JaxSimpM n n (e n) -> m n (e n) liftJaxSimpM act = liftBuilder $ runSubstReaderT idSubst $ runJaxSimpM act {-# INLINE liftJaxSimpM #-} -simplifyClosedJaxpr :: ClosedJaxpr i -> JaxSimpM i o (LamExpr SimpIR o) -simplifyClosedJaxpr ClosedJaxpr{jaxpr, consts=[]} = simplifyJaxpr jaxpr +simplifyClosedJaxpr :: ClosedJaxpr i -> JaxSimpM i o (TopLam SimpIR o) +simplifyClosedJaxpr ClosedJaxpr{jaxpr, consts=[]} = asTopLam =<< simplifyJaxpr jaxpr simplifyClosedJaxpr _ = error "TODO Support consts" simplifyJaxpr :: Jaxpr i -> JaxSimpM i o (LamExpr SimpIR o) diff --git a/src/lib/Name.hs b/src/lib/Name.hs index ddb01ab8d..68d1ad2f8 100644 --- a/src/lib/Name.hs +++ b/src/lib/Name.hs @@ -521,6 +521,20 @@ data PairB (b1::B) (b2::B) (n::S) (l::S) where PairB :: b1 n l' -> b2 l' l -> PairB b1 b2 n l deriving instance (ShowB b1, ShowB b2) => Show (PairB b1 b2 n l) +data WithAttrB (a:: *) (b::B) (n::S) (l::S) = + WithAttrB {getAttr :: a , withoutAttr :: b n l } + deriving (Show, Generic) + +unzipAttrs :: Nest (WithAttrB a b) n l -> ([a], Nest b n l) +unzipAttrs Empty = ([], Empty) +unzipAttrs (Nest (WithAttrB a b) rest) = (a:as, Nest b bs) + where (as, bs) = unzipAttrs rest + +zipAttrs :: [a] -> Nest b n l -> Nest (WithAttrB a b) n l +zipAttrs [] Empty = Empty +zipAttrs (a:as) (Nest b bs) = Nest (WithAttrB a b) (zipAttrs as bs) +zipAttrs _ _ = error "zip error" + data EitherB (b1::B) (b2::B) (n::S) (l::S) = LeftB (b1 n l) | RightB (b2 n l) @@ -655,7 +669,7 @@ forNest :: Nest b i i' -> Nest b' i i' forNest n f = fmapNest f n -zipWithNest :: Nest b n l -> [a] +zipWithNest :: Nest b n l -> [a] -> (forall n1 n2. b n1 n2 -> a -> b' n1 n2) -> Nest b' n l zipWithNest Empty [] _ = Empty @@ -3195,6 +3209,35 @@ instance Monad HoistExcept where HoistSuccess x >>= f = f x {-# INLINE (>>=) #-} +instance (Store a, Store (b n l)) => Store (WithAttrB a b n l) + +instance (Eq a, AlphaEqB b) => AlphaEqB (WithAttrB a b) where + withAlphaEqB (WithAttrB a1 b1) (WithAttrB a2 b2) cont = do + unless (a1 == a2) zipErr + withAlphaEqB b1 b2 cont + +instance (Hashable a, AlphaHashableB b) => AlphaHashableB (WithAttrB a b) where + hashWithSaltB env salt (WithAttrB expl b) = do + let h = hashWithSalt salt expl + hashWithSaltB env h b + +instance BindsNames b => ProvesExt (WithAttrB a b) where +instance BindsNames b => BindsNames (WithAttrB a b) where + toScopeFrag (WithAttrB _ b) = toScopeFrag b + +instance (SinkableB b) => SinkableB (WithAttrB a b) where + sinkingProofB fresh (WithAttrB a b) cont = + sinkingProofB fresh b \fresh' b' -> + cont fresh' (WithAttrB a b') + +instance (BindsNames b, RenameB b) => RenameB (WithAttrB a b) where + renameB env (WithAttrB a b) cont = + renameB env b \env' b' -> + cont env' $ WithAttrB a b' + +instance HoistableB b => HoistableB (WithAttrB a b) where + freeVarsB (WithAttrB _ b) = freeVarsB b + -- === extra data structures === -- A map from names in some scope to values that do not contain names. This is diff --git a/src/lib/PPrint.hs b/src/lib/PPrint.hs index c68ea0390..23bc7ea60 100644 --- a/src/lib/PPrint.hs +++ b/src/lib/PPrint.hs @@ -296,26 +296,26 @@ forStr Fwd = "for" forStr Rev = "rof" instance Pretty (CorePiType n) where - pretty (CorePiType appExpl bs (EffTy eff resultTy)) = - prettyBindersWithExpl bs <+> p appExpl <> prettyEff <> p resultTy + pretty (CorePiType appExpl expls bs (EffTy eff resultTy)) = + prettyBindersWithExpl expls bs <+> p appExpl <> prettyEff <> p resultTy where prettyEff = case eff of Pure -> space _ -> space <> pretty eff <> space prettyBindersWithExpl :: forall b n l ann. PrettyB b - => Nest (WithExpl b) n l -> Doc ann -prettyBindersWithExpl bs = do - let groups = groupByExpl $ fromNest bs + => [Explicitness] -> Nest b n l -> Doc ann +prettyBindersWithExpl expls bs = do + let groups = groupByExpl $ zip expls (fromNest bs) let groups' = case groups of [] -> [(Explicit, [])] _ -> groups mconcat [withExplParens expl $ commaSep bsGroup | (expl, bsGroup) <- groups'] -groupByExpl :: [WithExpl b UnsafeS UnsafeS] -> [(Explicitness, [b UnsafeS UnsafeS])] +groupByExpl :: [(Explicitness, b UnsafeS UnsafeS)] -> [(Explicitness, [b UnsafeS UnsafeS])] groupByExpl [] = [] -groupByExpl (WithExpl expl b:bs) = do - let (matches, rest) = span (\(WithExpl expl' _) -> expl == expl') bs - let matches' = map withoutExpl matches +groupByExpl ((expl, b):bs) = do + let (matches, rest) = span (\(expl', _) -> expl == expl') bs + let matches' = map snd matches (expl, b:matches') : groupByExpl rest withExplParens :: Explicitness -> Doc ann -> Doc ann @@ -431,36 +431,29 @@ instance Pretty (TyConParams n) where pretty (TyConParams _ _) = undefined instance Pretty (TyConDef n) where - pretty (TyConDef name bs cons) = - "data" <+> p name <+> (p $ map (\(RolePiBinder _ b) -> b) $ fromNest bs) <> pretty cons + pretty (TyConDef name _ bs cons) = "data" <+> p name <+> p bs <> pretty cons instance Pretty (DataConDefs n) where pretty = undefined -instance Pretty (RolePiBinder n l) where - pretty (RolePiBinder _ b) = pretty b - instance Pretty (DataConDef n) where pretty (DataConDef name _ repTy _) = p name <+> ":" <+> p repTy instance Pretty (ClassDef n) where - pretty (ClassDef classSourceName methodNames _ params superclasses methodTys) = + pretty (ClassDef classSourceName methodNames _ _ params superclasses methodTys) = "Class:" <+> pretty classSourceName <+> pretty methodNames <> indented ( - line <> "parameter binders:" <+> prettyRolePiBinders params <> + line <> "parameter binders:" <+> pretty params <> line <> "superclasses:" <+> pretty superclasses <> line <> "methods:" <+> pretty methodTys) instance Pretty ParamRole where pretty r = p (show r) -prettyRolePiBinders :: RolePiBinders n l -> Doc ann -prettyRolePiBinders = undefined - instance Pretty (InstanceDef n) where - pretty (InstanceDef className bs params _) = - "Instance" <+> p className <+> prettyRolePiBinders bs <+> p params + pretty (InstanceDef className _ bs params _) = + "Instance" <+> p className <+> pretty bs <+> p params deriving instance (forall c n. Pretty (v c n)) => Pretty (RecSubst v o) @@ -629,14 +622,11 @@ instance Pretty FieldName' where instance Pretty (UAlt n) where pretty (UAlt pat body) = p pat <+> "->" <+> p body -instance PrettyB b => Pretty (WithExpl b n l) where - pretty (WithExpl _ b) = pretty b - instance Pretty (UTopDecl n l) where - pretty (UDataDefDecl (UDataDef nm bs dataCons) bTyCon bDataCons) = + pretty (UDataDefDecl (UDataDef nm (_, bs) dataCons) bTyCon bDataCons) = "data" <+> p bTyCon <+> p nm <+> spaced (fromNest bs) <+> "where" <> nest 2 (prettyLines (zip (toList $ fromNest bDataCons) dataCons)) - pretty (UStructDecl bTyCon (UStructDef nm bs fields defs)) = + pretty (UStructDecl bTyCon (UStructDef nm (_, bs) fields defs)) = "struct" <+> p bTyCon <+> p nm <+> spaced (fromNest bs) <+> "where" <> nest 2 (prettyLines fields <> prettyLines defs) pretty (UInterface params methodTys interfaceName methodNames) = diff --git a/src/lib/QueryType.hs b/src/lib/QueryType.hs index eaad95cba..f5952402b 100644 --- a/src/lib/QueryType.hs +++ b/src/lib/QueryType.hs @@ -73,7 +73,7 @@ blockEff :: (EnvReader m, IRRep r) => Block r n -> m n (EffectRow r n) blockEff b = blockEffTy b <&> \(EffTy eff _) -> eff typeOfApp :: (IRRep r, EnvReader m) => Type r n -> [Atom r n] -> m n (Type r n) -typeOfApp (Pi (CorePiType _ bs (EffTy _ resultTy))) xs = do +typeOfApp (Pi (CorePiType _ _ bs (EffTy _ resultTy))) xs = do let subst = bs @@> fmap SubstVal xs applySubst subst resultTy typeOfApp _ _ = error "expected a pi type" @@ -93,14 +93,14 @@ typeOfApplyMethod d i args = do typeOfDictExpr :: EnvReader m => DictExpr n -> m n (CType n) typeOfDictExpr e = liftM ignoreExcept $ liftEnvReaderT $ case e of InstanceDict instanceName args -> do - InstanceDef className bs params _ <- lookupInstanceDef instanceName - ClassDef sourceName _ _ _ _ _ <- lookupClassDef className + InstanceDef className _ bs params _ <- lookupInstanceDef instanceName + ClassDef sourceName _ _ _ _ _ _ <- lookupClassDef className ListE params' <- applySubst (bs @@> map SubstVal args) $ ListE params return $ DictTy $ DictType sourceName className params' InstantiatedGiven given args -> typeOfApp (getType given) args SuperclassProj d i -> do DictTy (DictType _ className params) <- return $ getType d - ClassDef _ _ _ bs superclasses _ <- lookupClassDef className + ClassDef _ _ _ _ bs superclasses _ <- lookupClassDef className applySubst (bs @@> map SubstVal params) $ getSuperclassType REmpty superclasses i IxFin n -> liftM DictTy $ ixDictType $ NewtypeTyCon $ Fin n @@ -131,20 +131,21 @@ typeOfProjRef (TC (RefType h s)) p = do typeOfProjRef _ _ = error "expected a reference" appEffTy :: (IRRep r, EnvReader m) => Type r n -> [Atom r n] -> m n (EffTy r n) -appEffTy (Pi (CorePiType _ bs effTy)) xs = do +appEffTy (Pi (CorePiType _ _ bs effTy)) xs = do let subst = bs @@> fmap SubstVal xs applySubst subst effTy appEffTy t _ = error $ "expected a pi type, got: " ++ pprint t partialAppType :: (IRRep r, EnvReader m) => Type r n -> [Atom r n] -> m n (Type r n) -partialAppType (Pi (CorePiType expl bs effTy)) xs = do +partialAppType (Pi (CorePiType appExpl expls bs effTy)) xs = do + (_, expls2) <- return $ splitAt (length xs) expls PairB bs1 bs2 <- return $ splitNestAt (length xs) bs let subst = bs1 @@> fmap SubstVal xs - applySubst subst $ Pi $ CorePiType expl bs2 effTy + applySubst subst $ Pi $ CorePiType appExpl expls2 bs2 effTy partialAppType _ _ = error "expected a pi type" appEffects :: (EnvReader m, IRRep r) => Type r n -> [Atom r n] -> m n (EffectRow r n) -appEffects (Pi (CorePiType _ bs (EffTy effs _))) xs = do +appEffects (Pi (CorePiType _ _ bs (EffTy effs _))) xs = do let subst = bs @@> fmap SubstVal xs applySubst subst effs appEffects _ _ = error "expected a pi type" @@ -198,7 +199,7 @@ deleteEff eff (EffectRow effs t) = EffectRow (effs `eSetDifference` eSetSingleto getMethodIndex :: EnvReader m => ClassName n -> SourceName -> m n Int getMethodIndex className methodSourceName = do - ClassDef _ methodNames _ _ _ _ <- lookupClassDef className + ClassDef _ methodNames _ _ _ _ _ <- lookupClassDef className case elemIndex methodSourceName methodNames of Nothing -> error $ methodSourceName ++ " is not a method of " ++ pprint className Just i -> return i @@ -211,9 +212,8 @@ getUVarType = \case UDataConVar v -> getDataConNameType v UPunVar v -> getStructDataConType v UClassVar v -> do - ClassDef _ _ _ bs _ _ <- lookupClassDef v - let bs' = fmapNest (\(RolePiBinder _ b) -> b) bs - return $ Pi $ CorePiType ExplicitApp bs' $ EffTy Pure TyKind + ClassDef _ _ _ roleExpls bs _ _ <- lookupClassDef v + return $ Pi $ CorePiType ExplicitApp (map snd roleExpls) bs $ EffTy Pure TyKind UMethodVar v -> getMethodNameType v UEffectVar _ -> error "not implemented" UEffectOpVar _ -> error "not implemented" @@ -221,23 +221,22 @@ getUVarType = \case getMethodNameType :: EnvReader m => MethodName n -> m n (CType n) getMethodNameType v = liftEnvReaderM $ lookupEnv v >>= \case MethodBinding className i -> do - ClassDef _ _ paramNames paramBs scBinders methodTys <- lookupClassDef className - let paramBs' = zipWithNest paramBs paramNames \(RolePiBinder _ (WithExpl _ b)) paramName -> - WithExpl (Inferred paramName Unify) b - refreshAbs (Abs paramBs' $ Abs scBinders (methodTys !! i)) \paramBs'' (Abs scBinders' piTy) -> do - let params = Var <$> nestToAtomVars (fmapNest withoutExpl paramBs'') + ClassDef _ _ paramNames _ paramBs scBinders methodTys <- lookupClassDef className + refreshAbs (Abs paramBs $ Abs scBinders (methodTys !! i)) \paramBs' (Abs scBinders' piTy) -> do + let params = Var <$> nestToAtomVars paramBs' dictTy <- DictTy <$> dictType (sink className) params withFreshBinder noHint dictTy \dictB -> do scDicts <- getSuperclassDicts (Var $ binderVar dictB) piTy' <- applySubst (scBinders'@@>(SubstVal<$>scDicts)) piTy - CorePiType appExpl methodBs effTy <- return piTy' - let dictBs = UnaryNest $ WithExpl (Inferred Nothing (Synth $ Partial $ succ i)) dictB - return $ Pi $ CorePiType appExpl (paramBs'' >>> dictBs >>> methodBs) effTy + CorePiType appExpl methodExpls methodBs effTy <- return piTy' + let paramExpls = paramNames <&> \name -> Inferred name Unify + let expls = paramExpls <> [Inferred Nothing (Synth $ Partial $ succ i)] <> methodExpls + return $ Pi $ CorePiType appExpl expls (paramBs' >>> UnaryNest dictB >>> methodBs) effTy getMethodType :: EnvReader m => Dict n -> Int -> m n (CorePiType n) getMethodType dict i = do ~(DictTy (DictType _ className params)) <- return $ getType dict - ClassDef _ _ _ paramBs classBs methodTys <- lookupClassDef className + ClassDef _ _ _ _ paramBs classBs methodTys <- lookupClassDef className let methodTy = methodTys !! i superclassDicts <- getSuperclassDicts dict let subst = ( paramBs @@> map SubstVal params @@ -245,60 +244,56 @@ getMethodType dict i = do applySubst subst methodTy {-# INLINE getMethodType #-} - getTyConNameType :: EnvReader m => TyConName n -> m n (Type CoreIR n) getTyConNameType v = do - TyConDef _ bs _ <- lookupTyCon v + TyConDef _ expls bs _ <- lookupTyCon v case bs of Empty -> return TyKind - _ -> do - let bs' = fmapNest (\(RolePiBinder _ b) -> b) bs - return $ Pi $ CorePiType ExplicitApp bs' $ EffTy Pure TyKind + _ -> return $ Pi $ CorePiType ExplicitApp (snd <$> expls) bs $ EffTy Pure TyKind getDataConNameType :: EnvReader m => DataConName n -> m n (Type CoreIR n) getDataConNameType dataCon = liftEnvReaderM do (tyCon, i) <- lookupDataCon dataCon lookupTyCon tyCon >>= \case - tyConDef@(TyConDef tcSn paramBs ~(ADTCons dataCons)) -> - buildDataConType tyConDef \paramBs' paramVs params -> do + tyConDef@(TyConDef tcSn _ paramBs ~(ADTCons dataCons)) -> + buildDataConType tyConDef \expls paramBs' paramVs params -> do DataConDef _ ab _ _ <- applyRename (paramBs @@> paramVs) (dataCons !! i) refreshAbs ab \dataBs UnitE -> do let appExpl = case dataBs of Empty -> ImplicitApp _ -> ExplicitApp let resultTy = NewtypeTyCon $ UserADTType tcSn (sink tyCon) (sink params) - let dataBs' = fmapNest (WithExpl Explicit) dataBs - return $ Pi $ CorePiType appExpl (paramBs' >>> dataBs') (EffTy Pure resultTy) + let dataExpls = nestToList (const $ Explicit) dataBs + return $ Pi $ CorePiType appExpl (expls <> dataExpls) (paramBs' >>> dataBs) (EffTy Pure resultTy) getStructDataConType :: EnvReader m => TyConName n -> m n (CType n) getStructDataConType tyCon = liftEnvReaderM do - tyConDef@(TyConDef tcSn paramBs ~(StructFields fields)) <- lookupTyCon tyCon - buildDataConType tyConDef \paramBs' paramVs params -> do + tyConDef@(TyConDef tcSn _ paramBs ~(StructFields fields)) <- lookupTyCon tyCon + buildDataConType tyConDef \expls paramBs' paramVs params -> do fieldTys <- forM fields \(_, t) -> applyRename (paramBs @@> paramVs) t let resultTy = NewtypeTyCon $ UserADTType tcSn (sink tyCon) params Abs dataBs resultTy' <- return $ typesAsBinderNest fieldTys resultTy - let dataBs' = fmapNest (WithExpl Explicit) dataBs - return $ Pi $ CorePiType ExplicitApp (paramBs' >>> dataBs') (EffTy Pure resultTy') + let dataExpls = nestToList (const Explicit) dataBs + return $ Pi $ CorePiType ExplicitApp (expls <> dataExpls) (paramBs' >>> dataBs) (EffTy Pure resultTy') buildDataConType :: (EnvReader m, EnvExtender m) => TyConDef n - -> (forall l. DExt n l => Nest (WithExpl CBinder) n l -> [CAtomName l] -> TyConParams l -> m l a) + -> (forall l. DExt n l => [Explicitness] -> Nest CBinder n l -> [CAtomName l] -> TyConParams l -> m l a) -> m n a -buildDataConType (TyConDef _ bs _) cont = do - bs' <- return $ forNest bs \(RolePiBinder _ (WithExpl expl b)) -> case expl of - Explicit -> WithExpl (Inferred Nothing Unify) b - _ -> WithExpl expl b - refreshAbs (Abs bs' UnitE) \bs'' UnitE -> do - let expls = nestToList (\(RolePiBinder _ b) -> getExpl b) bs - let vs = nestToNames bs'' +buildDataConType (TyConDef _ roleExpls bs _) cont = do + let expls = snd <$> roleExpls + expls' <- forM expls \case + Explicit -> return $ Inferred Nothing Unify + expl -> return $ expl + refreshAbs (Abs bs UnitE) \bs' UnitE -> do + let vs = nestToNames bs' vs' <- mapM toAtomVar vs - cont bs'' vs $ TyConParams expls (Var <$> vs') + cont expls' bs' vs $ TyConParams expls (Var <$> vs') makeTyConParams :: EnvReader m => TyConName n -> [CAtom n] -> m n (TyConParams n) makeTyConParams tc params = do - TyConDef _ paramBs _ <- lookupTyCon tc - let expls = nestToList (\(RolePiBinder _ b) -> getExpl b) paramBs - return $ TyConParams expls params + TyConDef _ expls _ _ <- lookupTyCon tc + return $ TyConParams (map snd expls) params getDataClassName :: (Fallible1 m, EnvReader m) => m n (ClassName n) getDataClassName = lookupSourceMap "Data" >>= \case @@ -319,7 +314,7 @@ getIxClassName = lookupSourceMap "Ix" >>= \case dictType :: EnvReader m => ClassName n -> [CAtom n] -> m n (DictType n) dictType className params = do - ClassDef sourceName _ _ _ _ _ <- lookupClassDef className + ClassDef sourceName _ _ _ _ _ _ <- lookupClassDef className return $ DictType sourceName className params ixDictType :: (Fallible1 m, EnvReader m) => CType n -> m n (DictType n) @@ -374,7 +369,7 @@ getSuperclassDicts dict = do getSuperclassTys :: EnvReader m => DictType n -> m n [CType n] getSuperclassTys (DictType _ className params) = do - ClassDef _ _ _ bs superclasses _ <- lookupClassDef className + ClassDef _ _ _ _ bs superclasses _ <- lookupClassDef className forM [0 .. nestLength superclasses - 1] \i -> do applySubst (bs @@> map SubstVal params) $ getSuperclassType REmpty superclasses i diff --git a/src/lib/QueryTypePure.hs b/src/lib/QueryTypePure.hs index 3bda0d8fd..2501cbf8f 100644 --- a/src/lib/QueryTypePure.hs +++ b/src/lib/QueryTypePure.hs @@ -221,8 +221,8 @@ typesAsBinderNest types body = toConstBinderNest types body nonDepPiType :: [CType n] -> EffectRow CoreIR n -> CType n -> CorePiType n nonDepPiType argTys eff resultTy = case typesAsBinderNest argTys (PairE eff resultTy) of Abs bs (PairE eff' resultTy') -> do - let bs' = fmapNest (WithExpl Explicit) bs - CorePiType ExplicitApp bs' $ EffTy eff' resultTy' + let expls = nestToList (const Explicit) bs + CorePiType ExplicitApp expls bs $ EffTy eff' resultTy' nonDepTabPiType :: IRRep r => IxType r n -> Type r n -> TabPiType r n nonDepTabPiType (IxType t d) resultTy = @@ -230,7 +230,7 @@ nonDepTabPiType (IxType t d) resultTy = Abs b resultTy' -> TabPiType d (b:>t) resultTy' corePiTypeToPiType :: CorePiType n -> PiType CoreIR n -corePiTypeToPiType (CorePiType _ bs effTy) = PiType (fmapNest withoutExpl bs) effTy +corePiTypeToPiType (CorePiType _ _ bs effTy) = PiType bs effTy coreLamToTopLam :: CoreLamExpr n -> TopLam CoreIR n coreLamToTopLam (CoreLamExpr ty f) = TopLam False (corePiTypeToPiType ty) f diff --git a/src/lib/RuntimePrint.hs b/src/lib/RuntimePrint.hs index d1b17c792..3255773ad 100644 --- a/src/lib/RuntimePrint.hs +++ b/src/lib/RuntimePrint.hs @@ -169,8 +169,9 @@ withBuffer cont = do body <- buildBlock do cont $ sink $ Var $ binderVar b return UnitVal - let piBinders = BinaryNest (WithExpl (Inferred Nothing Unify) h) (WithExpl Explicit b) - let piTy = CorePiType ExplicitApp piBinders $ EffTy eff UnitTy + let binders = BinaryNest h b + let expls = [Inferred Nothing Unify, Explicit] + let piTy = CorePiType ExplicitApp expls binders $ EffTy eff UnitTy let lam = LamExpr (BinaryNest h b) body return $ Lam $ CoreLamExpr piTy lam applyPreludeFunction "with_stack_internal" [lam] diff --git a/src/lib/Simplify.hs b/src/lib/Simplify.hs index 2e820ee71..c151937b6 100644 --- a/src/lib/Simplify.hs +++ b/src/lib/Simplify.hs @@ -618,8 +618,8 @@ simplifyDictMethod absDict@(Abs bs dict) method = do ixMethodType :: IxMethod -> AbsDict n -> EnvReaderM n (PiType CoreIR n) ixMethodType method absDict = do refreshAbs absDict \extraArgBs dict -> do - CorePiType _ methodArgs (EffTy _ resultTy) <- getMethodType dict (fromEnum method) - let allBs = extraArgBs >>> fmapNest withoutExpl methodArgs + CorePiType _ _ methodArgs (EffTy _ resultTy) <- getMethodType dict (fromEnum method) + let allBs = extraArgBs >>> methodArgs return $ PiType allBs (EffTy Pure resultTy) -- TODO: do we even need this, or is it just a glorified `SubstM`? @@ -788,7 +788,7 @@ applyDictMethod resultTy d i methodArgs = do cheapNormalize d >>= \case DictCon _ (InstanceDict instanceName instanceArgs) -> dropSubst do instanceArgs' <- mapM simplifyAtom instanceArgs - InstanceDef _ bsInstance _ body <- lookupInstanceDef instanceName + InstanceDef _ _ bsInstance _ body <- lookupInstanceDef instanceName let InstanceBody _ methods = body let method = methods !! i extendSubst (bsInstance @@> (SubstVal <$> instanceArgs')) do @@ -921,7 +921,7 @@ preludeNothingVal ty = do preludeMaybeNewtypeCon :: EnvReader m => CType n -> m n (NewtypeCon n) preludeMaybeNewtypeCon ty = do ~(Just (UTyConVar tyConName)) <- lookupSourceMap "Maybe" - TyConDef sn _ _ <- lookupTyCon tyConName + TyConDef sn _ _ _ <- lookupTyCon tyConName let params = TyConParams [Explicit] [Type ty] return $ UserADTData sn tyConName params @@ -997,7 +997,7 @@ simplifyCustomLinearization (Abs runtimeBs staticArgs) actives rule = do -- a custom linearization defined for a function on ADTs will -- not work. fLin' <- sinkM fLin - Pi (CorePiType _ bs _) <- return $ getType fLin' + Pi (CorePiType _ _ bs _) <- return $ getType fLin' let tangentCoreTys = fromNonDepNest bs tangentArgs' <- zipWithM liftSimpAtom tangentCoreTys tangentArgs resultTyTangent <- typeOfApp (getType fLin') tangentArgs' diff --git a/src/lib/SourceRename.hs b/src/lib/SourceRename.hs index 45fafc01e..3ee3b13b1 100644 --- a/src/lib/SourceRename.hs +++ b/src/lib/SourceRename.hs @@ -180,9 +180,9 @@ instance SourceRenamableE UExpr' where UVar v -> UVar <$> sourceRenameE v ULit l -> return $ ULit l ULam lam -> ULam <$> sourceRenameE lam - UPi (UPiExpr pats appExpl eff body) -> + UPi (UPiExpr (attrs, pats) appExpl eff body) -> sourceRenameB pats \pats' -> - UPi <$> (UPiExpr pats' <$> pure appExpl <*> sourceRenameE eff <*> sourceRenameE body) + UPi <$> (UPiExpr (attrs, pats') <$> pure appExpl <*> sourceRenameE eff <*> sourceRenameE body) UApp f xs ys -> UApp <$> sourceRenameE f <*> forM xs sourceRenameE <*> forM ys (\(name, y) -> (name,) <$> sourceRenameE y) @@ -245,20 +245,20 @@ instance SourceRenamableB UTopDecl where sourceRenameUBinder UPunVar tyConName \tyConName' -> do structDef' <- sourceRenameE structDef cont $ UStructDecl tyConName' structDef' - UInterface paramBs methodTys className methodNames -> do + UInterface (attrs, paramBs) methodTys className methodNames -> do Abs paramBs' (ListE methodTys') <- sourceRenameB paramBs \paramBs' -> do methodTys' <- mapM sourceRenameE methodTys return $ Abs paramBs' $ ListE methodTys' sourceRenameUBinder UClassVar className \className' -> sourceRenameUBinderNest UMethodVar methodNames \methodNames' -> - cont $ UInterface paramBs' methodTys' className' methodNames' - UInstance className conditions params methodDefs instanceName expl -> do + cont $ UInterface (attrs, paramBs') methodTys' className' methodNames' + UInstance className (roleExpls, conditions) params methodDefs instanceName expl -> do className' <- sourceRenameE className Abs conditions' (PairE (ListE params') (ListE methodDefs')) <- sourceRenameE $ Abs conditions (PairE (ListE params) $ ListE methodDefs) sourceRenameB instanceName \instanceName' -> - cont $ UInstance className' conditions' params' methodDefs' instanceName' expl + cont $ UInstance className' (roleExpls, conditions') params' methodDefs' instanceName' expl UEffectDecl opTypes effName opNames -> do opTypes' <- mapM (\(UEffectOpType p ty) -> (UEffectOpType p) <$> sourceRenameE ty) opTypes sourceRenameUBinder UEffectVar effName \effName' -> @@ -277,8 +277,8 @@ instance SourceRenamableB UDecl' where UPass -> cont UPass instance SourceRenamableE ULamExpr where - sourceRenameE (ULamExpr args expl effs resultTy body) = - sourceRenameB args \args' -> ULamExpr args' + sourceRenameE (ULamExpr (expls, args) expl effs resultTy body) = + sourceRenameB args \args' -> ULamExpr (expls, args') <$> pure expl <*> mapM sourceRenameE effs <*> mapM sourceRenameE resultTy @@ -304,9 +304,6 @@ instance (SourceRenamableB b1, SourceRenamableB b2) => SourceRenamableB (PairB b sourceRenameB b2 \b2' -> cont $ PairB b1' b2' -instance SourceRenamableB b => SourceRenamableB (WithExpl b) where - sourceRenameB (WithExpl x b) cont = sourceRenameB b \b' -> cont $ WithExpl x b' - sourceRenameUBinderNest :: (Color c, Renamer m, Distinct o) => (forall l. Name c l -> UVar l) @@ -339,15 +336,15 @@ sourceRenameUBinder asUVar ubinder cont = case ubinder of UIgnore -> cont UIgnore instance SourceRenamableE UDataDef where - sourceRenameE (UDataDef tyConName paramBs dataCons) = do + sourceRenameE (UDataDef tyConName (expls, paramBs) dataCons) = do sourceRenameB paramBs \paramBs' -> do dataCons' <- forM dataCons \(dataConName, argBs) -> do argBs' <- sourceRenameE argBs return (dataConName, argBs') - return $ UDataDef tyConName paramBs' dataCons' + return $ UDataDef tyConName (expls, paramBs') dataCons' instance SourceRenamableE UStructDef where - sourceRenameE (UStructDef tyConName paramBs fields methods) = do + sourceRenameE (UStructDef tyConName (expls, paramBs) fields methods) = do sourceRenameB paramBs \paramBs' -> do fields' <- forM fields \(fieldName, ty) -> do ty' <- sourceRenameE ty @@ -355,7 +352,7 @@ instance SourceRenamableE UStructDef where methods' <- forM methods \(ann, methodName, lam) -> do lam' <- sourceRenameE lam return (ann, methodName, lam') - return $ UStructDef tyConName paramBs' fields' methods' + return $ UStructDef tyConName (expls, paramBs') fields' methods' instance SourceRenamableE UDataDefTrail where sourceRenameE (UDataDefTrail args) = sourceRenameB args \args' -> diff --git a/src/lib/Subst.hs b/src/lib/Subst.hs index 5e29db46a..5b13ef624 100644 --- a/src/lib/Subst.hs +++ b/src/lib/Subst.hs @@ -18,7 +18,6 @@ import Control.Monad.State.Strict import Name import IRVariants import Types.Core -import Types.Primitives import Core import qualified RawName as R import Err @@ -444,20 +443,16 @@ instance (BindsNames b, SubstB v b, SinkableV v) instance FromName v => SubstE v UnitE where substE _ UnitE = UnitE +instance SubstB v b => SubstB v (WithAttrB a b) where + substB env (WithAttrB x b) cont = + substB env b \env' b' -> cont env' $ WithAttrB x b' + instance (Traversable f, SubstE v e) => SubstE v (ComposeE f e) where substE env (ComposeE xs) = ComposeE $ fmap (substE env) xs instance (SubstE v e1, SubstE v e2) => SubstE v (PairE e1 e2) where substE env (PairE x y) = PairE (substE env x) (substE env y) -instance SubstB v b => SubstB v (WithExpl b) where - substB env (WithExpl x b) cont = - substB env b \env' b' -> cont env' $ WithExpl x b' - -instance (FromName v, SubstB v CBinder) => SubstB v RolePiBinder where - substB env (RolePiBinder role b) cont = - substB env b \env' b' -> cont env' $ RolePiBinder role b' - instance (SubstE v e1, SubstE v e2) => SubstE v (EitherE e1 e2) where substE env (LeftE x) = LeftE $ substE env x substE env (RightE x) = RightE $ substE env x diff --git a/src/lib/TopLevel.hs b/src/lib/TopLevel.hs index 9b08665ec..fdcafcc51 100644 --- a/src/lib/TopLevel.hs +++ b/src/lib/TopLevel.hs @@ -937,8 +937,8 @@ instance Generic TopStateEx where getLinearizationType :: SymbolicZeros -> CType n -> EnvReaderT Except n (Int, Int, CType n) getLinearizationType zeros = \case - Pi (CorePiType ExplicitApp bs (EffTy Pure resultTy)) -> do - (numIs, numEs) <- getNumImplicits $ fst $ unzipExpls bs + Pi (CorePiType ExplicitApp expls bs (EffTy Pure resultTy)) -> do + (numIs, numEs) <- getNumImplicits expls refreshAbs (Abs bs resultTy) \bs' resultTy' -> do PairB _ bsE <- return $ splitNestAt numIs bs' let explicitArgTys = nestToList (\b -> sink $ binderType b) bsE @@ -951,7 +951,7 @@ getLinearizationType zeros = \case Just rtt -> return rtt Nothing -> throw TypeErr $ "No tangent type for: " ++ pprint resultTy' let tanFunTy = Pi $ nonDepPiType argTanTys Pure resultTanTy - let fullTy = CorePiType ExplicitApp bs' $ EffTy Pure (PairTy resultTy' tanFunTy) + let fullTy = CorePiType ExplicitApp expls bs' $ EffTy Pure (PairTy resultTy' tanFunTy) return (numIs, numEs, Pi fullTy) _ -> throw TypeErr $ "Can't define a custom linearization for implicit or impure functions" where diff --git a/src/lib/Types/Core.hs b/src/lib/Types/Core.hs index 9f22018b8..9e9062d3f 100644 --- a/src/lib/Types/Core.hs +++ b/src/lib/Types/Core.hs @@ -151,7 +151,8 @@ data TyConDef n where -- binder name is in UExpr and Env TyConDef :: SourceName - -> RolePiBinders n l + -> [RoleExpl] + -> Nest CBinder n l -> DataConDefs l -> TyConDef n @@ -213,10 +214,8 @@ data TabPiType (r::IR) (n::S) where data PiType (r::IR) (n::S) where PiType :: Nest (Binder r) n l -> EffTy r l -> PiType r n -type CoreBinders = Nest (WithExpl CBinder) - data CorePiType (n::S) where - CorePiType :: AppExplicitness -> CoreBinders n l -> EffTy CoreIR l -> CorePiType n + CorePiType :: AppExplicitness -> [Explicitness] -> Nest CBinder n l -> EffTy CoreIR l -> CorePiType n data DepPairType (r::IR) (n::S) where DepPairType :: DepPairExplicitness -> Binder r n l -> Type r l -> DepPairType r n @@ -443,16 +442,15 @@ isSumCon = \case -- === type classes === -data RolePiBinder (n::S) (l::S) = RolePiBinder ParamRole (WithExpl CBinder n l) - deriving (Show, Generic) -type RolePiBinders = Nest RolePiBinder +type RoleExpl = (ParamRole, Explicitness) data ClassDef (n::S) where ClassDef :: SourceName -- name of class -> [SourceName] -- method source names -> [Maybe SourceName] -- parameter source names - -> RolePiBinders n1 n2 -- parameters + -> [RoleExpl] -- parameter info + -> Nest CBinder n1 n2 -- parameters -> Nest CBinder n2 n3 -- superclasses -> [CorePiType n3] -- method types -> ClassDef n1 @@ -460,7 +458,8 @@ data ClassDef (n::S) where data InstanceDef (n::S) where InstanceDef :: ClassName n1 - -> RolePiBinders n1 n2 -- parameters (types and dictionaries) + -> [RoleExpl] -- parameter info + -> Nest CBinder n1 n2 -- parameters (types and dictionaries) -> [CAtom n2] -- class parameters -> InstanceBody n2 -> InstanceDef n1 @@ -921,10 +920,6 @@ instance IRRep r => BindsOneAtomName r (BinderP (AtomNameC r) (Type r)) where binderType (_ :> ty) = ty binderVar (b:>t) = AtomVar (binderName b) (sink t) -instance BindsOneAtomName CoreIR b => BindsOneAtomName CoreIR (WithExpl b) where - binderType (WithExpl _ b) = binderType b - binderVar (WithExpl _ b) = binderVar b - toBinderNest :: BindsOneAtomName r b => Nest b n l -> Nest (Binder r) n l toBinderNest Empty = Empty toBinderNest (Nest b bs) = Nest (asNameBinder b :> binderType b) (toBinderNest bs) @@ -1183,10 +1178,10 @@ instance AlphaEqE DataConDefs instance AlphaHashableE DataConDefs instance GenericE TyConDef where - type RepE TyConDef = PairE (LiftE SourceName) (Abs RolePiBinders DataConDefs) - fromE (TyConDef sourceName bs cons) = PairE (LiftE sourceName) (Abs bs cons) + type RepE TyConDef = PairE (LiftE (SourceName, [RoleExpl])) (Abs (Nest CBinder) DataConDefs) + fromE (TyConDef sourceName expls bs cons) = PairE (LiftE (sourceName, expls)) (Abs bs cons) {-# INLINE fromE #-} - toE (PairE (LiftE sourceName) (Abs bs cons)) = TyConDef sourceName bs cons + toE (PairE (LiftE (sourceName, expls)) (Abs bs cons)) = TyConDef sourceName expls bs cons {-# INLINE toE #-} deriving instance Show (TyConDef n) @@ -1198,7 +1193,7 @@ instance AlphaEqE TyConDef instance AlphaHashableE TyConDef instance HasNameHint (TyConDef n) where - getNameHint (TyConDef v _ _) = getNameHint v + getNameHint (TyConDef v _ _ _) = getNameHint v instance GenericE DataConDef where type RepE DataConDef = (LiftE (SourceName, [[Projection]])) @@ -1834,39 +1829,15 @@ instance (IRRep r, AlphaEqE ann) => AlphaEqB (NonDepNest r ann) instance (IRRep r, AlphaHashableE ann) => AlphaHashableB (NonDepNest r ann) deriving instance (Show (ann n)) => IRRep r => Show (NonDepNest r ann n l) -instance GenericB RolePiBinder where - type RepB RolePiBinder = PairB (LiftB (LiftE ParamRole)) (WithExpl CBinder) - fromB (RolePiBinder role b) = PairB (LiftB (LiftE role)) b - toB (PairB (LiftB (LiftE role)) b) = RolePiBinder role b - -instance BindsAtMostOneName RolePiBinder (AtomNameC CoreIR) where - RolePiBinder _ b @> x = b @> x - {-# INLINE (@>) #-} - -instance BindsOneName RolePiBinder (AtomNameC CoreIR) where - binderName (RolePiBinder _ b) = binderName b - -instance BindsOneAtomName CoreIR RolePiBinder where - binderType (RolePiBinder _ b) = binderType b - binderVar (RolePiBinder _ b) = binderVar b - -instance ProvesExt RolePiBinder -instance BindsNames RolePiBinder -instance SinkableB RolePiBinder -instance HoistableB RolePiBinder -instance RenameB RolePiBinder -instance AlphaEqB RolePiBinder -instance AlphaHashableB RolePiBinder - instance GenericE ClassDef where type RepE ClassDef = - LiftE (SourceName, [SourceName], [Maybe SourceName]) - `PairE` Abs RolePiBinders (Abs (Nest CBinder) (ListE CorePiType)) - fromE (ClassDef name names paramNames b scs tys) = - LiftE (name, names, paramNames) `PairE` Abs b (Abs scs (ListE tys)) + LiftE (SourceName, [SourceName], [Maybe SourceName], [RoleExpl]) + `PairE` Abs (Nest CBinder) (Abs (Nest CBinder) (ListE CorePiType)) + fromE (ClassDef name names paramNames roleExpls b scs tys) = + LiftE (name, names, paramNames, roleExpls) `PairE` Abs b (Abs scs (ListE tys)) {-# INLINE fromE #-} - toE (LiftE (name, names, paramNames) `PairE` Abs b (Abs scs (ListE tys))) = - ClassDef name names paramNames b scs tys + toE (LiftE (name, names, paramNames, roleExpls) `PairE` Abs b (Abs scs (ListE tys))) = + ClassDef name names paramNames roleExpls b scs tys {-# INLINE toE #-} instance SinkableE ClassDef @@ -1879,11 +1850,11 @@ deriving via WrapE ClassDef n instance Generic (ClassDef n) instance GenericE InstanceDef where type RepE InstanceDef = - ClassName `PairE` Abs RolePiBinders (ListE CAtom `PairE` InstanceBody) - fromE (InstanceDef name bs params body) = - name `PairE` Abs bs (ListE params `PairE` body) - toE (name `PairE` Abs bs (ListE params `PairE` body)) = - InstanceDef name bs params body + ClassName `PairE` LiftE [RoleExpl] `PairE` Abs (Nest CBinder) (ListE CAtom `PairE` InstanceBody) + fromE (InstanceDef name expls bs params body) = + name `PairE` LiftE expls `PairE` Abs bs (ListE params `PairE` body) + toE (name `PairE` LiftE expls `PairE` Abs bs (ListE params `PairE` body)) = + InstanceDef name expls bs params body instance SinkableE InstanceDef instance HoistableE InstanceDef @@ -2015,10 +1986,10 @@ deriving instance Show (CoreLamExpr n) deriving via WrapE CoreLamExpr n instance Generic (CoreLamExpr n) instance GenericE CorePiType where - type RepE CorePiType = LiftE AppExplicitness `PairE` Abs CoreBinders (EffTy CoreIR) - fromE (CorePiType ex b effTy) = LiftE ex `PairE` Abs b effTy + type RepE CorePiType = LiftE (AppExplicitness, [Explicitness]) `PairE` Abs (Nest CBinder) (EffTy CoreIR) + fromE (CorePiType ex exs b effTy) = LiftE (ex, exs) `PairE` Abs b effTy {-# INLINE fromE #-} - toE (LiftE ex `PairE` Abs b effTy) = CorePiType ex b effTy + toE (LiftE (ex, exs) `PairE` Abs b effTy) = CorePiType ex exs b effTy {-# INLINE toE #-} instance SinkableE CorePiType @@ -2766,7 +2737,6 @@ instance Store (DictType n) instance Store (DictExpr n) instance Store (EffectDef n) instance Store (EffectOpDef n) -instance Store (RolePiBinder n l) instance Store (EffectOpType n) instance Store (EffectOpIdx) instance Store (SynthCandidates n) diff --git a/src/lib/Types/Primitives.hs b/src/lib/Types/Primitives.hs index a9230a3ed..002a6d09a 100644 --- a/src/lib/Types/Primitives.hs +++ b/src/lib/Types/Primitives.hs @@ -22,9 +22,7 @@ module Types.Primitives ( module Types.Primitives, UnOp (..), BinOp (..), CmpOp (..), Projection (..)) where -import Name import qualified Data.ByteString as BS -import Control.Monad import Data.Int import Data.Word import Data.Hashable @@ -35,7 +33,6 @@ import Foreign.Ptr import GHC.Generics (Generic (..)) import Occurrence -import Util (zipErr) import Types.OpNames (UnOp (..), BinOp (..), CmpOp (..), Projection (..)) type SourceName = String @@ -58,23 +55,6 @@ data Explicitness = data AppExplicitness = ExplicitApp | ImplicitApp deriving (Show, Generic, Eq) data DepPairExplicitness = ExplicitDepPair | ImplicitDepPair deriving (Show, Generic, Eq) -data WithExpl (b::B) (n::S) (l::S) = - WithExpl { getExpl :: Explicitness , withoutExpl :: b n l } - deriving (Show, Generic) - -unzipExpls :: Nest (WithExpl b) n l -> ([Explicitness], Nest b n l) -unzipExpls Empty = ([], Empty) -unzipExpls (Nest (WithExpl expl b) rest) = (expl:expls, Nest b bs) - where (expls, bs) = unzipExpls rest - -zipExpls :: [Explicitness] -> Nest b n l -> Nest (WithExpl b) n l -zipExpls [] Empty = Empty -zipExpls (expl:expls) (Nest b bs) = Nest (WithExpl expl b) (zipExpls expls bs) -zipExpls _ _ = error "zip error" - -addExpls :: Explicitness -> Nest b n l -> Nest (WithExpl b) n l -addExpls expl bs = fmapNest (\b -> WithExpl expl b) bs - data RequiredMethodAccess = Full | Partial Int deriving (Show, Eq, Ord, Generic) data LetAnn = @@ -225,40 +205,3 @@ instance Hashable AppExplicitness instance Hashable DepPairExplicitness instance Hashable InferenceMechanism instance Hashable RequiredMethodAccess - -instance Store (b n l) => Store (WithExpl b n l) - -instance (Color c, BindsOneName b c) => BindsOneName (WithExpl b) c where - binderName (WithExpl _ b) = binderName b - asNameBinder (WithExpl _ b) = asNameBinder b - -instance (Color c, BindsAtMostOneName b c) => BindsAtMostOneName (WithExpl b) c where - WithExpl _ b @> x = b @> x - {-# INLINE (@>) #-} - -instance AlphaEqB b => AlphaEqB (WithExpl b) where - withAlphaEqB (WithExpl a1 b1) (WithExpl a2 b2) cont = do - unless (a1 == a2) zipErr - withAlphaEqB b1 b2 cont - -instance AlphaHashableB b => AlphaHashableB (WithExpl b) where - hashWithSaltB env salt (WithExpl expl b) = do - let h = hashWithSalt salt expl - hashWithSaltB env h b - -instance BindsNames b => ProvesExt (WithExpl b) where -instance BindsNames b => BindsNames (WithExpl b) where - toScopeFrag (WithExpl _ b) = toScopeFrag b - -instance (SinkableB b) => SinkableB (WithExpl b) where - sinkingProofB fresh (WithExpl a b) cont = - sinkingProofB fresh b \fresh' b' -> - cont fresh' (WithExpl a b') - -instance (BindsNames b, RenameB b) => RenameB (WithExpl b) where - renameB env (WithExpl a b) cont = - renameB env b \env' b' -> - cont env' $ WithExpl a b' - -instance HoistableB b => HoistableB (WithExpl b) where - freeVarsB (WithExpl _ b) = freeVarsB b diff --git a/src/lib/Types/Source.hs b/src/lib/Types/Source.hs index a27b02268..0c361236d 100644 --- a/src/lib/Types/Source.hs +++ b/src/lib/Types/Source.hs @@ -274,9 +274,12 @@ data FieldName' = | FieldNum Int deriving (Show, Eq, Ord) +type UAnnExplBinders req n l = ([Explicitness], Nest (UAnnBinder req) n l) +type UOptAnnExplBinders n l = UAnnExplBinders AnnOptional n l + data ULamExpr (n::S) where ULamExpr - :: Nest (WithExpl UOptAnnBinder) n l -- args + :: UOptAnnExplBinders n l -- args -> AppExplicitness -> Maybe (UEffectRow l) -- optional effect -> Maybe (UType l) -- optional result type @@ -284,7 +287,7 @@ data ULamExpr (n::S) where -> ULamExpr n data UPiExpr (n::S) where - UPiExpr :: Nest (WithExpl UOptAnnBinder) n l -> AppExplicitness -> UEffectRow l -> UType l -> UPiExpr n + UPiExpr :: UOptAnnExplBinders n l -> AppExplicitness -> UEffectRow l -> UType l -> UPiExpr n data UTabPiExpr (n::S) where UTabPiExpr :: UOptAnnBinder n l -> UType l -> UTabPiExpr n @@ -297,14 +300,14 @@ type UConDef (n::S) (l::S) = (SourceName, Nest UReqAnnBinder n l) data UDataDef (n::S) where UDataDef :: SourceName -- source name for pretty printing - -> Nest (WithExpl UOptAnnBinder) n l + -> UOptAnnExplBinders n l -> [(SourceName, UDataDefTrail l)] -- data constructor types -> UDataDef n data UStructDef (n::S) where UStructDef :: SourceName -- source name for pretty printing - -> Nest (WithExpl UOptAnnBinder) n l + -> UOptAnnExplBinders n l -> [(SourceName, UType l)] -- named payloads -> [(LetAnn, SourceName, Abs UAtomBinder ULamExpr l)] -- named methods (initial binder is for `self`) -> UStructDef n @@ -324,14 +327,14 @@ data UTopDecl (n::S) (l::S) where -> UStructDef l -- actual definition -> UTopDecl n l UInterface - :: Nest (WithExpl UOptAnnBinder) n p -- parameter binders + :: UOptAnnExplBinders n p -- parameter binders -> [UType p] -- method types -> UBinder ClassNameC n l' -- class name -> Nest (UBinder MethodNameC) l' l -- method names -> UTopDecl n l UInstance :: SourceNameOr (Name ClassNameC) n -- class name - -> Nest (WithExpl UOptAnnBinder) n l' + -> UOptAnnExplBinders n l' -> [UExpr l'] -- class parameters -> [UMethodDef l'] -- method definitions -- Maybe we should make a separate color (namespace) for instance names? @@ -346,7 +349,7 @@ data UTopDecl (n::S) (l::S) where UHandlerDecl :: SourceNameOr (Name EffectNameC) n -- effect name -> UAtomBinder n b -- body type argument - -> Nest (WithExpl UOptAnnBinder) b l' -- type args + -> UOptAnnExplBinders b l' -- type args -> UEffectRow l' -- returning effect -> UType l' -- returning type -> [UEffectOpDef l'] -- operation definitions From b274115cc7d4d03483fe74dc8e5cba6e77e17776 Mon Sep 17 00:00:00 2001 From: Dougal Date: Tue, 27 Jun 2023 11:45:42 -0400 Subject: [PATCH 4/4] Avoid some uses of `:>` and `@>`. --- src/lib/Builder.hs | 12 ++++++------ src/lib/CheapReduction.hs | 6 +++++- src/lib/CheckType.hs | 15 +++++---------- src/lib/Export.hs | 4 ++-- src/lib/Imp.hs | 36 ++++++++++++++++++------------------ src/lib/Lower.hs | 6 +++--- src/lib/QueryTypePure.hs | 7 +++++-- src/lib/RuntimePrint.hs | 8 ++++---- src/lib/Simplify.hs | 10 +++++----- src/lib/Types/Core.hs | 33 ++++++++------------------------- 10 files changed, 61 insertions(+), 76 deletions(-) diff --git a/src/lib/Builder.hs b/src/lib/Builder.hs index 8e7563511..5f1c372de 100644 --- a/src/lib/Builder.hs +++ b/src/lib/Builder.hs @@ -786,8 +786,8 @@ buildMap :: (Emits n, ScopableBuilder r m) -> (forall l. (Emits l, DExt n l) => Atom r l -> m l (Atom r l)) -> m n (Atom r n) buildMap xs f = do - TabTy d (_:>t) _ <- return $ getType xs - buildFor noHint Fwd (IxType t d) \i -> + TabPi t <- return $ getType xs + buildFor noHint Fwd (tabIxType t) \i -> tabApp (sink xs) (Var i) >>= f unzipTab :: (Emits n, Builder r m) => Atom r n -> m n (Atom r n, Atom r n) @@ -857,8 +857,8 @@ zeroAt ty = liftEmitBuilder $ go ty where go = \case BaseTy bt -> return $ Con $ Lit $ zeroLit bt ProdTy tys -> ProdVal <$> mapM go tys - TabTy d (b:>t) bodyTy -> buildFor (getNameHint b) Fwd (IxType t d) \i -> - go =<< applySubst (b @> SubstVal (Var i)) bodyTy + TabPi tabPi -> buildFor (getNameHint tabPi) Fwd (tabIxType tabPi) \i -> + go =<< instantiateTabPiTy (sink tabPi) (Var i) _ -> unreachable zeroLit bt = case bt of Scalar Float64Type -> Float64Lit 0.0 @@ -902,8 +902,8 @@ tangentBaseMonoidFor ty = do addTangent :: (Emits n, SBuilder m) => SAtom n -> SAtom n -> m n (SAtom n) addTangent x y = do case getType x of - TabTy d (b:>t) _ -> - liftEmitBuilder $ buildFor (getNameHint b) Fwd (IxType t d) \i -> do + TabPi t -> + liftEmitBuilder $ buildFor (getNameHint t) Fwd (tabIxType t) \i -> do bindM2 addTangent (tabApp (sink x) (Var i)) (tabApp (sink y) (Var i)) TC con -> case con of BaseType (Scalar _) -> emitOp $ BinOp FAdd x y diff --git a/src/lib/CheapReduction.hs b/src/lib/CheapReduction.hs index 9a35c2ed5..4c42bbda2 100644 --- a/src/lib/CheapReduction.hs +++ b/src/lib/CheapReduction.hs @@ -15,7 +15,7 @@ module CheapReduction , unwrapLeadingNewtypesType, wrapNewtypesData, liftSimpAtom, liftSimpType , liftSimpFun, makeStructRepVal, NonAtomRenamer (..), Visitor (..), VisitGeneric (..) , visitAtomPartial, visitTypePartial, visitAtomDefault, visitTypeDefault, Visitor2 - , visitBinders, visitPiDefault, visitAlt, toAtomVar, instantiatePiTy + , visitBinders, visitPiDefault, visitAlt, toAtomVar, instantiatePiTy, instantiateTabPiTy , bindersToVars, bindersToAtoms) where @@ -474,6 +474,10 @@ instantiatePiTy :: (EnvReader m, IRRep r) => PiType r n -> [Atom r n] -> m n (Ef instantiatePiTy (PiType bs effTy) xs = do applySubst (bs @@> (SubstVal <$> xs)) effTy +instantiateTabPiTy :: (EnvReader m, IRRep r) => TabPiType r n -> Atom r n -> m n (Type r n) +instantiateTabPiTy (TabPiType _ b resultTy) x = do + applySubst (b @> SubstVal x) resultTy + -- Returns a representation type (type of an TypeCon-typed Newtype payload) -- given a list of instantiated DataConDefs. dataDefRep :: DataConDefs n -> CType n diff --git a/src/lib/CheckType.hs b/src/lib/CheckType.hs index 671067d04..47cf2df19 100644 --- a/src/lib/CheckType.hs +++ b/src/lib/CheckType.hs @@ -788,11 +788,7 @@ checkTabApp ty (i:rest) = do resultTy' <- applySubst (b@>SubstVal i') resultTy checkTabApp resultTy' rest -checkArgTys - :: (Typer m r, SubstB AtomSubstVal b, BindsNames b, BindsOneAtomName r b, IRRep r) - => Nest b o o' - -> [Atom r o] - -> m i o () +checkArgTys :: (Typer m r, IRRep r) => Nest (Binder r) o o' -> [Atom r o] -> m i o () checkArgTys Empty [] = return () checkArgTys (Nest b bs) (x:xs) = do dropSubst $ x |: binderType b @@ -930,15 +926,14 @@ checkedInstantiateTyConDef (TyConDef _ _ bs cons) (TyConParams _ xs) = do checkedApplyNaryAbs (Abs bs cons) xs checkedApplyNaryAbs - :: forall b r e o m - . ( BindsOneAtomName r b, EnvReader m, Fallible1 m, SinkableE e - , SubstE AtomSubstVal e, IRRep r, SubstB AtomSubstVal b) - => Abs (Nest b) e o -> [Atom r o] -> m o (e o) + :: forall r e o m + . ( EnvReader m, Fallible1 m, SinkableE e , SubstE AtomSubstVal e, IRRep r) + => Abs (Nest (Binder r)) e o -> [Atom r o] -> m o (e o) checkedApplyNaryAbs (Abs bsTop e) xsTop = do go (EmptyAbs bsTop) xsTop applySubst (bsTop@@>(SubstVal<$>xsTop)) e where - go :: EmptyAbs (Nest b) o -> [Atom r o] -> m o () + go :: EmptyAbs (Nest (Binder r)) o -> [Atom r o] -> m o () go (Abs Empty UnitE) [] = return () go (Abs (Nest b bs) UnitE) (x:xs) = do checkAlphaEq (binderType b) (getType x) diff --git a/src/lib/Export.hs b/src/lib/Export.hs index 42dcc7ba1..f7ab3184d 100644 --- a/src/lib/Export.hs +++ b/src/lib/Export.hs @@ -175,8 +175,8 @@ parseTabTy = go [] NewtypeTyCon Nat -> return $ Just $ RectContArrayPtr IdxRepScalarBaseTy shape TabTy d (b:>ixty) a -> do maybeN <- case IxType ixty d of - (IxType (NewtypeTyCon (Fin n)) _) -> return $ Just n - (IxType _ (IxDictRawFin n)) -> return $ Just n + IxType (NewtypeTyCon (Fin n)) _ -> return $ Just n + IxType _ (IxDictRawFin n) -> return $ Just n _ -> return Nothing maybeDim <- case maybeN of Just (Var v) -> do diff --git a/src/lib/Imp.hs b/src/lib/Imp.hs index 1a2ab54dd..bfb73537c 100644 --- a/src/lib/Imp.hs +++ b/src/lib/Imp.hs @@ -366,16 +366,16 @@ toImpRefOp refDest' m = do ans <- liftBuilderImp $ emitBlock (sink body') storeAtom accDest ans False -> case accTy of - TabTy d (b:>t) eltTy -> do - let ixTy = IxType t d + TabPi t -> do + let ixTy = tabIxType t n <- indexSetSizeImp ixTy emitLoop noHint Fwd n \i -> do idx <- unsafeFromOrdinalImp (sink ixTy) i xElt <- liftBuilderImp $ tabApp (sink x) (sink idx) yElt <- liftBuilderImp $ tabApp (sink y) (sink idx) - eltTy' <- applySubst (b@>SubstVal idx) eltTy + eltTy <- instantiateTabPiTy (sink t) idx ithDest <- indexDest (sink accDest) idx - liftMonoidCombine ithDest eltTy' (sink bc) xElt yElt + liftMonoidCombine ithDest eltTy (sink bc) xElt yElt _ -> error $ "Base monoid type mismatch: can't lift " ++ pprint baseTy ++ " to " ++ pprint accTy @@ -578,15 +578,15 @@ toImpTypedHof (TypedHof (EffTy _ resultTy') hof) = do alphaEq xTy accTy >>= \case True -> storeAtom accDest x False -> case accTy of - TabTy d (b:>t) eltTy -> do - let ixTy = IxType t d + TabPi t -> do + let ixTy = tabIxType t n <- indexSetSizeImp ixTy emitLoop noHint Fwd n \i -> do idx <- unsafeFromOrdinalImp (sink ixTy) i x' <- sinkM x - eltTy' <- applySubst (b@>SubstVal idx) eltTy + eltTy <- instantiateTabPiTy (sink t) idx ithDest <- indexDest (sink accDest) idx - liftMonoidEmpty ithDest eltTy' x' + liftMonoidEmpty ithDest eltTy x' _ -> error $ "Base monoid type mismatch: can't lift " ++ pprint xTy ++ " to " ++ pprint accTy @@ -1002,11 +1002,11 @@ buildGarbageVal ty = -- === Operations on dests === indexDest :: Emits n => Dest n -> SAtom n -> SubstImpM i n (Dest n) -indexDest (Dest destValTy@(TabTy d (b:>t) eltTy) tree) i = do - eltTy' <- applySubst (b@>SubstVal i) eltTy - ord <- ordinalImp (IxType t d) i - leafTys <- typeToTree destValTy - Dest eltTy' <$> forM (zipTrees leafTys tree) \(leafTy, ptr) -> do +indexDest (Dest (TabPi tabTy) tree) i = do + eltTy <- instantiateTabPiTy tabTy i + ord <- ordinalImp (tabIxType tabTy) i + leafTys <- typeToTree $ TabPi tabTy + Dest eltTy <$> forM (zipTrees leafTys tree) \(leafTy, ptr) -> do BufferType ixStruct _ <- return $ getRefBufferType leafTy offset <- computeOffsetImp ixStruct ord impOffset ptr offset @@ -1026,10 +1026,10 @@ indexRepValParam :: Emits n => SRepVal n -> SAtom n -> (SType n -> SType n) -> (IExpr n -> SubstImpM i n (IExpr n)) -> SubstImpM i n (SRepVal n) -indexRepValParam (RepVal tabTy@(TabPi (TabPiType d (b:>t) eltTy)) vals) i tyFunc func = do - eltTy' <- applySubst (b@>SubstVal i) eltTy - ord <- ordinalImp (IxType t d) i - leafTys <- typeToTree tabTy +indexRepValParam (RepVal (TabPi tabTy) vals) i tyFunc func = do + eltTy <- instantiateTabPiTy tabTy i + ord <- ordinalImp (tabIxType tabTy) i + leafTys <- typeToTree (TabPi tabTy) vals' <- forM (zipTrees leafTys vals) \(leafTy, ptr) -> do BufferPtr (BufferType ixStruct _) <- return $ getIExprInterpretation leafTy offset <- computeOffsetImp ixStruct ord @@ -1041,7 +1041,7 @@ indexRepValParam (RepVal tabTy@(TabPi (TabPiType d (b:>t) eltTy)) vals) i tyFunc _ -> func ptr' -- `func` may have changed the types of the `vals'`. The caller must also -- supply `tyFunc` to reflect that change in the SType. - return $ RepVal (tyFunc eltTy') vals' + return $ RepVal (tyFunc eltTy) vals' indexRepValParam _ _ _ _ = error "expected table type" {-# INLINE indexRepValParam #-} diff --git a/src/lib/Lower.hs b/src/lib/Lower.hs index 8eaf8cbeb..bce5b8050 100644 --- a/src/lib/Lower.hs +++ b/src/lib/Lower.hs @@ -153,12 +153,12 @@ lowerFor _ _ _ _ _ = error "expected a unary lambda expression" lowerTabCon :: forall i o. Emits o => Maybe (Dest SimpIR o) -> SType i -> [SAtom i] -> LowerM i o (SExpr o) lowerTabCon maybeDest tabTy elems = do - tabTy'@(TabPi (TabPiType dict (_:>t) _)) <- substM tabTy + TabPi tabTy' <- substM tabTy dest <- case maybeDest of Just d -> return d - Nothing -> emitExpr $ PrimOp $ DAMOp $ AllocDest tabTy' + Nothing -> emitExpr $ PrimOp $ DAMOp $ AllocDest $ TabPi tabTy' Abs bord ufoBlock <- buildAbs noHint IdxRepTy \ord -> do - buildBlock $ unsafeFromOrdinal (sink $ IxType t dict) $ Var $ sink ord + buildBlock $ unsafeFromOrdinal (sink $ tabIxType tabTy') $ Var $ sink ord -- This is emitting a chain of RememberDest ops to force `dest` to be used -- linearly, and to force reads of the `Freeze dest'` result not to be -- reordered in front of the writes. diff --git a/src/lib/QueryTypePure.hs b/src/lib/QueryTypePure.hs index 2501cbf8f..9be267241 100644 --- a/src/lib/QueryTypePure.hs +++ b/src/lib/QueryTypePure.hs @@ -119,8 +119,8 @@ instance IRRep r => HasType r (Con r) where getSuperclassType :: RNest CBinder n l -> Nest CBinder l l' -> Int -> CType n getSuperclassType _ Empty = error "bad index" -getSuperclassType bsAbove (Nest b bs) = \case - 0 -> ignoreHoistFailure $ hoist bsAbove $ binderType b +getSuperclassType bsAbove (Nest b@(_:>t) bs) = \case + 0 -> ignoreHoistFailure $ hoist bsAbove t i -> getSuperclassType (RNest bsAbove b) bs (i-1) instance IRRep r => HasType r (Expr r) where @@ -213,6 +213,9 @@ rawStrType = case newName "n" of rawFinTabType :: IRRep r => Atom r n -> Type r n -> Type r n rawFinTabType n eltTy = IxType IdxRepTy (IxDictRawFin n) ==> eltTy +tabIxType :: TabPiType r n -> IxType r n +tabIxType (TabPiType d (_:>t) _) = IxType t d + typesAsBinderNest :: (SinkableE e, HoistableE e, IRRep r) => [Type r n] -> e n -> Abs (Nest (Binder r)) e n diff --git a/src/lib/RuntimePrint.hs b/src/lib/RuntimePrint.hs index 3255773ad..4a4c2c6a5 100644 --- a/src/lib/RuntimePrint.hs +++ b/src/lib/RuntimePrint.hs @@ -185,8 +185,8 @@ bufferTy h = do extendBuffer :: (Emits n, CBuilder m) => CAtom n -> CAtom n -> m n () extendBuffer buf tab = do RefTy h _ <- return $ getType buf - TabTy d (_:>t) _ <- return $ getType tab - n <- applyIxMethodCore Size (IxType t d) [] + TabPi t <- return $ getType tab + n <- applyIxMethodCore Size (tabIxType t) [] void $ applyPreludeFunction "stack_extend_internal" [n, h, buf, tab] -- argument has type `Word8` @@ -237,8 +237,8 @@ forEachTabElt -> (forall l. (Emits l, DExt n l) => CAtom l -> CAtom l -> m l ()) -> m n () forEachTabElt tab cont = do - TabTy d (_:>t) _ <- return $ getType tab - let ixTy = IxType t d + TabPi t <- return $ getType tab + let ixTy = tabIxType t void $ buildFor "i" Fwd ixTy \i -> do x <- tabApp (sink tab) (Var i) i' <- applyIxMethodCore Ordinal (sink ixTy) [Var i] diff --git a/src/lib/Simplify.hs b/src/lib/Simplify.hs index c151937b6..a40b01fa5 100644 --- a/src/lib/Simplify.hs +++ b/src/lib/Simplify.hs @@ -153,12 +153,12 @@ getRepType ty = go ty where x <- liftSimpAtom (sink l) (Var $ binderVar b') r' <- go =<< applySubst (b@>SubstVal x) r return $ DepPairTy $ DepPairType expl b' r' - TabPi (TabPiType d (b:>t) bodyTy) -> do - let ixTy = IxType t d + TabPi tabTy -> do + let ixTy = tabIxType tabTy IxType t' d' <- simplifyIxType ixTy - withFreshBinder (getNameHint b) t' \b' -> do + withFreshBinder (getNameHint tabTy) t' \b' -> do x <- liftSimpAtom (sink $ ixTypeType ixTy) (Var $ binderVar b') - bodyTy' <- go =<< applySubst (b@>SubstVal x) bodyTy + bodyTy' <- go =<< instantiateTabPiTy (sink tabTy) x return $ TabPi $ TabPiType d' b' bodyTy' NewtypeTyCon con -> do (_, ty') <- unwrapNewtypeType con @@ -1025,7 +1025,7 @@ simplifyCustomLinearization (Abs runtimeBs staticArgs) actives rule = do return $ activeArg':rest buildTangentArgs _ _ _ = error "zip error" - fromNonDepNest :: (HoistableB b, BindsOneAtomName CoreIR b) => Nest b n l -> [CType n] + fromNonDepNest :: Nest CBinder n l -> [CType n] fromNonDepNest Empty = [] fromNonDepNest (Nest b bs) = case ignoreHoistFailure $ hoist b (Abs bs UnitE) of diff --git a/src/lib/Types/Core.hs b/src/lib/Types/Core.hs index 9e9062d3f..d09f7c69f 100644 --- a/src/lib/Types/Core.hs +++ b/src/lib/Types/Core.hs @@ -897,17 +897,13 @@ data LinearizationSpec (n::S) = LinearizationSpec (TopFunName n) [Active] deriving (Show, Generic) --- === BindsOneAtomName === +-- === Binder utils === -class BindsOneName b (AtomNameC r) => BindsOneAtomName (r::IR) (b::B) | b -> r where - binderType :: b n l -> Type r n - binderVar :: DExt n l => b n l -> AtomVar r l +binderType :: Binder r n l -> Type r n +binderType (_:>ty) = ty -bindersTypes :: (IRRep r, Distinct l, ProvesExt b, BindsNames b, BindsOneAtomName r b) - => Nest b n l -> [Type r l] -bindersTypes Empty = [] -bindersTypes n@(Nest b bs) = ty : bindersTypes bs - where ty = withExtEvidence n $ sink (binderType b) +binderVar :: (IRRep r, DExt n l) => Binder r n l -> AtomVar r l +binderVar (b:>ty) = AtomVar (binderName b) (sink ty) nestToAtomVars :: (Distinct l, Ext n l, IRRep r) => Nest (Binder r) n l -> [AtomVar r l] @@ -916,14 +912,6 @@ nestToAtomVars = \case Nest b bs -> withExtEvidence b $ withSubscopeDistinct bs $ sink (binderVar b) : nestToAtomVars bs -instance IRRep r => BindsOneAtomName r (BinderP (AtomNameC r) (Type r)) where - binderType (_ :> ty) = ty - binderVar (b:>t) = AtomVar (binderName b) (sink t) - -toBinderNest :: BindsOneAtomName r b => Nest b n l -> Nest (Binder r) n l -toBinderNest Empty = Empty -toBinderNest (Nest b bs) = Nest (asNameBinder b :> binderType b) (toBinderNest bs) - -- === ToBinding === atomBindingToBinding :: AtomBinding r n -> Binding (AtomNameC r) n @@ -957,14 +945,6 @@ instance (ToBinding e1 c, ToBinding e2 c) => ToBinding (EitherE e1 e2) c where toBinding (LeftE e) = toBinding e toBinding (RightE e) = toBinding e --- === HasArgType === - -class HasArgType (e::E) (r::IR) | e -> r where - argType :: e n -> Type r n - -instance HasArgType (TabPiType r) r where - argType (TabPiType _ (_:>ty) _) = ty - -- === Pattern synonyms === -- XXX: only use this pattern when you're actually expecting a type. If it's @@ -2055,6 +2035,9 @@ instance IRRep r => AlphaEqE (TabPiType r) where instance IRRep r => AlphaHashableE (TabPiType r) where hashWithSaltE env salt (TabPiType _ b t) = hashWithSaltE env salt $ Abs b t +instance HasNameHint (TabPiType r n) where + getNameHint (TabPiType _ b _) = getNameHint b + instance IRRep r => SinkableE (TabPiType r) instance IRRep r => HoistableE (TabPiType r) instance IRRep r => RenameE (TabPiType r)