Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -67,6 +67,7 @@ def __init__(

# Tracks activity since last reset
self.recently_active = torch.zeros(hidden_dim, dtype=torch.bool, device=device)
self._last_avg_nonzero: float = 0.0

@torch.no_grad()
def update(self, codes: torch.Tensor) -> None:
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -66,7 +66,9 @@ def process_item( # noqa: D417
prob_input_sequence = np.ones_like(input_sequence_toks) * mlm_probability
if codon_weights is not None:
pos_weight = codon_weights[input_sequence_toks]
pos_weight[1:-1] = pos_weight[1:-1] / pos_weight[1:-1].mean()
mean_weight = pos_weight[1:-1].mean()
if mean_weight > 0:
pos_weight[1:-1] = pos_weight[1:-1] / mean_weight
prob_input_sequence = prob_input_sequence * pos_weight
prob_input_sequence = np.clip(prob_input_sequence, 0.05, 0.4)
mask_indices = np.random.binomial(1, prob_input_sequence).astype(bool) # noqa: NPY002
Expand All @@ -82,10 +84,13 @@ def process_item( # noqa: D417
masked_input_sequence_toks[indices_replaced] = tokenizer.mask_token_id

if random_replace_prob > 0.0:
remaining_prob = 1.0 - mask_replace_prob
# Conditional probability of random replacement given the token was not replaced by [MASK].
# Guard against mask_replace_prob == 1.0 (no remaining positions) or rounding above 1.0.
conditional_prob = min(random_replace_prob / remaining_prob, 1.0) if remaining_prob > 0.0 else 0.0
indices_random = np.random.binomial( # noqa: NPY002
1,
(np.ones_like(mask_indices).astype(bool) & mask_indices & ~indices_replaced)
* (random_replace_prob / (1 - mask_replace_prob)),
(np.ones_like(mask_indices).astype(bool) & mask_indices & ~indices_replaced) * conditional_prob,
).astype(bool)
valid_tokens = np.setdiff1d(
np.arange(tokenizer.vocab_size),
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -38,7 +38,6 @@ class Evo2Dataset(GPTDataset):
MAX_TAG_LEN = 2048

VALID_DNA_AND_DEGENERATE: ClassVar[set[int]] = {
45,
45,
65,
66,
Expand Down
Loading