python - tensorflow repeated calls to train_on_batch - Stack Overflow

admin2025-04-17  2

I have a tensorflow model that I am training. The y data is partially generated by using predictions from the neural network, so I cannot simply use model.fit(x,y,...)

This means I need to take some sample of a batch of x data, make a prediction, do something with that prediction, and then use that to generate the y data for the batch, then call model.train_on_batch(x,y) but I have to do that several times - inside a loop. This leads to a warning about the graph and then MASSIVELY slows down very after only a few iterations of the loop. Even when the x and y data are converted to tensors before being fed to train_on_batch and they're always the same size.

I have found a workaround using the gradient tape, but it seems like this is more complicated than it needs to be. Is there anything easier?

My work around is:

myloss = tf.keras.losses.MeanSquaredError()

@tf.function
def train_step(x_, y_):
    with tf.GradientTape() as tape:
        predictions = model(x_, training=True)
        loss = myloss(y_, predictions)
    gradients = tape.gradient(loss, model.trainable_weights)
    model.optimizer.apply_gradients(zip(gradients, model.trainable_weights))
    return loss

Then inside my training loop, I generate numpy arrays for x and y, convert them to tf tensors, then pass them to train_step. And if the x and y data aren't the same size for each iteration of the loop this also causes a slowdown - which isn't so convenient..

Is there an easier way to do this?

If it were just for me, this would be fine, but I have to teach it to some other people, and I don't want to get into how the gradient tape works...

** UPDATE TO INCLUDE MINIMAL EXAMPLE ** Here's some code that brings up the graph retracing issue:

import numpy as np
import tensorflow as tf

model = tf.keras.models.Sequential([tf.keras.layers.Dense(10,activation='relu',input_shape=(3,)),
                                  tf.keras.layers.Dense(1,activation='sigmoid')])
modelpile(loss='MeanSquaredError',optimizer='adam')

x = np.random.random(size=(1000,3))

for i in range(50):
    ind = np.random.choice(1000,16,replace=False)
    xt = x[ind,:]
    yt0 = model.predict_on_batch(xt)
    yt = yt0**2
    xt = tf.convert_to_tensor(xt,dtype=tf.float32)
    yt = tf.convert_to_tensor(yt,dtype=tf.float32)
    model.train_on_batch(xt,yt)

This code is not what I actually want to do, but it's a simple example that illustrates the issue.

If I replace the model.train_on_batch(xt,yt) line with train_step(xt,yt) I don't get the error and the code runs MUCH faster!

转载请注明原文地址:http://anycun.com/QandA/1744890367a89075.html