Working with AI Models in a distributed compute infrastructure
A fork-safe Python implementation of the Token Bucket Algorithm on MongoDB
Abstract
In this article, a fork-safe Python implementation of the Token Bucket Algorithm is explored, focusing on a scenario where the token bucket is shared among different workers using a central database. More specifically, a MongoDB database, utilizing the non-fork-safe PyMongo connector, will be employed to illustrate how this goal can be achieved. Forking is managed using a pool of Celery workers, which can be started on (m)any machine(s).
Introduction
When working with production-grade AI systems, one of the key challenges is managing the rate limits imposed by various AI APIs. These limits are essential for ensuring fair usage and maintaining the stability of the services, but they can also become a bottleneck for your application if not handled correctly.
Consider this scenario: Your AI-powered application is built initially as a small proof-of-concept and has now passed all relevant stakeholders, ready to scale from the project team to the larger organization. The app is gaining popularity among colleagues, and more users are interacting with its features. As usage increases, you start hitting the rate limits of your AI APIs, e.g. the Large Language Model (LLM) endpoint. Requests are throttled, performance suffers, and your users experience delays or failures. This isn’t just a hypothetical situation — it’s a common issue faced by developers integrating AI APIs.
Or consider an implementation where several AI Agents collaborate to reach a common goal. If these agents overwhelm the AI endpoints, the entire swarm could fail, resulting in no meaningful outcomes.
To effectively manage these rate limits and ensure a smooth user experience, you need a robust solution. The Token Bucket Algorithm is one such approach that provides a controlled and predictable way to handle request rates. In this article, we’ll explore how to implement the Token Bucket Algorithm using Python Celery and MongoDB.
We’ll dive into the details of the Token Bucket Pattern, explain why it’s well-suited for AI API rate limiting, and provide a step-by-step guide to implementing it in your application. By the end of this article, you’ll have a solid understanding of how to maintain your application’s reliability and performance, even under high load.
The Token Bucket Algorithm
The Token Bucket algorithm is a widely-used rate limiting technique that helps control the flow of requests to an API, ensuring system stability and preventing overload. In this algorithm, tokens are added to a bucket at a fixed rate, representing the capacity to make API calls. Each outgoing API call consumes a token from the bucket. If tokens are available, the API call is made immediately; if the bucket is empty, the call must wait until new tokens are added. This method allows for bursts of incoming API calls while maintaining a controlled average outgoing rate, making it ideal for managing variable workloads and ensuring reliable performance in high-traffic environments. You can refer to this Wikipedia Article for a more in-depth explanation.
The technical setup
The technical setup for this project is built arround a Flask API, which is utilized to serve the final product, e.g. a web-app. Python Flask is a good choice for API development due to its lightweight and flexible framework.
Asynchronous API calls are managed through a Celery worker pool, configured with a potentially arbitrary number of workers to ensure scalability and efficient task processing. Tasks are enqueued using RabbitMQ, a robust message broker that reliably stores and distributes tasks among the worker pool. Processed results are stored in a MongoDB database, selected for its ability to scale and efficiently manage large datasets. This architecture is designed to provide a responsive and resilient system, capable of handling high traffic and complex operations with ease.
The challenge
The Flask App is tied to a MongoDB, but PyMongo is not Forksafe . Since the workers are all running independently, potentially on different machines, a Python based Singleton pattern of some shared resource is of little use. Instead, we will use a bucket on the MongoDB instance that can be shared between all workers.
But also this introduces a challenge, since this approach is prone to Race conditions by design. So writing Python code to inspect whether the bucket has a token and then update the bucket will not work here. We need to execute code in an atomic fashion on Mongo withouth letting any other process slip inbetween the inspect and update call.
The solution
At the core of the solution is the observation that the find and update call needs to be done in an atomic fashion — so the document (our Token Bucket) needs to be locked when it’s inspected and updated by one process. This can be achieved with an Aggregation Pipeline.
The data model on the Mongo Collection is rather simple, this is an example:
{
"_id": {
"$oid": "650460f1cab8f432f119c6a4"
},
"bucket_name": "my_task",
"token": 0,
"last_time": {
"$date": "2025–01–20T09:18:09.241Z"
}
}
A single collection can contain many different buckets for different tasks. So “bucket_name” is used as identifier. “token” is the current number of allowed requests, or available tokens in the bucket” and “last_time” is the timestamp of the last time when tokens were added to the bucket.
The core of the TokenBucket implementation is this formulation of the update pipeline:
def get_token(self):
# two cases to distinguish:
# 1) There is at least one token in the bucket:
# retrieve it → reduce the number of tokens by one, and add tokens that should be added since
# the last operation
# 2) There is no token in the bucket, then don't subtract, don't register an update
pipeline = [{
"$set": {
"token":
{"$cond": [
{"$gte": ["$token", 1]},
{"$min": [self.capacity,
{"$add": [
"$token",
{"$floor": {
"$multiply": [{"$dateDiff": {"startDate":"$last_time","endDate":"$$NOW", "unit":"millisecond"}}, self.rate_per_millisec]}},
-1
]}]},
{"$min": [self.capacity, {"$add": [
"$token",
{"$floor": {
"$multiply": [{"$dateDiff": {"startDate":"$last_time","endDate":"$$NOW", "unit":"millisecond"}}, self.rate_per_millisec]}}
]}]}
]},
# last time (refill timestamp) only gets an update, if there are tokens added:
"last_time": {"$cond": [{"$gte": [{"$add": [
"$token",
{"$floor": {
"$multiply": [{"$dateDiff": {"startDate":"$last_time","endDate":"$$NOW", "unit":"millisecond"}}, self.rate_per_millisec]}}
]}, 1]}, "$$NOW", "$last_time"]}
}
}]
# Find and update document
document = self.bucket_collection.find_one_and_update(
{
'bucket_name': self.bucket_name
},
pipeline,
return_document=pymongo.ReturnDocument.BEFORE
)
is_allowed = document['token'] >= 1
logging.info("Token-bucket with " + str(document['token']) + " tokens returns " +str(is_allowed) + " for " + self.bucket_name + " at " + str(datetime.datetime.utcnow()))
return is_allowed
So we test whether there is at least one token in the bucket, if that’s the case we retrieve it and update the number of tokens by multiplying elapsed time with the refill rate, but never more than the capacity. You can set this capacity either to 1, to prevent small bursts, or if the API that is being queried sets a limit of e.g. 4 calls per second and also checks traffic only on a per-second level, this can be set to 4, allowing a small burst of 4 calls per second.
The whole implementation is then a Python class, with the get_token() at the heart of it:
import datetime
import logging
import pymongo
class TokenBucket:
"""
The TokenBucket class is a helper to implement global rate limits that
are adhered to even in a distributed environment
"""
def __init__(self, refill_rate_per_sec, capacity, db_con, bucket_name, init_new=False):
"""
The TokenBucket implements a TokenBucket for outgoing API calls on the shared mongo instance
:param refill_rate_per_sec: How often are tokens refilled? e.g. 1 if it is every second, 1/30 every 30 seconds
:param capacity: How much capacity does the bucket have.
:param db_con: The database connection
:param bucket_name: The name of the token bucket, this must match the name in the collection
:param init_new: Boolean to indicate whether the bucket should be newly setup
"""
self.refill_rate_per_sec = refill_rate_per_sec
self.rate_per_millisec = refill_rate_per_sec / 1000.0
self.capacity = capacity
self.db_con = db_con
self.bucket_name = bucket_name
self.bucket_collection = db_con.token_buckets
if init_new:
self.bucket_collection.update_one(
{'bucket_name' : bucket_name},
{"$set":
{'bucket_name' : bucket_name, 'token': capacity, 'last_time': "$clusterTime"}
},upsert=True)
def get_token(self):
…
Usage
This Token Bucket can then be integrated with a Celery infrastructure, by making a call to the Python class within the shared_task:
import time, os
from celery import shared_task
from pymongo import MongoClient
@shared_task(queue=os.environ.get('CELERY_QUEUE'), soft_time_limit=10)
def get_response_from_api(key=value):
# open the worker own db client:
db_client = MongoClient(
MONGO_SCHEME+'://' +MONGO_USER + ':' + MONGO_PW +
'@'+MONGO_HOST+'/' + MONGO_DB + '?retryWrites=true&w=majority')
db_con = db_client.get_database()
# now init the token bucket:
token_bucket = TokenBucket(refill_rate_per_sec=4, capacity=1, db_con=db_con, bucket_name='my_task')
while not token_bucket.get_token():
time.sleep(0.1) # Wait for 100 milliseconds before trying again
result = your_api_call(params)
The shared task can be then executed for instance in a celery group and you can check the log file for correctness — with the setting above, only 4 calls a second will be going through. All other tasks will need to wait until it’s their turn.
Remember to adjust the soft_time_limit parameter accordingly. If you have very bursty traffic and many distributed workers that try to access a very slow API that does not allow many calls per second, the workers will wait in the queue for a while and might trigger the soft_time_limiter.
Simulation test
To confirm that this approach is working as expected, we can run a small test. Set the refill rate, e.g. to ½, then an outgoing request is allowed every 2 seconds. If we add a print statement to the worker that tries to execute outgoing requests and show the timestamp and whether the worker’s request is held back (not allowed) or goes through (allowed), you see the following in your log-file:
2025–01–21 09:40:09.929536: Not allowed
2025–01–21 09:40:10.061105: Allowed
2025–01–21 09:40:10.741350: Not allowed
2025–01–21 09:40:10.861117: Not allowed
2025–01–21 09:40:10.991974: Not allowed
2025–01–21 09:40:11.114126: Not allowed
2025–01–21 09:40:11.240675: Not allowed
2025–01–21 09:40:11.362580: Not allowed
2025–01–21 09:40:11.481074: Not allowed
2025–01–21 09:40:11.604872: Not allowed
2025–01–21 09:40:11.735626: Not allowed
2025–01–21 09:40:11.869136: Not allowed
2025–01–21 09:40:11.998463: Not allowed
2025–01–21 09:40:12.131688: Not allowed
2025–01–21 09:40:12.262276: Allowed
2025–01–21 09:40:12.604045: Not allowed
2025–01–21 09:40:12.729667: Not allowed
2025–01–21 09:40:12.854335: Not allowed
2025–01–21 09:40:12.999993: Not allowed
2025–01–21 09:40:13.128204: Not allowed
2025–01–21 09:40:13.252679: Not allowed
2025–01–21 09:40:13.378298: Not allowed
2025–01–21 09:40:13.508852: Not allowed
2025–01–21 09:40:13.633653: Not allowed
2025–01–21 09:40:13.756551: Not allowed
2025–01–21 09:40:13.887134: Not allowed
2025–01–21 09:40:14.017464: Not allowed
2025–01–21 09:40:14.143564: Not allowed
2025–01–21 09:40:14.281994: Not allowed
2025–01–21 09:40:14.416012: Allowed
2025–01–21 09:40:14.812309: Not allowed
2025–01–21 09:40:14.933719: Not allowed
2025–01–21 09:40:15.056074: Not allowed
2025–01–21 09:40:15.179546: Not allowed
…
You can see from the timestamps that our bursty traffic is now regularized to a single request every two seconds.
Conclusion
In this article, we have outlined the implementation of a fork-safe Python version of the Token Bucket algorithm, specifically tailored for managing outgoing API calls in a distributed environment. By leveraging a MongoDB database and the Celery worker pool, we demonstrated how to efficiently handle rate limits imposed by AI APIs, ensuring system stability and optimal performance.
The Token Bucket algorithm, provides a robust solution for controlling the flow of outgoing API calls. By implementing this algorithm, we can manage variable workloads and prevent system overload, thus maintaining a smooth user experience even under high traffic conditions.
The use of MongoDB to store the token bucket allows for a centralized and consistent rate limiting mechanism that can be shared across multiple Celery workers, running independently on different machines.
We addressed the challenge of PyMongo’s non-fork-safe nature and the potential for race conditions by employing an atomic update pipeline. This ensures that the token bucket’s state remains consistent and reliable, even when accessed concurrently by multiple workers.
By following the steps outlined in this article, you can implement a similar setup in your own applications, ensuring that your AI-powered systems can scale effectively while adhering to the rate limits of external APIs. This approach not only enhances the reliability and performance of your application but also provides a scalable solution that can grow with your user base.
In summary, we see the combination of Flask, Celery, RabbitMQ, and MongoDB, along with the Token Bucket algorithm, as a powerful and flexible architecture for running production-grade AI systems.