Hello, thanks for your pretty implementation. I think I may find a small bug in your LARS implementation.trust_ratio = tf.where( tf.greater(w_norm, 0), tf.where( tf.greater(g_norm, 0), (self.eeta * w_norm / g_norm), 1.0), 1.0)
is a little different from
|
trust_ratio = torch.where( |
|
w_norm.ge(0), |
|
torch.where( |
|
g_norm.ge(0), |
|
(self.eeta * w_norm / g_norm), |
|
torch.Tensor([1.0]).to(device), |
|
), |
|
torch.Tensor([1.0]).to(device), |
|
).item() |
As
greater is > and
ge is >=. Thus bias paramater which is initialized as 0 is never updated. I think
trust_ratio = torch.where( w_norm.gt(0), torch.where( g_norm.gt(0), (self.eeta * w_norm / g_norm), torch.Tensor([1.0]).to(device), ), torch.Tensor([1.0]).to(device), ).item() may work better.
Hello, thanks for your pretty implementation. I think I may find a small bug in your LARS implementation.
trust_ratio = tf.where( tf.greater(w_norm, 0), tf.where( tf.greater(g_norm, 0), (self.eeta * w_norm / g_norm), 1.0), 1.0)is a little different from
SimCLR/modules/lars.py
Lines 119 to 127 in 654f05f
As greater is > and ge is >=. Thus bias paramater which is initialized as 0 is never updated. I think
trust_ratio = torch.where( w_norm.gt(0), torch.where( g_norm.gt(0), (self.eeta * w_norm / g_norm), torch.Tensor([1.0]).to(device), ), torch.Tensor([1.0]).to(device), ).item()may work better.