• FastAPI CPU 密集型任务的处理办法

    FastAPI 经常号称是非常快的框架,因为它大量使用 async。对于一些 IO 密集型的任务,是有一些坑要处理的,不然整个 async 就被堵住了。这个处理还是比较简单的,按官网文档的说法,

    如果你的应用程序不需要与其他任何东西通信而等待其响应,请使用async def
    如果你不清楚,使用def就好

    那如果是 CPU 密集型的任务呢?这是工作中同事遇到的问题。

    def的办法就不行了——会把线程池堵住。

    下面 4 种写法,有 2 种都可以解决这个问题,都是进程池的方案。

    from concurrent.futures import ProcessPoolExecutor
    import asyncio
    import time
    
    from fastapi import FastAPI  
    import uvicorn
    
    
    app = FastAPI()  
    
    def fib(n: int) -> int:
        if n == 1 or n == 2:
            return 1
        return fib(n-1) + fib(n-2)
    
    def some_cpu_work() -> int:
        return fib(36)  # 大概1s
    
    
    # 通用配置:20个并发,同时请求
    
    @app.get("/test1")  
    async def test1():
        """ async定义+直接计算    排队 27s  CPU占用1核心 """
        data = some_cpu_work()
        print(data)
        return data  
    
    @app.get("/test2")  
    def test2():
        """ 普通定义+直接计算      自带线程池 依然27s  CPU占用1核心 """
        data = some_cpu_work()
        print(data)
        return data  
    
    #################################### 开一个进程池
    process_pool_executor = ProcessPoolExecutor(max_workers=4)
    
    @app.get("/test3")
    def test3():
        """ 普通定义+进程池   8s CPU占用4核心   """
        task = process_pool_executor.submit(some_cpu_work)
        data = task.result()
        print(data)
        return data
    
    @app.get("/test4")
    async def test4():
        """ async定义+交给进程池  8s CPU占用4核心  """
        data = await asyncio.get_running_loop().run_in_executor(process_pool_executor, some_cpu_work)
        print(data)
        return data
    
    if __name__ == "__main__":
        uvicorn.run(app, host="127.0.0.1", port=8000)

    具体来说,20 个请求同时请求,这 20 个请求在 4 种写法的结果如下:(其中第2种堵住了自带线程池的结果最出乎意料)

    阅读更多…
  • Google Colab 运行其他 Python 版本

    首先说明,这个运行其他 Python 版本并不是影响 jupyter notebook 的 Python 版本,而是在 notebook 中通过!python xxxxx.py执行我上传的脚本所使用的 Python。

    撰写本文时,Colab 的 Python 是 3.7。

    我的项目中用到了一些 3.8、3.9 的新语法,所以没法在 Colab 上面跑,但我也不想把这些语法改成旧的,所以只能在 Colab 上升级了。

    经过一番摸索,发现最靠谱的方法是在 apt 中安装新版 Python,然后全程使用 venv,而不用系统的 Python。系统的 Python 总是遇到 pip 安装库时的各种奇怪问题。

    所以,在 notebook 中的命令如下:

    !sudo apt-get update -y
    !sudo apt-get install python3.9 python3.9-distutils python3.9-venv
    
    !python3.9 --version
    
    !python3.9 -m venv venv

    之后运行 pip 安装第三方库、运行 python xxxx.py 时,都使用如下方式:

    !venv/bin/pip install xxx_library
    !venv/bin/python xxx.py
  • 真 · Python 序列化/反序列化库:marshmallow

    前言

    本文通过实际项目的经历,经过搜索与比较,终于找到了好用的 Python 序列化库:marshmallow。
    它拥有类似 django 的语法,支持详细的自定义设置,适合各种各样的序列化/反序列化情景。

    问题

    我们都知道,Python 自带了一些序列化的库,例如 json、pickle、marshal 等。现在的问题是,要向一个 HTTP 服务器提交一个 POST 请求,附带一个 JSON 作为 payload,假设这个 payload 是这样的:

    {
      "name": "abc",
      "price": 1.23456,
      "date": "2021-06-26"
    }

    那么,它对应的 Python 定义是

    from dataclasses import dataclass
    import datetime
    
    @dataclass
    class Item:
        name: str
        price: float
        date: datetime.date

    所以,这个 POST 请求会是这样发送的:

    import json
    import requests
    
    item = Item('abc', Decimal('1.23456'), '2021-06-26')
    requests.post(url, data=json.dumps(item))

    这样一跑,马上给报错:TypeError: Object of type Item is not JSON serializable。嗯?????怎么报错了?才知道 Python 自带的 json 库不支持序列化一个自定义的 class,仅仅支持那几个 Python 内置的类,如 dict, str, float, list, bool。对此,网上是有一些解决方法,但其实都是妥协之举:如果服务器返回一个 {"name": "abc","price": 1.23456,"date":"2021-06-26"},能方便地反序列化成实例么?

    遇到这种问题,首先就想起“不要自己造轮子”的原则。Python 作为一门非常成熟的语言,就没有什么 pythonic 的解决方案?然后我去百度、Google,似乎还真没有,marshmallow 已经是我能找到的最好的了。

    手上还有 JVM 的项目,可以用 jackson/fastjson 等库,那才是方便啊,类型传进去后,正反序列化一气呵成。

    如果自己给这个类写一个 to_json() 方法,手动构造一个 dict 呢?想想还是不行,如果未来要修改这个类的属性,那么还得对应改。什么?你说连这个类都不要了,post 的时候直接现场构造一个 dict?这绝对是在给自己挖坑啊。

    所以,marshmallow 赶紧学起来。

    Marshmallow

    Schema

    最重要的,是定义 schema。对于刚才的 Item 类,对应的 schema 会这样写:

    @dataclass
    class Item:
        name: str
        price: Decimal
        date: datetime.date
    
    # 对于其他的所有 fields, 可以参考文档
    # https://marshmallow.readthedocs.io/en/stable/marshmallow.fields.html#api-fields
    
    from marshmallow import Schema, fields
    
    class ItemSchema(Schema):
        name = fields.String()
        price = fields.Decimal()
        date = fields.Date()

    如果以前接触过 django,应该对这种写法很熟悉,不过不熟悉也没关系。

    序列化

    然后,就可以用 dump() 序列化了。

    item = Item('abc', Decimal('1.23456'), datetime.date(2021, 6, 26))
    
    schema = ItemSchema()
    result_obj = schema.dump(item)
    print(result_obj)
    
    # 会输出 {'price': Decimal('1.23456'), 'name': 'abc', 'date': '2021-06-26'}

    可以看到,在定义好的 schema 的帮助下,item 实例变成了一个 dict。

    如果想输出一个字符串呢?可以把 schema.dump 换成 schema.dumps(不过由于 price 是一个 Decimal 实例,而自带的 json 不支持 Decimal 会报错,可以考虑另外装一个 simplejson 库来解决。这个问题在后文会再讲到)

    反序列化

    使用 load() 进行反序列化:

    input_data = { 'name': 'abc', 'price': Decimal('1.23456'), 'date': '2021-06-26'}
    
    load_result = schema.load(input_data)
    print(load_result)
    
    # 输出 {'name': 'abc', 'date': datetime.date(2021, 6, 26), 'price': Decimal('1.23456')}

    注意到 date 属性变成了一个 Python 的 datetime.date 实例,这是我们希望的。

    这里只是反序列化成了一个 dict,那么怎样才能返回一个真正的 Item 实例?可以给 schema 添加一个返回 Item 的方法,并打上 post_load 的注解(装饰器)。

    from marshmallow import Schema, fields, post_load
    
    class ItemSchema(Schema):
        name = fields.String()
        price = fields.Decimal()
        date = fields.Date()
    
        @post_load
        def make_item(self, data, **kwargs):
            return Item(**data)
    
    result_obj = schema.dump(item)
    print(result_obj)
    # 返回 Item(name='abc', price=Decimal('1.23456'), date=datetime.date(2021, 6, 26))

    可以,有点样子了。

    数据校验

    既然有了 schema,当然可以顺便在 load 时做校验了。例如输入的字段有没有多,有没有缺,属性的类型对不对,值的范围有没有超,哪些属性可以填 None(null)等等。这里不再赘述,见文档

    自定义字段

    Marshmallow 库自带的一些类型可能不够满足需求,例如想加一个 省份、身份证 什么的,用字符串的话好像感觉欠缺了一点校验。而刚才提到的 Decimal 比较难处理,我觉得也可以用自定义字段来解决。

    例如,可以用 Method 字段来自定义正反序列化时所使用的方法。

    class ItemSchema(Schema):
        name = fields.String()
        price = fields.Method('price_decimal_2_float', deserialize='float_2_decimal')
        date = fields.Date()
    
        @post_load
        def make_item(self, data, **kwargs):
            return Item(**data)
    
        def price_decimal_2_float(self, item: Item):
            return float(item.price)
    
        def float_2_decimal(self, float):
            return decimal.Decimal(str(float))

    这样就可以正常用 dumps(), loads() 了:

    result_str = schema.dumps(item)
    print(result_str)
    # {"date": "2021-06-26", "name": "abc", "price": 1.23456}
    
    input_str = '{"date": "2021-06-26", "name": "abc", "price": 1.23456}'
    load_result = schema.loads(input_str)
    print(load_result)
    # Item(name='abc', price=Decimal('1.23456'), date=datetime.date(2021, 6, 26))

    太棒了,输入的 JSON 的 date 和 price,在反序列化后,自动变成了 datetime.dateDecimal

    小结

    Marshmallow 的核心是 schema,数据类型、校验等都记录在 schema 中,从而支持复杂对象的序列化和反序列化。如果说它有什么缺点,那必须是它仍然需要专门定义一个 schema 才能使用。如果是 JVM 系列的 jackson 等库,可以直接使用 data class 作为 model/schema,额外的配置通过注解等方式引入,这样会少一个对应的 schema 类。marshmallow 的做法使类的数量双倍了。总体来说,它仍然是不造轮子的情况下的好选择。

    至于更深入的用法,还是需要参考官方文档,见文末。

    相关参考

  • 使反向代理后的 Flask 的 url_for() 使用 https

    这个问题困扰我多时,曾尝试多个方案未果,今天终于搜到了正确的解决方法,记录下来。

    我的网络拓扑: 外部 → docker的nginx → docker的gunicorn → flask

    问题

    Flask 代码中的 redirect(url_for('login')) 会跳回 http 的登录页面。

    我也不想给每个 url_for() 添加强制 https 的参数。

    解决

    from werkzeug.middleware.proxy_fix import ProxyFix
    from flask import Flask
    
    app = Flask(__name__)
    app.wsgi_app = ProxyFix(app.wsgi_app)

    相关参考

  • Redis 的 Python 客户端推荐:walrus

    最近在工作中需要在 Python 程序中读写 Redis。之前在自己的项目中,用的是 redis-py ——也是最广为人知的客户端。

    不过,它相当难用,几乎就是原生的 Redis 命令,在大一点的项目中,写一堆 Redis 命令,我估计是受不了的。而我在工作的其他 Java/JVM 项目里,用的是经过抽象封装后的库 Redisson,使用的体验就很舒服。

    那么,在 Python 环境,如果不用 redis-py,用什么库好呢,还是自己造轮子?

    在官网的客户端列表中可以找到,除了 redis-py,Python 另外还有 walrus 这个推荐的客户端。直接点进去,了解到它支持很 pythonic 风格的 Hash、List、Set、Sorted Set 等容器,足以满足我的使用需求,就用它了~

    附其代码样例:

    >>> h = db.Hash('charlie')
    >>> h['age'] = 31
    >>> print(h)
    <Hash "charlie": {'age': '31'}>
  • FIX 协议开发(1):协议介绍及开发方案

    本系列导航

    FIX 协议开发(1):协议介绍及开发方案(本文)
    FIX 协议开发(2):QuickFIX/J 入门
    FIX 协议开发(3):QuickFIX/J 实战经验小结

    本文背景

    公司因业务需要,准备接入 FIX 协议。在调研过程中,我发现中文的 FIX 协议相关资料不太多,准备边学边记录,预计会写 3 篇左右。

    FIX 协议简介

    FIX(Financial Information eXchange Protocol,金融信息交换协议)是由国际FIX协会组织提供的一个开放式协议,目的是推动国际贸易电子化进程,在各类参与者之间,包括投资经理、经纪人、买方、卖方建立起实时的电子化通讯协议。

    FIX协议的目标是把各类证券金融业务需求流程格式化,成为一个可用计算机语言描述的功能流程,并在每个业务功能接口上统一交换格式,方便各个功能模块的连接。

    消息格式

    FIX 协议消息均由多个 “key=value” 组成。其中 key 可以是协议规定的字段,或自定义字段。协议规定的key可查询 FIX 协议字典,【不同版本的 FIX 协议均有其字典,用于开发的库一般也有自带;也可参考第三方,如 wireshark】例如 8 代表 begin string,34 代表消息的序列号,52 代表时间戳等。自定义字段不与规定的 key 重复,供金融机构定制,开发时需要向对应金融机构获取其专有字段的字典。只要有了对应的字典,就可以读懂 FIX 数据包的内容。

    一般来说,一个 消息由“头部 + 消息体 + 尾部”构成。头部包含一些必要的字段,例如 BeginString (8)、BodyLength (9)、MsgType (35)、MsgSeqNum (34)、SenderCompID (49) 等,尾部包含的必要字段是 CheckSum (10)。

    FIX 登陆消息示例(假设”^”是分隔符):

    8=FIX.4.3^9=65^35=A^34=1^49=TESTACC^52=20130703-15:55:08.609^56=EXEC^98=0^108=30^10=225^

    对照字典可知,BeginString (8) 是 FIX.4.3;BodyLength (9) 是 65 字节;MsgType (35) 是 A,A 对应 logon 操作;MsgSeqNum (34) 是 1,即这是我方发送的第 1 个消息。

    有关更详细的协议介绍,可参考 https://blog.51cto.com/9291927/2536105

    开发方案

    不使用库

    由上面的示例可以发现,FIX 协议十分简单。可以不需要依赖第三方库,手动查字典构造消息
    8=FIX.4.3^9=65^35=A^34=1<省略>
    再通过标准的 socket 通信,即可完成交互。

    这个方案自由度最高,不依赖底层开发语言,但开发流程与查字典较为繁琐,后续维护也不太方便。

    Python 库 simplefix

    simplefix 是一个 FIX 协议的简易实现。它使用户可以方便地任意构造 FIX 消息,非常适合学习、测试协议。但这个库不包含任何网络收发、FIX 异常处理等功能模块。因此,开发 FIX 客户端时,我使用该库构造数据包,然后通过标准 socket 发送,再分析其网络底层交互。

    示例:发送 logon 消息的代码

    import socket
    import time
    
    import simplefix
    from simplefix.constants import (ENCRYPTMETHOD_NONE, MSGTYPE_LOGON,
                                     TAG_BEGINSTRING, TAG_CLIENTID,
                                     TAG_ENCRYPTMETHOD, TAG_HEARTBTINT,
                                     TAG_MSGSEQNUM, TAG_MSGTYPE, TAG_SENDER_COMPID,
                                     TAG_SENDING_TIME, TAG_TARGET_COMPID)
    
    HOST = '1.2.3.45'
    PORT = 9000
    CLIENT_ID = 12345678
    PASSWORD = 'mypassword'
    
    seq_num = 1  # 需维护msg sequence number,每次发送后加1
    
    
    if __name__ == "__main__":
        sock = socket.socket(socket.AF_INET, socket.SOCK_STREAM)
        sock.setsockopt(socket.SOL_SOCKET, socket.SO_KEEPALIVE, 1)  # 长连接
        sock.connect((HOST, PORT))
    
        # logon
        msg_logon = simplefix.FixMessage()
        msg_logon.append_pair(TAG_BEGINSTRING, 'FIX.4.2')
        msg_logon.append_pair(TAG_MSGSEQNUM, seq_num)  # 需维护msg sequence number,每次发送后加1
        msg_logon.append_pair(TAG_SENDER_COMPID, 'FIXTEST001')
        msg_logon.append_utc_timestamp(TAG_SENDING_TIME)
        msg_logon.append_pair(TAG_TARGET_COMPID, 'TESTENV')
    
        msg_logon.append_pair(TAG_MSGTYPE, MSGTYPE_LOGON)  # 类型
    
        msg_logon.append_pair(TAG_ENCRYPTMETHOD, ENCRYPTMETHOD_NONE)
        msg_logon.append_pair(TAG_HEARTBTINT, 30)
        msg_logon_buffer = msg_logon.encode()
    
        sock.send(msg_logon_buffer)
        print('seq', seq_num, 'sent')
        seq_num += 1
    
    
        time.sleep(1)
        sock.close()

    多平台 QuickFix 引擎

    QuickFix 是全功能的 FIX 开源引擎,目前很多 Fix 解决方案都是根据或参考 QuickFix 实现的。目前(2020年10月)它有 C++、Python、Java、.NET、Go 和 Ruby 共 6 种语言的实现/接口。

    根据我司的情况,选择其中的 Java 实现 QuickFIX/J 进行下一步开发。其使用方法,将在下一篇文章继续。

  • 树状数组(Binary Indexed Tree / Fenwick Tree)学习与实现

    树状数组是一个能高效处理数组①更新、②求前缀和的数据结构。它提供了2 个方法,时间复杂度均为O(log n)

    1. update(index, delta):将 delta 加到数组的 index 位置
    2. prefix_sum(n):获取数组的前 n 个元素的和
      range_sum(start, end):获取数组从 [start, end] 的和,相当于 prefix_sum(end) – prefix_sum(start-1)

    如果只追求第 1 点,即快速修改数组,普通的线性数组可满足需求。但对于 range sum(),需要O(n)

    如果只追求第 2 点,即快速求 range sum,使用前缀数组的效果更好。但对于 add() 操作,则需要O(n),所以只适合更新较少的情况。

    树状数组则处于两者之间,适合数组又修改,又获取区间和的情景。

    思想

    树状数组的思想是怎样的呢?

    假设有一个数组 [1, 7, 3, 0, 5, 8, 3, 2, 6, 2, 1, 1, 4, 5],想求前 13 个元素的和。那么,

    13 = 23 + 22 + 20 = 8 + 4 + 1

    前 13 个数的和等于【前 8 个数的和】+【接下来 4 个数的和】+【接下来 1 个数的和】,即 range(1, 13) = range(1, 8) + range(9, 12) + range(13, 13)。如果有一种方法,可以保存 range(1, 8)、range(9, 12)、range(13, 13),那么计算这个区间和就可以加快了。

    这里给出已经计算好的结果(即最下面的 array 层)。例如 array[8] 是 29,往上可以找到 29 对应的是 [1,8],即 range(1, 8) = array[8]。同理,range(9, 12) = array[12],range(13, 13) = array[13]。

    range(1, 13) = range(1, 8) + range(9, 12) + range(13, 13) = array[8] + array[12] + array[13]

    由此图可以发现,虽然它的英文是含有 Tree,中间的部分看起来也是树状的,但是最终用到的 array 是线性的数组(太好了,复杂程度大减)。

    那中间这 3 层是怎么来的呢?——需要从上到下,从左到右看。

    首先计算 [1, 1] 的和,然后计算 [1, 2] 的和,然后计算 [1, 4]、[1, 8] 的和,每次乘 2,直到越界([1, 16] 越界),这里分别算出来了1、8、11、29。

    然后是第二层,从空缺的位置继续,这里的“界”不是整个数组的最大值,而是所有上层中下一个非空缺的位置。计算 [3, 3] 的和,[3, 4] 不用算,因为越界了。然后计算 [5, 5] 的和,接下来是 [5, 6] 的和,[5, 8] 越界不用算。

    第三层也是类似,然后发现填完了。

    以上可以帮助理解 result 数组中各值的来源,实际建立时有更简洁的做法。至于为什么是这样定义,可以另外找找资料,我看起来这有点像“分形”的感觉。

    阅读更多…
  • 一图流领悟跳跃列表(Skip List),附 Python/Java/Kotlin 实现

    跳跃列表是一种随机数据结构。它使得包含 n 个元素的有序序列的查找、插入、删除操作的平均时间复杂度都是 O(log n)。(注意关键词有序,如果不需要有序,也不需要用到跳跃列表;数据量大时,时间复杂度退化到较慢的概率微乎其微)

    平均最差
    搜索O(log n)O(n)
    插入O(log n)O(n)
    删除O(log n)O(n)
    空间O(n)O(n log n)

    跳跃列表是通过维护一个多层链表实现的。每一层链表中的元素的数量,在统计上来说都比下面一层链表元素的数量更少。也就是说,上层疏,下层密,底层数据是完整的,上面的稀疏层作为索引——这就是链表的“二分查找”啊。

    一开始时,算法在最稀疏的层次进行搜索,直至需要查找的元素在该层两个相邻的元素中间。这时,算法将跳转到下一个层次,重复刚才的搜索,直到找到需要查找的元素为止。

    Wikipedia 的道理就讲到这里,我不希望把本文写得难懂。说好的一图流就能领悟呢?其实我有点标题党,本文不止一幅图,但是核心的图只有一幅,上图(来自 Wikipedia):

    请多次认真观看插入节点的全过程 gif。我看完之后,就觉得自己可以实现出来了(虽然后来实际开发调试了很多次)。

    例如想在上图中所示的跳跃列表中插入 80,首先要找到应该插入的位置。

    首先从最稀疏的层的 30 开始,把当前位置设置为顶层的 30。
    80 比当前位置右边的 NIL 小,所以把当前位置移到下一层的 30;
    80 比当前位置右边的 50 大,所以把当前位置右移到 50;
    80 比当前位置右边的 NIL 小,所以把当前位置移到下一层的 50;
    80 比当前位置右边的 70 大,所以把当前位置右移到 70;
    80 比当前位置右边的 NIL 小,所以把当前位置移到下一层的 70;(当前位置已到达底层)
    之后用 80 不断与右边的节点比较大小,右移至合适的位置插入 80 节点。(底层插入完毕)
    接下来用随机决定是否把这个 80 提升到上面的层中,例如图中的提升概率是 50%(抛硬币),不断提升至硬币为反面为止。

    上面一段描述了 gif 中插入 80 的搜索和插入过程。那么,代码如何实现?右移和下移的逻辑很浅显,那么重点就在如何提升节点到上层的逻辑。

    阅读更多…
  • Python 分布式任务队列 Celery 入门与使用

    最近工作的项目使我接触到了 Celery 这个任务队列。看了一下官方的文档,感觉设计得还挺 Pythonic,理念也非常简单易懂——类似生产者与消费者。在这里稍微总(fan)结(yi)一下 Celery 的使用方法。

    简介

    Celery 是一个分布式任务队列,网上也有说是分布式任务调度框架,这里我以官方文档的“Distributed Task Queue”为准。它简单、灵活、可靠,可以处理大量的大量的任务,其主要专注于实时处理,同时也支持计划任务。

    为什么要用任务队列?我的理解是,首先方便了任务的分发调度与管理,另外也使调用的过程变得异步(非常适合 web 请求)。

    名词解释

    • 任务队列(task queue):一种分发任务到不同的线程或机器的方法,其输入为一个任务(task)。
    • Worker:实际执行任务的进程,它不断检查任务队列中的新任务并执行。
    • Broker:客户端与 worker 通信的中介。客户端发送任务的消息到队列中,broker 把这条消息传递给一个 worker。

    入门

    如果不考虑进阶用法,5 分钟入门。

    安装

    首先安装 Celery 并选择 broker。其中 broker 主要支持 RabbitMQ 和 Redis。RabbitMQ 不需要额外依赖,直接执行pip install -U celery安装。 Redis 需要在安装 Celery 时安装附加的依赖:pip install -U "celery[redis]"

    RabbitMQ 更为适合生产环境,也稍微大型;Redis 更轻量级,但突然退出时可能丢失数据。为了稍微简单轻量,本文都用 Redis。(如何安装 broker 不在本文内讨论,docker 启动一个最为简单)

    新建 Celery 应用

    新建一个mytasks.py

    from celery import Celery
    
    app = Celery('tasks', broker='redis://localhost:6379/0')
    
    @app.task
    def add(x, y):
        return x + y

    接下来就可以启动 worker 了:(生产环境当然不会这样手动运行,而会把它作为后台程序运行)

    $ celery -A mytasks worker --loglevel=info
    
    # 如果不了解上面的命令用法,可查看命令帮助
    # celery help
    # celery worker --help

    调用 task

    在当前的目录,运行

    >>> from mytasks import add
    >>> add.delay(1, 2)  # 使用 delay() 来使worker调用这个task

    可以得到类似<AsyncResult: fd9cdbe3-bcb3-432a-8d46-67b41243cfed>的返回值,而不会返回 3;这个 3 在 worker 的控制台里可以看到:Task mytasks.add[fd9cdbe3-bcb3-432a-8d46-67b41243cfed] succeeded in 0.0003419s: 3

    保存结果

    默认情况下,结果是不保存的。如果想保存结果,需要指定 result backend,支持 SQLAlchemy/Django ORM, MongoDB, Memcached, Redis, RPC (RabbitMQ/AMQP) 等。例如app = Celery('tasks', broker='redis://localhost:6379/0', backend='redis://localhost:6379/0'),调用之后就可以查询任务的状态及结果。

    >>> result = add.delay(1, 2)
    >>> result.ready()
    True
    >>> result.get(timeout=1)
    3

    参数配置

    简单的参数配置可以直接在代码中修改app.conf,例如:

    app.conf.task_serializer = 'json'

    对于大型一点的项目,最好专门做一个配置模块。先新建一个 celeryconfig.py:

    broker_url = 'pyamqp://xxxx'
    result_backend = 'rpc://xxxx'
    
    task_serializer = 'json'
    timezone = 'Europe/Oslo'
    enable_utc = True
    
    task_annotations = {
        'mytasks.add': {'rate_limit': '10/m'}
    }

    然后通过app.config_from_object('celeryconfig')导入。

    稍微深入

    Task

    Task 有很多选项可以填入,例如用@app.task(bind=True, default_retry_delay=30 * 60),可以修改任务失败后,等待重试的时间。

    关于任务的重试,我后来因工作需要,又深入阅读了文档。理想的目标是使一个任务可以自动重试,若重试一定次数仍失败,则发送通知。

    首先我看到了acks_late这个参数,它的意思是说一个 task 只有在执行成功后,才给队列 ack(移除)。我试了一下,似乎是不行的,fail 一次之后就没有然后了:

    # 是不行的,会被ack
    @app.task(acks_late=True)
    def add_may_fail_late_ack(x, y):
        if random.random() < 0.5:
            raise RuntimeError('unlucky')
        print('ok')  
        return x + y

    然后是autoretry_for=(XxxException,)参数。这个是最简单的自动重试写法,不需要修改原代码的逻辑,但不够灵活,对于简单的任务比较适用。

    最后是功能最全面的写法。首先定义一个自己的 Task,而不使用自带的 Task,因为 Task 可以提供一系列的回调函数(on_xxx)供自定义。例如我可以覆写on_failure方法,在任务超过一定重试次数仍失败时报警。然后是要注意两处地方:一是bind=True,对应的要把def add(x, y)改为def add(self, x, y);二是重试的操作是在业务逻辑手动触发的,且是通过 raise 的方式进行。代码大概是这样子:

    class MyTask(Task):
        def on_failure(self, exc, task_id, args, kwargs, einfo):  # einfo是完整的traceback
            print(f'on failure!!! name={self.name}, exc={exc}, task_id={task_id}, args={args}, kwargs={kwargs}')
    
    @app.task(base=MyTask, bind=True, default_retry_delay=5, max_retries=1)
    def add_may_fail_custom_retry(self: Task, x, y):
        try:
            if random.random() < 0.5:
                print('fail')
                raise RuntimeError('unlucky')
            print('ok')
            return x + y
        except RuntimeError as e:
            raise self.retry(exc=e)

    上述的代码在第一次遇到RuntimeError时,会等待 5s 重新执行,若仍然遇到RuntimeError(设置了max_retries=1),worker 才会抛出异常。此时会调用 on_failure(),把有用的信息记录下来,例如

    on failure!!! name=mytasks.add_may_fail_custom_retry, exc=unlucky, task_id=9ad47d43-7b7f-4d8d-a078-e54934f54d6e, args=[1, 7], kwargs={}

    这样就基本达成了预想的效果。其他有关 task 的具体内容,见Tasks文档

    调用 task

    前面用到的 delay() 方法是 apply_async() 的简化,但前者不支持传递执行的参数。举例来说,

    task.delay(arg1, arg2, kwarg1='x', kwarg2='y')
    # 等价于
    task.apply_async(args=(arg1, arg2), kwargs={'kwarg1': 'x', 'kwarg2': 'y'})

    可见简化了许多。

    Countdown 参数可以设置任务至少(可能受 worker busy 或其他原因有所推迟)多少秒后执行;而 eta (estimated time of arrival) 参数可以设置任务至少(原因相同)在具体时刻之后执行:

    >>> result = add.apply_async((1, 2), countdown=5)  # 至少5秒后执行
    >>> result.get()  # 阻塞至任务完成
    
    >>> tomorrow = datetime.utcnow() + timedelta(days=1)
    >>> add.apply_async((1, 2), eta=tomorrow)

    一个任务由于种种原因,延迟太久了,我们可以把它设置为过期,支持输入秒数或一个 datetime:

    add.apply_async((10, 10), expires=60)  # 如果任务延迟超过60s,将不会被执行

    对于一个任务,还可以指定这个任务放到哪个队列中(routing),例如

    add.apply_async(queue='priority.high')

    使用 -Q 来给 worker 指定监听的队列:

    $ celery -A mytasks worker -l info -Q celery,priority.high

    像上面这样硬编码 add 的对应 queue 不是太好,更佳的方法是使用 configuration routers

    其他调用 task 的文档,见 Calling Tasks

    函数签名(signature)

    对于简单的 task 调用,使用 .delay() 或 .apply_async() 方法一般就已足够。但有时我们需要更高级的调用,例如把任务的返回值用作下一个任务的输入,如果把一系列任务写成串行,就很不推荐了。为此,可以通过函数签名来调用 tasks。

    下面给 add() 函数创建一个签名(signature):

    >>> add.signature((2, 2), countdown=10)
    tasks.add(2, 2)
    >>> add.s(2, 2)  # 简化,但不能传入task的option,例如countdown
    tasks.add(2, 2)
    >>> sig = add.signature((2, 2), {'debug': True}, countdown=10)  # 完全版

    定义了签名后,就可以用sig.delay()来调用这个任务。

    签名的一个很重要的功能是它可以定义偏函数,类似 Python 的 functools.partial:

    >>> partial = add.s(2)          # 不完整的 signature
    >>> partial.delay(1)            # 1 + 2  注意这个1是填在前面的

    偏函数主要的应用场合是各种的原语(Primitives)。这些 primitives 主要包括 group、chain、chord、map、starmap、chunks 等。下面介绍其中几个的用法。

    group

    group 可以实现任务的并行:

    >>> from celery import group
    >>> res = group(add.s(i, i) for i in range(10))()
    >>> res.get(timeout=1)
    [0, 2, 4, 6, 8, 10, 12, 14, 16, 18]
    chain

    chain 可以按顺序执行任务,把前一个任务的结果作为接下来的任务的输入。注意偏函数的使用:

    >>> from celery import chain
    >>> result = chain(add.s('h', 'e'), add.s('llo'), add.s(' world'))()
    >>> result.get()
    'hello world'
    >>> (add.s('h', 'e') | add.s('llo') | add.s(' world'))().get()  # 也可以用 | 连接

    有关这一部分的更详细内容,见 Canvas: Designing Work-flows

    后台启动

    在实际环境中,celery 肯定是以后台服务的方式运行的。文档给出了 systemd、init.d、supervisor 等启动的方式。具体见 Daemonization

    定时任务

    定时运行任务的功能由 celery beat 完成。它按设定的周期/时间把任务发送到队列中,并被 worker 执行。对于一个集群的 worker,celery beat 应只有一个,否则任务会重复。

    mytasks.py改成下面所示:

    from celery import Celery
    from celery.schedules import crontab
    
    app = Celery('tasks', broker='redis://localhost:6379/0')
    
    @app.task
    def add(x, y):
        return x + y
    
    @app.on_after_configure.connect
    def setup_periodic_tasks(sender, **kwargs):
        # 每10s执行一次
        sender.add_periodic_task(10.0, add.s('hello', ' world'), name='every 10s')
    
        # 按crontab的格式定期执行
        sender.add_periodic_task(
            crontab(hour='*', minute=5),
            add.s('it is', ' xx:05')
        )

    然后启动 beat:

    $ celery -A mytasks beat

    可以在 worker 看到每 10s 输出了一次 “hello world”。每个小时的 5 分,都会输出 “it is xx:05”

    关于定时任务,具体见 Periodic Tasks

    相关参考

    上面只是比较基本的用法。对于更多深入使用中遇到的问题,还是应该参考官网文档