-
Notifications
You must be signed in to change notification settings - Fork 24
Backend switching for Mooncake #768
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
Conversation
Codecov ReportAll modified and coverable lines are covered by tests ✅
Additional details and impacted files@@ Coverage Diff @@
## main #768 +/- ##
==========================================
- Coverage 97.93% 97.05% -0.88%
==========================================
Files 128 126 -2
Lines 7693 7705 +12
==========================================
- Hits 7534 7478 -56
- Misses 159 227 +68
Flags with carried forward coverage won't be shown. Click here to find out more. ☔ View full report in Codecov by Sentry. 🚀 New features to boost your workflow:
|
removed the code that piggybacks off the Chainrules wrapper. This is specifically now a Mooncake generic rule which handles backend switching. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thanks for this first draft!
I think there are some changes necessary, and most importantly you need to test it, first locally and then during CI (try not to run CI before having tested your changes locally, the process is very expensive since it tests a dozen different backends for like half an hour each).
For the testing, start with manual tests, and then once your code works you can add AutoMooncake()
to this line
...tionInterface/ext/DifferentiationInterfaceMooncakeExt/DifferentiationInterfaceMooncakeExt.jl
Outdated
Show resolved
Hide resolved
DifferentiationInterface/ext/DifferentiationInterfaceMooncakeExt/differentiate_with.jl
Outdated
Show resolved
Hide resolved
...tionInterface/ext/DifferentiationInterfaceMooncakeExt/DifferentiationInterfaceMooncakeExt.jl
Outdated
Show resolved
Hide resolved
DifferentiationInterface/ext/DifferentiationInterfaceMooncakeExt/differentiate_with.jl
Outdated
Show resolved
Hide resolved
DifferentiationInterface/ext/DifferentiationInterfaceMooncakeExt/differentiate_with.jl
Outdated
Show resolved
Hide resolved
DifferentiationInterface/ext/DifferentiationInterfaceMooncakeExt/differentiate_with.jl
Outdated
Show resolved
Hide resolved
DifferentiationInterface/ext/DifferentiationInterfaceMooncakeExt/differentiate_with.jl
Outdated
Show resolved
Hide resolved
DifferentiationInterface/ext/DifferentiationInterfaceMooncakeExt/differentiate_with.jl
Outdated
Show resolved
Hide resolved
DifferentiationInterface/ext/DifferentiationInterfaceMooncakeExt/differentiate_with.jl
Outdated
Show resolved
Hide resolved
sorry i got preoccupied with some other work, hence the incomplete PR. This would be on route now. |
Please keep in mind that every commit costs around 6 hours of CI budget. I suggest you make as many modifications as possible locally and add tests first before pushing |
DifferentiationInterface/ext/DifferentiationInterfaceMooncakeExt/differentiate_with.jl
Outdated
Show resolved
Hide resolved
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thanks, we're getting closer!
Unfortunately I think my existing tests are not enough to capture everything that can go wrong in a Mooncake rule. Perhaps the Mooncake test utilities should be brought in, or more sophisticated tests should be written.
DifferentiationInterface/ext/DifferentiationInterfaceMooncakeExt/differentiate_with.jl
Show resolved
Hide resolved
DifferentiationInterface/ext/DifferentiationInterfaceMooncakeExt/differentiate_with.jl
Show resolved
Hide resolved
DifferentiationInterface/ext/DifferentiationInterfaceMooncakeExt/differentiate_with.jl
Show resolved
Hide resolved
DifferentiationInterface/ext/DifferentiationInterfaceMooncakeExt/differentiate_with.jl
Show resolved
Hide resolved
DifferentiationInterface/ext/DifferentiationInterfaceMooncakeExt/differentiate_with.jl
Outdated
Show resolved
Hide resolved
DifferentiationInterface/ext/DifferentiationInterfaceMooncakeExt/differentiate_with.jl
Show resolved
Hide resolved
|
||
using DifferentiationInterface, DifferentiationInterfaceTest | ||
import DifferentiationInterfaceTest as DIT | ||
using FiniteDiff: FiniteDiff | ||
using ForwardDiff: ForwardDiff | ||
using Zygote: Zygote | ||
using Mooncake: Mooncake |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
In this file we need to add tests that are specific to Mooncake. Ideally I should have done that with the other backends too.
Basically, what we test now is the projection of the Mooncake rule you wrote onto the subset of stuff that DI cares about. But we should also check that the rule is correct from the Mooncake perspective. Probably the best tool for that is Mooncake.TestUtils.test_rule
?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
yup, will check it out
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Okay so it fails for some of the general primal, functions. But for this PR its maybe okay? as DifferentiateWith
is exclusive to DI, so the user is anyways limited to DI when using the Mooncake substitute backend.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
it fails for some of the general primal, functions.
To clarify, these (primal) functions are permitted by the DI interface, right?
Assuming that is true, @gdalle, I think this is okay.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
@AstitvaAggarwal is it possible to add these tests (excluding those not supported by DI) to this PR?
Define
Mooncake.rrule!!
forDI.DifferentiateWith
.