Placeholder Image

Subtitles section Play video

  • KRZYSZTOF OSTROWSKI: My name is Chris.

  • I'm leading the TensorFlow Federated in Seattle.

  • And I'm here to tell you about federated learning

  • and the platform we've built to support it.

  • There are two parts of this talk.

  • First I'll talk about federated learning, how it works and why,

  • and then I'll switch to talk about the platform.

  • All right, let's do it.

  • And this is a machine learning story,

  • so it begins with data, of course.

  • And today, the most exciting data out there

  • is born decentralized on billions of personal devices,

  • like cell phones.

  • So how can we create intelligence and better

  • products from data that's decentralized?

  • Traditionally, what we do is that there's

  • a server in the clouds that is hosting the machine learning

  • model in TensorFlow, and clients all talk to it

  • to make predictions on their behalf.

  • And as they do, the client data accumulates on the server

  • right next to the model.

  • So model and data, that's all in one place.

  • Very easy.

  • We can use traditional techniques that we all know.

  • And what's also great about this scenario

  • is that the same model is exposed to data

  • from all clients and so--

  • pushing millions of clients, and so it's very efficient.

  • All right.

  • If it's so good, why change that, right?

  • Well actually, in some applications,

  • it's not so great.

  • First it doesn't work offline.

  • There's high latency, so applications that need

  • fast turnaround may not work.

  • All these network communications consuming battery life

  • and bandwidth.

  • And some data is too sensitive, so collecting is not

  • an option, or too large.

  • Some sensitive data could be large.

  • OK.

  • What can we do?

  • Maybe we go to the complete other extreme.

  • So ditch the server in the clouds.

  • Now each client is its own client bubble, right?

  • It has its own TensorFlow run time, its own model.

  • And it's training.

  • It's grinding over its data to train

  • and doesn't communicate with anything.

  • So now, of course, nothing leaves the device.

  • None of the concerns from the preceding slide

  • apply, but you have other problems.

  • A single client just doesn't have enough data, very often.

  • It doesn't have enough data to create a good model on its own.

  • So this doesn't always work.

  • What if we bring the several back,

  • but the clients are actually only receiving data

  • from the server?

  • Could that work?

  • So if you have some proxy data on the server that's

  • similar to the on-device data, you could use it.

  • You could pre-train the model on the server,

  • then deploy it to clients, and then let

  • it potentially evolve further.

  • So that could work.

  • Except, very often, there's no good proxy data or not enough

  • of it for the kinds of on-device data you're interested in.

  • A second problem is that this here,

  • the intelligence we're creating is

  • kind of frozen in time, in the sense that, as I mentioned,

  • clients won't be able to do a whole lot on their own.

  • And why does it matter?

  • And here's one concrete example from actual production

  • application.

  • Consider a smart keyboard that's trying

  • to learn to autocomplete.

  • If you train a model in the server

  • and deploy it, now suddenly millions of people

  • start using a new word, what happens?

  • You'd think, hey, it's a strong signal, millions of people.

  • But if you're not one of those millions,

  • your phone has no clue, right?

  • And so it could take a lot of punching

  • into that phone to make it notice that something new has

  • happened, right?

  • So yeah, this is not what we want.

  • We really need the clients so somehow contribute back

  • towards the common good so they can all benefit.

  • Federated learning is one way to do that.

  • Here we start with initial model provided by the server.

  • This one is not pre-trained.

  • We don't assume we have proxy data.

  • It doesn't matter.

  • It can be just 0s.

  • So we send it to the client.

  • The client now trains it locally on its own data.

  • And this is more than just one step of gradient descent,

  • but it's also now training to convergence.

  • Typically, you would just make a few passes

  • over the data on the clients and then produce

  • a locally trained model and send it to the server.

  • And now all the clients are training independently,

  • but they all use the same initial model to start with.

  • And the server's job is to orchestrate this process

  • to make it happen and produce the same--

  • feed the same initial model to all the clients.

  • So now once the server collects the locally trained models

  • from clients, it aggregates them into

  • a so-called federated model.

  • And typically what we do is simply average the model

  • parameters across all clients.

  • So the server just adds the numbers and that's it.

  • So this federated model, it has been influenced by data

  • from our clients, right?

  • Because it's been influenced by the client models, and those,

  • in turn, have been influenced by client data.

  • So we do get those benefits of scale in this scenario,

  • so that's great.

  • But there's one question.

  • What happens to privacy?

  • So let's look at this closely.

  • First, client data never left the device.

  • Only the models trained on this data was shared.

  • So next, the server does not retain, store,

  • any of the client models.

  • It simply adds them up and then throws them away.

  • It deletes them, right?

  • So they are ephemeral.

  • But here they're asking how they know that this

  • is what the server is doing.

  • Maybe the server is secretly, somehow,

  • logging something on site.

  • So there are cryptographic protocols

  • that we can use to ensure that that's all legit.

  • So with those protocols, the server

  • will only see the final result of aggregation

  • that will not have access to any of your client contributions.

  • And we use those in practice, so hopefully

  • to put your mind at rest.

  • So the server only ever sees the final aggregate.

  • You can still wonder how do we know that that doesn't contain

  • anything sensitive.

  • So this is where you would use differential privacy.

  • In a nutshell, each client keeps its updates

  • and adds a little bit of noise.

  • So once the final aggregate emerges of the server,

  • there's enough noise to sort of mask out any

  • of the individual contributions, but there is still

  • enough signal to make progress.

  • So not to get too much into the detail,

  • but this is also a technique we use in production.

  • Differential privacy is an established and a commonly used

  • way to provide anonymity.

  • If you have any more concerns, I'll

  • be happy to discuss them offline.

  • So how does it work in practice?

  • Firstly, it's not enough to just do it once.

  • So once you produce a federated model,

  • you'll feed it back on the server as an initial model

  • for the next round, then execute many thousands

  • of rounds, potentially.

  • That's how long it takes to converge.

  • And so in this scenario, both clients and server have a role.

  • Clients are doing all the learning.

  • That's where all the machine learning sits.

  • And server is just orchestrating the process,

  • aggregating and also providing continuity to this process

  • as we move from one round to another,

  • because the server is what carries

  • the state between rounds.

  • And to drill into this a little bit more,

  • in the practical applications, clients

  • are not all available at the same time.