@@ -24,12 +24,17 @@ def __init__(self, *args, **kwargs):
2424 self .padding [1 ], 2 * self .padding [0 ], 0 )
2525 self .padding = (0 , 0 , 0 )
2626
27- def forward (self , x , cache_x = None ):
27+ def forward (self , x , cache_x = None , cache_list = None , cache_idx = None ):
28+ if cache_list is not None :
29+ cache_x = cache_list [cache_idx ]
30+ cache_list [cache_idx ] = None
31+
2832 padding = list (self ._padding )
2933 if cache_x is not None and self ._padding [4 ] > 0 :
3034 cache_x = cache_x .to (x .device )
3135 x = torch .cat ([cache_x , x ], dim = 2 )
3236 padding [4 ] -= cache_x .shape [2 ]
37+ del cache_x
3338 x = F .pad (x , padding )
3439
3540 return super ().forward (x )
@@ -166,7 +171,7 @@ def __init__(self, in_dim, out_dim, dropout=0.0):
166171 if in_dim != out_dim else nn .Identity ()
167172
168173 def forward (self , x , feat_cache = None , feat_idx = [0 ]):
169- h = self . shortcut ( x )
174+ old_x = x
170175 for layer in self .residual :
171176 if isinstance (layer , CausalConv3d ) and feat_cache is not None :
172177 idx = feat_idx [0 ]
@@ -178,12 +183,12 @@ def forward(self, x, feat_cache=None, feat_idx=[0]):
178183 cache_x .device ), cache_x
179184 ],
180185 dim = 2 )
181- x = layer (x , feat_cache [ idx ] )
186+ x = layer (x , cache_list = feat_cache , cache_idx = idx )
182187 feat_cache [idx ] = cache_x
183188 feat_idx [0 ] += 1
184189 else :
185190 x = layer (x )
186- return x + h
191+ return x + self . shortcut ( old_x )
187192
188193
189194class AttentionBlock (nn .Module ):
0 commit comments