@@ -148,29 +148,6 @@ def forward(self, x, feat_cache=None, feat_idx=[0]):
148148 feat_idx [0 ] += 1
149149 return x
150150
151- def init_weight (self , conv ):
152- conv_weight = conv .weight
153- nn .init .zeros_ (conv_weight )
154- c1 , c2 , t , h , w = conv_weight .size ()
155- one_matrix = torch .eye (c1 , c2 )
156- init_matrix = one_matrix
157- nn .init .zeros_ (conv_weight )
158- #conv_weight.data[:,:,-1,1,1] = init_matrix * 0.5
159- conv_weight .data [:, :, 1 , 0 , 0 ] = init_matrix #* 0.5
160- conv .weight .data .copy_ (conv_weight )
161- nn .init .zeros_ (conv .bias .data )
162-
163- def init_weight2 (self , conv ):
164- conv_weight = conv .weight .data
165- nn .init .zeros_ (conv_weight )
166- c1 , c2 , t , h , w = conv_weight .size ()
167- init_matrix = torch .eye (c1 // 2 , c2 )
168- #init_matrix = repeat(init_matrix, 'o ... -> (o 2) ...').permute(1,0,2).contiguous().reshape(c1,c2)
169- conv_weight [:c1 // 2 , :, - 1 , 0 , 0 ] = init_matrix
170- conv_weight [c1 // 2 :, :, - 1 , 0 , 0 ] = init_matrix
171- conv .weight .data .copy_ (conv_weight )
172- nn .init .zeros_ (conv .bias .data )
173-
174151
175152class ResidualBlock (nn .Module ):
176153
@@ -485,12 +462,6 @@ def __init__(self,
485462 self .decoder = Decoder3d (dim , z_dim , dim_mult , num_res_blocks ,
486463 attn_scales , self .temperal_upsample , dropout )
487464
488- def forward (self , x ):
489- mu , log_var = self .encode (x )
490- z = self .reparameterize (mu , log_var )
491- x_recon = self .decode (z )
492- return x_recon , mu , log_var
493-
494465 def encode (self , x ):
495466 self .clear_cache ()
496467 ## cache
@@ -536,18 +507,6 @@ def decode(self, z):
536507 self .clear_cache ()
537508 return out
538509
539- def reparameterize (self , mu , log_var ):
540- std = torch .exp (0.5 * log_var )
541- eps = torch .randn_like (std )
542- return eps * std + mu
543-
544- def sample (self , imgs , deterministic = False ):
545- mu , log_var = self .encode (imgs )
546- if deterministic :
547- return mu
548- std = torch .exp (0.5 * log_var .clamp (- 30.0 , 20.0 ))
549- return mu + std * torch .randn_like (std )
550-
551510 def clear_cache (self ):
552511 self ._conv_num = count_conv3d (self .decoder )
553512 self ._conv_idx = [0 ]
0 commit comments