Skip to content

Commit 357ecd1

Browse files
committed
DOC & ENH: update latent embeddings method
1 parent c5bd804 commit 357ecd1

2 files changed

Lines changed: 50 additions & 16 deletions

File tree

nff/nn/models/painn.py

Lines changed: 7 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -14,6 +14,7 @@
1414
ReadoutBlock_Complex,
1515
ReadoutBlock_Tuple,
1616
ReadoutBlock_Vec,
17+
ReadoutBlock_Vec2,
1718
TransformerMessageBlock,
1819
UpdateBlock,
1920
)
@@ -84,7 +85,6 @@ def __init__(self, modelparams):
8485
)
8586
self.update_blocks = nn.ModuleList(
8687
[UpdateBlock(feat_dim=feat_dim, activation=activation, dropout=conv_dropout) for _ in range(num_conv)]
87-
[UpdateBlock(feat_dim=feat_dim, activation=activation, dropout=conv_dropout) for _ in range(num_conv)]
8888
)
8989

9090
self.output_keys = output_keys
@@ -256,6 +256,7 @@ def run(
256256
requires_embedding=False,
257257
requires_stress=False,
258258
inference=False,
259+
pool_embeddings=False,
259260
):
260261
atomwise_out, xyz, r_ij, nbrs = self.atomwise(batch=batch, xyz=xyz)
261262

@@ -275,7 +276,9 @@ def run(
275276
)
276277

277278
if requires_embedding:
278-
all_results = add_embedding(atomwise_out=atomwise_out, all_results=all_results)
279+
all_results = add_embedding(
280+
atomwise_out=atomwise_out, all_results=all_results, pool_embeddings=pool_embeddings, batch=batch
281+
)
279282

280283
if requires_stress:
281284
all_results = add_stress(batch=batch, all_results=all_results, nbrs=nbrs, r_ij=r_ij)
@@ -292,6 +295,7 @@ def forward(
292295
requires_embedding=False,
293296
requires_stress=False,
294297
inference=False,
298+
pool_embeddings=False,
295299
**kwargs,
296300
):
297301
"""
@@ -308,6 +312,7 @@ def forward(
308312
requires_embedding=requires_embedding,
309313
requires_stress=requires_stress,
310314
inference=inference,
315+
pool_embeddings=pool_embeddings,
311316
)
312317

313318
return results
@@ -517,9 +522,6 @@ def __init__(self, modelparams):
517522
"""
518523
Args:
519524
modelparams (dict): dictionary of model parameters
520-
521-
522-
523525
"""
524526

525527
super().__init__(modelparams)

nff/nn/modules/schnet.py

Lines changed: 43 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,5 @@
11
import unittest
2+
from typing import Dict, Optional
23

34
import numpy as np
45
import torch
@@ -56,8 +57,35 @@ def get_rij(xyz, batch, nbrs, cutoff):
5657
return r_ij, nbrs
5758

5859

59-
def add_embedding(atomwise_out, all_results):
60-
all_results["embedding"] = atomwise_out["features"]
60+
def add_embedding(
61+
atomwise_out: dict,
62+
all_results: dict,
63+
pool_embeddings: bool = False,
64+
pool_type: str = "sum",
65+
batch: Optional[Dict] = None,
66+
) -> dict:
67+
"""Add node-wise embeddings to the results dictionary.
68+
69+
Args:
70+
atomwise_out (dict): output of the atomwise layers
71+
all_results (dict): results dictionary from the forward pass through the full network
72+
pool_embedding (bool, optional): whether or not to pool the embedding. Defaults to False.
73+
pool_type (str, optional): type of pooling to use, either "sum" or "mean". Defaults to "sum".
74+
75+
Returns:
76+
_type_: _description_
77+
"""
78+
if pool_embeddings:
79+
if not batch:
80+
raise ValueError("batch must be provided if pooling is requested.")
81+
n_atoms = batch["num_atoms"].detach().cpu().tolist()
82+
split_feat = torch.split(atomwise_out["features"], n_atoms)
83+
if pool_type == "sum":
84+
all_results["embedding"] = torch.stack([i.sum(0) for i in split_feat])
85+
elif pool_type == "mean":
86+
all_results["embedding"] = torch.stack([i.mean(0) for i in split_feat])
87+
else:
88+
all_results["embedding"] = atomwise_out["features"]
6189

6290
return all_results
6391

@@ -74,10 +102,12 @@ def add_stress(batch, all_results, nbrs, r_ij):
74102
if batch["num_atoms"].shape[0] == 1:
75103
all_results["stress_volume"] = torch.matmul(Z.t(), r_ij)
76104
else:
77-
allstress = torch.stack([
78-
torch.matmul(Z[torch.where(nbrs[:, 0] == j)].t(), r_ij[torch.where(nbrs[:, 0] == j)])
79-
for j in range(batch["nxyz"].shape[0])
80-
])
105+
allstress = torch.stack(
106+
[
107+
torch.matmul(Z[torch.where(nbrs[:, 0] == j)].t(), r_ij[torch.where(nbrs[:, 0] == j)])
108+
for j in range(batch["nxyz"].shape[0])
109+
]
110+
)
81111
N = batch["num_atoms"].detach().cpu().tolist()
82112
split_val = torch.split(allstress, N)
83113
all_results["stress_volume"] = torch.stack([i.sum(0) for i in split_val])
@@ -1014,7 +1044,7 @@ def sum_and_grad(batch, xyz, r_ij, nbrs, atomwise_output, grad_keys, out_keys=No
10141044
use_val = val.sum(-1)
10151045

10161046
else:
1017-
raise Exception("Don't know how to handle val shape " f"{val.shape} for key {key}")
1047+
raise Exception(f"Don't know how to handle val shape {val.shape} for key {key}")
10181048

10191049
pooled_result = scatter_add(use_val, mol_idx, dim_size=dim_size)
10201050
if mean:
@@ -1030,10 +1060,12 @@ def sum_and_grad(batch, xyz, r_ij, nbrs, atomwise_output, grad_keys, out_keys=No
10301060
if key == "stress":
10311061
output = results["energy"]
10321062
grad_ = compute_grad(output=output, inputs=r_ij)
1033-
allstress = torch.stack([
1034-
torch.matmul(grad_[torch.where(nbrs[:, 0] == i)].t(), r_ij[torch.where(nbrs[:, 0] == i)])
1035-
for i in range(batch["nxyz"].shape[0])
1036-
])
1063+
allstress = torch.stack(
1064+
[
1065+
torch.matmul(grad_[torch.where(nbrs[:, 0] == i)].t(), r_ij[torch.where(nbrs[:, 0] == i)])
1066+
for i in range(batch["nxyz"].shape[0])
1067+
]
1068+
)
10371069
split_val = torch.split(allstress, N)
10381070
grad_ = torch.stack([i.sum(0) for i in split_val])
10391071
if "cell" in batch:

0 commit comments

Comments
 (0)