@@ -21,7 +21,7 @@ class DIPSDGLDataModule(LightningDataModule):
2121 dips_test = None
2222
2323 def __init__ (self , data_dir : str , batch_size : int , num_dataloader_workers : int , knn : int , self_loops : bool ,
24- pn_ratio : float , percent_to_use : float , use_dgl : bool , process_complexes : bool , input_indep : bool ):
24+ pn_ratio : float , percent_to_use : float , process_complexes : bool , input_indep : bool ):
2525 super ().__init__ ()
2626
2727 self .data_dir = data_dir
@@ -31,32 +31,30 @@ def __init__(self, data_dir: str, batch_size: int, num_dataloader_workers: int,
3131 self .self_loops = self_loops
3232 self .pn_ratio = pn_ratio
3333 self .percent_to_use = percent_to_use # Fraction of DIPS dataset splits to use
34- self .use_dgl = use_dgl # Whether to process each complex into a pair of DGL graphs for its final representation
3534 self .process_complexes = process_complexes # Whether to process any unprocessed complexes before training
3635 self .input_indep = input_indep # Whether to use an input-independent pipeline to train the model
36+ self .collate_fn = dgl_picp_collate # Which collation function to use
3737
3838 def setup (self , stage : Optional [str ] = None ):
3939 # Assign training/validation/testing data set for use in DataLoaders - called on every GPU
4040 self .dips_train = DIPSDGLDataset (mode = 'train' , raw_dir = self .data_dir , knn = self .knn , self_loops = self .self_loops ,
4141 pn_ratio = self .pn_ratio , percent_to_use = self .percent_to_use ,
42- use_dgl = self .use_dgl , process_complexes = self .process_complexes ,
43- input_indep = self .input_indep )
42+ process_complexes = self .process_complexes , input_indep = self .input_indep )
4443 self .dips_val = DIPSDGLDataset (mode = 'val' , raw_dir = self .data_dir , knn = self .knn , self_loops = self .self_loops ,
45- pn_ratio = self .pn_ratio , percent_to_use = self .percent_to_use , use_dgl = self . use_dgl ,
44+ pn_ratio = self .pn_ratio , percent_to_use = self .percent_to_use ,
4645 process_complexes = self .process_complexes , input_indep = self .input_indep )
4746 self .dips_test = DIPSDGLDataset (mode = 'test' , raw_dir = self .data_dir , knn = self .knn , self_loops = self .self_loops ,
4847 pn_ratio = self .pn_ratio , percent_to_use = self .percent_to_use ,
49- use_dgl = self .use_dgl , process_complexes = self .process_complexes ,
50- input_indep = self .input_indep )
48+ process_complexes = self .process_complexes , input_indep = self .input_indep )
5149
5250 def train_dataloader (self ) -> DataLoader :
5351 return DataLoader (self .dips_train , batch_size = self .batch_size , shuffle = True ,
54- num_workers = self .num_dataloader_workers , collate_fn = dgl_picp_collate , pin_memory = True )
52+ num_workers = self .num_dataloader_workers , collate_fn = self . collate_fn , pin_memory = True )
5553
5654 def val_dataloader (self ) -> DataLoader :
5755 return DataLoader (self .dips_val , batch_size = self .batch_size , shuffle = False ,
58- num_workers = self .num_dataloader_workers , collate_fn = dgl_picp_collate , pin_memory = True )
56+ num_workers = self .num_dataloader_workers , collate_fn = self . collate_fn , pin_memory = True )
5957
6058 def test_dataloader (self ) -> DataLoader :
6159 return DataLoader (self .dips_test , batch_size = self .batch_size , shuffle = False ,
62- num_workers = self .num_dataloader_workers , collate_fn = dgl_picp_collate , pin_memory = True )
60+ num_workers = self .num_dataloader_workers , collate_fn = self . collate_fn , pin_memory = True )
0 commit comments