Skip to content

Commit

Permalink
Support UnitRange for loops properly
Browse files Browse the repository at this point in the history
  • Loading branch information
michel2323 committed Jun 2, 2023
1 parent 5a4b699 commit a7e73a9
Show file tree
Hide file tree
Showing 2 changed files with 17 additions and 5 deletions.
2 changes: 1 addition & 1 deletion Project.toml
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
name = "Checkpointing"
uuid = "eb46d486-4f9c-4c3d-b445-a617f2a2f1ca"
authors = ["Michel Schanen <[email protected]>", "Sri Hari Krishna Narayanan <[email protected]>"]
version = "0.8.0"
version = "0.8.1"

[deps]
ChainRulesCore = "d360d2e6-b24c-11e9-a2a3-2a2ae2dbcce4"
Expand Down
20 changes: 16 additions & 4 deletions src/Checkpointing.jl
Original file line number Diff line number Diff line change
Expand Up @@ -128,7 +128,7 @@ function set_zero!(nestedmodel::MT) where {MT}
end

function checkpoint_struct_for(body::Function, scheme::Scheme, model, range)
for i in range
for gensym() in range
body(model)
end
return model
Expand Down Expand Up @@ -159,10 +159,22 @@ adjoints and is created here. It is supposed to be initialized by ChainRules.
"""
macro checkpoint_struct(alg, model, loop)
if loop.head == :for
body = loop.args[2]
iterator = loop.args[1].args[1]
from = loop.args[1].args[2].args[2]
to = loop.args[1].args[2].args[3]
range = loop.args[1].args[2]
ex = quote
$model = Checkpointing.checkpoint_struct_for($alg, $model, $(loop.args[1])) do $model
$(loop.args[2])
nothing
let
if !isa($range, UnitRange{Int64})
error("Checkpointing.jl: Only UnitRange{Int64} is supported.")
end
$iterator = $from
$model = Checkpointing.checkpoint_struct_for($alg, $model, $(loop.args[1].args[2])) do $model
$body
$iterator += 1
nothing
end
end
end
elseif loop.head == :while
Expand Down

0 comments on commit a7e73a9

Please sign in to comment.