Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
8 changes: 5 additions & 3 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -166,7 +166,7 @@ either by randomly sampling the sequences (“Random”) or by greedily maximizi

It is possible to unconditionally generate an entire MSA, using the following script:
```
python evodiff/generate-msa.py --model-type msa_oa_dm_maxsub --batch-size 1 --n-sequences 64 --n-sequences 256 --subsampling MaxHamming
python evodiff/generate_msa.py --model-type msa_oa_dm_maxsub --batch-size 1 --n-sequences 64 --n-sequences 256 --subsampling MaxHamming
```

The default model type is `msa_oa_dm_maxsub`, which is EvoDiff-MSA-OADM trained on Max subsampled sequences, and the other available
Expand All @@ -193,15 +193,17 @@ thus generating new members of a protein family without needing to train family-
To generate a new query sequence, given an alignment, use the following with the `--start-msa` flag. This starts conditional
generation by sampling from a validation MSA. To run this script you must have the Openfold dataset and splits downloaded.
```
python evodiff/generate-msa.py --model-type msa_oa_dm_maxsub --batch-size 1 --n-sequences 64 --n-sequences 256 --subsampling MaxHamming --start-msa
python evodiff/generate_msa.py --model-type msa_oa_dm_maxsub --batch-size 1 --n-sequences 64 --n-sequences 256 --subsampling MaxHamming --start-msa
```
If you want to generate on a custom MSA, it is possible to retrofit existing code.

Additionally, the code is capable of generating an alignment given a query sequence, use the following `--start-query` flag.
This starts with the query and generates the alignment.
```
python evodiff/generate-msa.py --model-type msa_oa_dm_maxsub --batch-size 1 --n-sequences 64 --n-sequences 256 --subsampling MaxHamming --start-query
python evodiff/generate_msa.py --model-type msa_oa_dm_maxsub --batch-size 4 --n-sequences 2 --gpus 0 --subsampling MaxHamming --start-query --dataset openfold/test.a3m --out_fpath out
```
This command takes a .a3m file in data/`--dataset` as input. You have to provide additional non-all-gaps lines, which are higher or equal to the amount of sequences you wan to generate. Special pre-processing behaviour for the default `openfold` dataset.

NOTE: you can only specify one of the above flags at a time. You cannot specify both (`--start-query` & `--start-msa`) together.
Please look at `generate.py` for more information.

Expand Down
20 changes: 20 additions & 0 deletions data/openfold/test.a3m
Original file line number Diff line number Diff line change
@@ -0,0 +1,20 @@
>Test_MSA
ACDEFGHIKLMNPQRSTVY
>Buffer_1
ACDEFGHIKLMNPQRSTVY
>Buffer_2
ACDEFGHIKLMNPQRSTVY
>Buffer_3
ACDEFGHIKLMNPQRSTVY
>Buffer_4
ACDEFGHIKLMNPQRSTVY
>Buffer_5
ACDEFGHIKLMNPQRSTVY
>Buffer_6
ACDEFGHIKLMNPQRSTVY
>Buffer_7
ACDEFGHIKLMNPQRSTVY
>Buffer_8
ACDEFGHIKLMNPQRSTVY
>Buffer_9
ACDEFGHIKLMNPQRSTVY
91 changes: 55 additions & 36 deletions evodiff/data.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
import os
from pathlib import Path
from tqdm import tqdm
from scipy.spatial.distance import hamming, cdist

Expand Down Expand Up @@ -308,7 +309,7 @@ def __getitem__(self, idx):
class A3MMSADataset(Dataset):
"""Build dataset for A3M data: MSA Absorbing Diffusion model"""

def __init__(self, selection_type, n_sequences, max_seq_len, data_dir=None, min_depth=None):
def __init__(self, selection_type, n_sequences, max_seq_len, data_dir=None, min_depth=None, openfold=True):
"""
Args:
selection_type: str,
Expand All @@ -319,52 +320,68 @@ def __init__(self, selection_type, n_sequences, max_seq_len, data_dir=None, min_
maximum MSA sequence length
data_dir: str,
if you have a specified data directory
min_depth: int,
filter out shallower MSAs
openfold: bool,
use openfold dataset or custom dataset at data_dir
"""
alphabet = PROTEIN_ALPHABET
self.tokenizer = Tokenizer(alphabet)
self.alpha = np.array(list(alphabet))
self.gap_idx = self.tokenizer.alphabet.index(GAP)
self.openfold=openfold

# Get npz_data dir
if data_dir is not None:
self.data_dir = data_dir
else:
raise FileNotFoundError(data_dir)

[print("Excluding", x) for x in os.listdir(self.data_dir) if x.endswith('.npz')]
all_files = [x for x in os.listdir(self.data_dir) if not x.endswith('.npz')]
all_files = sorted(all_files)
[print(f"Excluding {x}") for x in Path(self.data_dir).glob("*.npz")]
if Path(self.data_dir).is_dir():
all_files = [x for x in Path(self.data_dir).glob("*[!.npz]")]
all_files = sorted(all_files)
else:
all_files = [self.data_dir]
print("unfiltered length", len(all_files))

## Filter based on depth (keep > 64 seqs/MSA)
if not os.path.exists(data_dir + 'openfold_lengths.npz'):
raise Exception("Missing openfold_lengths.npz in openfold/")
if not os.path.exists(data_dir + 'openfold_depths.npz'):
#get_msa_depth_openfold(data_dir, sorted(all_files), 'openfold_depths.npz')
raise Exception("Missing openfold_depths.npz in openfold/")
if min_depth is not None: # reindex, filtering out MSAs < min_depth
_depths = np.load(data_dir+'openfold_depths.npz')['arr_0']
depths = pd.DataFrame(_depths, columns=['depth'])
depths = depths[depths['depth'] >= min_depth]
keep_idx = depths.index

_lengths = np.load(data_dir+'openfold_lengths.npz')['ells']
lengths = np.array(_lengths)[keep_idx]
all_files = np.array(all_files)[keep_idx]
print("filter MSA depth > 64", len(all_files))

# Re-filter based on high gap-contining rows
if not os.path.exists(data_dir + 'openfold_gap_depths.npz'):
#get_sliced_gap_depth_openfold(data_dir, all_files, 'openfold_gap_depths.npz', max_seq_len=max_seq_len)
raise Exception("Missing openfold_gap_depths.npz in openfold/")
_gap_depths = np.load(data_dir + 'openfold_gap_depths.npz')['arr_0']
gap_depths = pd.DataFrame(_gap_depths, columns=['gapdepth'])
gap_depths = gap_depths[gap_depths['gapdepth'] >= min_depth]
filter_gaps_idx = gap_depths.index
lengths = np.array(lengths)[filter_gaps_idx]
all_files = np.array(all_files)[filter_gaps_idx]
print("filter rows with GAPs > 512", len(all_files))

if openfold:
## Filter based on depth (keep > 64 seqs/MSA)
if not os.path.exists(data_dir + 'openfold_lengths.npz'):
raise Exception("Missing openfold_lengths.npz in openfold/")
if not os.path.exists(data_dir + 'openfold_depths.npz'):
#get_msa_depth_openfold(data_dir, sorted(all_files), 'openfold_depths.npz')
raise Exception("Missing openfold_depths.npz in openfold/")
if min_depth is not None: # reindex, filtering out MSAs < min_depth
_depths = np.load(data_dir+'openfold_depths.npz')['arr_0']
depths = pd.DataFrame(_depths, columns=['depth'])
print(depths)
depths = depths[depths['depth'] >= min_depth]
keep_idx = depths.index

_lengths = np.load(data_dir+'openfold_lengths.npz')['ells']
print(np.array(_lengths))
lengths = np.array(_lengths)[keep_idx]
all_files = np.array(all_files)[keep_idx]
print("filter MSA depth > 64", len(all_files))

# Re-filter based on high gap-contining rows
if not os.path.exists(data_dir + 'openfold_gap_depths.npz'):
#get_sliced_gap_depth_openfold(data_dir, all_files, 'openfold_gap_depths.npz', max_seq_len=max_seq_len)
raise Exception("Missing openfold_gap_depths.npz in openfold/")
_gap_depths = np.load(data_dir + 'openfold_gap_depths.npz')['arr_0']
gap_depths = pd.DataFrame(_gap_depths, columns=['gapdepth'])
gap_depths = gap_depths[gap_depths['gapdepth'] >= min_depth]
filter_gaps_idx = gap_depths.index
lengths = np.array(lengths)[filter_gaps_idx]
all_files = np.array(all_files)[filter_gaps_idx]
print("filter rows with GAPs > 512", len(all_files))
else:
all_files = np.array(all_files) #maybe expand to whole path
lengths = []
for file in all_files:
parsed_msa = parse_fasta(file)
lengths.append(max([len(line) for line in parsed_msa]))
lengths = np.array(lengths)
self.filenames = all_files # IDs of samples to include
self.lengths = lengths # pass to batch sampler
self.n_sequences = n_sequences
Expand All @@ -376,7 +393,10 @@ def __len__(self):

def __getitem__(self, idx):
filename = self.filenames[idx]
path = read_openfold_files(self.data_dir, filename)
if self.openfold:
path = read_openfold_files(self.data_dir, filename)
else:
path = filename
parsed_msa = parse_fasta(path)

aligned_msa = [[char for char in seq if (char.isupper() or char == '-') and not char == '.'] for seq in parsed_msa]
Expand Down Expand Up @@ -522,7 +542,6 @@ def __len__(self):

def __getitem__(self, idx):
filename = self.filenames[idx]
print(filename)
path = read_idr_files(self.data_dir, filename)
parsed_msa = parse_fasta(path)
aligned_msa = [[char for char in seq if (char.isupper() or char == '-') and not char == '.'] for seq in parsed_msa]
Expand Down
59 changes: 30 additions & 29 deletions evodiff/generate_msa.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,10 +18,10 @@
def main():
parser = argparse.ArgumentParser()
#parser.add_argument('config_fpath')
#parser.add_argument('out_fpath', type=str, nargs='?',
parser.add_argument('--out_fpath', type=str, default=None)# nargs='?',
# default=os.getenv('AMLT_OUTPUT_DIR', '/tmp') + '/')
parser.add_argument('-g', '--gpus', default=1, type=int,
help='number of gpus per node')
parser.add_argument('-g', '--gpus', default=0, type=int,
help='Number of gpus per node')
parser.add_argument('-off', '--offset', default=0, type=int,
help='Number of GPU devices to skip.')
parser.add_argument('--model-type', type=str, default='msa_oa_dm_maxsub')
Expand Down Expand Up @@ -113,11 +113,11 @@ def main():

if args.amlt:
home = os.getenv('AMLT_OUTPUT_DIR', '/tmp') + '/'
out_fpath = home
out_fpath = home if args.out_fpath is None else args.out_fpath
else:
home = str(pathlib.Path.home()) + '/Desktop/DMs/'
top_dir = home
out_fpath = home + args.model_type + '/gen-'+str(args.run) + '/'
out_fpath = home + args.model_type + '/gen-'+str(args.run) + '/' if args.out_fpath is None else args.out_fpath

if not os.path.exists(out_fpath):
os.makedirs(out_fpath)
Expand All @@ -133,23 +133,21 @@ def main():
print("Penalizing GAPS by factor of", 1+args.penalty_value)
else:
print("Not penalizing GAPS")

batch_size = args.batch_size if pathlib.Path(data_dir).is_dir() else 1
if scheme == 'mask':
sample, _string = generate_msa(model, tokenizer, args.batch_size, args.n_sequences, args.seq_length,
sample, _string = generate_msa(model, tokenizer, batch_size, args.n_sequences, args.seq_length,
penalty_value=args.penalty_value, device=device, start_query=args.start_query,
start_msa=args.start_msa,
data_top_dir=data_top_dir, selection_type=args.subsampling, out_path=out_fpath)
data_top_dir=data_top_dir, selection_type=args.subsampling, out_path=out_fpath, openfold=args.dataset=="openfold", data_dir=args.dataset)
elif scheme == 'd3pm':
sample, _string = generate_msa_d3pm(model, args.batch_size, args.n_sequences, args.seq_length,
sample, _string = generate_msa_d3pm(model, batch_size, args.n_sequences, args.seq_length,
Q_bar=Q_bar, Q=Q, tokenizer=Tokenizer(), data_top_dir=data_top_dir,
selection_type=args.subsampling, out_path=out_fpath,
max_timesteps=timestep, start_query=args.start_query,
no_step=False, penalty_value=args.penalty_value, device=device)


no_step=False, penalty_value=args.penalty_value, device=device, openfold=args.dataset=="openfold", data_dir=args.dataset)
for count, msa in enumerate(_string):
fasta_string = ""
with open(out_fpath + 'generated_msas.a3m', 'a') as f:
with open(pathlib.Path(out_fpath)/'generated_msas.a3m', 'a') as f:
for seq in range(args.n_sequences):
seq_num = seq * args.seq_length
next_seq_num = (seq+1) * args.seq_length
Expand All @@ -160,19 +158,19 @@ def main():
f.write(">tr \n" + str(seq_string) + "\n" )
f.write(fasta_string)
f.close()
np.save(out_fpath+'generated_msas', np.array(sample.cpu()))
np.save(pathlib.Path(out_fpath)/'generated_msas', np.array(sample.cpu()))


def generate_msa(model, tokenizer, batch_size, n_sequences, seq_length, penalty_value=2, device='gpu',
start_query=False, start_msa=False, data_top_dir='../data', selection_type='MaxHamming', out_path='../ref/'):
start_query=False, start_msa=False, data_top_dir='../data', selection_type='MaxHamming', out_path='../ref/', openfold=False, data_dir="openfold/"):
mask_id = tokenizer.mask_id
src = torch.full((batch_size, n_sequences, seq_length), fill_value=mask_id)
masked_loc_x = np.arange(n_sequences)
masked_loc_y = np.arange(seq_length)
if start_query:
valid_msas, query_sequences, tokenizer =get_valid_data(data_top_dir, batch_size, 'autoreg', data_dir='openfold/',
valid_msas, query_sequences, tokenizer =get_valid_data(data_top_dir, batch_size, 'autoreg', data_dir=data_dir,
selection_type=selection_type, n_sequences=n_sequences, max_seq_len=seq_length,
out_path=out_path)
out_path=out_path, openfold=openfold)
# First row is query sequence
for i in range(batch_size):
seq_len = len(query_sequences[i])
Expand All @@ -184,10 +182,11 @@ def generate_msa(model, tokenizer, batch_size, n_sequences, seq_length, penalty_
y_indices = np.arange(seq_len)
elif start_msa:
valid_msas, query_sequences, tokenizer = get_valid_data(data_top_dir, batch_size, 'autoreg',
data_dir='openfold/',
data_dir=data_dir,
selection_type=selection_type, n_sequences=n_sequences,
max_seq_len=seq_length,
out_path=out_path)
out_path=out_path,
openfold=openfold)
for i in range(batch_size):
seq_len = len(query_sequences[i])
src[i, 1:n_sequences, :seq_len] = valid_msas[i][0, 1:n_sequences, :seq_len].squeeze()
Expand Down Expand Up @@ -270,14 +269,14 @@ def generate_query_oadm_msa_simple(path_to_msa, model, tokenizer, n_sequences, s

def generate_msa_d3pm(model, batch_size, n_sequences, seq_length, Q_bar=None, Q=None, tokenizer=Tokenizer(),
start_query=False, data_top_dir='../data', selection_type='MaxHamming', out_path='../ref/',
max_timesteps=500, no_step=False, penalty_value=0, device='gpu'):
max_timesteps=500, no_step=False, penalty_value=0, device='gpu', openfold=False, data_dir="openfold/"):
sample = torch.randint(0, tokenizer.K, (batch_size, n_sequences, seq_length))
if start_query:
x_indices = []
y_indices = []
valid_msas, query_sequences, tokenizer =get_valid_data(data_top_dir, batch_size, 'autoreg', data_dir='openfold/',
selection_type=selection_type, n_sequences=n_sequences, max_seq_len=seq_length,
out_path=out_path)
out_path=out_path, openfold=openfold)
# First row is query sequence
for i in range(batch_size):
seq_len = len(query_sequences[i])
Expand Down Expand Up @@ -340,22 +339,24 @@ def generate_msa_d3pm(model, batch_size, n_sequences, seq_length, Q_bar=None, Q=


def get_valid_data(data_top_dir, num_seqs, arg_mask, data_dir='openfold/', selection_type='MaxHamming', n_sequences=64, max_seq_len=512,
out_path='../DMs/ref/'):
out_path='../DMs/ref/', openfold=True):
valid_msas = []
query_msas = []
seq_lens = []

_ = torch.manual_seed(1) # same seeds as training
np.random.seed(1)

dataset = A3MMSADataset(selection_type, n_sequences, max_seq_len, data_dir=os.path.join(data_top_dir,data_dir), min_depth=64)
dataset = A3MMSADataset(selection_type, n_sequences, max_seq_len, data_dir=os.path.join(data_top_dir,data_dir), min_depth=64, openfold=openfold)

train_size = len(dataset)
random_ind = np.random.choice(train_size, size=(train_size - 10000), replace=False)
val_ind = np.delete(np.arange(train_size), random_ind)


ds_valid = Subset(dataset, val_ind)
openfold=False
if openfold:
random_ind = np.random.choice(train_size, size=(train_size - 10000), replace=False)
val_ind = np.delete(np.arange(train_size), random_ind)
ds_valid = Subset(dataset, val_ind)
else:
ds_valid = dataset

if arg_mask == 'autoreg':
tokenizer = Tokenizer()
Expand Down Expand Up @@ -394,7 +395,7 @@ def get_valid_data(data_top_dir, num_seqs, arg_mask, data_dir='openfold/', selec
print("LEN VALID MSAS", len(valid_msas))
untokenized = [[tokenizer.untokenize(msa.flatten())] for msa in valid_msas]
fasta_string = ""
with open(out_path + 'valid_msas.a3m', 'a') as f:
with open(pathlib.Path(out_path)/'valid_msas.a3m', 'a') as f:
for i, msa in enumerate(untokenized):
for seq in range(n_sequences):
seq_num = seq * seq_lens[i]
Expand Down