As an algorithm engineer, it is inevitable that you will encounter the problem of bringing models online in your daily work. For some less demanding scenarios, you can handle this by utilizing a web framework: for each user request, call the model to infer and return the result. However, this straightforward implementation often fails to maximize the use of the GPU, and is slightly overwhelming for scenarios with high performance requirements.
There are many ways to optimize, and one useful tip is to change from inference for each request to inference for multiple requests at once. Last year, about this time I wrote a small tool to achieve this function and gave it a rather overbearing name InferLight. Honestly, that tool was not very well implemented. Recently, I refactor the tool with reference to Shannon Technology’s Service-Streamer .
This feature seems simple, but in the process of implementation, we can understand a lot of Python asynchronous programming knowledge and feel the parallel computing power of modern GPU.
Architecture
First, to improve the model’s online inference throughput, you should make the inference service asynchronous. For web services, asynchronous means that the program can handle other requests while the model is computing. For Python, asynchronous services can be implemented with good Asyncio-based frameworks, such as Sanic , which I commonly use. Whereas inference is computationally intensive, our goal is to be able to aggregate multiple inference requests, make efficient use of the parallel computing power of the GPU, and be able to return the results of bulk inference to the corresponding requestor correctly.
To achieve the above goal, the following modules are needed
- Front-end service: used to receive requests and return results. It can be various protocols such as Http, PRC, etc. It is an independent process.
- Inference Worker: responsible for model initialization, bulk inference data construction, and inference calculation. It is an independent process.
- Task queue: the front-end service receives the request and sends the calculation task to the task queue; the inference worker listens to the queue and takes out a small batch each time by the model inference
- Result queue: After the inference done, inference worker sends the result to the result queue; the front-end service listens to the queue and gets the inference result
- Result distribution: before sending the task to the task queue, a unique identifier of the task needs to be generated, and the result corresponding to the task is obtained according to the identifier after retrieving the result from the result queue
There are many ways to implement the task queue and result queue, and you can use some mature middleware such as Kafka and Redis. To avoid external dependencies, I chose to use Python’s native multi-process queue this time. The result queue is listened to and distributed through a sub-thread of the front-end service process.
Implementation
The inference worker is relatively simple. Since there are a variety of models to load and data processing steps, I designed the inference worker as a base class that is inherited and implements specific methods when used.
class BaseInferLightWorker:
def __init__(self, data_queue:mp.Queue, result_queue:mp.Queue,
model_args:dict,
batch_size=16, max_delay=0.1,
ready_event=None) -> None:
self.data_queue = data_queue
self.result_queue = result_queue
self.batch_size = batch_size
self.max_delay = max_delay
self.logger = logging.getLogger('InferLight-Worker')
self.logger.setLevel(logging.DEBUG)
self.load_model(model_args)
# Inform parent process when model loaded
if ready_event:
ready_event.set()
def run(self):
self.logger.info('Worker started!')
while True:
data, task_ids = [], []
since = time.time()
for i in range(self.batch_size):
try:
# get data form data queue
d = self.data_queue.get(block=True, timeout=self.max_delay)
task_ids.append(d[0])
data.append(d[1])
self.logger.info('get one new task')
except Empty:
pass
if time.time()-since>=self.max_delay:
break
if len(data)>0:
start = time.perf_counter()
batch = self.build_batch(data)
results = self.inference(batch)
end = time.perf_counter()
time_elapsed = (end-start)*1000
self.logger.info(f'inference succeeded. batch size: {len(data)}, time elapsed: {time_elapsed:.3f} ms')
# write results to result queue
for (task_id, result) in zip(task_ids, results):
self.result_queue.put((task_id, result))
def build_batch(self, requests):
raise NotImplementedError
def inference(self, batch):
raise NotImplementedError
def load_model(self, model_args):
raise NotImplementedError
@classmethod
def start(cls, data_queue:mp.Queue, result_queue:mp.Queue, model_args:dict, batch_size=16, max_delay=0.1,ready_event=None):
w = cls(data_queue, result_queue, model_args, batch_size, max_delay, ready_event)
w.run()
Along with this is a Wrapper class used in the front-end service to do the request receiving, result collection and distribution of inference requests.
import asyncio
import logging
import multiprocessing as mp
import threading
import uuid
from queue import Empty
from cachetools import TTLCache
from .data import InferStatus, InferResponse
class LightWrapper:
def __init__(self, worker_class, model_args: dict,
batch_size=16, max_delay=0.1) -> None:
# setup logger
self.logger = logging.getLogger('InferLight-Wrapper')
self.logger.setLevel(logging.INFO)
# save results in a TTL cache
self.result_cache = TTLCache(maxsize=10000, ttl=5)
self.mp = mp.get_context('spawn')
self.result_queue = self.mp.Queue()
self.data_queue = self.mp.Queue()
# start inference worker process
self.logger.info('Starting worker...')
worker_ready_event = self.mp.Event()
self._worker_p = self.mp.Process(target=worker_class.start, args=(
self.data_queue, self.result_queue, model_args, batch_size, max_delay, worker_ready_event
), daemon=True)
self._worker_p.start()
# wait at most 30 seconds
is_ready = worker_ready_event.wait(timeout=30)
if is_ready:
self.logger.info('Worker started!')
else:
self.logger.error('Failed to start worker!')
# start the result collecting thread
self.back_thread = threading.Thread(
target=self._collect_result, name="thread_collect_result")
self.back_thread.daemon = True
self.back_thread.start()
def _collect_result(self):
# keep reading result queue
# write result to cache with task_id as key
self.logger.info('Result collecting thread started!')
while True:
try:
msg = self.result_queue.get(block=True, timeout=0.01)
except Empty:
msg = None
if msg is not None:
(task_id, result) = msg
self.result_cache[task_id] = result
async def get_result(self, task_id):
# non-blocking check result
while task_id not in self.result_cache:
await asyncio.sleep(0.01)
return self.result_cache[task_id]
async def predict(self, input, timeout=2) -> InferResponse:
# generate unique task_id
task_id = str(uuid.uuid4())
# send input to worker process
self.data_queue.put((task_id, input))
try:
# here we set a timeout threshold to avoid waiting forever
result = await asyncio.wait_for(self.get_result(task_id), timeout=timeout)
except asyncio.TimeoutError:
return InferResponse(InferStatus.TIMEOUT, None)
return InferResponse(InferStatus.SUCCEED, result)
Some of the data structures used are defined as follows
from enum import Enum
class InferStatus(Enum):
SUCCEED = 0
TIMEOUT = 1
class InferResponse:
def __init__(self, status: InferStatus, result) -> None:
self.status = status
self.result = result
def succeed(self):
return self.status==InferStatus.SUCCEED
Use Case and Test Result
Here we show how the above components can be used with a sentiment analysis BERT model.
First define the model
class BertModel(nn.Module):
def __init__(self, config) -> None:
super().__init__()
self.config = config
self.bert = AutoModelForSequenceClassification.from_pretrained(config['model'])
self.bert.eval()
self.device = torch.device('cuda' if config.get('use_cuda') else 'cpu')
self.bert.to(self.device)
def forward(self, inputs):
return self.bert(**inputs).logits
Then inherit BaseInferLightWorker and implement three functions to get a complete Worker class
class MyWorker(BaseInferLightWorker):
def load_model(self, model_args):
self.model = BertModel(model_args)
self.tokenizer = AutoTokenizer.from_pretrained(model_args['model'])
self.device = torch.device('cuda' if model_args.get('use_cuda') else 'cpu')
return
def build_batch(self, requests):
# 这个函数用来构建batch inference的输入
encoded_input = self.tokenizer.batch_encode_plus(requests,
return_tensors='pt',
padding=True,
truncation=True,
max_length=512)
return encoded_input.to(self.device)
@torch.no_grad()
def inference(self, batch):
model_output = self.model.forward(batch).cpu().numpy()
scores = softmax(model_output, axis=1)
# 将整个batch的结果以list形式返回即可
ret = [x.tolist() for x in scores]
return ret
Finally, building services
if __name__=='__main__':
# for convenience,we use a fixed text from Aesop's Fables as input
text = """
A Fox one day spied a beautiful bunch of ripe grapes hanging from a vine trained along the branches of a tree. The grapes seemed ready to burst with juice, and the Fox's mouth watered as he gazed longingly at them.
The bunch hung from a high branch, and the Fox had to jump for it. The first time he jumped he missed it by a long way. So he walked off a short distance and took a running leap at it, only to fall short once more. Again and again he tried, but in vain.
Now he sat down and looked at the grapes in disgust.
"What a fool I am," he said. "Here I am wearing myself out to get a bunch of sour grapes that are not worth gaping for."
And off he walked very, very scornfully.
"""
config = {
'model':"nlptown/bert-base-multilingual-uncased-sentiment",
'use_cuda':True
}
wrapped_model = LightWrapper(MyWorker, config, batch_size=16, max_delay=0.05)
app = Sanic('test')
@app.get('/batch_predict')
async def batched_predict(request):
dummy_input = text
response = await wrapped_model.predict(dummy_input)
if not response.succeed():
return json_response({'output':None, 'status':'failed'})
return json_response({'output': response.result})
app.run(port=8888)
I did some tests with the famous Apache’s ab tool. I started the above app on my HP Z4 Workstation and made sure the worker process was running on a RTX 6000 GPU.
With ab -n 1000 -c 32 http://localhost:8888/batched_predict
, I got the following result.
Concurrency Level: 32
Time taken for tests: 4.019 seconds
Complete requests: 1000
Failed requests: 999
(Connect: 0, Receive: 0, Length: 999, Exceptions: 0)
Total transferred: 202978 bytes
HTML transferred: 111978 bytes
Requests per second: 248.79 [#/sec] (mean)
Time per request: 128.620 [ms] (mean)
Time per request: 4.019 [ms] (mean, across all concurrent requests)
Transfer rate: 49.32 [Kbytes/sec] received
Test result of another straightford implement without batch inference is as follow:
Concurrency Level: 32
Time taken for tests: 10.164 seconds
Complete requests: 1000
Failed requests: 0
Total transferred: 202000 bytes
HTML transferred: 111000 bytes
Requests per second: 98.39 [#/sec] (mean)
Time per request: 325.234 [ms] (mean)
Time per request: 10.164 [ms] (mean, across all concurrent requests)
Transfer rate: 19.41 [Kbytes/sec] received
As you can see, we got about 2.5 times throughput with batch inference! When doing the benchmark, I also observed that the GPU utilization is much higher with batch inference.
I have opened source the InferLight, and it can be found at https://github.com/thuwyh/InferLight. Hope you love it :)