Issue
RBM we add methods to convert the visible input to the hidden representation and the hidden representation back to reconstructed visible input. Both methods return the activation probabilities, while the sample_h method also returns the observed hidden state as well
<pre><code>
class RBM():
def __init__(self, visible_dim, hidden_dim, gaussian_hidden_distribution=False):
self.visible_dim = visible_dim
self.hidden_dim = hidden_dim
self.gaussian_hidden_distribution = gaussian_hidden_distribution
# intialize parameters
self.W = torch.randn(visible_dim, hidden_dim) * 0.1
self.h_bias = torch.zeros(hidden_dim) # visible --> hidden
self.v_bias = torch.zeros(visible_dim) # hidden --> visible
# parameters for learning with momentum
self.W_momentum = torch.zeros(visible_dim, hidden_dim)
self.h_bias_momentum = torch.zeros(hidden_dim)
self.v_bias_momentum = torch.zeros(visible_dim)
def sample_h(self, v):
activation = torch.mm(v, self.W) + self.h_bias
if self.gaussian_hidden_distribution:
return activation, torch.normal(activation, torch.tensor([1]))
else:
p = torch.sigmoid(activation)
return p, torch.bernoulli(p)
def sample_v(self, h):
"""Get visible activation probabilities"""
activation = torch.mm(h, self.W.t()) + self.v_bias
p = torch.sigmoid(activation)
return p
def update_weights(self, v0, vk, ph0, phk, lr,
momentum_coef, weight_decay, batch_size):
self.W_momentum *= momentum_coef
self.W_momentum += torch.mm(v0.t(), ph0) - torch.mm(vk.t(), phk)
self.h_bias_momentum *= momentum_coef
self.h_bias_momentum += torch.sum((ph0 - phk), 0)
self.v_bias_momentum *= momentum_coef
self.v_bias_momentum += torch.sum((v0 - vk), 0)
self.W += lr*self.W_momentum/batch_size
self.h_bias += lr*self.h_bias_momentum/batch_size
self.v_bias += lr*self.v_bias_momentum/batch_size
self.W -= self.W * weight_decay # L2 weight decay
</code></pre>
Training RBM While training the model i am getting " RuntimeError: self must be a matrix", can someone help me out and tell me what changes should I make in code.
<pre><code>
models = [] # store trained RBM models
visible_dim = 784
rbm_train_dl = train_dl_flat
for hidden_dim in [1000, 500, 250, 2]:
# configs - we have a different configuration for the last layer
num_epochs = 30 if hidden_dim == 2 else 10
lr = 1e-3 if hidden_dim == 2 else 0.1
use_gaussian = hidden_dim == 2
# train RBM
rbm = RBM(visible_dim=visible_dim, hidden_dim=hidden_dim,
gaussian_hidden_distribution=use_gaussian)
for epoch in range(num_epochs):
for i, data_list in enumerate(train_dl):
v0 = data_list[0]
# get reconstructed input via Gibbs sampling with k=1
_, hk = rbm.sample_h(v0)
pvk = rbm.sample_v(hk)
# update weights
rbm.update_weights(v0, pvk, rbm.sample_h(v0)[0], rbm.sample_h(pvk)[0], lr,
momentum_coef=0.5 if epoch < 5 else 0.9,
weight_decay=2e-4,
batch_size=sample_data.shape[0])
models.append(rbm)
# rederive new data loader based on hidden activations of trained model
new_data = [model.sample_h(data_list[0])[0].detach().numpy() for data_list in rbm_train_dl]
rbm_train_dl = DataLoader(
TensorDataset(torch.Tensor(np.concatenate(new_data))),
batch_size=64, shuffle=False
)
visible_dim = hidden_dim
</code></pre>
ERROR
<pre><code>
RuntimeError Traceback (most recent call last)
<ipython-input-3-53fe4223334d> in <module>()
16
17 # get reconstructed input via Gibbs sampling with k=1
---> 18 _, hk = rbm.sample_h(v0)
19 pvk = rbm.sample_v(hk)
20 # update weights
<ipython-input-1-49d2abc1da92> in sample_h(self, v)
15 def sample_h(self, v):
16 """Get sample hidden values and activation probabilities"""
---> 17 activation = torch.mm(v, self.W) + self.h_bias
18 if self.gaussian_hidden_distribution:
19 return activation, torch.normal(activation, torch.tensor([1]))
RuntimeError: self must be a matrix
</code></pre>
Solution
Seems you need broadcasting (because you're multiplying 1d vector on 2D matrix).
Try using torch.matmul
instead.
This link for understanding the difference between mm
and matmul
:
What's the difference between torch.mm, torch.matmul and torch.mul?
Answered By - Alex
0 comments:
Post a Comment
Note: Only a member of this blog may post a comment.