Skip to content

Commit 3669afc

Browse files
authored
Yn update (#304)
* add contining training from specific dir[load_model_dir] * Update model loading to handle different output dimensions in retrain(trasnfer learning) * update docs * update unimol format Uni-Mol * update url dptech-core to deepmodeling * update version setup * update unimol v2 docs * update split methods * update split method: group split; kfold=1 for all training * merge main * update train docs * Fix: unimol_tools using unimolv2 sometimes hang at multiprocesses * update version 0.1.2 * update log for generate conformers
1 parent 90ad6af commit 3669afc

File tree

9 files changed

+117
-33
lines changed

9 files changed

+117
-33
lines changed

unimol_tools/setup.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -5,7 +5,7 @@
55

66
setup(
77
name="unimol_tools",
8-
version="0.1.1.post1",
8+
version="0.1.2",
99
description=("unimol_tools is a Python package for property prediciton with Uni-Mol in molecule, materials and protein."),
1010
long_description=open('README.md').read(),
1111
long_description_content_type='text/markdown',

unimol_tools/unimol_tools/data/conformer.py

Lines changed: 20 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -291,6 +291,7 @@ def single_process(self, smiles):
291291
:return: A unimolecular data representation (dictionary) of the molecule.
292292
:raises ValueError: If the conformer generation method is unrecognized.
293293
"""
294+
torch.set_num_threads(1)
294295
if self.method == 'rdkit_random':
295296
mol = inner_smi2coords(smiles, seed=self.seed, mode=self.mode, remove_hs=self.remove_hs, return_mol=True)
296297
return mol2unimolv2(mol, self.max_atoms, remove_hs=self.remove_hs)
@@ -306,14 +307,26 @@ def transform_raw(self, atoms_list, coordinates_list):
306307
return inputs
307308

308309
def transform(self, smiles_list):
309-
pool = Pool()
310+
torch.set_num_threads(1)
311+
pool = Pool(processes=min(8, os.cpu_count()))
310312
logger.info('Start generating conformers...')
311313
inputs = [item for item in tqdm(pool.imap(self.single_process, smiles_list))]
312314
pool.close()
313-
# failed_cnt = np.mean([(item['src_coord']==0.0).all() for item in inputs])
314-
# logger.info('Succeeded in generating conformers for {:.2f}% of molecules.'.format((1-failed_cnt)*100))
315-
# failed_3d_cnt = np.mean([(item['src_coord'][:,2]==0.0).all() for item in inputs])
316-
# logger.info('Succeeded in generating 3d conformers for {:.2f}% of molecules.'.format((1-failed_3d_cnt)*100))
315+
316+
failed_conf = [(item['src_coord']==0.0).all() for item in inputs]
317+
logger.info('Succeeded in generating conformers for {:.2f}% of molecules.'.format((1-np.mean(failed_conf))*100))
318+
failed_conf_indices = [index for index, value in enumerate(failed_conf) if value]
319+
if len(failed_conf_indices) > 0:
320+
logger.info('Failed conformers indices: {}'.format(failed_conf_indices))
321+
logger.debug('Failed conformers SMILES: {}'.format([smiles_list[index] for index in failed_conf_indices]))
322+
323+
failed_conf_3d = [(item['src_coord'][:,2]==0.0).all() for item in inputs]
324+
logger.info('Succeeded in generating 3d conformers for {:.2f}% of molecules.'.format((1-np.mean(failed_conf_3d))*100))
325+
failed_conf_3d_indices = [index for index, value in enumerate(failed_conf_3d) if value]
326+
if len(failed_conf_3d_indices) > 0:
327+
logger.info('Failed 3d conformers indices: {}'.format(failed_conf_3d_indices))
328+
logger.debug('Failed 3d conformers SMILES: {}'.format([smiles_list[index] for index in failed_conf_3d_indices]))
329+
317330
return inputs
318331

319332
def create_mol_from_atoms_and_coords(atoms, coordinates):
@@ -365,13 +378,13 @@ def mol2unimolv2(mol, max_atoms=128, remove_hs=True, **params):
365378
coordinates = coordinates[idx]
366379
# tokens padding
367380
src_tokens = torch.tensor([AllChem.GetPeriodicTable().GetAtomicNumber(item) for item in atoms])
368-
src_pos = torch.tensor(coordinates)
381+
src_coord = torch.tensor(coordinates)
369382
# change AllChem.RemoveHs to AllChem.RemoveAllHs
370383
mol = AllChem.RemoveAllHs(mol)
371384
node_attr, edge_index, edge_attr = get_graph(mol)
372385
feat = get_graph_features(edge_attr, edge_index, node_attr, drop_feat=0)
373386
feat['src_tokens'] = src_tokens
374-
feat['src_pos'] = src_pos
387+
feat['src_coord'] = src_coord
375388
return feat
376389

377390
def safe_index(l, e):

unimol_tools/unimol_tools/data/datahub.py

Lines changed: 26 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -7,6 +7,9 @@
77
from .datareader import MolDataReader
88
from .datascaler import TargetScaler
99
from .conformer import ConformerGen, UniMolV2Feature
10+
from .split import Splitter
11+
from ..utils import logger
12+
1013

1114
class DataHub(object):
1215
"""
@@ -31,6 +34,7 @@ def __init__(self, data=None, is_train=True, save_path=None, **params):
3134
self.multiclass_cnt = params.get('multiclass_cnt', None)
3235
self.ss_method = params.get('target_normalize', 'none')
3336
self._init_data(**params)
37+
self._init_split(**params)
3438

3539
def _init_data(self, **params):
3640
"""
@@ -89,3 +93,25 @@ def _init_data(self, **params):
8993
no_h_list = UniMolV2Feature().transform(smiles_list)
9094

9195
self.data['unimol_input'] = no_h_list
96+
97+
def _init_split(self, **params):
98+
99+
self.split_method = params.get('split_method','5fold_random')
100+
kfold, method = int(self.split_method.split('fold')[0]), self.split_method.split('_')[-1] # Nfold_xxxx
101+
self.kfold = params.get('kfold', kfold)
102+
self.method = params.get('split', method)
103+
self.split_seed = params.get('split_seed', 42)
104+
self.data['kfold'] = self.kfold
105+
if not self.is_train:
106+
return
107+
self.splitter = Splitter(self.method, self.kfold, seed=self.split_seed)
108+
split_nfolds = self.splitter.split(**self.data)
109+
if self.kfold == 1:
110+
logger.info(f"Kfold is 1, all data is used for training.")
111+
else:
112+
logger.info(f"Split method: {self.method}, fold: {self.kfold}")
113+
nfolds = np.zeros(len(split_nfolds[0][0])+len(split_nfolds[0][1]), dtype=int)
114+
for enu, (tr_idx, te_idx) in enumerate(split_nfolds):
115+
nfolds[te_idx] = enu
116+
self.data['split_nfolds'] = split_nfolds
117+
return split_nfolds

unimol_tools/unimol_tools/data/datareader.py

Lines changed: 7 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -82,17 +82,17 @@ def read_data(self, data=None, is_train=True, **params):
8282
target_cols = target_cols.split(',')
8383
elif isinstance(target_cols, list):
8484
pass
85-
else:
85+
else:
8686
for col in target_cols:
8787
if col not in data.columns:
8888
data[target_cols] = -1.0
8989
break
90-
91-
if is_train and anomaly_clean:
92-
data = self.anomaly_clean(data, task, target_cols)
93-
94-
if is_train and task == 'multiclass':
95-
multiclass_cnt = int(data[target_cols].max() + 1)
90+
91+
if is_train:
92+
if anomaly_clean:
93+
data = self.anomaly_clean(data, task, target_cols)
94+
if task == 'multiclass':
95+
multiclass_cnt = int(data[target_cols].max() + 1)
9696

9797
targets = data[target_cols].values.tolist()
9898
num_classes = len(target_cols)
Lines changed: 40 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -4,26 +4,30 @@
44

55
from __future__ import absolute_import, division, print_function
66

7+
import numpy as np
78
from sklearn.model_selection import (
89
GroupKFold,
910
KFold,
1011
StratifiedKFold,
1112
)
13+
from ..utils import logger
14+
1215

1316
class Splitter(object):
1417
"""
1518
The Splitter class is responsible for splitting a dataset into train and test sets
1619
based on the specified method.
1720
"""
18-
def __init__(self, split_method='5fold_random', seed=42):
21+
def __init__(self, method='random', kfold=5, seed=42, **params):
1922
"""
2023
Initializes the Splitter with a specified split method and random seed.
2124
2225
:param split_method: (str) The method for splitting the dataset, in the format 'Nfold_method'.
2326
Defaults to '5fold_random'.
2427
:param seed: (int) Random seed for reproducibility in random splitting. Defaults to 42.
2528
"""
26-
self.n_splits, self.method = int(split_method.split('fold')[0]), split_method.split('_')[-1] # Nfold_xxxx
29+
self.method = method
30+
self.n_splits = kfold
2731
self.seed = seed
2832
self.splitter = self._init_split()
2933

@@ -34,18 +38,22 @@ def _init_split(self):
3438
:return: The initialized splitter object.
3539
:raises ValueError: If an unknown splitting method is specified.
3640
"""
41+
if self.n_splits == 1:
42+
return None
3743
if self.method == 'random':
3844
splitter = KFold(n_splits=self.n_splits, shuffle=True, random_state=self.seed)
3945
elif self.method == 'scaffold' or self.method == 'group':
4046
splitter = GroupKFold(n_splits=self.n_splits)
4147
elif self.method == 'stratified':
4248
splitter = StratifiedKFold(n_splits=self.n_splits, shuffle=True, random_state=self.seed)
49+
elif self.method == 'select':
50+
splitter = GroupKFold(n_splits=self.n_splits)
4351
else:
4452
raise ValueError('Unknown splitter method: {}fold - {}'.format(self.n_splits, self.method))
4553

4654
return splitter
4755

48-
def split(self, data, target=None, group=None):
56+
def split(self, smiles, target=None, group=None, scaffolds=None, **params):
4957
"""
5058
Splits the dataset into train and test sets based on the initialized method.
5159
@@ -56,7 +64,32 @@ def split(self, data, target=None, group=None):
5664
:return: An iterator yielding train and test set indices for each fold.
5765
:raises ValueError: If the splitter method does not support the provided parameters.
5866
"""
59-
try:
60-
return self.splitter.split(data, target, group)
61-
except:
62-
raise ValueError('Unknown splitter method: {}fold - {}'.format(self.n_splits, self.method))
67+
if self.n_splits == 1:
68+
logger.warning('Only one fold is used for training, no splitting is performed.')
69+
return [(np.arange(len(smiles)), ())]
70+
if self.method in ['random']:
71+
self.skf = self.splitter.split(smiles)
72+
elif self.method in ['scaffold']:
73+
self.skf = self.splitter.split(smiles, target, scaffolds)
74+
elif self.method in ['group']:
75+
self.skf = self.splitter.split(smiles, target, group)
76+
elif self.method in ['stratified']:
77+
self.skf = self.splitter.split(smiles, group)
78+
elif self.method in ['select']:
79+
unique_groups = np.unique(group)
80+
if len(unique_groups) == self.n_splits:
81+
split_folds = []
82+
for unique_group in unique_groups:
83+
train_idx = np.where(group != unique_group)[0]
84+
test_idx = np.where(group == unique_group)[0]
85+
split_folds.append((train_idx, test_idx))
86+
self.split_folds = split_folds
87+
return self.split_folds
88+
else:
89+
logger.error('The number of unique groups is not equal to the number of splits.')
90+
exit(1)
91+
else:
92+
logger.error('Unknown splitter method: {}'.format(self.method))
93+
exit(1)
94+
self.split_folds = list(self.skf)
95+
return self.split_folds

unimol_tools/unimol_tools/models/nnmodel.py

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -70,7 +70,7 @@ def __init__(self, data, trainer, **params):
7070
self.data_type = params.get('data_type', 'molecule')
7171
self.loss_key = params.get('loss_key', None)
7272
self.trainer = trainer
73-
self.splitter = self.trainer.splitter
73+
#self.splitter = self.trainer.splitter
7474
self.model_params = params.copy()
7575
self.task = params['task']
7676
if self.task in OUTPUT_DIM:
@@ -150,7 +150,7 @@ def run(self):
150150
y.reshape(y.shape[0], self.num_classes)).astype(float)
151151
else:
152152
y_pred = np.zeros((y.shape[0], self.model_params['output_dim']))
153-
for fold, (tr_idx, te_idx) in enumerate(self.splitter.split(X, y, group)):
153+
for fold, (tr_idx, te_idx) in enumerate(self.data['split_nfolds']):
154154
X_train, y_train = X[tr_idx], y[tr_idx]
155155
X_valid, y_valid = X[te_idx], y[te_idx]
156156
traindataset = NNDataset(X_train, y_train)
@@ -220,7 +220,7 @@ def evaluate(self, trainer=None, checkpoints_path=None):
220220
"""
221221
logger.info("start predict NNModel:{}".format(self.model_name))
222222
testdataset = NNDataset(self.features, np.asarray(self.data['target']))
223-
for fold in range(self.splitter.n_splits):
223+
for fold in range(self.data['kfold']):
224224
model_path = os.path.join(checkpoints_path, f'model_{fold}.pth')
225225
self.model.load_state_dict(torch.load(
226226
model_path, map_location=self.trainer.device)['model_state_dict'])
@@ -229,7 +229,7 @@ def evaluate(self, trainer=None, checkpoints_path=None):
229229
if fold == 0:
230230
y_pred = np.zeros_like(_y_pred)
231231
y_pred += _y_pred
232-
y_pred /= self.splitter.n_splits
232+
y_pred /= self.data['kfold']
233233
self.cv['test_pred'] = y_pred
234234

235235
def count_parameters(self, model):

unimol_tools/unimol_tools/models/unimolv2.py

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -165,14 +165,14 @@ def forward(
165165
pair_type,
166166
attn_bias,
167167
src_tokens,
168-
src_pos,
168+
src_coord,
169169
return_repr=False,
170170
return_atomic_reprs=False,
171171
**kwargs
172172
):
173173

174174

175-
pos = src_pos
175+
pos = src_coord
176176

177177
n_mol, n_atom = atom_feat.shape[:2]
178178
token_feat = self.embed_tokens(src_tokens)
@@ -232,7 +232,7 @@ def one_block(x, pos, return_x=False):
232232
filtered_tensors = []
233233
filtered_coords = []
234234

235-
for tokens, coord in zip(src_tokens, src_pos):
235+
for tokens, coord in zip(src_tokens, src_coord):
236236
filtered_tensor = tokens[(tokens != 0) & (tokens != 1) & (tokens != 2)] # filter out BOS(0), EOS(1), PAD(2)
237237
filtered_coord = coord[(tokens != 0) & (tokens != 1) & (tokens != 2)]
238238
filtered_tensors.append(filtered_tensor)
@@ -315,7 +315,7 @@ def batch_collate_fn(self, samples):
315315
v = pad_2d([s[0][k] for s in samples], pad_idx=self.padding_idx)
316316
elif k == 'src_tokens':
317317
v = pad_1d_tokens([s[0][k] for s in samples], pad_idx=self.padding_idx)
318-
elif k == 'src_pos':
318+
elif k == 'src_coord':
319319
v = pad_coords([s[0][k] for s in samples], pad_idx=self.padding_idx)
320320
batch[k] = v
321321
try:

unimol_tools/unimol_tools/tasks/trainer.py

Lines changed: 0 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -15,7 +15,6 @@
1515
# from transformers.optimization import get_linear_schedule_with_warmup
1616
from ..utils import Metrics
1717
from ..utils import logger
18-
from .split import Splitter
1918
from tqdm import tqdm
2019

2120
import time
@@ -46,7 +45,6 @@ def _init_trainer(self, **params):
4645
self.split_seed = params.get('split_seed', 42)
4746
self.seed = params.get('seed', 42)
4847
self.set_seed(self.seed)
49-
self.splitter = Splitter(self.split_method, self.split_seed)
5048
self.logger_level = int(params.get('logger_level', 1))
5149
### init NN trainer params ###
5250
self.learning_rate = float(params.get('learning_rate', 1e-4))

unimol_tools/unimol_tools/train.py

Lines changed: 15 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -72,9 +72,23 @@ def __init__(self,
7272
7373
- multilabel_regression: mae, mse, r2.
7474
75-
:param split: str, default='random', split method of training dataset. currently support: random, scaffold, group, stratified.
75+
:param split: str, default='random', split method of training dataset. currently support: random, scaffold, group, stratified, select.
76+
77+
- random: random split.
78+
79+
- scaffold: split by scaffold.
80+
81+
- group: split by group. `split_group_col` should be specified.
82+
83+
- stratified: stratified split. `split_group_col` should be specified.
84+
85+
- select: use `split_group_col` to manually select the split group. Column values of `split_group_col` should be range from 0 to kfold-1 to indicate the split group.
86+
7687
:param split_group_col: str, default='scaffold', column name of group split.
7788
:param kfold: int, default=5, number of folds for k-fold cross validation.
89+
90+
- 1: no split. all data will be used for training.
91+
7892
:param save_path: str, default='./exp', path to save training results.
7993
:param remove_hs: bool, default=False, whether to remove hydrogens from molecules.
8094
:param smiles_col: str, default='SMILES', column name of SMILES.

0 commit comments

Comments
 (0)