Issue
While reading this annotated implementation of Diffusion Probabilistic models in PyTorch, I got stuck at understanding this function
def extract(a, t, x_shape):
batch_size = t.shape[0]
out = a.gather(-1, t.cpu())
return out.reshape(batch_size, *((1,) * (len(x_shape) - 1))).to(t.device)
What it's not clear it's the final return statement, what does the *((1,)
mean into reshape
function? Does that asterisk correspond to the unpacking operator? And if yes, how is it used here?
Solution
(1,) * (len(x_shape) - 1))
means to create a tuple with length len(x_shape) - 1
filled with just 1
s
*(...)
means to spread the tuple into arguments
So it ends up being (say len(x_shape)
== 5)
return out.reshape(batch_size, 1, 1, 1, 1).to(t.device)
Answered By - Samathingamajig
0 comments:
Post a Comment
Note: Only a member of this blog may post a comment.