11import unittest
2+ from typing import Dict , Optional
23
34import numpy as np
45import torch
@@ -56,8 +57,35 @@ def get_rij(xyz, batch, nbrs, cutoff):
5657 return r_ij , nbrs
5758
5859
59- def add_embedding (atomwise_out , all_results ):
60- all_results ["embedding" ] = atomwise_out ["features" ]
60+ def add_embedding (
61+ atomwise_out : dict ,
62+ all_results : dict ,
63+ pool_embeddings : bool = False ,
64+ pool_type : str = "sum" ,
65+ batch : Optional [Dict ] = None ,
66+ ) -> dict :
67+ """Add node-wise embeddings to the results dictionary.
68+
69+ Args:
70+ atomwise_out (dict): output of the atomwise layers
71+ all_results (dict): results dictionary from the forward pass through the full network
72+ pool_embedding (bool, optional): whether or not to pool the embedding. Defaults to False.
73+ pool_type (str, optional): type of pooling to use, either "sum" or "mean". Defaults to "sum".
74+
75+ Returns:
76+ _type_: _description_
77+ """
78+ if pool_embeddings :
79+ if not batch :
80+ raise ValueError ("batch must be provided if pooling is requested." )
81+ n_atoms = batch ["num_atoms" ].detach ().cpu ().tolist ()
82+ split_feat = torch .split (atomwise_out ["features" ], n_atoms )
83+ if pool_type == "sum" :
84+ all_results ["embedding" ] = torch .stack ([i .sum (0 ) for i in split_feat ])
85+ elif pool_type == "mean" :
86+ all_results ["embedding" ] = torch .stack ([i .mean (0 ) for i in split_feat ])
87+ else :
88+ all_results ["embedding" ] = atomwise_out ["features" ]
6189
6290 return all_results
6391
@@ -74,10 +102,12 @@ def add_stress(batch, all_results, nbrs, r_ij):
74102 if batch ["num_atoms" ].shape [0 ] == 1 :
75103 all_results ["stress_volume" ] = torch .matmul (Z .t (), r_ij )
76104 else :
77- allstress = torch .stack ([
78- torch .matmul (Z [torch .where (nbrs [:, 0 ] == j )].t (), r_ij [torch .where (nbrs [:, 0 ] == j )])
79- for j in range (batch ["nxyz" ].shape [0 ])
80- ])
105+ allstress = torch .stack (
106+ [
107+ torch .matmul (Z [torch .where (nbrs [:, 0 ] == j )].t (), r_ij [torch .where (nbrs [:, 0 ] == j )])
108+ for j in range (batch ["nxyz" ].shape [0 ])
109+ ]
110+ )
81111 N = batch ["num_atoms" ].detach ().cpu ().tolist ()
82112 split_val = torch .split (allstress , N )
83113 all_results ["stress_volume" ] = torch .stack ([i .sum (0 ) for i in split_val ])
@@ -1014,7 +1044,7 @@ def sum_and_grad(batch, xyz, r_ij, nbrs, atomwise_output, grad_keys, out_keys=No
10141044 use_val = val .sum (- 1 )
10151045
10161046 else :
1017- raise Exception ("Don't know how to handle val shape " f" { val .shape } for key { key } " )
1047+ raise Exception (f "Don't know how to handle val shape { val .shape } for key { key } " )
10181048
10191049 pooled_result = scatter_add (use_val , mol_idx , dim_size = dim_size )
10201050 if mean :
@@ -1030,10 +1060,12 @@ def sum_and_grad(batch, xyz, r_ij, nbrs, atomwise_output, grad_keys, out_keys=No
10301060 if key == "stress" :
10311061 output = results ["energy" ]
10321062 grad_ = compute_grad (output = output , inputs = r_ij )
1033- allstress = torch .stack ([
1034- torch .matmul (grad_ [torch .where (nbrs [:, 0 ] == i )].t (), r_ij [torch .where (nbrs [:, 0 ] == i )])
1035- for i in range (batch ["nxyz" ].shape [0 ])
1036- ])
1063+ allstress = torch .stack (
1064+ [
1065+ torch .matmul (grad_ [torch .where (nbrs [:, 0 ] == i )].t (), r_ij [torch .where (nbrs [:, 0 ] == i )])
1066+ for i in range (batch ["nxyz" ].shape [0 ])
1067+ ]
1068+ )
10371069 split_val = torch .split (allstress , N )
10381070 grad_ = torch .stack ([i .sum (0 ) for i in split_val ])
10391071 if "cell" in batch :
0 commit comments