-
Notifications
You must be signed in to change notification settings - Fork 9
[Submission] Cautious NAdamW jax #9
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
Conversation
|
MLCommons CLA bot All contributors have signed the MLCommons CLA ✍️ ✅ |
|
recheck |
1 similar comment
|
recheck |
|
Hi! Thanks for your submission. We are very interested in benchmarking Cautious optimizers. |
|
recheck |
I just copied the implementation over to self-tuning. |
|
@kyleliang919 we have just released v 0.6.0 for the algoperf benchmark which includes moving away from pmap to jit sharding for jax workloads. We temporary halted scoring new submissions so that all new submissions can be scored on >= v0.6.0. |
|
@priyakasimbeg looks like the baseline example is still the same as before. Is there a good example for the new approach? I am not really familiar with Jax so probably need some help here. |
Cautious NAdamW jax
Submission Information
Evidence for the Submission's Performance
Paper:
https://huggingface.co/papers/2411.16085
Results on RL: https://x.com/KyleLiang5/status/1931344549302927444
Independent verification:
https://huggingface.co/rwightman/timm-optim-caution
https://x.com/_clashluke/status/1935961388553290108
Comments
Finger crossed