Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
43 changes: 42 additions & 1 deletion README.md
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,7 @@ The library now supports reasoning traces through the `reasoning_content` field
- [Using the library](#using-the-library)
- [Data format](#data-format)
- [Reasoning content support](#reasoning-content-support-1)
- [Continual pretraining mode](#continual-pretraining-mode)
- [Documentation](#documentation)
- [Learning about the training arguments](#learning-about-training-arguments)
- [`TrainingArgs`](#trainingargs)
Expand Down Expand Up @@ -122,6 +123,46 @@ The library now supports an optional `reasoning_content` field in addition to th
}
```

## Continual pretraining mode

In addition to instruction tuning, the library can run document-style continual pretraining on raw text corpora.
Enable this by supplying a block size when invoking `main_ds.py`:

```bash
torchrun main_ds.py \
--model_name_or_path mistralai/Mistral-7B-v0.1 \
--data_path /data/documents.jsonl \
--ckpt_output_dir ./checkpoints \
--effective_batch_size 128 \
--max_batch_len 60000 \
--block-size 8192 \
--document-column-name text # optional, defaults to "document"
```

- `--block-size` (required) toggles continual pretraining and controls how many tokens are packed into each block.
- `--document-column-name` (optional) specifies which JSONL field contains the raw document text.

The same options are available programmatically via `TrainingArgs.pretraining_config`:

```python
from instructlab.training import TrainingArgs, PretrainingConfig

train_args = TrainingArgs(
model_name_or_path="mistralai/Mistral-7B-v0.1",
data_path="documents.jsonl",
ckpt_output_dir="./checkpoints",
max_seq_len=4096,
max_batch_len=40000,
effective_batch_size=128,
pretraining_config=PretrainingConfig(
block_size=2048,
document_column_name="text", # optional
),
)
```

When a pretraining config is provided, `process_documents_for_pretraining()` is invoked under the hood to tokenize raw documents before training.

**Standard message structure:**

```json
Expand All @@ -139,7 +180,7 @@ The library now supports an optional `reasoning_content` field in addition to th
}
```

#### Important Notes
### Important Notes

1. **Automatic reasoning content processing**: If `reasoning_content` exists in a message, it will always be processed and unmasked as long as the message role is targeted for unmasking. This ensures that reasoning traces are properly included in the training data.

Expand Down
2 changes: 2 additions & 0 deletions src/instructlab/training/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@
"FSDPOptions",
"ShardingStrategies",
"DistributedBackend",
"PretrainingConfig",
)

# First Party
Expand All @@ -23,6 +24,7 @@
DistributedBackend,
FSDPOptions,
LoraOptions,
PretrainingConfig,
QuantizeDataType,
ShardingStrategies,
TorchrunArgs,
Expand Down
5 changes: 4 additions & 1 deletion src/instructlab/training/accelerator.py
Original file line number Diff line number Diff line change
Expand Up @@ -63,14 +63,17 @@ def __init__(
self.lr_scheduler = None
if self.distributed_framework == DistributedBackend.DEEPSPEED:
# Standard
cpu_offload_optimizer_ratio = (
self.deepspeed_cpu_offload_optimizer_ratio or 0.0
)
accel_args = {
"deepspeed_plugin": self.get_ds_plugin(
world_size=torch.distributed.get_world_size(),
samples_per_gpu=samples_per_gpu,
grad_accum=grad_accum,
opts=DeepSpeedOptions(
cpu_offload_optimizer=deepspeed_cpu_offload_optimizer,
cpu_offload_optimizer_ratio=self.deepspeed_cpu_offload_optimizer_ratio,
cpu_offload_optimizer_ratio=cpu_offload_optimizer_ratio,
cpu_offload_optimizer_pin_memory=self.deepspeed_cpu_offload_optimizer_pin_memory,
save_samples=save_samples,
),
Expand Down
32 changes: 32 additions & 0 deletions src/instructlab/training/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -63,10 +63,34 @@ class DataProcessArgs(BaseModel):
description="this is the number of CPU procs we use for data processing parallelization",
)

# Pretraining mode flag
is_pretraining: bool = Field(
default=False,
description="Enable pretraining mode: tokenizes raw documents without chat templates or chunking",
)
pretraining_column_name: str = Field(
default="document",
description="the name of the column containing the text to pretrain on",
)

# disable the protected namespace for the model_config field
model_config = ConfigDict(protected_namespaces=())


class PretrainingConfig(BaseModel):
"""
Configuration for pretraining mode.
"""

block_size: int = Field(
description="Size of each block in tokens for pretraining datasets."
)
document_column_name: str = Field(
default="document",
description="Name of the column containing raw documents for pretraining.",
)


# public API
class TorchrunArgs(BaseModel):
"""
Expand Down Expand Up @@ -266,6 +290,14 @@ class TrainingArgs(BaseModel):
# "last_epoch". This works alongside the '--checkpoint_at_epoch' flag.
keep_last_checkpoint_only: Optional[bool] = False

pretraining_config: Optional[PretrainingConfig] = Field(
default=None,
description=(
"Pretraining configuration. When provided, enables block-based sampling "
"for raw document pretraining datasets."
),
)

# TODO(osilkin):
# we are only exposing this here because `run_training` today is implicitly coupled
# with `process_data`. Since we don't have a specific field for data processing arguments,
Expand Down
111 changes: 109 additions & 2 deletions src/instructlab/training/data_process.py
Original file line number Diff line number Diff line change
Expand Up @@ -412,7 +412,10 @@ def process_messages_into_input_ids_with_chat_template(args: DataProcessArgs):
logger.info("Tokenizing the dataset with %s tokenizer...", args.model_path)
data_with_input_ids = data.map(
lambda x: {
"input_ids": tokenizer.apply_chat_template(x["messages"], tokenize=True),
# newer versions of transformers have `return_dict=True` by default
"input_ids": tokenizer.apply_chat_template(
x["messages"], tokenize=True, return_dict=False
),
"unmask": bool(x["unmask"]) if "unmask" in x else False,
},
num_proc=NUM_PROC,
Expand Down Expand Up @@ -687,7 +690,8 @@ def unmask_messages(
if regions:
message_regions_map[idx] = regions

input_ids = tokenizer.apply_chat_template(msgs_with_unmasking)
# newer versions of transformers have `return_dict=True` by default
input_ids = tokenizer.apply_chat_template(msgs_with_unmasking, return_dict=False)

# Get token IDs for all unmask tokens
unmask_begin_token_id = tokenizer.encode(
Expand Down Expand Up @@ -1133,6 +1137,109 @@ def process_messages_into_input_ids(
save_dataset(final_dataset, data_output_path, num_cpu_procs)


def process_documents_for_pretraining(
data_path: str,
data_output_path: str,
model_path: str,
num_cpu_procs: int,
document_column_name: str = "document",
) -> None:
"""
Process raw documents for pretraining by tokenizing without chunking.

Outputs one JSONL record per document with only input_ids (no labels).
Blocking/chunking happens later during training.

Pattern: Each document → [BOS][tokens][EOS]

Args:
data_path: Path to input JSONL with {"document": "text"} format
data_output_path: Directory for processed data output
model_path: Path to model/tokenizer
num_cpu_procs: Number of parallel processes
document_column_name: Name of the column containing the documents
"""
ensure_can_write_to_directory(data_output_path)

# Load and validate dataset
try:
data = load_dataset("json", data_files=data_path, split="train")
except Exception as e:
raise ValueError(
"Malformed or missing data, please ensure your dataset is correctly formatted"
) from e

if data.num_rows == 0:
raise ValueError("The provided dataset is empty")

if document_column_name not in data.column_names:
raise ValueError(
f"Pretraining data must have '{document_column_name}' field. Found: {data.column_names}"
)

logger.info("Loading tokenizer from %s", model_path)
tokenizer = AutoTokenizer.from_pretrained(model_path)

if tokenizer.eos_token_id is None:
raise ValueError("Tokenizer must have an EOS token defined for pretraining")

logger.info("Tokenizing %d documents for pretraining...", data.num_rows)

# Tokenize each document: encode() adds BOS, then append EOS
def tokenize_document(sample):
input_ids = tokenizer.encode(
sample[document_column_name], add_special_tokens=True
)

# ensures eos token is present without double-adding it.
if input_ids[-1] != tokenizer.eos_token_id:
input_ids.append(tokenizer.eos_token_id)

return {
"input_ids": input_ids,
"len": len(input_ids),
}

# Filter out empty documents before tokenization
def filter_empty_documents(batch):
return [bool(doc) for doc in batch[document_column_name]]

filtered_data = data.filter(
filter_empty_documents,
batched=True,
num_proc=num_cpu_procs,
desc="Filtering empty documents",
)

dropped_count = data.num_rows - filtered_data.num_rows
if dropped_count > 0:
logger.info(f"Dropped {dropped_count:,} empty documents")
tokenized_data = filtered_data.map(
tokenize_document,
num_proc=num_cpu_procs,
desc="Tokenizing documents",
remove_columns=filtered_data.column_names,
)

# Calculate statistics
total_tokens = sum(tokenized_data["len"])
avg_tokens = total_tokens / len(tokenized_data)
logger.info(f"Processed {len(tokenized_data):,} documents")
logger.info(f"Total tokens: {total_tokens:,}")
logger.info(f"Average tokens per document: {avg_tokens:.1f}")

# Save to JSONL (one record per document)
os.makedirs(data_output_path, exist_ok=True)
output_file = Path(data_output_path) / "data.jsonl"

tokenized_data.to_json(
output_file, num_proc=num_cpu_procs, lines=True, orient="records"
)

logger.info(f"Saved tokenized documents to {output_file}")
logger.info("Note: Blocking into fixed-size chunks will happen during training")


def ensure_can_write_to_directory(output_dir: str) -> None:
"""
Ensure that we can write to the output directory.
Expand Down
Loading
Loading