Placeholder Image

Subtitles section Play video

  • SKYE WANDERMAN-MILNE: I'm Skye, for those who don't know me.

  • I've been working on Control Flow in TensorFlow

  • for quite some time, with the help of [? Sarab ?]

  • and many other individuals on the team.

  • And so my goal with this talk is to tell you

  • everything I know about Control Flow that's important.

  • Let's get started.

  • I'm going to start by going over the lay

  • of the land with Control Flow in TensorFlow.

  • So starting with what I'm going to call the Base APIs,

  • tf dot cond and tf dot while loop.

  • So these are the primitives that are

  • exposed in the public Python API for users

  • to access Control Flow.

  • So you have conditional execution and loops.

  • That's it.

  • So you might be wondering, what about all the other Control

  • Flow functions I know and love, like map or case?

  • These are all built on those two base APIs, cond and while loop.

  • They're sort of wrappers around it

  • that add useful functionality.

  • So diving down into the stack, how

  • are these primitives, cond and while, actually implemented?

  • How are they represented in the graph?

  • So in TensorFlow 1.x, we have these low-level Control Flow

  • ops.

  • You might have heard of them, Exit, Enter, Nextiteration,

  • Switch, and Merge.

  • We'll talk more about these in a bit.

  • There's also an alternate representation.

  • That's what Control Flow version 2 is all about.

  • These are the "functional" ops.

  • And I put "functional" in quotes because it's caused

  • some confusion in the past.

  • It's not like pure functional.

  • In the programming sense, they're still state.

  • But they're higher order functions

  • that take functions as input.

  • So now, the cond branches will be represented as functions.

  • So these sort of do the same thing as the low-level ops,

  • but the higher level functionality is all

  • wrapped up into a single op.

  • Moving back up the stack, you might

  • be wondering what's going to happen with TensorFlow 2.0.

  • If you're using Eager execution, you just write Python

  • and you just use Python Control Flow.

  • So if statements, or loops, or list comprehensions,

  • that kind of thing.

  • So there's no arrow connecting it to this graph mode stuff.

  • But if you use tf dot function, maybe some people

  • have heard of Autograph, which is automatically

  • included in tf dot function, and this

  • attempts to take your eager style, just Python code,

  • and convert it into new Python code that calls the TensorFlow

  • graph APIs.

  • So it's going to try to rewrite all

  • that Python Control Flow, your if statements and while loops,

  • into tf dot cond and tf dot while loop.

  • So note that Autograph is just dealing

  • at this abstraction layer of the public TensorFlow API.

  • It doesn't have to dive down into the low-level ops

  • or anything like that.

  • So that's kind of where we're at.

  • We have the 2.0 world where you just write Python that maybe it

  • can get converted into our public Graph APIs, which

  • in turn are producing these various operators in the graph.

  • And one more thing.

  • Right now, in this new implementation

  • of Control Flow, Control Flow version 2,

  • we are still converting the functional ops back

  • into the low-level ops.

  • This is basically a performance optimization.

  • I hope we don't have to do it in the future.

  • That's why it's this faded-dash arrow.

  • So this talk, we're gonna focus on the base API

  • and how it's implemented.

  • I think there'll be another talk about Autographs,

  • so hopefully they can talk about Control Flow there.

  • Maybe there's also talk about Eager execution

  • and the high-level APIs that are not so complicated.

  • So leave that as an exercise to the viewer.

  • OK.

  • So I'm going to start with going over Control Flow

  • v1, the original low-level representation.

  • You might be asking, why?

  • Why do we care at all?

  • So like I showed in the diagram, we

  • do still convert the functional ops to this representation.

  • So this is basically how it's executed today, always.

  • Furthermore, this is still what we use in TensorFlow 1.x.

  • So all 1.x code is using Control Flow v1.

  • Still very much alive.

  • And I hope it provides a little bit of motivation

  • for why we wanted to implement Control Flow using

  • the functional ops.

  • So I'm going to start with these low-level ops.

  • So up here, Switch and Merge are used for conditional execution,

  • this is tf dot cond.

  • Also in while loops to determine whether we need to keep

  • iterating or we're done.

  • And then Enter, Exit, and Nextiteration

  • are just used while loops to manage the iterations.

  • So let's dive in.

  • So Switch and Merge, these are for conditionals.

  • Let's just start with Switch.

  • The idea is you get your predicate tensor in,

  • this is a Boolean, that tells you which conditional branch

  • you want to take.

  • And then it has a single data input, so

  • [INAUDIBLE] some tensor.

  • And it's just going to forward that data input to one

  • of its two outputs depending on the predicate.

  • So in this picture, the predicate must be false.

  • And so the data's coming out of the false output.

  • Merge basically does the opposite.

  • It takes two inputs, but it only expects data

  • from one of its inputs.

  • And then it just outputs a single output.

  • So Switch is how you start your conditional execution,

  • because it's going to divert that data into one branch.

  • And then Merge brings it back together

  • into your mainline execution.

  • It's not conditional anymore.

  • One implementation detail I'm going to mention here

  • is dead tensors.

  • So you might think that nothing is

  • going to come out of the true output of the Switch,

  • but it actually does output a special dead tensor, which

  • is just like a sentinel value.

  • Like a little tiny thing.

  • And dead tensors flow through the whole untaken

  • conditional branch.

  • And eventually, you're going to get a dead tensor

  • into this Merge.

  • It just ignores it and outputs whatever data tensor it gets.

  • So dead tensors are needed for distributed Control

  • Flow, which I'm actually not going to cover in this talk.

  • Because it's kind of technical and I

  • haven't found it that important to know the details of it.

  • It's covered in Yuan's paper.

  • But I'm mentioning dead tensors because they do show up

  • a lot in the execution.

  • Like, if you look at the executor code,

  • there's all this special case for dead tensors.

  • This is what they're about, it's for conditional execution

  • so we can do distribution.

  • SPEAKER 1: And retval zero doesn't help any.

  • SKYE WANDERMAN-MILNE: Oh, yeah.

  • And that famous error message I want

  • to put on a t-shirt, retval zero does not have a value.

  • That means you're trying to read a dead tensor,

  • or it probably means there's a bug.

  • OK.

  • Moving on to the low-level ops we use for while loops.

  • These manage iterations, basically.

  • The concept you need to know about in execution is frames.

  • So you have one frame per execution.

  • And this is what allows the executor

  • to keep track of multiple iterations,

  • and allows a single op to be run multiple times as you

  • do multiple iterations.

  • So a frame defines a name, which is for the whole while loop.

  • And then it also has an iteration number.

  • So the Enter op, that just establishes a new frame.

  • It means we're starting a new while loop.

  • So it just forwards its input.

  • It's like an identity, except that output is now

  • in this new frame.

  • And it has an attribute that's the frame

  • name, starts at frame 0.

  • Exit's the opposite.

  • It just it's like an identity, except it strips

  • the frame from its input.

  • So output is now not in that frame anymore.

  • And these can be stacked.

  • So if you have a bunch of Enters on a bunch of frames,

  • you have a bunch of Exits, it'll pop them off one at the time.

  • The Nextiteration's just the final piece in order

  • to increment that iteration count.

  • This might make more sense when we put it all together,

  • so let's do that.

  • Starting with tf cond again.

  • Let's just work through this.

  • So down here, you have the API call that we're using.

  • So we start, we have this predicate.

  • Note that the predicate isn't actually part of the cond.

  • It happens outside here, but then we feed it

  • into the Switch operators.

  • So the Switches and Merges mark the boundary

  • of the conditional execution, remember.

  • So we'll feed this predicate and then, the true branch

  • is an Add.

  • So we have a Switch for each input,

  • for x and z, which is the external tensors we

  • use in that branch.

  • You'll note that they are only being emitted

  • from the true side of it.

  • So if the false branch is taken, nothing's connected to that.

  • That comes out of Add, then similarly on the other side,

  • we're Squaring y, so we have a Switch for the y.

  • This time, it's going to be emitted from the false branch

  • into the Square.

  • And then, we only have one output

  • from this cond so we have a single Merge.

  • Either the Square or the Add, only one of those

  • is going to actually have data, and that's what will be output.

  • So note that there is a Switch for each input

  • and a Merge for each output, they don't have to match.

  • And in this example, the two branches

  • are using disjoint tensors.

  • But say, we did the Square of x instead of y,

  • then you would have an edge from both

  • the true output and the false output, depending.

  • Go to the Add or the Square.

  • Let's quickly, actually, go over the while loop

  • API, just to make sure we all remember.

  • So the first argument, is a function.

  • That's the predicate function.

  • The second function is the body that we're going to execute.

  • And this is where it's kind of interesting.

  • So you have some inputs, these are called the loop variables,

  • the input to the while loop.

  • And then it's going to output updated versions

  • of those same loop variables.

  • So the inputs of the body match the outputs of the body.

  • Like, same number-type shape of tensors because they're

  • just the updated variables.

  • SPEAKER 2: Can't the shape-type [INAUDIBLE]

  • SKYE WANDERMAN-MILNE: The shape can change, you're right.

  • Same number and types.

  • And then the final, we'd provide some initial input

  • to start it off.

  • So that's the 0, the final argument.

  • And then the output is going to be

  • whatever the final value of the loop variables are.

  • And then the predicate function takes those same loop variables

  • as input but just outputs a Boolean, like,

  • do we continue execution or not?

  • So now we'll start with the inter-node.

  • This, remember, establishes the new frame.

  • We're starting a new while loop.

  • I guess it's called L for loop.

  • We go through a Merge now, kind of reversed from the cond

  • where you start with the Switch.

  • Now you start with a Merge.

  • Because it's choosing is this the initial value

  • or is this the new, updated value from an iteration?

  • That feeds into the predicate.

  • Note that the predicate is inside the while loop

  • now because it has to execute multiple times.

  • The output goes into the Switch node

  • to choose whether if it's false, and we're

  • going to exit the while loop with that exit node.

  • Otherwise, we go into the body, which is an Add in this case,

  • take the output of the body, feed it to the next iteration.