Today, we introduce eager execution for TensorFlow.
Eager execution is an imperative, define-by-run interface where operations are executed immediately as they are called from Python. This makes it easier to get started with TensorFlow, and can make research and development more intuitive.
The benefits of eager execution include:
Eager execution is available now as an experimental feature, so we're looking for feedback from the community to guide our direction.
To understand this all better, let's look at some code. This gets pretty technical; familiarity with TensorFlow will help.
When you enable eager execution, operations execute immediately and return their
values to Python without requiring a Session.run()
. For example, to
multiply two matrices together, we write this:
import tensorflow as tf import tensorflow.contrib.eager as tfe tfe.enable_eager_execution() x = [[2.]] m = tf.matmul(x, x)
It's straightforward to inspect intermediate results with print
or
the Python debugger.
print(m) # The 1x1 matrix [[4.]]
Dynamic models can be built with Python flow control. Here's an example of the Collatz conjecture using TensorFlow's arithmetic operations:
a = tf.constant(12) counter = 0 while not tf.equal(a, 1): if tf.equal(a % 2, 0): a = a / 2 else: a = 3 * a + 1 print(a)
Here, the use of the tf.constant(12)
Tensor
object
will promote all math operations to tensor operations, and as such all return
values with be tensors.
Most TensorFlow users are interested in automatic differentiation. Because different operations can occur during each call, we record all forward operations to a tape, which is then played backwards when computing gradients. After we've computed the gradients, we discard the tape.
If you're familiar with the autograd
package, the API is
very similar. For example:
def square(x): return tf.multiply(x, x) grad = tfe.gradients_function(square) print(square(3.)) # [9.] print(grad(3.)) # [6.]
The gradients_function
call takes a Python function
square()
as an argument and returns a Python callable that computes
the partial derivatives of square()
with respect to its inputs. So,
to get the derivative of square()
at 3.0, invoke
grad(3.0)
, which is 6.
The same gradients_function
call can be used to get the second
derivative of square:
gradgrad = tfe.gradients_function(lambda x: grad(x)[0]) print(gradgrad(3.)) # [2.]
As we noted, control flow can cause different operations to run, such as in this example.
def abs(x): return x if x > 0. else -x grad = tfe.gradients_function(abs) print(grad(2.0)) # [1.] print(grad(-2.0)) # [-1.]
Users may want to define custom gradients for an operation, or for a function. This may be useful for multiple reasons, including providing a more efficient or more numerically stable gradient for a sequence of operations.
Here is an example that illustrates the use of custom gradients. Let's start by looking at the function log(1 + ex), which commonly occurs in the computation of cross entropy and log likelihoods.
def log1pexp(x): return tf.log(1 + tf.exp(x)) grad_log1pexp = tfe.gradients_function(log1pexp) # The gradient computation works fine at x = 0. print(grad_log1pexp(0.)) # [0.5] # However it returns a `nan` at x = 100 due to numerical instability. print(grad_log1pexp(100.)) # [nan]
We can use a custom gradient for the above function that analytically simplifies
the gradient expression. Notice how the gradient function implementation below
reuses an expression (tf.exp(x)
) that was computed during the
forward pass, making the gradient computation more efficient by avoiding
redundant computation.
@tfe.custom_gradient def log1pexp(x): e = tf.exp(x) def grad(dy): return dy * (1 - 1 / (1 + e)) return tf.log(1 + e), grad grad_log1pexp = tfe.gradients_function(log1pexp) # Gradient at x = 0 works as before. print(grad_log1pexp(0.)) # [0.5] # And now gradient computation at x=100 works as well. print(grad_log1pexp(100.)) # [1.0]
Models can be organized in classes. Here's a model class that creates a (simple) two layer network that can classify the standard MNIST handwritten digits.
class MNISTModel(tfe.Network): def __init__(self): super(MNISTModel, self).__init__() self.layer1 = self.track_layer(tf.layers.Dense(units=10)) self.layer2 = self.track_layer(tf.layers.Dense(units=10)) def call(self, input): """Actually runs the model.""" result = self.layer1(input) result = self.layer2(result) return result
We recommend using the classes (not the functions) in tf.layers since they create and contain model parameters (variables). Variable lifetimes are tied to the lifetime of the layer objects, so be sure to keep track of them.
Why are we using tfe.Network
? A Network is a container for layers
and is a tf.layer.Layer
itself, allowing
Network
objects to be embedded in other Network
objects. It also contains
utilities to assist with inspection, saving, and restoring.
Even without training the model, we can imperatively call it and inspect the output:
# Let's make up a blank input image model = MNISTModel() batch = tf.zeros([1, 1, 784]) print(batch.shape) # (1, 1, 784) result = model(batch) print(result) # tf.Tensor([[[ 0. 0., ...., 0.]]], shape=(1, 1, 10), dtype=float32)
Note that we do not need any placeholders or sessions. The first time we pass in the input, the sizes of the layers' parameters are set.
To train any model, we define a loss function to optimize, calculate gradients, and use an optimizer to update the variables. First, here's a loss function:
def loss_function(model, x, y): y_ = model(x) return tf.nn.softmax_cross_entropy_with_logits(labels=y, logits=y_)
And then, our training loop:
optimizer = tf.train.GradientDescentOptimizer(learning_rate=0.001) for (x, y) in tfe.Iterator(dataset): grads = tfe.implicit_gradients(loss_function)(model, x, y) optimizer.apply_gradients(grads)
implicit_gradients()
calculates the derivatives of
loss_function
with respect to all the TensorFlow variables used
during its computation.
We can move computation to a GPU the same way we've always done with TensorFlow:
with tf.device("/gpu:0"): for (x, y) in tfe.Iterator(dataset): optimizer.minimize(lambda: loss_function(model, x, y))
(Note: We're shortcutting storing our loss and directly calling the
optimizer.minimize
, but you could also use the
apply_gradients()
method above; they are equivalent.)
Eager execution makes development and debugging far more interactive, but TensorFlow graphs have a lot of advantages with respect to distributed training, performance optimizations, and production deployment.
The same code that executes operations when eager execution is enabled will construct a graph describing the computation when it is not. To convert your models to graphs, simply run the same code in a new Python session where eager execution hasn't been enabled, as seen, for example, in the MNIST example. The value of model variables can be saved and restored from checkpoints, allowing us to move between eager (imperative) and graph (declarative) programming easily. With this, models developed with eager execution enabled can be easily exported for production deployment.
In the near future, we will provide utilities to selectively convert portions of your model to graphs. In this way, you can fuse parts of your computation (such as internals of a custom RNN cell) for high-performance, but also keep the flexibility and readability of eager execution.
Using eager execution should be intuitive to current TensorFlow users. There are only a handful of eager-specific APIs; most of the existing APIs and operations work with eager enabled. Some notes to keep in mind:
tf.data
for input processing, you should. It's
easier to use and usually faster. For help, see this
blog post and the documentation
page.
tf.layer.Conv2D()
or Keras
layers; these have explicit storage for variables.
tfe.enable_eager_execution()
, it cannot be
turned off. To get graph behavior, start a new Python session.This is still a preview release, so you may hit some rough edges. To get started today:
There's a lot more to talk about with eager execution and we're excited… or, rather, we're eager for you to try it today! Feedback is absolutely welcome.