Skip to content
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
27 changes: 10 additions & 17 deletions finetune/dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -90,21 +90,7 @@ def __len__(self) -> int:
return self.n_samples

def __getitem__(self, idx: int) -> tuple[torch.Tensor, torch.Tensor]:
"""
Retrieves a random sample from the dataset.
Note: The `idx` argument is ignored. Instead, a random index is drawn
from the pre-computed `self.indices` list using `self.py_rng`. This
ensures random sampling over the entire dataset for each call.
Args:
idx (int): Ignored.
Returns:
tuple[torch.Tensor, torch.Tensor]: A tuple containing:
- x_tensor (torch.Tensor): The normalized feature tensor.
- x_stamp_tensor (torch.Tensor): The time feature tensor.
"""

# Select a random sample from the entire pool of indices.
random_idx = self.py_rng.randint(0, len(self.indices) - 1)
symbol, start_idx = self.indices[random_idx]
Expand All @@ -118,8 +104,15 @@ def __getitem__(self, idx: int) -> tuple[torch.Tensor, torch.Tensor]:
x = win_df[self.feature_list].values.astype(np.float32)
x_stamp = win_df[self.time_feature_list].values.astype(np.float32)

# Perform instance-level normalization.
x_mean, x_std = np.mean(x, axis=0), np.std(x, axis=0)
# Normalize the window. Mean and std are calculated strictly on the
# lookback window (past data) to prevent future data leakage.
past_len = self.config.lookback_window
past_x = x[:past_len]

x_mean = np.mean(past_x, axis=0)
x_std = np.std(past_x, axis=0)

# Apply normalization and robust clipping to the entire sequence
x = (x - x_mean) / (x_std + 1e-5)
x = np.clip(x, -self.config.clip, self.config.clip)

Expand Down