ML models in Flask
Does this situation sound familiar to you?
- You are a data scientist, you developed an ML model in Python (using PyTorch, TensorFlow, or something like that), and you'd like your users to interact with it,
- You would like to make either an API or a web interface to your model,
- Your model is big enough (and therefore slow) that you would prefer not to load it from scratch every time a user wants to use it, and
- You know a thing or two about servers, but you don't have a deep background, you don't have the time and/or patience to get into it, you don't have the proper server administrator rights, or a combination of all three.
If that's your situation, this post is for you.
Flask is a Python package for the quick and easy creation of APIs that you can use to serve model predictions over the internet. And if your users need a GUI, Dash is a software package built on top of Flask that allows you to quickly create web interfaces - if you are familiar with R, Dash has been described as "ShinyApps for Python".
Typically you would use a "real" web server (Apache, NGINX, etc) to do the heavy lifting, but this post focuses on how to use Flask alone to quickly return results generated from an ML model. In particular, we will focus on how to keep a model in memory between calls so you don't need to restart your model at every turn.
WARNING: Flask is not designed to work this way. The Flask documentation itself tells you not to use their integrated web server for anything other than testing, and if you blindly expose this code to the internet things can get ugly. It will also be much slower than using a proper web server. And yet, I am painfully aware that sometimes you don't have the resources to do things right, and people telling you "there's a way but I'm not telling you because it's ugly" doesn't help. So remember: this solution is ideal for situations where you have low traffic, ideally inside an intranet and/or behind a firewall, and you don't have the technical help you'd need to do it right. But be aware of its limitations!
Method 1: global variables for simple models
Let's start with the simplest of web servers. This code exposes a single
API endpoint /helloworld
that receives a name and returns a greeting:
from flask import Flask
app = Flask(__name__)
@app.route('/helloworld/<name>')
def hello_world(name):
return f'Hello, {name}'
if __name__ == '__main__':
app.run(debug=True)
If you send a request to http://localhost:5000/helloworld/Test
, you would
get Hello, Test
as a result.
Let's say that you now want to return a counter of how many times you have received a request - every time you get a request you simply increase a counter, and then you return that counter. One simple solution is using a global variable, like so:
from flask import Flask
counter = 0
app = Flask(__name__)
@app.route('/helloworld/<name>')
def hello_world(name):
global counter
counter += 1
return f'Hello, {name}, you are request number {counter}'
if __name__ == '__main__':
app.run(debug=True)
This code does work only in Linux (I think), but under some circumstances it could be all you need - all that would remain is for you to replace the variable initialization with code that loads your model into memory. You would use this method, for instance, when you need to perform a task with a long enough startup time, such as parsing a long list of JSON files. If that's your case, you can leave this server running like so:
- Disable debug and open the server to the world by changing the
app.run
line toapp.run(debug=False, host='0.0.0.0')
. - Install the
screen
utility (tmux
is also good), start it typing simplyscreen
in your console, run your server (python script_name.py
) and leave the server running in the background (press Ctrl+A+D). The server will keep running until the computer is restarted.
Unfortunately for some of you, this solution doesn't work under Windows nor does it work if you use a "real" web server instead of the one provided with Flask. More important, it also tends to fail when using some ML libraries that are not happy with the parallel simultaneous access. If that's your case, your best solution is to create a sub-process (ugh) and communicate with your model via IPC (double ugh).
Method 2: Inter-Process Communication (IPC)
Before we jump into the code, we need to understand who is going to talk to
whom and how. It goes as follows: the ML model will run in its own process,
which we'll call the ML-process
. This process
can only be reached via a multiprocessing queue, a data
structure where you put multiple elements which are later retrieved
in the same order in which they were inserted and such that it
can be shared with multiple processes.
Whenever you make a request to the API, Flask creates a new process that we
will call a request-process
. The first thing that this process does is
to open a Pipe. You can think of a pipe like a special pair of telephones that
can only talk to each other and where sound is only emitted when someone is
listening - you can talk for hours into the receiver, but nothing will come out
of the other end until someone listens (in which case they'll get all of your
talking at once) or until the pipe is full. Whenever a request-process needs
to perform a request to the ML-process, it does so as follows:
- As we said above, the request-process opens a Pipe. It has two ends which we'll call the 'source' and 'target' ends. Remember, though, that despite their name communication can flow in both directions.
- The request-process puts some data in the 'source' end of the pipe. Whenever someone picks up the 'target' end of the pipe they'll receive this data.
- Next, the 'target' end of the pipe is put in the multiprocessing queue. If we stick to our analogy, it would be the equivalent of having two cell phones, putting one of them in a box, and mailing it to another person.
- And now, we wait.
The ML-process is constantly monitoring the queue, and it will eventually receive the 'target' end of the pipe that we put in the queue. I say "eventually" because other processes are also trying to talk to the ML-process, and therefore every process has to wait for their turn. In our analogy, it is the equivalent of a person receiving package after package, each one containing a cell phone. Once the ML-process receives our 'target' end of the pipe it extracts the data, processes it, and puts the result back into the pipe using the 'target' end it received earlier. This result is then sent back via the pipe, where our request-process retrieves it and where it can be served back to the user that made the original request.
The following code does exactly that:
from flask import Flask
from multiprocessing import Process, SimpleQueue, Pipe
# This is the global queue that we use for communication
# between the ML-process and the request-processes
job_queue = SimpleQueue()
app = Flask(__name__)
# This is the process that will run the server
class MLProcess(Process):
def init(self, queue):
super(MLProcess, self).__init__()
self.queue = queue
# The slow initialization code should come here.
# For this example, we just create a really bad cache
self.cache = dict()
def run(self):
# Query the end of the pipe until we tell it to stop
stop = False
while not stop:
# Receive the next message
incoming = self.queue.get()
if incoming == 'shutdown':
# We got the magic value that tells us to stop.
# Make sure this value doesn't happen by accident!
stop = True
else:
# `incoming` is a pipe and therefore I can read from it
data = incoming.recv()
# Do something with the data. In this case, we simply
# convert it to lower case and store it in the cache,
# but you would probably call an ML model here
if data not in self.cache:
self.cache[data] = data.lower()
# Send the result back to the process that requested it
incoming.send(self.cache[data])
# If your model requires any shutdown code, you would place it here.
pass
# This is a normal API endpoint that will communicate with the ML-process
@app.route('/helloworld/<name>')
def hello_world(name):
# Create both ends of a pipe
my_end, other_end = Pipe()
# Send the data through my end
my_end.send(name)
# Send the other end of the pipe via the queue
job_queue.put(other_end)
# This process will now wait forever for a reply
# to come via its own end of the pipe
result = my_end.recv()
# Return the result from the model
return 'Hello, {}'.format(result)
if __name__ == '__main__':
ml_process = MLProcess(job_queue)
ml_process.start()
app.run(debug=True)
job_queue.put('shutdown')
ml_process.join()
This code works well as long as there is perfect
communication between all moving parts. If the ML-process hangs up, for
instance, then no more data will be returned and all request-processes will
keep waiting forever for a reply that will never come. The same will happen
if you send the pipe to the server but you don't put any data in it. You can
mitigate these problems by using the poll
method of a Pipe (which looks
whether there's any data and returns immediately), but you should be aware
that synchronization errors are both common and mean to debug.
Note also that we have a special value that we use for instructing the ML-process to shut down - this is necessary to ensure that we clean up everything before exiting our program, but make sure no client can send this special value by accident!
Final thoughts
Is this a good idea? Probably not - if the developers of Flask themselves tell you not to do something, then "don't do it" sounds like solid advice. And we both know that there's nothing more permanent than a temporary solution.
Having said that, and as far as horrible hacks go, I like this one: if you are a data scientist then you are not here to learn how web servers work nor to have long discussions with your system administrator on what CGI means. You are here to get stuff done and getting your ML model in front of users as fast as possible is a great way to achieve that.
I used to know a way to extend this method to work with Apache (you know, a "real" web server) but I honestly can't remember it right now. If you need some help with that then reach out to me and I'll try to figure it out.