Issue
The following code is an instance method built in a class
def get_samples_from_component(self,batchSize):
SMALL = torch.tensor(1e-10, dtype=torch.float64, device=local_device)
a_inv = torch.pow(self.kumar_a,-1)
b_inv = torch.pow(self.kumar_b,-1)
r1 = torch.tensor(SMALL, dtype=torch.float64,device=self.device)
r2 = torch.tensor(1-SMALL, dtype=torch.float64, device=self.device)
v_means = torch.mul(self.kumar_b, beta_fn(1.+a_inv, self.kumar_b)).to(device=self.device)
u = torch.distributions.uniform.Uniform(low=r1, high=r2).sample([1]).squeeze()
v_samples = torch.pow(1 - torch.pow(u, b_inv), a_inv).to(device=self.device)
if v_samples.ndim > 2:
v_samples = v_samples.squeeze()
v0 = v_samples[:, -1].pow(0).reshape(v_samples.shape[0], 1)
v1 = torch.cat([v_samples[:, :self.z_dim - 1], v0], dim=1)
n_samples = v1.size()[0]
n_dims = v1.size()[1]
components = torch.zeros((n_samples, n_dims)).to(device=self.device)
for k in range(n_dims):
if k == 0:
components[:, k] = v1[:, k]
else:
components[:, k] = v1[:, k] * torch.stack([(1 - v1[:, j]) for j in range(n_dims) if j < k]).prod(axis=0)
# ensure stick segments sum to 1
assert_almost_equal(torch.ones(n_samples,device=self.device).cpu().numpy(), components.sum(axis=1).detach().cpu().numpy(),
decimal=4, err_msg='stick segments do not sum to 1')
print(f'size of sticks: {components}')
components = torch.IntTensor(torch.argmax(torch.cat( self.compose_stick_segments(v_means),1) ,1), dtype=torch.long, device=self.device)
components = torch.cat( [torch.range(0, batchSize).unsqueeze(1), components.unsqueeze(1)], 1)
print(f'size of sticks: {components}')
all_z = []
for d in range(self.z_dim):
temp_z = torch.cat(1, [self.z_sample_list[k][:, d].unsqueeze(1) for k in range(self.K)])
all_z.append(gather_nd(temp_z, components).unsqueeze(1))
out = torch.cat( all_z,1)
return out
By running my code I get the following error message
components = torch.IntTensor(torch.argmax(torch.cat( self.compose_stick_segments(v_means),1) ,1), dtype=torch.long, device=self.device)
TypeError: new() received an invalid combination of arguments - got (Tensor, device=torch.device, dtype=torch.dtype), but expected one of:
* (*, torch.device device)
* (torch.Storage storage)
* (Tensor other)
* (tuple of ints size, *, torch.device device)
didn't match because some of the keywords were incorrect: dtype
* (object data, *, torch.device device)
didn't match because some of the keywords were incorrect: dtype
I will appreciate if someone suggests a solution for this error.
Solution
v_means
is already a tensor, try to simple remove the tensor re-implementation in:
components = torch.IntTensor(torch.argmax(torch.cat( self.compose_stick_segments(v_means),1) ,1), dtype=torch.long, device=self.device)
to:
components = torch.argmax(torch.cat( self.compose_stick_segments(v_means),1) ,1)
or simply remove the dtype, since it seems to cast it to integer anyway:
components = torch.IntTensor(torch.argmax(torch.cat( self.compose_stick_segments(v_means),1) ,1), device=self.device)
Answered By - Guinther Kovalski
0 comments:
Post a Comment
Note: Only a member of this blog may post a comment.