Issue
I'm implementing my own iterator. tqdm does not show a progressbar, as it does not know the total amount of elements in the list. I don't want to use "total=" as it looks ugly. Rather I would prefer to add something to my iterator that tqdm can use to figure out the total.
class Batches:
def __init__(self, batches, target_input):
self.batches = batches
self.pos = 0
self.target_input = target_input
def __iter__(self):
return self
def __next__(self):
if self.pos < len(self.batches):
minibatch = self.batches[self.pos]
target = minibatch[:, :, self.target_input]
self.pos += 1
return minibatch, target
else:
raise StopIteration
def __len__(self):
return self.batches.len()
Is this even possible? What to add to the above code...
Using tqdm like below..
for minibatch, target in tqdm(Batches(test, target_input)):
output = lstm(minibatch)
loss = criterion(output, target)
writer.add_scalar('loss', loss, tensorboard_step)
Solution
The original question states:
I don't want to use "total=" as it looks ugly. Rather I would prefer to add something to my iterator that tqdm can use to figure out the total.
However, the currently accepted answer explicitly states to use total
:
with tqdm(total=len(my_iterable)) as progress_bar:
In fact, the given example is more complicated than it would need to be as the original question did not ask for complex updating of the bar. Hence,
for i in tqdm(my_iterable, total=my_total):
do_something()
is actually sufficient already (as the author, @emem, already noted in a comment).
This question is relatively old (4 years at the time of writing this), yet looking at tqdm's code, one can see that already from the very beginning (8 years ago at the time of writing this) the behavior was to default to total = len(iterable)
in case total
is not given.
Thus, the correct answer to the question is to implement __len__
. Which, as is stated in the question, the original example already implements. Hence, it should already work correctly.
A full toy example to test the behavior can be found in the following (please note the comment above the __len__
method):
from time import sleep
from tqdm import tqdm
class Iter:
def __init__(self, n=10):
self.n = n
self.iter = iter(range(n))
def __iter__(self):
return self
def __next__(self):
return next(self.iter)
# commenting the next two lines disables showing the bar
# due to tqdm not knowing the total number of elements:
def __len__(self):
return self.n
it = Iter()
for i in tqdm(it):
sleep(0.2)
Looking at what tqdm does exactly:
try:
total = len(iterable)
except (TypeError, AttributeError):
total = None
... and since we do not know exactly what @Duane used as batches
, I would think that this is basically just a well hidden typo (self.batches.len()
), which causes an AttributeError
that is caught within tqdm.
If batches
is just a sequence type, then this was probably the intended definition:
def __len__(self):
return len(self.batches)
The definition of __next__
(using len(self.batches)
) also points in this direction.
Answered By - NichtJens
0 comments:
Post a Comment
Note: Only a member of this blog may post a comment.