7c0h

ML models in Flask

Does this situation sound familiar to you?

  • You are a data scientist, you developed an ML model in Python (using PyTorch, TensorFlow, or something like that), and you'd like your users to interact with it,
  • You would like to make either an API or a web interface to your model,
  • Your model is big enough (and therefore slow) that you would prefer not to load it from scratch every time a user wants to use it, and
  • You know a thing or two about servers, but you don't have a deep background, you don't have the time and/or patience to get into it, you don't have the proper server administrator rights, or a combination of all three.

If that's your situation, this post is for you.

Flask is a Python package for the quick and easy creation of APIs that you can use to serve model predictions over the internet. And if your users need a GUI, Dash is a software package built on top of Flask that allows you to quickly create web interfaces - if you are familiar with R, Dash has been described as "ShinyApps for Python".

Typically you would use a "real" web server (Apache, NGINX, etc) to do the heavy lifting, but this post focuses on how to use Flask alone to quickly return results generated from an ML model. In particular, we will focus on how to keep a model in memory between calls so you don't need to restart your model at every turn.

WARNING: Flask is not designed to work this way. The Flask documentation itself tells you not to use their integrated web server for anything other than testing, and if you blindly expose this code to the internet things can get ugly. It will also be much slower than using a proper web server. And yet, I am painfully aware that sometimes you don't have the resources to do things right, and people telling you "there's a way but I'm not telling you because it's ugly" doesn't help. So remember: this solution is ideal for situations where you have low traffic, ideally inside an intranet and/or behind a firewall, and you don't have the technical help you'd need to do it right. But be aware of its limitations!

Method 1: global variables for simple models

Let's start with the simplest of web servers. This code exposes a single API endpoint /helloworld that receives a name and returns a greeting:

from flask import Flask
app = Flask(__name__)

@app.route('/helloworld/<name>')
def hello_world(name):
    return f'Hello, {name}'

if __name__ == '__main__':
    app.run(debug=True)

If you send a request to http://localhost:5000/helloworld/Test, you would get Hello, Test as a result.

Let's say that you now want to return a counter of how many times you have received a request - every time you get a request you simply increase a counter, and then you return that counter. One simple solution is using a global variable, like so:

from flask import Flask
counter = 0
app = Flask(__name__)

@app.route('/helloworld/<name>')
def hello_world(name):
    global counter
    counter += 1
    return f'Hello, {name}, you are request number {counter}'

if __name__ == '__main__':
    app.run(debug=True)

This code does work only in Linux (I think), but under some circumstances it could be all you need - all that would remain is for you to replace the variable initialization with code that loads your model into memory. You would use this method, for instance, when you need to perform a task with a long enough startup time, such as parsing a long list of JSON files. If that's your case, you can leave this server running like so:

  • Disable debug and open the server to the world by changing the app.run line to app.run(debug=False, host='0.0.0.0').
  • Install the screen utility (tmux is also good), start it typing simply screen in your console, run your server (python script_name.py) and leave the server running in the background (press Ctrl+A+D). The server will keep running until the computer is restarted.

Unfortunately for some of you, this solution doesn't work under Windows nor does it work if you use a "real" web server instead of the one provided with Flask. More important, it also tends to fail when using some ML libraries that are not happy with the parallel simultaneous access. If that's your case, your best solution is to create a sub-process (ugh) and communicate with your model via IPC (double ugh).

Method 2: Inter-Process Communication (IPC)

Before we jump into the code, we need to understand who is going to talk to whom and how. It goes as follows: the ML model will run in its own process, which we'll call the ML-process. This process can only be reached via a multiprocessing queue, a data structure where you put multiple elements which are later retrieved in the same order in which they were inserted and such that it can be shared with multiple processes.

Whenever you make a request to the API, Flask creates a new process that we will call a request-process. The first thing that this process does is to open a Pipe. You can think of a pipe like a special pair of telephones that can only talk to each other and where sound is only emitted when someone is listening - you can talk for hours into the receiver, but nothing will come out of the other end until someone listens (in which case they'll get all of your talking at once) or until the pipe is full. Whenever a request-process needs to perform a request to the ML-process, it does so as follows:

  • As we said above, the request-process opens a Pipe. It has two ends which we'll call the 'source' and 'target' ends. Remember, though, that despite their name communication can flow in both directions.
  • The request-process puts some data in the 'source' end of the pipe. Whenever someone picks up the 'target' end of the pipe they'll receive this data.
  • Next, the 'target' end of the pipe is put in the multiprocessing queue. If we stick to our analogy, it would be the equivalent of having two cell phones, putting one of them in a box, and mailing it to another person.
  • And now, we wait.

The ML-process is constantly monitoring the queue, and it will eventually receive the 'target' end of the pipe that we put in the queue. I say "eventually" because other processes are also trying to talk to the ML-process, and therefore every process has to wait for their turn. In our analogy, it is the equivalent of a person receiving package after package, each one containing a cell phone. Once the ML-process receives our 'target' end of the pipe it extracts the data, processes it, and puts the result back into the pipe using the 'target' end it received earlier. This result is then sent back via the pipe, where our request-process retrieves it and where it can be served back to the user that made the original request.

The following code does exactly that:

from flask import Flask
from multiprocessing import Process, SimpleQueue, Pipe

# This is the global queue that we use for communication
# between the ML-process and the request-processes
job_queue = SimpleQueue()
app = Flask(__name__)

# This is the process that will run the server
class MLProcess(Process):
    def init(self, queue):
        super(MLProcess, self).__init__()
        self.queue = queue
        # The slow initialization code should come here.
        # For this example, we just create a really bad cache
        self.cache = dict()

    def run(self):
        # Query the end of the pipe until we tell it to stop
        stop = False
        while not stop:
            # Receive the next message
            incoming = self.queue.get()
            if incoming == 'shutdown':
                # We got the magic value that tells us to stop.
                # Make sure this value doesn't happen by accident!
                stop = True
            else:
                # `incoming` is a pipe and therefore I can read from it
                data = incoming.recv()
                # Do something with the data. In this case, we simply
                # convert it to lower case and store it in the cache,
        # but you would probably call an ML model here
                if data not in self.cache:
                    self.cache[data] = data.lower()
                # Send the result back to the process that requested it
                incoming.send(self.cache[data])
        # If your model requires any shutdown code, you would place it here.
        pass

# This is a normal API endpoint that will communicate with the ML-process
@app.route('/helloworld/<name>')
def hello_world(name):
    # Create both ends of a pipe
    my_end, other_end = Pipe()
    # Send the data through my end
    my_end.send(name)
    # Send the other end of the pipe via the queue
    job_queue.put(other_end)
    # This process will now wait forever for a reply
    # to come via its own end of the pipe
    result = my_end.recv()
    # Return the result from the model
    return 'Hello, {}'.format(result)

if __name__ == '__main__':
    ml_process = MLProcess(job_queue)
    ml_process.start()
    app.run(debug=True)
    job_queue.put('shutdown')
    ml_process.join()

This code works well as long as there is perfect communication between all moving parts. If the ML-process hangs up, for instance, then no more data will be returned and all request-processes will keep waiting forever for a reply that will never come. The same will happen if you send the pipe to the server but you don't put any data in it. You can mitigate these problems by using the poll method of a Pipe (which looks whether there's any data and returns immediately), but you should be aware that synchronization errors are both common and mean to debug.

Note also that we have a special value that we use for instructing the ML-process to shut down - this is necessary to ensure that we clean up everything before exiting our program, but make sure no client can send this special value by accident!

Final thoughts

Is this a good idea? Probably not - if the developers of Flask themselves tell you not to do something, then "don't do it" sounds like solid advice. And we both know that there's nothing more permanent than a temporary solution.

Having said that, and as far as horrible hacks go, I like this one: if you are a data scientist then you are not here to learn how web servers work nor to have long discussions with your system administrator on what CGI means. You are here to get stuff done and getting your ML model in front of users as fast as possible is a great way to achieve that.

I used to know a way to extend this method to work with Apache (you know, a "real" web server) but I honestly can't remember it right now. If you need some help with that then reach out to me and I'll try to figure it out.