-
Notifications
You must be signed in to change notification settings - Fork 611
Mix batch 0429 #5439
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: master
Are you sure you want to change the base?
Mix batch 0429 #5439
Changes from all commits
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -73,3 +73,10 @@ frozen_model.* | |
|
|
||
| # Test system directories | ||
| system/ | ||
|
|
||
| temp/ | ||
| test_mptraj/ | ||
| pkl/ | ||
| history/ | ||
| deepmd-kit/ | ||
| *.hdf5 | ||
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -224,7 +224,20 @@ def forward( | |
| more_loss = {} | ||
| # more_loss['log_keys'] = [] # showed when validation on the fly | ||
| # more_loss['test_keys'] = [] # showed when doing dp test | ||
| atom_norm = 1.0 / natoms | ||
|
|
||
| # Detect mixed batch format | ||
| is_mixed_batch = "ptr" in input_dict and input_dict["ptr"] is not None | ||
|
|
||
| # For mixed batch, compute per-frame atom_norm and average | ||
| if is_mixed_batch: | ||
| ptr = input_dict["ptr"] | ||
| nframes = ptr.numel() - 1 | ||
| # Compute natoms for each frame | ||
| natoms_per_frame = ptr[1:] - ptr[:-1] # [nframes] | ||
| # Average atom_norm across frames | ||
| atom_norm = torch.mean(1.0 / natoms_per_frame.float()) | ||
| else: | ||
| atom_norm = 1.0 / natoms | ||
|
Comment on lines
+228
to
+240
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. This mixed-batch normalization misweights frames with different atom counts.
🤖 Prompt for AI Agents |
||
| if self.has_e and "energy" in model_pred and "energy" in label: | ||
| energy_pred = model_pred["energy"] | ||
| energy_label = label["energy"] | ||
|
|
||
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -318,6 +318,137 @@ def forward_atomic( | |
| ) | ||
| return fit_ret | ||
|
|
||
| def forward_common_atomic_flat( | ||
| self, | ||
| extended_coord: torch.Tensor, | ||
| extended_atype: torch.Tensor, | ||
| extended_batch: torch.Tensor, | ||
| nlist: torch.Tensor, | ||
| mapping: torch.Tensor, | ||
| batch: torch.Tensor, | ||
| ptr: torch.Tensor, | ||
| fparam: torch.Tensor | None = None, | ||
| aparam: torch.Tensor | None = None, | ||
| extended_ptr: torch.Tensor | None = None, | ||
| central_ext_index: torch.Tensor | None = None, | ||
| nlist_ext: torch.Tensor | None = None, | ||
| a_nlist: torch.Tensor | None = None, | ||
| a_nlist_ext: torch.Tensor | None = None, | ||
| nlist_mask: torch.Tensor | None = None, | ||
| a_nlist_mask: torch.Tensor | None = None, | ||
| edge_index: torch.Tensor | None = None, | ||
| angle_index: torch.Tensor | None = None, | ||
| ) -> dict[str, torch.Tensor]: | ||
| """Forward pass with flat batch format. | ||
|
|
||
| Parameters | ||
| ---------- | ||
| extended_coord : torch.Tensor | ||
| Extended coordinates [total_extended_atoms, 3]. | ||
| extended_atype : torch.Tensor | ||
| Extended atom types [total_extended_atoms]. | ||
| extended_batch : torch.Tensor | ||
| Frame assignment for extended atoms [total_extended_atoms]. | ||
| nlist : torch.Tensor | ||
| Neighbor list [total_atoms, nnei]. | ||
| mapping : torch.Tensor | ||
| Extended atom -> local flat index mapping [total_extended_atoms]. | ||
| batch : torch.Tensor | ||
| Frame assignment for local atoms [total_atoms]. | ||
| ptr : torch.Tensor | ||
| Frame boundaries [nframes + 1]. | ||
| fparam : torch.Tensor | None | ||
| Frame parameters [nframes, ndf]. | ||
| aparam : torch.Tensor | None | ||
| Atomic parameters [total_atoms, nda]. | ||
| central_ext_index : torch.Tensor | None | ||
| Extended-atom indices corresponding to local atoms. | ||
| nlist_ext, a_nlist_ext : torch.Tensor | None | ||
| Edge and angle neighbor lists indexing concatenated extended atoms. | ||
| nlist_mask, a_nlist_mask : torch.Tensor | None | ||
| Valid-neighbor masks for flat edge and angle neighbor lists. | ||
| edge_index, angle_index : torch.Tensor | None | ||
| Dynamic graph indices produced by the flat graph preprocessor. | ||
|
|
||
| Returns | ||
| ------- | ||
| result_dict : dict[str, torch.Tensor] | ||
| Model predictions in flat format. | ||
| """ | ||
| if self.do_grad_r() or self.do_grad_c(): | ||
| extended_coord.requires_grad_(True) | ||
|
|
||
| # Descriptor and fitting both consume the flat atom layout. | ||
| descriptor_out = self.descriptor.forward_flat( | ||
| extended_coord, | ||
| extended_atype, | ||
| extended_batch, | ||
| nlist, | ||
| mapping, | ||
| batch, | ||
| ptr, | ||
| fparam=fparam if self.add_chg_spin_ebd else None, | ||
|
Comment on lines
+378
to
+390
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Mirror the default-
Suggested fix if self.do_grad_r() or self.do_grad_c():
extended_coord.requires_grad_(True)
+ if (
+ hasattr(self.fitting_net, "get_dim_fparam")
+ and self.fitting_net.get_dim_fparam() > 0
+ and fparam is None
+ ):
+ default_fparam_tensor = self.fitting_net.get_default_fparam()
+ assert default_fparam_tensor is not None
+ fparam_input_for_des = torch.tile(
+ default_fparam_tensor.unsqueeze(0), [ptr.numel() - 1, 1]
+ )
+ else:
+ fparam_input_for_des = fparam
+
# Descriptor and fitting both consume the flat atom layout.
descriptor_out = self.descriptor.forward_flat(
extended_coord,
@@
- fparam=fparam if self.add_chg_spin_ebd else None,
+ fparam=fparam_input_for_des if self.add_chg_spin_ebd else None,
central_ext_index=central_ext_index,🤖 Prompt for AI Agents |
||
| central_ext_index=central_ext_index, | ||
| nlist_ext=nlist_ext, | ||
| a_nlist=a_nlist, | ||
| a_nlist_ext=a_nlist_ext, | ||
| nlist_mask=nlist_mask, | ||
| a_nlist_mask=a_nlist_mask, | ||
| edge_index=edge_index, | ||
| angle_index=angle_index, | ||
| ) | ||
|
|
||
| descriptor = descriptor_out.get("descriptor") | ||
| rot_mat = descriptor_out.get("rot_mat") | ||
| g2 = descriptor_out.get("g2") | ||
| h2 = descriptor_out.get("h2") | ||
|
|
||
| if self.enable_eval_descriptor_hook: | ||
| self.eval_descriptor_list.append(descriptor.detach()) | ||
|
|
||
| if central_ext_index is None: | ||
| from deepmd.pt.utils.nlist import get_central_ext_index | ||
|
|
||
| central_ext_index = get_central_ext_index(extended_batch, ptr) | ||
| atype = extended_atype[central_ext_index] | ||
| else: | ||
| atype = extended_atype[central_ext_index] | ||
|
|
||
| fit_ret = self.fitting_net.forward_flat( | ||
| descriptor, | ||
| atype, | ||
| batch, | ||
| ptr, | ||
| gr=rot_mat, | ||
| g2=g2, | ||
| h2=h2, | ||
| fparam=fparam, | ||
| aparam=aparam, | ||
| ) | ||
| fit_ret = self.apply_out_stat(fit_ret, atype) | ||
|
|
||
| atom_mask = self.make_atom_mask(atype).to(torch.int32) | ||
| if self.atom_excl is not None: | ||
| atom_mask *= self.atom_excl(atype.unsqueeze(0)).squeeze(0) | ||
|
|
||
| for kk in fit_ret.keys(): | ||
| out_shape = fit_ret[kk].shape | ||
| out_shape2 = 1 | ||
| for ss in out_shape[1:]: | ||
| out_shape2 *= ss | ||
| fit_ret[kk] = ( | ||
| fit_ret[kk].reshape([out_shape[0], out_shape2]) * atom_mask[:, None] | ||
| ).view(out_shape) | ||
| fit_ret["mask"] = atom_mask | ||
|
|
||
| if self.enable_eval_fitting_last_layer_hook: | ||
| if "middle_output" in fit_ret: | ||
| self.eval_fitting_last_layer_list.append( | ||
| fit_ret.pop("middle_output").detach() | ||
| ) | ||
|
|
||
| return fit_ret | ||
|
|
||
| def compute_or_load_stat( | ||
| self, | ||
| sampled_func: Callable[[], list[dict]], | ||
|
|
||
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Ignoring
test_mptraj/conflicts with tracked files in this PR.This PR commits
test_mptraj/lmdb_baseline.jsonandtest_mptraj/lmdb_mixed_batch.jsonundertest_mptraj/, but the newtest_mptraj/ignore rule will hide any future additions/edits in that directory and forcesgit add -f. Either drop the directory-level ignore or add explicit negations for the tracked configs.🛠️ Suggested fix
📝 Committable suggestion
🤖 Prompt for AI Agents