Welcome toVigges Developer Community-Open, Learning,Share
Welcome To Ask or Share your Answers For Others

Categories

0 votes
587 views
in Technique[技术] by (71.8m points)

python 3.x - Problem with tf.while_loop InvalidArgumentError: Index out of range using input dim 0; input has only 0 dims [Op:StridedSlice] name: strided_slice/

I'm trying to parallelize this chunk of code using Tensorflow -

import numpy as np
import tensorflow as tf
import time
#Importing a generic dataset from Keras
(x_train, y_train), (x_test, y_test) = tf.keras.datasets.mnist.load_data(
    path='mnist.npz'
)

#I would like to compare a reference image to a bunch of images

#This is my reference image
x_reference = np.expand_dims(x_train[0],axis = 2)

start = time.time()
for index, image in enumerate(x_train):
    # The tf.image.ssim is a similarity metric
    tf.image.ssim(np.expand_dims(x_train[index],axis = 2), x_reference, 255)
print("Total Time=", time.time() - start)

I've decided to use tf.while_loop for parallelization -

import numpy as np
import tensorflow as tf
#Importing a generic dataset from Keras
(x_train, y_train), (x_test, y_test) = tf.keras.datasets.mnist.load_data(
    path='mnist.npz'
)

#This is my reference image
x_reference = np.expand_dims(x_train[0],axis = 2)
t1 = tf.constant(x_train)
t2 = tf.constant(x_reference)
iters = tf.constant(60000)
def cond(t1, t2, i, iters):
  return tf.less(i, iters)

def body(t1, t2, i, iters):
  return [tf.image.ssim(np.expand_dims(t1[i],axis = 2), t2, 255), t2, tf.add(i,1), iters]

res = tf.while_loop(cond, body, [t1, t2, 0 , iters], parallel_iterations=60000)

However, I am running into errors. I'm quite new with tensorflow and therefore I understand that this code is going to be broken at a range of points. All forms of guidance (even through comments) are highly appreciated!


与恶龙缠斗过久,自身亦成为恶龙;凝视深渊过久,深渊将回以凝视…
Welcome To Ask or Share your Answers For Others

1 Answer

0 votes
by (71.8m points)

I have fixed your code.

There were 2 issues. The main one was your return from the body function. You need to understand that at every iteration of the "loop", the output of the previous iteration while be inject to the body function. In your code, you were return the sim value as the first column in the tupple. So, in the second "iteration", you don't have t1 there anymore. You have just a single value. That was the reason why you were getting that error. You were trying to index the "sim" value and not t1.

The other issue was when calling tf.image.ssim. It expects a batch of images, it basically needs [1,64,64,1] but you were passing [64,64,1]

def body(t1, t2, i, iters, sim):
  a = np.expand_dims(t1[i],axis = 2)
  a = tf.expand_dims(a, axis=0)   #make a batch
  b = tf.expand_dims(t2, axis=0)  #make a batch
  sim = tf.image.ssim(a, b, 255)
  return [t1, t2, tf.add(i,1), iters, sim]

Here is the whole code:

import numpy as np
import tensorflow as tf


#Importing a generic dataset from Keras
(x_train, y_train), (x_test, y_test) = tf.keras.datasets.mnist.load_data(
    path='mnist.npz'
)

#This is my reference image
x_reference = np.expand_dims(x_train[0],axis = 2)
t1 = tf.constant(x_train)

t2 = tf.constant(x_reference)
iters = tf.constant(60000)

def cond(t1, t2, i, iters, sim):
  return tf.less(i, iters)

def body(t1, t2, i, iters, sim):
  a = np.expand_dims(t1[i],axis = 2)
  a = tf.expand_dims(a, axis=0)   #make is a batch
  b = tf.expand_dims(t2, axis=0)  #make is a batch
  sim = tf.image.ssim(a, b, 255)
  return [t1, t2, tf.add(i,1), iters, sim]

res = tf.while_loop(cond, body, [t1, t2, 0 , iters, 0], parallel_iterations=60000)


与恶龙缠斗过久,自身亦成为恶龙;凝视深渊过久,深渊将回以凝视…
Welcome to Vigges Developer Community for programmer and developer-Open, Learning and Share
...