Skip to content
45 changes: 41 additions & 4 deletions pufferlib/config/ocean/drive.ini
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@ batch_size = auto

[policy]
; Encoder layer
input_size = 64
input_size = 256
encoder_gigaflow = True
Comment on lines 13 to 16
dropout = 0.0
; Shared backbone layer
Expand All @@ -25,7 +25,7 @@ actor_num_layers = 0
critic_hidden_size = 512
critic_num_layers = 0
; Dual or shared actor-critic backbone
split_network = False
split_network = True

[rnn]
input_size = 512
Expand All @@ -48,6 +48,24 @@ dynamics_model = "jerk"
dt = 0.1
; Optional nonzero launch speed for gigaflow random spawns
spawn_initial_speed = 0.0
; Per-agent spawn dimensions (meters). Training and eval use independent
; ranges so a policy trained on non-car shapes (e.g. trucks) can be eval'd
; on matching shapes. width is clipped to <= length at spawn time.
spawn_length_min = 0.8
spawn_length_max = 7.0
spawn_width_min = 0.8
spawn_width_max = 3.0
eval_spawn_length_min = 2.0
eval_spawn_length_max = 5.5
eval_spawn_width_min = 1.5
eval_spawn_width_max = 2.5
; Mixed-population spawn: per-agent P(is_truck). 0.0 (default) preserves
; the legacy single-population behavior
truck_fraction = 0.0
truck_spawn_length_min = 9.0
truck_spawn_length_max = 15.0
truck_spawn_width_min = 2.0
truck_spawn_width_max = 2.6
Comment on lines +51 to +68
; Collision behavior - options: 0 - Ignore, 1 - Stop, 2 - Remove
collision_behavior = 1
; Offroad behavior - options: 0 - Ignore, 1 - Stop, 2 - Remove
Expand Down Expand Up @@ -85,8 +103,8 @@ min_waypoint_spacing = 20.0
max_waypoint_spacing = 60.0

; --- Rewards ---
reward_conditioning = False
reward_randomization = False
reward_conditioning = True
reward_randomization = True
reward_goal = 1.0
reward_vehicle_collision = 1.0
reward_offroad_collision = 1.0
Expand Down Expand Up @@ -578,3 +596,22 @@ values = [0.001, 0.003, 0.01]

[controlled_exp.train.ent_coef]
values = [0.01, 0.005]

[finetune]
enabled = False
; Strategy: full | freeze | lora
; full - train every parameter starting from base weights (standard finetune)
; freeze - freeze params matching freeze_regex; train the rest
; lora - wrap target Linears with LoRA adapters; freeze base weights of those layers
mode = full
; Overrides train.learning_rate when enabled. None = inherit train.learning_rate.
base_lr = None
; Regex over named_parameters; matched params get requires_grad=False (used by
; mode=freeze; also additive on top of mode=lora).
freeze_regex = None
; LoRA knobs (used when mode=lora). lora_target is a regex over named_modules
; matching nn.Linear layers to wrap.
lora_rank = 16
lora_alpha = 32
lora_target = None
lora_lr_mult = 10.0
28 changes: 28 additions & 0 deletions pufferlib/config/ocean/drive_finetune_nuplan.ini
Original file line number Diff line number Diff line change
@@ -0,0 +1,28 @@
[env]
; --- nuPlan training data (replay-mode rollouts of real traffic logs) ---
map_dir = "/scratch/ev2237/data/nuplan/nuplan_mini_train_bins"
num_maps = 200
simulation_mode = "replay"
control_mode = "control_sdc_only"
init_mode = "create_all_valid"
; nuPlan log length is ~20s at 10Hz = 200 steps. +1 to match behavior eval sections.
scenario_length = 201
resample_frequency = 201
; SDC is the only controlled agent; the "too many inactive agents" early-reset
; trigger doesn't apply here. Use scenario-length termination.
termination_mode = 0

[train]
total_timesteps = 2_000_000_000
checkpoint_interval = 50

[finetune]
enabled = True
mode = lora
; LoRA only on the shared 4-layer backbone — encoders + heads stay fully
; trainable so they can adapt to replay-style observations end-to-end.
lora_target = "actor_backbone\\.backbone\\."
lora_rank = 32
lora_alpha = 64
lora_lr_mult = 10.0
base_lr = 1e-4
18 changes: 18 additions & 0 deletions pufferlib/config/ocean/drive_finetune_reward.ini
Original file line number Diff line number Diff line change
@@ -0,0 +1,18 @@
[env]
reward_goal = 2.0
reward_velocity = 0.3
reward_vehicle_collision = 0.5
reward_offroad_collision = 0.5

[train]
total_timesteps = 2_000_000_000
checkpoint_interval = 50

[finetune]
enabled = True
mode = lora
lora_target = "actor_backbone\\.backbone\\."
lora_rank = 32
lora_alpha = 64
lora_lr_mult = 10.0
base_lr = 1e-4
19 changes: 19 additions & 0 deletions pufferlib/config/ocean/drive_finetune_truck_mixed.ini
Original file line number Diff line number Diff line change
@@ -0,0 +1,19 @@
[env]
; --- Mixed population: 20% trucks at spawn, rest sample the default car range ---
truck_fraction = 0.2
truck_spawn_length_min = 9.0
truck_spawn_length_max = 15.0
truck_spawn_width_min = 2.0
truck_spawn_width_max = 2.6

goal_radius = 4.0
reward_lane_align = 0.01

[train]
total_timesteps = 2_000_000_000
checkpoint_interval = 50

[finetune]
enabled = True
mode = full
base_lr = 1e-4
191 changes: 191 additions & 0 deletions pufferlib/finetune.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,191 @@
"""Parameter-efficient finetuning primitives for PufferLib.

Activated by [finetune].enabled = True in the env config (drive.ini) or the
overlay passed via --finetune-config. Three strategies:

- full : every parameter trainable; nothing here applies.
- freeze : params matching [finetune].freeze_regex get requires_grad=False.
- lora : nn.Linear submodules whose dotted name matches
[finetune].lora_target get wrapped with LoRALinear (frozen base
weight + trainable rank-r adapter B@A).

Ordering inside load_policy:
policy = build()
policy.load_state_dict(base_pt) # base weights present, NO lora keys
apply_freeze(policy, ...) # freeze the layers the user pinned
wrap_lora(policy, ...) # swap nn.Linear -> LoRALinear in place
# ... then DDP wrap, torch.compile, optimizer
"""

import math
import re
from typing import List, Optional

import torch
import torch.nn as nn
import torch.nn.functional as F


class LoRALinear(nn.Module):
"""Drop-in replacement for nn.Linear with a frozen base weight and a
trainable rank-r adapter.

forward(x) = F.linear(x, W, b) + (alpha / r) * F.linear(F.linear(x, A), B)

State-dict layout:
Saves under {prefix}weight (MERGED: W + scaling * B @ A) and {prefix}bias.
Does NOT save lora_A / lora_B — by design. Every LoRA finetune restarts
its adapter from scratch on resume; the merged weight carries forward
the previous run's learning so progress isn't lost. This keeps saved
.pt files load-compatible with a vanilla nn.Linear at the same path,
which is what subprocess evals and downstream finetunes expect.
"""

def __init__(self, base: nn.Linear, rank: int, alpha: float):
super().__init__()
self.in_features = base.in_features
self.out_features = base.out_features

# Copy the base weight + bias as own parameters (NOT a nested
# submodule) so the state_dict key layout matches a vanilla Linear.
self.weight = nn.Parameter(base.weight.data.clone())
self.weight.requires_grad = False
if base.bias is not None:
self.bias = nn.Parameter(base.bias.data.clone())
self.bias.requires_grad = False
else:
self.register_parameter("bias", None)

self.rank = int(rank)
self.alpha = float(alpha)
self.scaling = (self.alpha / self.rank) if self.rank > 0 else 0.0

# Kaiming-uniform init for A (standard LoRA practice).
# B is zero so at step 0 the adapter contributes nothing — the
# wrapped model is mathematically identical to the base policy
# before any gradient step.
self.lora_A = nn.Parameter(torch.empty(self.rank, self.in_features))
nn.init.kaiming_uniform_(self.lora_A, a=math.sqrt(5))
self.lora_B = nn.Parameter(torch.zeros(self.out_features, self.rank))

def forward(self, x: torch.Tensor) -> torch.Tensor:
out = F.linear(x, self.weight, self.bias)
if self.rank > 0:
out = out + self.scaling * F.linear(F.linear(x, self.lora_A), self.lora_B)
return out

def merged_weight(self) -> torch.Tensor:
if self.rank > 0:
return self.weight + self.scaling * (self.lora_B @ self.lora_A)
return self.weight

def _save_to_state_dict(self, destination, prefix, keep_vars):
merged = self.merged_weight()
destination[prefix + "weight"] = merged if keep_vars else merged.detach()
if self.bias is not None:
destination[prefix + "bias"] = self.bias if keep_vars else self.bias.detach()

def extra_repr(self) -> str:
return (
f"in_features={self.in_features}, out_features={self.out_features}, "
f"rank={self.rank}, alpha={self.alpha}"
)


def _resolve_optional_str(value) -> Optional[str]:
"""drive.ini may carry the literal string 'None' for unset values when
routed through ast.literal_eval; treat that the same as Python None."""
if value is None:
return None
if isinstance(value, str) and value.strip() in ("", "None"):
return None
return value


def apply_freeze(policy: nn.Module, regex) -> int:
"""Set requires_grad=False on every named_parameter whose name matches
`regex` (re.search semantics). Returns the count of frozen tensors.
No-op if regex is None / empty.
"""
regex = _resolve_optional_str(regex)
if not regex:
return 0
pattern = re.compile(regex)
count = 0
for name, p in policy.named_parameters():
if pattern.search(name):
p.requires_grad = False
count += 1
return count


def wrap_lora(policy: nn.Module, target_regex, rank: int, alpha: float) -> int:
"""Replace nn.Linear submodules whose dotted module-name matches
`target_regex` (re.search semantics) with LoRALinear. Returns the count
of wrapped modules. No-op if target_regex is None / empty or rank <= 0.

Note: regex matches MODULE NAMES (e.g. 'actor_backbone.backbone.1'), not
parameter names. Type filtering to nn.Linear is enforced internally, so
your regex doesn't need to include 'Linear'.
"""
target_regex = _resolve_optional_str(target_regex)
if not target_regex or rank <= 0:
return 0
pattern = re.compile(target_regex)
targets = []
for name, module in policy.named_modules():
# Skip LoRALinear instances so re-runs are idempotent.
if isinstance(module, LoRALinear):
continue
if isinstance(module, nn.Linear) and pattern.search(name):
targets.append((name, module))
for name, module in targets:
parent_name, _, child_name = name.rpartition(".")
parent = policy.get_submodule(parent_name) if parent_name else policy
wrapped = LoRALinear(module, rank=rank, alpha=alpha).to(
device=module.weight.device, dtype=module.weight.dtype
)
setattr(parent, child_name, wrapped)
return len(targets)


def get_lora_params(policy: nn.Module) -> List[nn.Parameter]:
"""Collect every LoRALinear.lora_A / lora_B in the policy. Used to build
a separate optimizer parameter group with its own LR."""
params: List[nn.Parameter] = []
for module in policy.modules():
if isinstance(module, LoRALinear):
params.append(module.lora_A)
params.append(module.lora_B)
return params


def build_param_groups(policy: nn.Module, base_lr: float, lora_lr: float) -> list:
"""Build optimizer parameter groups. LoRA adapter params get `lora_lr`;
every other trainable param gets `base_lr`. Frozen params are excluded.
Returns a list of dicts suitable for torch.optim.Optimizer.
"""
lora_params = get_lora_params(policy)
lora_param_ids = {id(p) for p in lora_params}
base_params = [
p for p in policy.parameters()
if p.requires_grad and id(p) not in lora_param_ids
]
groups = []
if base_params:
groups.append({"params": base_params, "lr": base_lr})
active_lora = [p for p in lora_params if p.requires_grad]
if active_lora:
groups.append({"params": active_lora, "lr": lora_lr})
return groups


def trainable_summary(policy: nn.Module) -> str:
total = sum(p.numel() for p in policy.parameters())
trainable = sum(p.numel() for p in policy.parameters() if p.requires_grad)
lora_modules = sum(1 for m in policy.modules() if isinstance(m, LoRALinear))
pct = (trainable / total * 100.0) if total > 0 else 0.0
return (
f"[finetune] trainable {trainable:,} / {total:,} ({pct:.2f}%) "
f"| LoRA modules wrapped: {lora_modules}"
)
13 changes: 13 additions & 0 deletions pufferlib/ocean/drive/binding.c
Original file line number Diff line number Diff line change
Expand Up @@ -1814,6 +1814,19 @@ static int my_init(Env *env, PyObject *args, PyObject *kwargs) {
env->traffic_control_scope = (int) unpack(kwargs, "traffic_control_scope");
env->dt = (float) unpack(kwargs, "dt");
env->spawn_initial_speed = (float) unpack(kwargs, "spawn_initial_speed");
env->spawn_length_min = (float) unpack(kwargs, "spawn_length_min");
env->spawn_length_max = (float) unpack(kwargs, "spawn_length_max");
env->spawn_width_min = (float) unpack(kwargs, "spawn_width_min");
env->spawn_width_max = (float) unpack(kwargs, "spawn_width_max");
env->eval_spawn_length_min = (float) unpack(kwargs, "eval_spawn_length_min");
env->eval_spawn_length_max = (float) unpack(kwargs, "eval_spawn_length_max");
env->eval_spawn_width_min = (float) unpack(kwargs, "eval_spawn_width_min");
env->eval_spawn_width_max = (float) unpack(kwargs, "eval_spawn_width_max");
env->truck_fraction = (float) unpack(kwargs, "truck_fraction");
env->truck_spawn_length_min = (float) unpack(kwargs, "truck_spawn_length_min");
env->truck_spawn_length_max = (float) unpack(kwargs, "truck_spawn_length_max");
env->truck_spawn_width_min = (float) unpack(kwargs, "truck_spawn_width_min");
env->truck_spawn_width_max = (float) unpack(kwargs, "truck_spawn_width_max");
env->goal_speed = (float) unpack(kwargs, "goal_speed");
env->scenario_length = (int) unpack(kwargs, "scenario_length");
env->termination_mode = (int) unpack(kwargs, "termination_mode");
Expand Down
Loading
Loading