Issue
I'm just learn pytorch recently. And I try to write a same model like the paper that I have read for practice.
This is the PDF of the paper I refer. https://dl.acm.org/doi/pdf/10.1145/3178876.3186066?download=true
Here is the code what I wrote.
class Tem(torch.nn.Module):
def __init__(self, embedding_size, hidden_size):
super(Tem, self).__init()
self.embedding_size = embedding_size
self.hidden_size = hidden_size
self.leaf_size = 0
self.xgb_model = None
self.vec_embedding = None
self.multi_hot_Q = None
self.user_embedding = torch.nn.Linear(1, embedding_size)
self.item_embedding = torch.nn.Linear(1, embedding_size)
def pretrain(self, ui_attributes, labels):
print("Start XGBoost Training...")
self.xgb_model = XGBoost(ui_attributes, labels)
self.leaf_size = self.xgb_model.leaf_size
self.vec_embedding = Variable(torch.rand(self.embedding_size, self.leaf_size, requires_grad=True))
self.h = Variable(torch.rand(self.hidden_size, 1, requires_grad=True))
self.att_w = Variable(torch.rand(2 * self.embedding_size, self.hidden_size, requires_grad=True))
self.att_b = Variable(torch.rand(self.leaf_size, self.hidden_size, requires_grad=True))
self.r_1 = Variable(torch.rand(self.embedding_size, 1, requires_grad=True))
self.r_2 = Variable(torch.rand(self.embedding_size, 1, requires_grad=True))
self.bias = Variable(torch.rand(1, 1, requires_grad=True))
def forward(self, ui_ids, ui_attributes):
if self.xgb_model == None:
raise Exception("Please run Tem.pretrain() to pre-train XGBoost model first.")
n_data = len(ui_ids)
att_input = torch.FloatTensor(ui_attributes)
self.multi_hot_Q = torch.FloatTensor(self.xgb_model.multi_hot(att_input)).permute(0,2,1)
vq = self.vec_embedding * self.multi_hot_Q
id_input = torch.FloatTensor(ui_ids)
user_embedded = self.user_embedding(id_input[:,0].reshape(n_data, 1))
item_embedded = self.item_embedding(id_input[:,1].reshape(n_data, 1))
ui = (user_embedded * item_embedded).reshape(n_data, self.embedding_size, 1)
ui_repeat = ui.repeat(1, 1, self.leaf_size)
cross = torch.cat([ui_repeat, vq], dim=1).permute(0,2,1)
re_cross = corss.reshape(cross.shape[0] * cross.shape[1], cross.shape[2])
attention = torch.mm(re_cross, self.att_w)
attention = F.leaky_relu(attention + self.att_b.repeat(n_data, 1))
attention = torch.mm(attention, self.h).reshape(n_data, self.leaf_size)
attention = F.softmax(attention).reshape(n_data, self.leaf_size, 1)
attention = self.vec_embedding.permute(1,0) * attention.repeat(1,1,20)
pool = torch.max(attention, 1).values
y_hat = self.bias.repeat(n_data, 1) + torch.mm(ui.reshape(n_data, self.embedding_size), self.r_1) + torch.mm(pool, self.r_2)
y_hat = F.softmax(torch.nn.Linear(1, 2)(y_hat))
return y_hat
My question is...It seems torch didn't know what tensor should be calculate gradient in backward propagation.
print(tem)
Tem(
(user_embedding): Linear(in_features=1, out_features=20, bias=True)
(item_embedding): Linear(in_features=1, out_features=20, bias=True)
)
I googled this problem, someone says those tensors should use torch.autograd.Variable()
, but it didn't solve my problem. And someone says autograd directly supports tensors now. torch.autograd.Variable()
is not necessary.
loss_func = torch.nn.CrossEntropyLoss()
optimizer = torch.Adagrad(tem.parameters(), lr=0.02)
for t in range(20):
prediction = tem(ids_train, att_train)
loss = loss_func(prediction, y_train)
optimizer.zero_grad()
loss.backward()
optimizer.step()
if t % 5 == 0:
print("loss: ", loss)
loss: tensor(0.8133, grad_fn=<NllLossBackward>)
loss: tensor(0.8133, grad_fn=<NllLossBackward>)
loss: tensor(0.8133, grad_fn=<NllLossBackward>)
loss: tensor(0.8133, grad_fn=<NllLossBackward>)
Solution
Your problem is not related to Variable
. As you said, it's not necessary anymore. To compute the gradients of a tensor declared in a model (that extends nn.Module
) you need to include them into the model's parameters using the method nn.Parameter()
. For example, to include self.h
, you can do:
self.h = nn.Parameter(torch.zeros(10,10)
Now, when you call loss.backward()
it'll collect the gradient for this variable (of course, loss
must be dependent on self.h
).
Answered By - André Pacheco
0 comments:
Post a Comment
Note: Only a member of this blog may post a comment.