JAX 0.8 and PyMC v5 compatibility#164
JAX 0.8 and PyMC v5 compatibility#164MilesCranmer wants to merge 21 commits intoexoplanet-dev:mainfrom
Conversation
dfm
left a comment
There was a problem hiding this comment.
Thank you - this looks great! Some small comments/questions inline.
| typedef typename LowRank::Scalar Scalar; | ||
| typedef typename Eigen::internal::plain_col_type<Coeffs>::type CoeffVector; | ||
| typedef typename Eigen::Matrix<Scalar, LowRank::ColsAtCompileTime, RightHandSide::ColsAtCompileTime> Inner; | ||
| typedef typename Eigen::Matrix<Scalar, Eigen::Dynamic, Eigen::Dynamic, Eigen::RowMajor> Inner; |
There was a problem hiding this comment.
From my experience, this change will dramatically impact performance because Eigen won't be able to generate properly vectorized code for small systems. It's really useful to compile for specific sizes! Why did you make this change?
There was a problem hiding this comment.
Thanks, I see. I couldn't get it working initially but this change seemed to do it. I didn't know it would hurt performance though so I'll fix it now.
| typedef typename LowRank::Scalar Scalar; | ||
| typedef typename Eigen::internal::plain_col_type<Coeffs>::type CoeffVector; | ||
| typedef typename Eigen::Matrix<Scalar, LowRank::ColsAtCompileTime, RightHandSide::ColsAtCompileTime> Inner; | ||
| typedef typename Eigen::Matrix<Scalar, Eigen::Dynamic, Eigen::Dynamic, Eigen::RowMajor> Inner; |
Co-authored-by: Dan Foreman-Mackey <dfm@dfm.io>
Co-authored-by: Dan Foreman-Mackey <dfm@dfm.io>
Co-authored-by: Dan Foreman-Mackey <dfm@dfm.io>
Co-authored-by: Dan Foreman-Mackey <dfm@dfm.io>
Co-authored-by: Dan Foreman-Mackey <dfm@dfm.io>
for more information, see https://pre-commit.ci
Co-authored-by: Dan Foreman-Mackey <dfm@dfm.io>
|
Really struggling to get it working with |
|
Sorry, turns out I won't have enough bandwidth to fix the PR for fixed sized matrices (seems hard), so I will have to leave things as-is for now. Feel free to take this PR as-is if you are okay with the speed hit to get things up to JAX 0.8, or we can just point people to this PR if they need it |
|
No problem! I think it shouldn't be a big deal since it's just for the "general" cases that are only used for predictions. Log prob calculations should still be fast. I'll try and get this merged soon - thanks!! |
|
Oops sorry I forgot to re-run generate.py |
|
Ping @dfm let me know if anything else is left! |
|
It looks like all the CI is failing for various reasons. Can you take a look at those? I'm not totally sure how to have them automatically run for you, but I'll try to be faster to press the button, and you could plausibly run them on your own fork by temporarily adding: here: It also looks like we'll need to update the Python version that we're using on ReadTheDocs. I think that should be a simple as just bumping it here: Line 11 in e7974e4 Do you mind doing that too? |
I think if you put me as collaborator status in the org it might do this? I think it's just a user trust scopes thing. (No need to give me merge rights though) |
|
Good idea! I've invited you - let me know if that works (or doesn't). |
|
@MilesCranmerBot can you make a new PR based on this one and try to get the tests working again? |
|
Actually ugh it might need to make the PR to my account's fork so I can get the CI tested via this PR |
This PR upgrades celerite2 to JAX 0.8.x. I updated the jinja templates and re-generated the cpp files. I also upgraded PyMC to v5 and ditched compatibility with PyMC3 since stuff wasn't working anyways and it will be easier to maintain.
I'm unfamiliar with a lot of the lower level JAX stuff, so I am not confident about some of this PR, especially the FFI stuff. A good look over by someone else would be helpful.
Also, PyMC v5 seemed to need this
@jax_funcify.register(_CeleriteOp)thing but I am not 100% sure about this.Paging @dfm for review.