Conversation
- raise exception when forecasting with levels - handle varying shapes better for mean calculation
f7f2b86 to
1a25cbe
Compare
| def _left_pad_and_stack_1D(self, tensors: list[torch.Tensor]) -> torch.Tensor: | ||
| max_len = max(len(c) for c in tensors) | ||
| padded = [] | ||
| for c in tensors: | ||
| assert isinstance(c, torch.Tensor) | ||
| assert c.ndim == 1 | ||
| padding = torch.full( | ||
| size=(max_len - len(c),), | ||
| fill_value=torch.nan, | ||
| device=c.device, | ||
| dtype=c.dtype, | ||
| ) | ||
| padded.append(torch.concat((padding, c), dim=-1)) | ||
| return torch.stack(padded) | ||
|
|
||
| def _prepare_and_validate_context( | ||
| self, | ||
| context: list[torch.Tensor] | torch.Tensor, | ||
| ) -> torch.Tensor: | ||
| if isinstance(context, list): | ||
| context = self._left_pad_and_stack_1D(context) | ||
| assert isinstance(context, torch.Tensor) | ||
| if context.ndim == 1: | ||
| context = context.unsqueeze(0) | ||
| assert context.ndim == 2 | ||
| return context | ||
|
|
||
| def _maybe_impute_missing( | ||
| self, batch: torch.Tensor, dtype=torch.float32 | ||
| ) -> torch.Tensor: | ||
| if torch.isnan(batch).any(): | ||
| batch = batch.to(dtype=dtype).numpy(force=True) | ||
| imputed_rows = [] | ||
| for i in range(batch.shape[0]): | ||
| row = batch[i] | ||
| imputed_row = LastValueImputation()(row) | ||
| imputed_rows.append(imputed_row) | ||
| batch = np.vstack(imputed_rows) | ||
| batch = torch.tensor( | ||
| batch, | ||
| dtype=self.dtype, | ||
| device=self.device, | ||
| ) | ||
| return batch |
There was a problem hiding this comment.
these three methods are also being used by sundial
timecopilot/timecopilot/models/foundation/sundial.py
Lines 94 to 137 in fe710ca
There was a problem hiding this comment.
I copied them from Flowstate, so that sounds good.
The only attributes of the class that are used are dtype and device, so the options that come to mind are:
- A class with
dtypeanddeviceattributes - A class with static methods and
dtypeanddeviceadded as args where needed - Functions with
dtypeanddeviceadded as args where needed
There was a problem hiding this comment.
including them as attributes in the class sounds good.
There was a problem hiding this comment.
@AzulGarza does this seem good?
db20d39
There was a problem hiding this comment.
Pull request overview
Copilot reviewed 6 out of 7 changed files in this pull request and generated 10 comments.
💡 Add Copilot custom instructions for smarter, more guided reviews. Learn how to get started.
You can also share your feedback on Copilot code review. Take the survey.
|
@spolisar let's also take a look at copilot's comments. they are valuable. |
Co-authored-by: Copilot <175728472+Copilot@users.noreply.github.com>
Add support for the new PatchTST-FM forecasting model.
PatchTST-FM is currently setup to use pytorch's mps backend when available, but there is a chance that will have issues. If it becomes a problem, we can disable that.
For now predicting with levels is disabled.
TODO: