Issue
I came across this on github (snippet from here):
(...)
for epoch in range(round):
for i, data in enumerate(dataloader, 0):
############################
# (1) Update D network: maximize log(D(x)) + log(1 - D(G(z)))
###########################
# train with real
netD.zero_grad()
real_cpu = data[0].to(device)
batch_size = real_cpu.size(0)
label = torch.full((batch_size,), real_label, device=device)
(...)
Would replacing batch_size = real_cpu.size(0)
with batch_size = len(data[0])
give the same effect? (or maybe at least with batch_size = len(real_cpu)
?) Reason why I'm asking is that iirc the official PyTorch tutorial incorporated len(X)
when displaying training progress during the loop for (X, y) in dataloader:
etc. so I was wondering if the two methods are equivalent for displaying the number of 'samples' in the 'current' batch.
Solution
If working with data where batch size is the first dimension then you can interchange real_cpu.size(0)
with len(real_cpu)
or with len(data[0])
.
However when working with some models like LSTMs you can have batch size at second dimension, and in such case you couldn't go with len
, but rather real_cpu.size(1)
for example
Answered By - Alka
0 comments:
Post a Comment
Note: Only a member of this blog may post a comment.