作为一个算法工程师,在日常工作中难免会碰到模型上线的问题。对于一些要求不高的场合,简单找一个web框架实现一下接口就能搞定:对于每个用户请求,调用模型得到结果再返回。但这种朴素的实现往往无法最大化利用GPU,对于性能要求比较高的场景应付起来就略显吃力。

优化的方法有很多,一个增益很大的措施就是把一个请求推理一次改成多个请求一起推理。去年大概也是这个时候我写了一个小工具来实现这个功能,还取了个蛮霸气的名字InferLight,但当时写得并不太好;最近参考香侬科技的Service-Streamer又重构了一版。这个功能看似简单,但是在实现的过程中可以了解很多Python异步编程的知识,感觉收获颇丰,于是写篇短文总结一下。

首先,要提高模型的线上推理吞吐量,应该把推理服务做成异步的。对于web服务来说,异步的意思是当模型在计算的时候它可以处理别的请求。对于Python来说,异步服务可以通过很多优秀的基于Asyncio的框架来实现,例如我常用的Sanic。而推理是计算密集的,也没有什么同步异步的说法,我们的目标就是能够汇聚多个推理请求,高效利用GPU的并行计算能力,并且能将批量推理的结果正确地返回给对应的请求者。

要实现上面的目标,需要以下几个模块

  • 前端服务:用于接收请求、返回结果。可以是Http、PRC等各种协议。是一个独立进程。
  • 推理Worker:负责模型的初始化、批量推理数据构建、推理计算。是一个独立进程。
  • 任务队列:前端服务收到请求之后把计算任务送入任务队列;推理Worker监听该队列,每次取出一个小批量由模型推理
  • 结果队列:推理服务推理完成后将结果送入结果队列;前端服务监听该队列,获得推理结果
  • 结果分发:在将任务送入任务队列前需要生成任务的唯一标识,从结果队列取回结果后根据标识获取到任务对应的结果

其中两个任务队列的实现方式很多,可以通过一些成熟的中间件例如Kafka、Redis等,但为了避免外部依赖,这次我选择使用Python原生的多进程队列。结果队列监听和分发通过前端服务进程的一个子线程来完成。

实现细节

推理服务相对简单,由于各种模型的加载、数据处理步骤千奇百怪,所以我将推理Worker设计成了一个基类,使用时继承它并实现特定方法。

import logging
import multiprocessing as mp
import time
from queue import Empty

class BaseInferLightWorker:

    def __init__(self, data_queue:mp.Queue, result_queue:mp.Queue, 
                 model_args:dict, 
                 batch_size=16, max_delay=0.1,
                 ready_event=None) -> None:
        self.data_queue = data_queue
        self.result_queue = result_queue
        self.batch_size = batch_size
        self.max_delay = max_delay
        self.logger = logging.getLogger('InferLight-Worker')
        self.logger.setLevel(logging.DEBUG)

        self.load_model(model_args)
        
        # 由于模型载入时间较长
        # 加载完成后使用一个event来通知主进程
        if ready_event:
            ready_event.set()

    def run(self):
        self.logger.info('Worker started!')
        while True:
            data, task_ids = [], []
            since = time.time()
            for i in range(self.batch_size):
                try:
                    # 从数据队列获取数据
                    d = self.data_queue.get(block=True, timeout=self.max_delay)
                    task_ids.append(d[0])
                    data.append(d[1])
                    self.logger.info('get one new task')
                except Empty:
                    pass
                if time.time()-since>=self.max_delay:
                    break
            if len(data)>0:
                start = time.perf_counter()
                batch = self.build_batch(data)
                results = self.inference(batch)
                end = time.perf_counter()
                time_elapsed = (end-start)*1000
                self.logger.info(f'inference succeeded. batch size: {len(data)}, time elapsed: {time_elapsed:.3f} ms')
                # 将结果写入结果队列
                for (task_id, result) in zip(task_ids, results):
                    self.result_queue.put((task_id, result))


    def build_batch(self, requests):
        raise NotImplementedError

    def inference(self, batch):
        raise NotImplementedError

    def load_model(self, model_args):
        raise NotImplementedError

    @classmethod
    def start(cls, data_queue:mp.Queue, result_queue:mp.Queue, model_args:dict, batch_size=16, max_delay=0.1,ready_event=None):
        w = cls(data_queue, result_queue, model_args, batch_size, max_delay, ready_event)
        w.run()

与之配合的是一个前端服务中使用的Wrapper类,来完成推理请求的接收、结果收集和分发。

import asyncio
import logging
import multiprocessing as mp
import threading
import uuid
from queue import Empty

from cachetools import TTLCache

from .data import InferStatus, InferResponse


class LightWrapper:

    def __init__(self, worker_class, model_args: dict,
                 batch_size=16, max_delay=0.1) -> None:
        # setup logger
        self.logger = logging.getLogger('InferLight-Wrapper')
        self.logger.setLevel(logging.INFO)
        
        # 用一个5秒自动超时的缓存来保存结果
        self.result_cache = TTLCache(maxsize=10000, ttl=5)

        self.mp = mp.get_context('spawn')
        self.result_queue = self.mp.Queue()
        self.data_queue = self.mp.Queue()

        # 启动推理Worker
        self.logger.info('Starting worker...')
        worker_ready_event = self.mp.Event()
        self._worker_p = self.mp.Process(target=worker_class.start, args=(
            self.data_queue, self.result_queue, model_args, batch_size, max_delay, worker_ready_event
        ), daemon=True)
        self._worker_p.start()
        
        # 最长等待30秒
        is_ready = worker_ready_event.wait(timeout=30)
        if is_ready:
            self.logger.info('Worker started!')
        else:
            self.logger.error('Failed to start worker!')
        
        # 启动收集结果的线程
        self.back_thread = threading.Thread(
            target=self._collect_result, name="thread_collect_result")
        self.back_thread.daemon = True
        self.back_thread.start()

    def _collect_result(self):
        # 在线程中不断读取结果队列
        # 以task_id为key将结果写入到结果缓存中
        self.logger.info('Result collecting thread started!')
        while True:
            try:
                msg = self.result_queue.get(block=True, timeout=0.01)
            except Empty:
                msg = None
            if msg is not None:
                (task_id, result) = msg
                self.result_cache[task_id] = result

    async def get_result(self, task_id):
        # 非阻塞地获取任务的结果
        while task_id not in self.result_cache:
            await asyncio.sleep(0.01)
        return self.result_cache[task_id]

    async def predict(self, input, timeout=2) -> InferResponse:
        # generate unique task_id
        task_id = str(uuid.uuid4())

        # send input to worker process
        self.data_queue.put((task_id, input))
        try:
            # 这里设置了最大等待时间
            result = await asyncio.wait_for(self.get_result(task_id), timeout=timeout)
        except asyncio.TimeoutError:
            return InferResponse(InferStatus.TIMEOUT, None)

        return InferResponse(InferStatus.SUCCEED, result)

其中用到的一些数据结构定义如下

from enum import Enum

class InferStatus(Enum):
  SUCCEED = 0
  TIMEOUT = 1

class InferResponse:

  def __init__(self, status: InferStatus, result) -> None:
      self.status = status
      self.result = result

  def succeed(self):
      return self.status==InferStatus.SUCCEED

使用及测试

这里借用一个Bert做情感识别的模型来看看上面的组件如何使用。

首先定义一个模型

class BertModel(nn.Module):

    def __init__(self, config) -> None:
        super().__init__()
        self.config = config
        self.bert = AutoModelForSequenceClassification.from_pretrained(config['model'])
        self.bert.eval()
        self.device = torch.device('cuda' if config.get('use_cuda') else 'cpu')
        self.bert.to(self.device)

    def forward(self, inputs):
        return self.bert(**inputs).logits

然后继承BaseInferLightWorker,实现三个函数来获得一个完整的Worker类

class MyWorker(BaseInferLightWorker):

    def load_model(self, model_args):
        self.model = BertModel(model_args)
        
        self.tokenizer = AutoTokenizer.from_pretrained(model_args['model'])
        self.device = torch.device('cuda' if model_args.get('use_cuda') else 'cpu')
        return

    def build_batch(self, requests):
        # 这个函数用来构建batch inference的输入
        encoded_input = self.tokenizer.batch_encode_plus(requests, 
                                                         return_tensors='pt',
                                                         padding=True,
                                                         truncation=True,
                                                         max_length=512)
        return encoded_input.to(self.device)

    @torch.no_grad()
    def inference(self, batch):
        model_output = self.model.forward(batch).cpu().numpy()
        scores = softmax(model_output, axis=1)
        # 将整个batch的结果以list形式返回即可
        ret = [x.tolist() for x in scores]
        return ret

最后是构建服务

if __name__=='__main__':
    # 为了方便测试,使用了一个固定的输入文本
    # 出自《伊索寓言》
    text = """
    A Fox one day spied a beautiful bunch of ripe grapes hanging from a vine trained along the branches of a tree. The grapes seemed ready to burst with juice, and the Fox's mouth watered as he gazed longingly at them.
    The bunch hung from a high branch, and the Fox had to jump for it. The first time he jumped he missed it by a long way. So he walked off a short distance and took a running leap at it, only to fall short once more. Again and again he tried, but in vain.
    Now he sat down and looked at the grapes in disgust.
    "What a fool I am," he said. "Here I am wearing myself out to get a bunch of sour grapes that are not worth gaping for."
    And off he walked very, very scornfully.
    """
    
    config = {
        'model':"nlptown/bert-base-multilingual-uncased-sentiment",
        'use_cuda':True
    }
    wrapped_model = LightWrapper(MyWorker, config, batch_size=16, max_delay=0.05)
    
    app = Sanic('test')
    
    @app.get('/batch_predict')
    async def batched_predict(request):
        dummy_input = text
        response = await wrapped_model.predict(dummy_input)
        if not response.succeed():
            return json_response({'output':None, 'status':'failed'})
        return json_response({'output': response.result})

    app.run(port=8888)

我做了一些简单的测试,通过这种方法可以提高近2.5倍的吞吐量,还是相当可观的。当模型更大,输入更长的时候这种提升将更加明显。

代码已经上传Github,InferLight,欢迎围观。