Issue
I build a training data for a model using pytorch.
def shufflerow(tensor1, tensor2, axis):
row_perm = torch.rand(tensor1.shape[:axis+1]).argsort(axis) # get permutation indices
for _ in range(tensor1.ndim-axis-1): row_perm.unsqueeze_(-1)
row_perm = row_perm.repeat(*[1 for _ in range(axis+1)], *(tensor1.shape[axis+1:])) # reformat this for the gather operation
return tensor1.gather(axis, row_perm),tensor2.gather(axis, row_perm)
class Dataset:
def __init__(self, observation, next_observation):
self.data =(observation, next_observation)
indices = torch.randperm(observation.shape[0])
self.train_samples = (observation[indices ,:], next_observation[indices ,:])
self.test_samples = shufflerow(observation, next_observation, 0)
I also have this function which examine whether the data converted to torch.tensor and set the device
def to_tensor(x, device):
if torch.is_tensor(x):
return x
elif isinstance(x, np.ndarray):
return torch.from_numpy(x).to(device=device, dtype=torch.float32)
elif isinstance(x, list):
if all(isinstance(item, np.ndarray) for item in x):
return [torch.from_numpy(item).to(device=device, dtype=torch.float32) for item in x]
elif isinstance(x, tuple):
return (torch.from_numpy(item).to(device=device, dtype=torch.float32) for item in x)
else:
print(f"X:{x} and X's type{type(x)}")
return torch.tensor(x).to(device=device, dtype=torch.float32)
But passing the input data that basically looks like this through the Dataset class data=Dataset(s1,s2) print(data.train_samples)
(tensor([[-0.3121, -0.9500, 1.4518],
[-0.9903, -0.1391, -4.4141],
[-0.9645, -0.2642, 5.0233],
[-0.6413, 0.7673, -4.5495],
[-0.3073, 0.9516, -1.0128],
[-0.5495, 0.8355, 3.4044],
[-0.5710, -0.8209, -3.2716],
[-0.9388, 0.3445, 3.9225],
[-0.8402, -0.5423, -4.0820]]), tensor([[-0.2723, -0.9622, 0.8342],
[-0.9958, 0.0912, -4.6186],
[-0.8747, -0.4847, 4.7741],
[-0.5495, 0.8355, 3.4044],
[-0.7146, 0.6996, 4.2841],
[-0.7128, -0.7014, -3.7148],
[-0.9915, 0.1303, 4.4200],
[-0.9358, -0.3526, -4.2585]]))
I am getting this error message
-> 1725 self._target_samples = to_tensor(true_samples)
1726 self._steps = []
/content/data_gen.py in to_tensor(x)
1368 else:
1369 print(f"X:{x} and X's type{type(x)}")
-> 1370 return torch.tensor(x).to(device=device, dtype=torch.float32)
X:<generator object to_tensor.<locals>.<genexpr> at 0x7f380235d6d0> and X's type<class 'generator'>
RuntimeError: Could not infer dtype of generator
Any suggestion, why I am getting this error?
Solution
The expression (torch.from_numpy(item).to(device=device, dtype=torch.float32) for item in x)
isn't creating a tuple, it's a generator expression. Since it's in a case where you test for tuples, I suspect you wanted a tuple instead of a generator. Try:
elif isinstance(x, tuple):
return tuple(torch.from_numpy(item).to(device=device, dtype=torch.float32) for item in x)
Answered By - Blckknght
0 comments:
Post a Comment
Note: Only a member of this blog may post a comment.