计算复杂度

提示:计算复杂度的简单理解(第一次写博客)

计算复杂度

计算复杂度

我们以Vicinity Vision Transformer论文中的图为例。
在这里插入图片描述图注:标准自注意力(左)和线性化自注意力(右)的图示。 N N N表示输入图像的 p a t c h patch patch数, d d d是特征维度。使 N ≫ d Ngg d Nd,线性化自注意力的计算复杂度相对于输入长度线性增长,而标准自注意力的计算复杂度是二次的。

从输入到输出可以这样计算:
( N × d ) × ( d × N ) = N × N × ( d × N ) × ( N × N ) = d × N (Ntimes d)times (dtimes N)=Ntimes Ntimes (dtimes N)times (Ntimes N)=dtimes N (N×d)×(d×N)=N×N×(d×N)×(N×N)=d×N
( d × N ) × ( N × d ) = d × d × ( d × d ) × ( d × N ) = d × N (dtimes N)times (Ntimes d)=dtimes dtimes (dtimes d)times (dtimes N)=dtimes N (d×N)×(N×d)=d×d×(d×d)×(d×N)=d×N

关于计算复杂度:其实可以认为是乘法次数。我们给出最直观的解释。

假设有两个矩阵做乘法,如下:
[ 1 2 3 4 5 6 ] × [ 1 2 3 4 5 6 ] = [ 1 2 3 4 5 6 7 8 9 ] left[begin{matrix}1&2\3&4\5&6\end{matrix}right]timesleft[begin{matrix}1&2&3\4&5&6\end{matrix}right]=left[begin{matrix}1&2&3\4&5&6\7&8&9\end{matrix}right] 135246×[142536]=147258369,其中行数为 N N N,列数为 d d d

( 3 × 2 ) × ( 2 × 3 ) = ( 3 × 3 ) × ( N × d ) × ( d × N ) = ( N × N ) (3times 2)times (2times 3)=(3times 3)times (Ntimes d)times (dtimes N)=(Ntimes N) (3×2)×(2×3)=(3×3)×(N×d)×(d×N)=(N×N)

3 × 3 3times 3 3×3矩阵第一个元素涉及的乘法次数: 1 × 1 + 2 × 4 = 9 1times 1+2times 4=9 1×1+2×4=9 共2次乘法;其它元素是一样的。最后可以得到 2 × 9 = 2 × 3 × 3 = d × N × N = N 2 d 2times 9=2times 3times 3=dtimes Ntimes N=N^{2}d 2×9=2×3×3=d×N×N=N2d.

假设又有两个矩阵做乘法,如下:
[ 1 2 3 4 5 6 ] × [ 1 2 3 4 5 6 ] = [ 1 2 3 4 ] left[begin{matrix}1&2&3\4&5&6\end{matrix}right]timesleft[begin{matrix}1&2\3&4\5&6\end{matrix}right]=left[begin{matrix}1&2\3&4\end{matrix}right] [142536]×135246=[1324],其中行数为 d d d,列数为 N N N

( 2 × 3 ) × ( 3 × 2 ) = ( 2 × 2 ) × ( d × N ) × ( N × d ) = ( d × d ) (2times 3)times (3times 2)=(2times 2)times (dtimes N)times (Ntimes d)=(dtimes d) (2×3)×(3×2)=(2×2)×(d×N)×(N×d)=(d×d)

2 × 2 2times 2 2×2矩阵第一个元素涉及的乘法次数: 1 × 1 + 2 × 3 + 2 × 5 = 17 1times 1+2times 3+2times 5=17 1×1+2×3+2×5=17 共3次乘法;其它元素是一样的。最后可以得到 3 × 4 = 3 × 2 × 2 = N × d × d = N d 2 3times 4=3times 2times 2=Ntimes dtimes d=Nd^2 3×4=3×2×2=N×d×d=Nd2 .

为什么会有这种情况呢?以第二个例子为例,可以观察到,所得结果的一个元素的乘法数量和消失的维度大小有关,也就是列数 N N N,或者说,列数 N N N就是所得结果一个元素的乘法次数。那么多少个元素呢?元素个数就要看你是如何进行的乘法操作,其实就是矩阵大小。比如 ( 2 × 3 ) × ( 3 × 2 ) = ( 2 × 2 ) × ( d × N ) × ( N × d ) = ( d × d ) (2times 3)times (3times 2)=(2times 2)times (dtimes N)times (Ntimes d)=(dtimes d) (2×3)×(3×2)=(2×2)×(d×N)×(N×d)=(d×d),那么就是 d 2 d^2 d2个元素,最后乘法次数就是 N d 2 Nd^2 Nd2

乘法次数=消失的维度 × 所得矩阵大小

那么计算复杂度呢?我们不要去管 O ( ∙ ) O(bullet) O()具体代表什么,这不重要。
以第一个图为例,乘法次数1: ( N × d ) × ( d × N ) = N 2 d (Ntimes d)times (dtimes N)=N^{2}d (N×d)×(d×N)=N2d;乘法次数 2 2 2 ( N × d ) × ( d × N ) = N 2 d (Ntimes d)times (dtimes N)=N^{2}d (N×d)×(d×N)=N2d O ( N 2 d + N 2 d ) = O ( N 2 ) O(N^{2}d+N^{2}d)=O(N^2) O(N2d+N2d)=O(N2)。因为 N ≫ d Ngg d Nd,所以 d d d(还有常数 2 2 2)被省略了,即 O ( N 2 ) O(N^2) O(N2)
以第二个图为例,乘法次数1: ( d × N ) × ( N × d ) = N d 2 (dtimes N)times (Ntimes d)=Nd^2 (d×N)×(N×d)=Nd2;乘法次数2: ( d × d ) × ( d × N ) = N d 2 (dtimes d)times (dtimes N)=Nd^2 (d×d)×(d×N)=Nd2 O ( N d 2 + N d 2 ) = O ( N ) O(Nd^2+Nd^2)=O(N) O(Nd2+Nd2)=O(N)。因为 N ≫ d Ngg d Nd,所以 d d d(还有常数2)被省略了,即 O ( N ) O(N) O(N)

事实告诉我们,我们两个的结果一样,但是我们可以通过控制中间过程减少计算复杂度。