Issue
I'm trying to train a net where I'm using two generators, one for training and one for validation. These are simply to functions that yield samples indefensibly.
I get the following error at the very end of the validation:
File "/home/ubuntu/tensorflow/lib/python3.5/site-packages/numpy/lib/function_base.py",
line 1142, in average "Axis must be specified when shapes of a and weights "
I looked into the code, the function training_generator
in keras.engine
contains the following line
averages.append(np.average([out[i] for out in outs_per_batch], weights=batch_sizes))
Looking at the definition of np.average
, the function requires axis
when weights and array are not the same length. I debugged the code, and by placing axis=0
or np.squeeze
over the out[i]
it """"works"""", only to stop few lines after when it collects summary stats of the validation. I can't stop thinking that there is an error somewhere else in my code.
This is my generator
def batch_generator(batch_size, folder):
files = listdir(folder)
print("Folder " + folder + " with " + str(len(files)) + " files.")
np.random.shuffle(files)
while True:
np.random.shuffle(files)
for i in range(batch_size, len(files), batch_size):
batch = files[(i-batch_size):(i)]
batch = tensor_generator(folder, files=batch)
yield (batch, batch)
def tensor_generator(folder, files=None):
if files is None:
files = listdir(folder)
verbose = len(files)>100
if verbose:
pbar = tqdm(total=len(files), unit='img')
tensor = []
for f in files:
f = SimpleITK.ReadImage(join(folder, f))
f = SimpleITK.GetArrayFromImage(f)
f = (f + 1000)/4000
tensor.append(f)
if verbose: pbar.update(1)
if verbose: pbar.close()
return np.stack(tensor, axis=0)
and this is the fit function
self.autoencoder.fit_generator(
generator=x_train,
steps_per_epoch=iters,
epochs=epochs,
callbacks=[log, rop],
validation_data=x_test,
validation_steps=10)
Any idea what's wrong?
Solution
I encountered the same problem. Even though i don' know what account for this issue, i have resolved this odd problem.
You only need to change your code validation_data=x_test
to validation_data=next(x_test)
. That means you only need add next()
on your validation data generator.
Answered By - Taylor Mei
0 comments:
Post a Comment
Note: Only a member of this blog may post a comment.