Numba 简介
Numba 是 Python 的一个 JIT (just-in-time) 编译器,最适用于 NumPy 数组、函数,以及 Python 循环。基本上,用法就是给原来的 Python 函数加一个修饰器,当运行到经 Numba 修饰的函数时,它会被编译为机器码,之后再调用时,就能以机器码的速度来执行了。
按我上手使用的经验来看,Numba 对原代码的改动不是太大,对能加速的部分,加速效果明显;对不支持的加速的 Python 语句/第三方库,可以选择不使用 numba 来规避。这是我选择 Numba 的原因。
首先:应该编译(优化)什么?
由于 Numba 本身的限制(稍后介绍),不能做到对整个程序完全的优化。实际上,也没必要这样做——只需要优化真正耗时间的部分即可。
怎么找到真正耗时间的部分?除了靠直觉,还可以借用工具来分析,例如 Python 自带的 cProfile,还有 line_profiler 等,这里不再细讲。
安装
可以通过 conda 或 pip,一个命令安装:
conda / pip install numba
什么样的代码能加速?
按照官方文档的示例代码,如果代码中含有很多数学运算、使用 NumPy,或者有大量 Python 的 for 循环(这可是 Python 性能大忌),那么 Numba 就能给你很好的效果。尤其是多重 for 循环,可以获得极大的加速。
大家都知道,给一个 np.ndarray 加 1 是很快的(向量化、广播),但是如果 for 遍历这个 array 的元素再每个加 1就会很慢(新手容易犯的小错误);但是这都没关系,有了 Numba 再 for 遍历元素加 1,和直接用 ndarray 加 1 的耗时是差不多的!
再举个例子,下面这段代码,就能享受到 JIT:
from numba import jit
import numpy as np
x = np.arange(100).reshape(10, 10)
@jit(nopython=True) # 设置为"nopython"模式 有更好的性能
def go_fast(a): # 第一次调用时会编译
trace = 0
for i in range(a.shape[0]): # Numba likes loops
trace += np.tanh(a[i, i]) # Numba likes NumPy functions
return a + trace # Numba likes NumPy broadcasting
print(go_fast(x))
但是,类似下面的代码,Numba 就没什么效果:
from numba import jit
import pandas as pd
x = {'a': [1, 2, 3], 'b': [20, 30, 40]}
@jit
def use_pandas(a): # 这个函数就加速不了
df = pd.DataFrame.from_dict(a) # Numba 不支持 pd.DataFrame
df += 1 # Numba 也不支持这个
return df.cov() # 和这个
print(use_pandas(x))
总之,Numba 应付不了 pandas。以我的经验,需要先把 DataFrame 转成 np.ndarray,再输入给 Numba。
要强制用 nopython 模式
刚才有效果的代码中,@jit(nopython=True) 这里传入了 nopython 这个参数,而没什么效果的代码中,就没有这个参数。为什么呢?
这是因为,@jit 实际上有两种模式,分为别 nopython 和 object 模式。只有 nopython 模式,才是能真正大幅加速的模式。而 nopython 模式只支持部分的 Python 和 NumPy 函数,如果运行时用到了不支持的函数/方法,程序就会崩掉 (例如刚才不能加速的例子如果加上 nopython 就会崩) 。如果不强制设定 nopython 模式,编译函数失败时,会回退到 object 模式,程序虽然不会崩,但却偏离了我们给它加速的本意。
我既然用了 Numba,我就希望它能真正地发挥作用。所以选择强制开启 nopython ,如果不能加速,不如让它直接崩溃,我们再作对应修改。
不支持哪些 features?
说了那么多,哪些能用,哪些不能用?当前版本(本文撰写时为 0.45)的 Numba 不支持以下 Python 功能:
- Exception handling (try .. except, try .. finally)
- Context management (the with statement)
- Dict/set/generator comprehensions
- Generator delegation (yield from)
此外,还有一些限制,例如 list/tuple 的元素必须同类(静态)等,见文末参考链接的“支持的 Python 功能”。对于 NumPy,其支持的 feature 也在参考链接中列出。
一些使用心得
- 再次提醒,不支持 pandas,不然分分钟报错。另外发现 pandas 的效率其实挺低的,例如取 column 的操作。应该把数据取出来,然后转成 numpy ndarray。
- 变量不能做类型转换(类似静态语言的要求)。
- 待加速的函数,要么是运行很多很多次,要么是运行时间很长;这是因为第一次编译都需要时间,如果编译完只运行这个函数几次,说不定还没有直接跑快;虽然有一个可以把编译结果存在硬盘的
cache
参数,但是实测发现从硬盘读取这个缓存结果也还是需要一点时间的。 - 涉及到具体的业务代码,如果有 class,或者比较长的函数,最好把真正需要优化的部分提取出来做 JIT,绕过不支持的 methods。class 虽然有初步的支持 (jitclass),但是限制很多,一般只能用于简单的 class。
- (2022年4月补充)list 和 dict 等容器比较麻烦。由于默认情况下 numba 的泛型是假的,一切重点都是要让 numba 能推断容器里的元素的类型,否则运行时会报错。
例如我的例子是有一个 nopython 模式的函数,它会接收一个 list 作为参数。问题是这个 list 在运行的时候,有时是有元素的(元素保证是同样的类型),有时是空的,那么传入这个空列表的时候,就会报错ValueError: cannot compute fingerprint of empty list
。
解决的办法是不使用 Python list,改用 numba 的 typed list,并且创建这个 typed list 时显式声明它的类型,就变成真泛型了。
后来发现,如果不是要不断往列表中增加元素,用 typed list 的性能好像还不如直接用 np.ndarray 来存放数据。
# 解决 "ValueError: cannot compute fingerprint of empty list" 问题 from typing import List import numba from numba.typed import List as NumbaList @numba.jit(nopython=True) def my_function(the_list: List): for element in the_list: print(element) print('finish', the_list) # my_list: List[int] = [1, 2, 3] # my_function(my_list) # 这个可以 # my_list = [] # my_function(my_list) # 但是这个会报错 generic_type = numba.typeof(1) # 让numba自己判断数据是什么类型 print(generic_type) # int64 my_list: List[int] = NumbaList.empty_list(generic_type) # 构建一个空的numba typed list,注意这里是真泛型 print(repr(my_list)) # ListType[int64]([]) print(type(my_list)) # <class 'numba.typed.typedlist.List'> my_function(my_list) # 这次传入了空列表,但可以正常运行 my_list.append(1) my_list.append(2) my_list.append(3) my_function(my_list)
相关参考
本文只是很初步的介绍,更深度的使用方法,还是要自己看文档:
发表评论