Subtitles section Play video Print subtitles KRZYSZTOF OSTROWSKI: My name is Chris. I'm leading the TensorFlow Federated in Seattle. And I'm here to tell you about federated learning and the platform we've built to support it. There are two parts of this talk. First I'll talk about federated learning, how it works and why, and then I'll switch to talk about the platform. All right, let's do it. And this is a machine learning story, so it begins with data, of course. And today, the most exciting data out there is born decentralized on billions of personal devices, like cell phones. So how can we create intelligence and better products from data that's decentralized? Traditionally, what we do is that there's a server in the clouds that is hosting the machine learning model in TensorFlow, and clients all talk to it to make predictions on their behalf. And as they do, the client data accumulates on the server right next to the model. So model and data, that's all in one place. Very easy. We can use traditional techniques that we all know. And what's also great about this scenario is that the same model is exposed to data from all clients and so-- pushing millions of clients, and so it's very efficient. All right. If it's so good, why change that, right? Well actually, in some applications, it's not so great. First it doesn't work offline. There's high latency, so applications that need fast turnaround may not work. All these network communications consuming battery life and bandwidth. And some data is too sensitive, so collecting is not an option, or too large. Some sensitive data could be large. OK. What can we do? Maybe we go to the complete other extreme. So ditch the server in the clouds. Now each client is its own client bubble, right? It has its own TensorFlow run time, its own model. And it's training. It's grinding over its data to train and doesn't communicate with anything. So now, of course, nothing leaves the device. None of the concerns from the preceding slide apply, but you have other problems. A single client just doesn't have enough data, very often. It doesn't have enough data to create a good model on its own. So this doesn't always work. What if we bring the several back, but the clients are actually only receiving data from the server? Could that work? So if you have some proxy data on the server that's similar to the on-device data, you could use it. You could pre-train the model on the server, then deploy it to clients, and then let it potentially evolve further. So that could work. Except, very often, there's no good proxy data or not enough of it for the kinds of on-device data you're interested in. A second problem is that this here, the intelligence we're creating is kind of frozen in time, in the sense that, as I mentioned, clients won't be able to do a whole lot on their own. And why does it matter? And here's one concrete example from actual production application. Consider a smart keyboard that's trying to learn to autocomplete. If you train a model in the server and deploy it, now suddenly millions of people start using a new word, what happens? You'd think, hey, it's a strong signal, millions of people. But if you're not one of those millions, your phone has no clue, right? And so it could take a lot of punching into that phone to make it notice that something new has happened, right? So yeah, this is not what we want. We really need the clients so somehow contribute back towards the common good so they can all benefit. Federated learning is one way to do that. Here we start with initial model provided by the server. This one is not pre-trained. We don't assume we have proxy data. It doesn't matter. It can be just 0s. So we send it to the client. The client now trains it locally on its own data. And this is more than just one step of gradient descent, but it's also now training to convergence. Typically, you would just make a few passes over the data on the clients and then produce a locally trained model and send it to the server. And now all the clients are training independently, but they all use the same initial model to start with. And the server's job is to orchestrate this process to make it happen and produce the same-- feed the same initial model to all the clients. So now once the server collects the locally trained models from clients, it aggregates them into a so-called federated model. And typically what we do is simply average the model parameters across all clients. So the server just adds the numbers and that's it. So this federated model, it has been influenced by data from our clients, right? Because it's been influenced by the client models, and those, in turn, have been influenced by client data. So we do get those benefits of scale in this scenario, so that's great. But there's one question. What happens to privacy? So let's look at this closely. First, client data never left the device. Only the models trained on this data was shared. So next, the server does not retain, store, any of the client models. It simply adds them up and then throws them away. It deletes them, right? So they are ephemeral. But here they're asking how they know that this is what the server is doing. Maybe the server is secretly, somehow, logging something on site. So there are cryptographic protocols that we can use to ensure that that's all legit. So with those protocols, the server will only see the final result of aggregation that will not have access to any of your client contributions. And we use those in practice, so hopefully to put your mind at rest. So the server only ever sees the final aggregate. You can still wonder how do we know that that doesn't contain anything sensitive. So this is where you would use differential privacy. In a nutshell, each client keeps its updates and adds a little bit of noise. So once the final aggregate emerges of the server, there's enough noise to sort of mask out any of the individual contributions, but there is still enough signal to make progress. So not to get too much into the detail, but this is also a technique we use in production. Differential privacy is an established and a commonly used way to provide anonymity. If you have any more concerns, I'll be happy to discuss them offline. So how does it work in practice? Firstly, it's not enough to just do it once. So once you produce a federated model, you'll feed it back on the server as an initial model for the next round, then execute many thousands of rounds, potentially. That's how long it takes to converge. And so in this scenario, both clients and server have a role. Clients are doing all the learning. That's where all the machine learning sits. And server is just orchestrating the process, aggregating and also providing continuity to this process as we move from one round to another, because the server is what carries the state between rounds. And to drill into this a little bit more, in the practical applications, clients are not all available at the same time.