Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

stack overflow issue #809

Open
jakubMitura14 opened this issue Sep 26, 2024 · 0 comments
Open

stack overflow issue #809

jakubMitura14 opened this issue Sep 26, 2024 · 0 comments

Comments

@jakubMitura14
Copy link

Hello
I have a memory-constrained problem with a Lux.jl model that uses Zygote for most of the backpropagation.

I tried to approach this from chainrules perspective I need to checkpoint each Lux.jl layer in neural network. So I tried to achieve it like that :

function ChainRulesCore.rrule(::typeof(Lux.apply), l::Lux.AbstractExplicitLayer, x, ps, st)
    y = Lux.apply(l, x, ps, st)
    
    function pullback_checkpointed(Δy)
        y, pb =Zygote.pullback(Lux.apply,l, x, ps, st) 
        return NoTangent(), pb(Δy)
    end
    
    y, pullback_checkpointed
end

Rule gets invoked in backpropagation Hovewer the issue is that for some reason it try recursively to do backpropagation of the first line

 y = Lux.apply(l, x, ps, st)

so I get stack overflow error; how to correct it?

I had also posted this issue in https://discourse.julialang.org/t/avoid-storing-intermediate-results-from-the-forward-pass-by-default/119694/4?u=jakub_mitura

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

1 participant