Skip to content

Commit

Permalink
Make a ToBindersAbs class to help abstract away direct uses of `@@>…
Browse files Browse the repository at this point in the history
…` and `@>`.
  • Loading branch information
dougalm committed Jun 27, 2023
1 parent cfab914 commit 43e42ec
Show file tree
Hide file tree
Showing 8 changed files with 43 additions and 26 deletions.
2 changes: 1 addition & 1 deletion src/lib/Builder.hs
Original file line number Diff line number Diff line change
Expand Up @@ -858,7 +858,7 @@ zeroAt ty = liftEmitBuilder $ go ty where
BaseTy bt -> return $ Con $ Lit $ zeroLit bt
ProdTy tys -> ProdVal <$> mapM go tys
TabPi tabPi -> buildFor (getNameHint tabPi) Fwd (tabIxType tabPi) \i ->
go =<< instantiateTabPiTy (sink tabPi) (Var i)
go =<< instantiate (sink tabPi) [Var i]
_ -> unreachable
zeroLit bt = case bt of
Scalar Float64Type -> Float64Lit 0.0
Expand Down
22 changes: 8 additions & 14 deletions src/lib/CheapReduction.hs
Original file line number Diff line number Diff line change
Expand Up @@ -11,11 +11,11 @@ module CheapReduction
( CheaplyReducibleE (..), cheapReduce, cheapReduceWithDecls, cheapNormalize
, normalizeProj, asNaryProj, normalizeNaryProj
, depPairLeftTy, instantiateTyConDef
, dataDefRep, instantiateDepPairTy, unwrapNewtypeType, repValAtom
, dataDefRep, unwrapNewtypeType, repValAtom
, unwrapLeadingNewtypesType, wrapNewtypesData, liftSimpAtom, liftSimpType
, liftSimpFun, makeStructRepVal, NonAtomRenamer (..), Visitor (..), VisitGeneric (..)
, visitAtomPartial, visitTypePartial, visitAtomDefault, visitTypeDefault, Visitor2
, visitBinders, visitPiDefault, visitAlt, toAtomVar, instantiatePiTy, instantiateTabPiTy
, visitBinders, visitPiDefault, visitAlt, toAtomVar, instantiate
, bindersToVars, bindersToAtoms)
where

Expand Down Expand Up @@ -450,7 +450,7 @@ projType i ty x = case ty of
DepPairTy t | i == 0 -> return $ depPairLeftTy t
DepPairTy t | i == 1 -> do
xFst <- normalizeProj (ProjectProduct 0) x
instantiateDepPairTy t xFst
instantiate t [xFst]
_ -> error $ "Can't project type: " ++ pprint ty

unwrapLeadingNewtypesType :: EnvReader m => CType n -> m n ([NewtypeCon n], CType n)
Expand All @@ -470,13 +470,11 @@ instantiateTyConDef (TyConDef _ _ bs conDefs) (TyConParams _ xs) = do
applySubst (bs @@> (SubstVal <$> xs)) conDefs
{-# INLINE instantiateTyConDef #-}

instantiatePiTy :: (EnvReader m, IRRep r) => PiType r n -> [Atom r n] -> m n (EffTy r n)
instantiatePiTy (PiType bs effTy) xs = do
applySubst (bs @@> (SubstVal <$> xs)) effTy

instantiateTabPiTy :: (EnvReader m, IRRep r) => TabPiType r n -> Atom r n -> m n (Type r n)
instantiateTabPiTy (TabPiType _ b resultTy) x = do
applySubst (b @> SubstVal x) resultTy
instantiate
:: (EnvReader m, IRRep r, SubstE (SubstVal Atom) body, SinkableE body, ToBindersAbs e body r)
=> e n -> [Atom r n] -> m n (body n)
instantiate e xs = case toAbs e of
Abs bs body -> applySubst (bs @@> (SubstVal <$> xs)) body

-- Returns a representation type (type of an TypeCon-typed Newtype payload)
-- given a list of instantiated DataConDefs.
Expand All @@ -498,10 +496,6 @@ makeStructRepVal tyConName args = do
_ -> error "wrong number of args"
_ -> return $ ProdVal args

instantiateDepPairTy :: (IRRep r, EnvReader m) => DepPairType r n -> Atom r n -> m n (Type r n)
instantiateDepPairTy (DepPairType _ b rhsTy) x = applyAbs (Abs b rhsTy) (SubstVal x)
{-# INLINE instantiateDepPairTy #-}

-- === traversable terms ===

class Monad m => NonAtomRenamer m i o | m -> i, m -> o where
Expand Down
6 changes: 3 additions & 3 deletions src/lib/CheckType.hs
Original file line number Diff line number Diff line change
Expand Up @@ -214,7 +214,7 @@ instance IRRep r => HasType r (Atom r) where
DepPair l r ty -> do
ty' <- checkTypeE TyKind ty
l' <- checkTypeE (depPairLeftTy ty') l
rTy <- instantiateDepPairTy ty' l'
rTy <- instantiate ty' [l']
r |: rTy
return $ DepPairTy ty'
Con con -> typeCheckPrimCon con
Expand All @@ -236,7 +236,7 @@ instance IRRep r => HasType r (Atom r) where
DepPairTy t | i == 1 -> do
x' <- renameM x
xFst <- normalizeProj (ProjectProduct 0) x'
instantiateDepPairTy t xFst
instantiate t [xFst]
_ -> throw TypeErr $ "Not a product type:" ++ pprint ty
TypeAsAtom ty -> getTypeE ty

Expand Down Expand Up @@ -275,7 +275,7 @@ instance IRRep r => HasType r (Type r) where
DepPairTy t | i == 1 -> do
x' <- renameM x
xFst <- normalizeProj (ProjectProduct 0) x'
instantiateDepPairTy t xFst
instantiate t [xFst]
_ -> throw TypeErr $ "Not a product type:" ++ pprint ty

instance HasType CoreIR SimpInCore where
Expand Down
8 changes: 4 additions & 4 deletions src/lib/Imp.hs
Original file line number Diff line number Diff line change
Expand Up @@ -373,7 +373,7 @@ toImpRefOp refDest' m = do
idx <- unsafeFromOrdinalImp (sink ixTy) i
xElt <- liftBuilderImp $ tabApp (sink x) (sink idx)
yElt <- liftBuilderImp $ tabApp (sink y) (sink idx)
eltTy <- instantiateTabPiTy (sink t) idx
eltTy <- instantiate (sink t) [idx]
ithDest <- indexDest (sink accDest) idx
liftMonoidCombine ithDest eltTy (sink bc) xElt yElt
_ -> error $ "Base monoid type mismatch: can't lift " ++
Expand Down Expand Up @@ -584,7 +584,7 @@ toImpTypedHof (TypedHof (EffTy _ resultTy') hof) = do
emitLoop noHint Fwd n \i -> do
idx <- unsafeFromOrdinalImp (sink ixTy) i
x' <- sinkM x
eltTy <- instantiateTabPiTy (sink t) idx
eltTy <- instantiate (sink t) [idx]
ithDest <- indexDest (sink accDest) idx
liftMonoidEmpty ithDest eltTy x'
_ -> error $ "Base monoid type mismatch: can't lift " ++
Expand Down Expand Up @@ -1003,7 +1003,7 @@ buildGarbageVal ty =

indexDest :: Emits n => Dest n -> SAtom n -> SubstImpM i n (Dest n)
indexDest (Dest (TabPi tabTy) tree) i = do
eltTy <- instantiateTabPiTy tabTy i
eltTy <- instantiate tabTy [i]
ord <- ordinalImp (tabIxType tabTy) i
leafTys <- typeToTree $ TabPi tabTy
Dest eltTy <$> forM (zipTrees leafTys tree) \(leafTy, ptr) -> do
Expand All @@ -1027,7 +1027,7 @@ indexRepValParam :: Emits n
-> (IExpr n -> SubstImpM i n (IExpr n))
-> SubstImpM i n (SRepVal n)
indexRepValParam (RepVal (TabPi tabTy) vals) i tyFunc func = do
eltTy <- instantiateTabPiTy tabTy i
eltTy <- instantiate tabTy [i]
ord <- ordinalImp (tabIxType tabTy) i
leafTys <- typeToTree (TabPi tabTy)
vals' <- forM (zipTrees leafTys vals) \(leafTy, ptr) -> do
Expand Down
4 changes: 2 additions & 2 deletions src/lib/Inference.hs
Original file line number Diff line number Diff line change
Expand Up @@ -887,7 +887,7 @@ checkSigma hint expr sTy = confuseGHC >>= \_ -> case sTy of
-- TODO: check for the case that we're given some of the implicit dependent pair args explicitly
lhsVal <- Var <$> freshInferenceName MiscInfVar lhsTy
-- TODO: make an InfVarDesc case for dep pair instantiation
rhsTy <- instantiateDepPairTy depPairTy lhsVal
rhsTy <- instantiate depPairTy [lhsVal]
rhsVal <- checkSigma noHint expr rhsTy
return $ DepPair lhsVal rhsVal depPairTy
_ -> fallback
Expand Down Expand Up @@ -996,7 +996,7 @@ checkOrInferRho hint uExprWithSrc@(WithSrcE pos expr) reqTy = do
case reqTy of
Check (DepPairTy ty@(DepPairType _ (_ :> lhsTy) _)) -> do
lhs' <- checkSigmaDependent noHint lhs lhsTy
rhsTy <- instantiateDepPairTy ty lhs'
rhsTy <- instantiate ty [lhs']
rhs' <- checkSigma noHint rhs rhsTy
return $ DepPair lhs' rhs' ty
_ -> throw TypeErr $ "Can't infer the type of a dependent pair; please annotate it"
Expand Down
2 changes: 1 addition & 1 deletion src/lib/Lower.hs
Original file line number Diff line number Diff line change
Expand Up @@ -67,7 +67,7 @@ lowerFullySequential wantDestStyle (TopLam False piTy (LamExpr bs body)) = liftE
True -> do
refreshAbs (Abs bs body) \bs' body' -> do
xs <- bindersToAtoms bs'
EffTy _ resultTy <- instantiatePiTy (sink piTy) xs
EffTy _ resultTy <- instantiate (sink piTy) xs
Abs b body'' <- lowerFullySequentialBlock resultTy body'
return $ LamExpr (bs' >>> UnaryNest b) body''
False -> do
Expand Down
2 changes: 1 addition & 1 deletion src/lib/Simplify.hs
Original file line number Diff line number Diff line change
Expand Up @@ -158,7 +158,7 @@ getRepType ty = go ty where
IxType t' d' <- simplifyIxType ixTy
withFreshBinder (getNameHint tabTy) t' \b' -> do
x <- liftSimpAtom (sink $ ixTypeType ixTy) (Var $ binderVar b')
bodyTy' <- go =<< instantiateTabPiTy (sink tabTy) x
bodyTy' <- go =<< instantiate (sink tabTy) [x]
return $ TabPi $ TabPiType d' b' bodyTy'
NewtypeTyCon con -> do
(_, ty') <- unwrapNewtypeType con
Expand Down
23 changes: 23 additions & 0 deletions src/lib/Types/Core.hs
Original file line number Diff line number Diff line change
Expand Up @@ -231,6 +231,29 @@ type Dict = Atom CoreIR
data NonDepNest r ann n l = NonDepNest (Nest (AtomNameBinder r) n l) [ann n]
deriving (Generic)

-- === ToAtomAbs class ===

class ToBindersAbs (e::E) (body::E) (r::IR) | e -> body, e -> r where
toAbs :: e n -> Abs (Nest (Binder r)) body n

instance ToBindersAbs CorePiType (EffTy CoreIR) CoreIR where
toAbs (CorePiType _ _ bs effTy) = Abs bs effTy

instance ToBindersAbs (Abs (Nest (Binder r)) body) body r where
toAbs = id

instance ToBindersAbs (PiType r) (EffTy r) r where
toAbs (PiType bs effTy) = Abs bs effTy

instance ToBindersAbs (LamExpr r) (Block r) r where
toAbs (LamExpr bs body) = Abs bs body

instance ToBindersAbs (TabPiType r) (Type r) r where
toAbs (TabPiType _ b eltTy) = Abs (UnaryNest b) eltTy

instance ToBindersAbs (DepPairType r) (Type r) r where
toAbs (DepPairType _ b rhsTy) = Abs (UnaryNest b) rhsTy

-- === GenericOp class ===

class IsPrimOp (e::IR->E) where
Expand Down

0 comments on commit 43e42ec

Please sign in to comment.