You signed in with another tab or window. Reload to refresh your session.You signed out in another tab or window. Reload to refresh your session.You switched accounts on another tab or window. Reload to refresh your session.Dismiss alert
I need to setup the following rematerialization scheme: save inputs and outputs of all dots, but only in subblock_1 and subblock_2, not subblock_no_remat. Is there a way to achieve this in jax?
I was hoping that wrapping the whole model in remat(policy=dots_saveable) and subblock_no_remat in remat(policy=nothing_saveable) will do the job, but from the looks of it that's not how rematerialization nesting works (actually, I can't figure out exactly what it does, but I can see that internals of subblock_no_remat are being saved for each layer).
reacted with thumbs up emoji reacted with thumbs down emoji reacted with laugh emoji reacted with hooray emoji reacted with confused emoji reacted with heart emoji reacted with rocket emoji reacted with eyes emoji
Uh oh!
There was an error while loading. Please reload this page.
Uh oh!
There was an error while loading. Please reload this page.
-
Hi!
Let's say my model has the following structure:
I need to setup the following rematerialization scheme: save inputs and outputs of all dots, but only in
subblock_1
andsubblock_2
, notsubblock_no_remat
. Is there a way to achieve this in jax?I was hoping that wrapping the whole model in
remat(policy=dots_saveable)
andsubblock_no_remat
inremat(policy=nothing_saveable)
will do the job, but from the looks of it that's not how rematerialization nesting works (actually, I can't figure out exactly what it does, but I can see that internals ofsubblock_no_remat
are being saved for each layer).Is there a way to achieve what I need?
Beta Was this translation helpful? Give feedback.
All reactions