From 43e42eca8d5a6f004c36ef029dde292393c684d0 Mon Sep 17 00:00:00 2001 From: Dougal Date: Tue, 27 Jun 2023 13:47:41 -0400 Subject: [PATCH] Make a `ToBindersAbs` class to help abstract away direct uses of `@@>` and `@>`. --- src/lib/Builder.hs | 2 +- src/lib/CheapReduction.hs | 22 ++++++++-------------- src/lib/CheckType.hs | 6 +++--- src/lib/Imp.hs | 8 ++++---- src/lib/Inference.hs | 4 ++-- src/lib/Lower.hs | 2 +- src/lib/Simplify.hs | 2 +- src/lib/Types/Core.hs | 23 +++++++++++++++++++++++ 8 files changed, 43 insertions(+), 26 deletions(-) diff --git a/src/lib/Builder.hs b/src/lib/Builder.hs index 5f1c372de..039481cc7 100644 --- a/src/lib/Builder.hs +++ b/src/lib/Builder.hs @@ -858,7 +858,7 @@ zeroAt ty = liftEmitBuilder $ go ty where BaseTy bt -> return $ Con $ Lit $ zeroLit bt ProdTy tys -> ProdVal <$> mapM go tys TabPi tabPi -> buildFor (getNameHint tabPi) Fwd (tabIxType tabPi) \i -> - go =<< instantiateTabPiTy (sink tabPi) (Var i) + go =<< instantiate (sink tabPi) [Var i] _ -> unreachable zeroLit bt = case bt of Scalar Float64Type -> Float64Lit 0.0 diff --git a/src/lib/CheapReduction.hs b/src/lib/CheapReduction.hs index 4c42bbda2..1476facca 100644 --- a/src/lib/CheapReduction.hs +++ b/src/lib/CheapReduction.hs @@ -11,11 +11,11 @@ module CheapReduction ( CheaplyReducibleE (..), cheapReduce, cheapReduceWithDecls, cheapNormalize , normalizeProj, asNaryProj, normalizeNaryProj , depPairLeftTy, instantiateTyConDef - , dataDefRep, instantiateDepPairTy, unwrapNewtypeType, repValAtom + , dataDefRep, unwrapNewtypeType, repValAtom , unwrapLeadingNewtypesType, wrapNewtypesData, liftSimpAtom, liftSimpType , liftSimpFun, makeStructRepVal, NonAtomRenamer (..), Visitor (..), VisitGeneric (..) , visitAtomPartial, visitTypePartial, visitAtomDefault, visitTypeDefault, Visitor2 - , visitBinders, visitPiDefault, visitAlt, toAtomVar, instantiatePiTy, instantiateTabPiTy + , visitBinders, visitPiDefault, visitAlt, toAtomVar, instantiate , bindersToVars, bindersToAtoms) where @@ -450,7 +450,7 @@ projType i ty x = case ty of DepPairTy t | i == 0 -> return $ depPairLeftTy t DepPairTy t | i == 1 -> do xFst <- normalizeProj (ProjectProduct 0) x - instantiateDepPairTy t xFst + instantiate t [xFst] _ -> error $ "Can't project type: " ++ pprint ty unwrapLeadingNewtypesType :: EnvReader m => CType n -> m n ([NewtypeCon n], CType n) @@ -470,13 +470,11 @@ instantiateTyConDef (TyConDef _ _ bs conDefs) (TyConParams _ xs) = do applySubst (bs @@> (SubstVal <$> xs)) conDefs {-# INLINE instantiateTyConDef #-} -instantiatePiTy :: (EnvReader m, IRRep r) => PiType r n -> [Atom r n] -> m n (EffTy r n) -instantiatePiTy (PiType bs effTy) xs = do - applySubst (bs @@> (SubstVal <$> xs)) effTy - -instantiateTabPiTy :: (EnvReader m, IRRep r) => TabPiType r n -> Atom r n -> m n (Type r n) -instantiateTabPiTy (TabPiType _ b resultTy) x = do - applySubst (b @> SubstVal x) resultTy +instantiate + :: (EnvReader m, IRRep r, SubstE (SubstVal Atom) body, SinkableE body, ToBindersAbs e body r) + => e n -> [Atom r n] -> m n (body n) +instantiate e xs = case toAbs e of + Abs bs body -> applySubst (bs @@> (SubstVal <$> xs)) body -- Returns a representation type (type of an TypeCon-typed Newtype payload) -- given a list of instantiated DataConDefs. @@ -498,10 +496,6 @@ makeStructRepVal tyConName args = do _ -> error "wrong number of args" _ -> return $ ProdVal args -instantiateDepPairTy :: (IRRep r, EnvReader m) => DepPairType r n -> Atom r n -> m n (Type r n) -instantiateDepPairTy (DepPairType _ b rhsTy) x = applyAbs (Abs b rhsTy) (SubstVal x) -{-# INLINE instantiateDepPairTy #-} - -- === traversable terms === class Monad m => NonAtomRenamer m i o | m -> i, m -> o where diff --git a/src/lib/CheckType.hs b/src/lib/CheckType.hs index 47cf2df19..d25ce524f 100644 --- a/src/lib/CheckType.hs +++ b/src/lib/CheckType.hs @@ -214,7 +214,7 @@ instance IRRep r => HasType r (Atom r) where DepPair l r ty -> do ty' <- checkTypeE TyKind ty l' <- checkTypeE (depPairLeftTy ty') l - rTy <- instantiateDepPairTy ty' l' + rTy <- instantiate ty' [l'] r |: rTy return $ DepPairTy ty' Con con -> typeCheckPrimCon con @@ -236,7 +236,7 @@ instance IRRep r => HasType r (Atom r) where DepPairTy t | i == 1 -> do x' <- renameM x xFst <- normalizeProj (ProjectProduct 0) x' - instantiateDepPairTy t xFst + instantiate t [xFst] _ -> throw TypeErr $ "Not a product type:" ++ pprint ty TypeAsAtom ty -> getTypeE ty @@ -275,7 +275,7 @@ instance IRRep r => HasType r (Type r) where DepPairTy t | i == 1 -> do x' <- renameM x xFst <- normalizeProj (ProjectProduct 0) x' - instantiateDepPairTy t xFst + instantiate t [xFst] _ -> throw TypeErr $ "Not a product type:" ++ pprint ty instance HasType CoreIR SimpInCore where diff --git a/src/lib/Imp.hs b/src/lib/Imp.hs index bfb73537c..c9fd6f42a 100644 --- a/src/lib/Imp.hs +++ b/src/lib/Imp.hs @@ -373,7 +373,7 @@ toImpRefOp refDest' m = do idx <- unsafeFromOrdinalImp (sink ixTy) i xElt <- liftBuilderImp $ tabApp (sink x) (sink idx) yElt <- liftBuilderImp $ tabApp (sink y) (sink idx) - eltTy <- instantiateTabPiTy (sink t) idx + eltTy <- instantiate (sink t) [idx] ithDest <- indexDest (sink accDest) idx liftMonoidCombine ithDest eltTy (sink bc) xElt yElt _ -> error $ "Base monoid type mismatch: can't lift " ++ @@ -584,7 +584,7 @@ toImpTypedHof (TypedHof (EffTy _ resultTy') hof) = do emitLoop noHint Fwd n \i -> do idx <- unsafeFromOrdinalImp (sink ixTy) i x' <- sinkM x - eltTy <- instantiateTabPiTy (sink t) idx + eltTy <- instantiate (sink t) [idx] ithDest <- indexDest (sink accDest) idx liftMonoidEmpty ithDest eltTy x' _ -> error $ "Base monoid type mismatch: can't lift " ++ @@ -1003,7 +1003,7 @@ buildGarbageVal ty = indexDest :: Emits n => Dest n -> SAtom n -> SubstImpM i n (Dest n) indexDest (Dest (TabPi tabTy) tree) i = do - eltTy <- instantiateTabPiTy tabTy i + eltTy <- instantiate tabTy [i] ord <- ordinalImp (tabIxType tabTy) i leafTys <- typeToTree $ TabPi tabTy Dest eltTy <$> forM (zipTrees leafTys tree) \(leafTy, ptr) -> do @@ -1027,7 +1027,7 @@ indexRepValParam :: Emits n -> (IExpr n -> SubstImpM i n (IExpr n)) -> SubstImpM i n (SRepVal n) indexRepValParam (RepVal (TabPi tabTy) vals) i tyFunc func = do - eltTy <- instantiateTabPiTy tabTy i + eltTy <- instantiate tabTy [i] ord <- ordinalImp (tabIxType tabTy) i leafTys <- typeToTree (TabPi tabTy) vals' <- forM (zipTrees leafTys vals) \(leafTy, ptr) -> do diff --git a/src/lib/Inference.hs b/src/lib/Inference.hs index 907f375a9..9da43fbba 100644 --- a/src/lib/Inference.hs +++ b/src/lib/Inference.hs @@ -887,7 +887,7 @@ checkSigma hint expr sTy = confuseGHC >>= \_ -> case sTy of -- TODO: check for the case that we're given some of the implicit dependent pair args explicitly lhsVal <- Var <$> freshInferenceName MiscInfVar lhsTy -- TODO: make an InfVarDesc case for dep pair instantiation - rhsTy <- instantiateDepPairTy depPairTy lhsVal + rhsTy <- instantiate depPairTy [lhsVal] rhsVal <- checkSigma noHint expr rhsTy return $ DepPair lhsVal rhsVal depPairTy _ -> fallback @@ -996,7 +996,7 @@ checkOrInferRho hint uExprWithSrc@(WithSrcE pos expr) reqTy = do case reqTy of Check (DepPairTy ty@(DepPairType _ (_ :> lhsTy) _)) -> do lhs' <- checkSigmaDependent noHint lhs lhsTy - rhsTy <- instantiateDepPairTy ty lhs' + rhsTy <- instantiate ty [lhs'] rhs' <- checkSigma noHint rhs rhsTy return $ DepPair lhs' rhs' ty _ -> throw TypeErr $ "Can't infer the type of a dependent pair; please annotate it" diff --git a/src/lib/Lower.hs b/src/lib/Lower.hs index bce5b8050..0c441b298 100644 --- a/src/lib/Lower.hs +++ b/src/lib/Lower.hs @@ -67,7 +67,7 @@ lowerFullySequential wantDestStyle (TopLam False piTy (LamExpr bs body)) = liftE True -> do refreshAbs (Abs bs body) \bs' body' -> do xs <- bindersToAtoms bs' - EffTy _ resultTy <- instantiatePiTy (sink piTy) xs + EffTy _ resultTy <- instantiate (sink piTy) xs Abs b body'' <- lowerFullySequentialBlock resultTy body' return $ LamExpr (bs' >>> UnaryNest b) body'' False -> do diff --git a/src/lib/Simplify.hs b/src/lib/Simplify.hs index a40b01fa5..34dbb2842 100644 --- a/src/lib/Simplify.hs +++ b/src/lib/Simplify.hs @@ -158,7 +158,7 @@ getRepType ty = go ty where IxType t' d' <- simplifyIxType ixTy withFreshBinder (getNameHint tabTy) t' \b' -> do x <- liftSimpAtom (sink $ ixTypeType ixTy) (Var $ binderVar b') - bodyTy' <- go =<< instantiateTabPiTy (sink tabTy) x + bodyTy' <- go =<< instantiate (sink tabTy) [x] return $ TabPi $ TabPiType d' b' bodyTy' NewtypeTyCon con -> do (_, ty') <- unwrapNewtypeType con diff --git a/src/lib/Types/Core.hs b/src/lib/Types/Core.hs index d09f7c69f..b43698f8b 100644 --- a/src/lib/Types/Core.hs +++ b/src/lib/Types/Core.hs @@ -231,6 +231,29 @@ type Dict = Atom CoreIR data NonDepNest r ann n l = NonDepNest (Nest (AtomNameBinder r) n l) [ann n] deriving (Generic) +-- === ToAtomAbs class === + +class ToBindersAbs (e::E) (body::E) (r::IR) | e -> body, e -> r where + toAbs :: e n -> Abs (Nest (Binder r)) body n + +instance ToBindersAbs CorePiType (EffTy CoreIR) CoreIR where + toAbs (CorePiType _ _ bs effTy) = Abs bs effTy + +instance ToBindersAbs (Abs (Nest (Binder r)) body) body r where + toAbs = id + +instance ToBindersAbs (PiType r) (EffTy r) r where + toAbs (PiType bs effTy) = Abs bs effTy + +instance ToBindersAbs (LamExpr r) (Block r) r where + toAbs (LamExpr bs body) = Abs bs body + +instance ToBindersAbs (TabPiType r) (Type r) r where + toAbs (TabPiType _ b eltTy) = Abs (UnaryNest b) eltTy + +instance ToBindersAbs (DepPairType r) (Type r) r where + toAbs (DepPairType _ b rhsTy) = Abs (UnaryNest b) rhsTy + -- === GenericOp class === class IsPrimOp (e::IR->E) where