I have a tensorflow model that I am training. The y data is partially generated by using predictions from the neural network, so I cannot simply use model.fit(x,y,...)
This means I need to take some sample of a batch of x data, make a prediction, do something with that prediction, and then use that to generate the y data for the batch, then call model.train_on_batch(x,y)
but I have to do that several times - inside a loop. This leads to a warning about the graph and then MASSIVELY slows down very after only a few iterations of the loop. Even when the x and y data are converted to tensors before being fed to train_on_batch
and they're always the same size.
I have found a workaround using the gradient tape, but it seems like this is more complicated than it needs to be. Is there anything easier?
My work around is:
myloss = tf.keras.losses.MeanSquaredError()
@tf.function
def train_step(x_, y_):
with tf.GradientTape() as tape:
predictions = model(x_, training=True)
loss = myloss(y_, predictions)
gradients = tape.gradient(loss, model.trainable_weights)
model.optimizer.apply_gradients(zip(gradients, model.trainable_weights))
return loss
Then inside my training loop, I generate numpy arrays for x and y, convert them to tf tensors, then pass them to train_step
. And if the x and y data aren't the same size for each iteration of the loop this also causes a slowdown - which isn't so convenient..
Is there an easier way to do this?
If it were just for me, this would be fine, but I have to teach it to some other people, and I don't want to get into how the gradient tape works...
** UPDATE TO INCLUDE MINIMAL EXAMPLE ** Here's some code that brings up the graph retracing issue:
import numpy as np
import tensorflow as tf
model = tf.keras.models.Sequential([tf.keras.layers.Dense(10,activation='relu',input_shape=(3,)),
tf.keras.layers.Dense(1,activation='sigmoid')])
modelpile(loss='MeanSquaredError',optimizer='adam')
x = np.random.random(size=(1000,3))
for i in range(50):
ind = np.random.choice(1000,16,replace=False)
xt = x[ind,:]
yt0 = model.predict_on_batch(xt)
yt = yt0**2
xt = tf.convert_to_tensor(xt,dtype=tf.float32)
yt = tf.convert_to_tensor(yt,dtype=tf.float32)
model.train_on_batch(xt,yt)
This code is not what I actually want to do, but it's a simple example that illustrates the issue.
If I replace the model.train_on_batch(xt,yt)
line with train_step(xt,yt)
I don't get the error and the code runs MUCH faster!