详解Winograd变换矩阵生成原理

文本首发知乎:https://zhuanlan.zhihu.com/p/87516875


目录:

0、前言

1、卷积与多项式乘法

1.1、Convolution和Correlation的区别

1.2、卷积与多项式乘法的关系

2、理解Winograd算法需要的数学理论知识

2.1、欧几里得算法

2.2、多项式的欧几里得算法

2.3、扩展欧几里得算法

2.4、多项式的扩展欧几里得算法

2.5、乘法模逆元

2.6、多项式乘法模逆元

2.7、中国剩余定理

2.8、多项式的中国剩余定理

3、多项式的中国剩余定理的应用

3.1、卷积操作与中国剩余定理的联系

3.2、Winograd F(2,3)变换矩阵推导

4、参考资料


0、前言

其实网上已经有不少从数学原理的角度去解说Winograd[1,2,3,4,5,6,10]这个算法的文章了,为什么我还要写这篇文章。

主要是在看完许多相关的文章之后,对于Winograd这个算法背后的数学原理我还是没法完全理解,尤其是Winograd的变换矩阵究竟是如何生成的。而在查阅到的资料里面,在描述到一些相关数学定理的时候,许多细节部分都没有很详细的说明,只能通过额外去查找资料和手推公式来理解。

这也是促成我写这篇文章的主要原因,想把有关Winograd这个算法背后所涉及到的数学知识用比较通俗的方式给读者描述一遍,并且在这的过程中也会添加一些我个人的理解,当然我的理解也不一定正确,如果有误也请读者指出。

总的来说感觉Winograd这个算法真的很巧妙,要理解这个算法,需要懂得前置数学知识挺多的,如果其中一个地方没弄懂,都会对理解这个算法的数学原理造成困难。而且即使已经看懂了整体部分,但是很多细节部分如果仔细去想就会觉得自己还没有完全弄懂。

这里我把收集到的所有相关资料链接都统一放到文末参考资料里面,也方便读者去查阅。

1、卷积与多项式乘法

1.1、Convolution和Correlation的区别

首先卷积其实有两个含义[8,9]

第一个是指一般数学意义上的的两个离散序列的卷积(Convolution);

第二个是深度学习中所用到的卷积(操作上更像Correlation而不是Convolution);

通俗来说两个离散序列在做Convolution操作的时候,首先需要将其中一个序列做镜像翻转,然后两个序列相向移动,从开始第一个元素重合到最后一个元素重合为止,相向移动步长为1,每次把重合的部分做点乘累加的到新的元素,最后生成新的序列。

而两个离散序列在做Correlation操作的时候,除了不需要翻转序列,操作上和Convolution一致。而深度学习中的Convolution其实和Correlation很像,但是又不完全一样。

不一样的地方在于,计算第一个元素的时候是直接把较短的序列和较长的序列从左首元素开始对齐,然后较短的序列按步长向右移动(假设步长也是1),一直到当前较短序列最右边的元素和较长序列最右边的元素对齐,要求短序列的每个元素都必须要和长序列的元素有重合,然后每次把重合的部分点乘累加得到新的元素。

从这里开始,下文提到的Correlation操作都是指深度学习中的卷积操作。

所以给定同样长度的两个序列,分别做Convolution和Correlation操作得到的结果序列长度是不一样的。假设两个长度分别是 nk 的序列( n>=k ),分别做Convolution和Correlation操作(步长都为1)得到的结果序列长度的计算公式分别为:

l_{Convolution}=n+k-1 \\\\ l_{Correlation}=n-k+1

下面简单画个示意图解释的两者的区别:

左边的Convolution操作还有最后两步没有画出来,不过这已经能足够解释两者的区别了。

然后来看下两者的输出元素个数计算公式之间的联系,比如给定两个序列长度分别为 nk 的序列,Convolution操作的到的序列长度为 n+k-1 ,然后和 n 或者 k 长度的序列做Correlation 可以得到长度为 kn 的序列。

还有一点要提下,就是Winograd这个算法发明出来其实是用来加速Convolution操作的,所以计算变换矩阵也是从Convolution角度去计算,而计算出来的变换矩阵在做一点小变动之后,也可以直接应用在深度学习的Correlation操作中,这个在下文会讲到。

为什么提这个是因为,之前我在理解Winograd这个算法的是陷入了一个误区,一直是从深度学习卷积(Correlation)应用的这个角度去理解这个算法,然后一直想不明白,后来换成是从Convolution角度去理解很多地方就豁然开朗了。

1.2、卷积与多项式乘法的关系

Convolution操作其实直观上等价于多项式乘法操作[7]

还是用上面举的例子来说明,假设两个离散序列 [3, 1, 5, 4][7, 9] ,我们可以把这两个序列看成是两个多项式3+1x+5x^{2}+4x^{3} 7+9x 。这两个多项式相乘的结果是 21+34x^{}+44x^{2}+73x^{3}+36x^{4} ,把其系数从低次冥到高次冥排列,刚好就等价于这两个序列做 Convolution 的结果 [21,34,44,73,36]

从多项式相乘结果的最高次冥也可以看出两者的联系,假设两个离散序列长度分别为 nk

则对应的多项式的最高次冥分别为 x^{n-1}x^{k-1} ,则这两个多项式相乘结果的最高次冥为 x^{n+k-2} ,再加上一个最低次冥 x^{0} 总共就是 n+k-1 个元素,算上系数0的次冥。

下面再举一个例子,看下多项式乘法和Convolution操作的关系:

2、理解Winograd算法需要的数学理论知识

2.1、欧几里得算法

我们先来看下可以求解两个整数最大公约数的欧几里得算法[11,13],可能换成“辗转相除法”这个名字读者会更加熟悉。

我们知道两个整数 ab 的公约数是既能整除 a 又能整除 b 的整数,而两者的最大公约数 g 就是这些数里面最大的数,通常用数学公式表示为 g=GCD(a,b) ,GCD是Greatest common divisor 的缩写。

简单复习下整除的定义: m(m>0) 整除 n 可以记作 m|n ,即存在一个整数 k 使得 km=n

然后如果 GCD(a,b)=1 ,则称 ab 互素,表示它们的最大公约数为1。而两个整数是否互素和它们本身是否是素数无关。

简单复习下素数[12]的定义:素数(Prime number)又称质数,指在大于1的自然数中,除了1和该数自身外,无法被其他自然数整除的数。

引用[11]的一个例子来解释最大公因数的概念,假设现在 a>b ,么设一个长方形的高是 a ,宽是 b ,因为 ab 的任何公约数 c 都可以整除 ab ,所以长方形的高和宽都可以等分为长度是 c 的线段。通俗的说法是也就是长方形的内部可以刚好被边长是 c 的正方形填满。而最大公约数 g 是其中最大的一个正方形的边长,下面画个简单的示意图来说明:

接着我们看下如何用欧几里得算法求解最大公约数,先给出欧几里得算法的定义[11,13]

\\ \\\\GCD(a,b)=a, \\quad \\quad \\quad \\quad \\quad \\ \\ \\ \\quad b=0 \\\\ GCD(a,b)= GCD(b, a \\ \\% \\ b), \\quad b>0

其中 a\\ \\% \\ b 也可写作 a \\ mod \\ b ,也就是求 a 除以 b 的余数,上面公式还有个条件就是 a>b ,第一个公式很好理解,因为任何整数都能整除0。

这里首先引入同余式的概念:若正整数 ab 分别对 p 取模的余数相同,则可以记作 a\\equiv b(mod \\ p) ,也就是 abp 同余。

再继续证明欧几里得算法第二个公式之前,先来看一下求模运算的一些运算规则[15]

(a + b) % p = (a % p + b % p) % p
(a - b) % p = (a % p - b % p) % p
(a * b) % p = (a % p * b % p) % p

然后对于第二条公式的证明引用自[13],更多细节可以参考资料[11,13,14]

a 除以 b 的商是 p 余数是 q ,则可以表示为:

a=b* p+q

考虑到 bq 的最大公约数 GCD(b,q) 可知:

b\\ \\% \\ GCD(b,q)=0 \\\\ q\\ \\% \\ GCD(b,q)=0

然后根据运算规则和上面的两个公式可得:

a\\ \\% \\ GCD(b,q)=(b*p+q)\\ \\% \\ GCD(b,q) \\\\ =(b*p\\ \\% \\ GCD(b,q)+q\\ \\% \\ GCD(b,q))\\ \\% \\ GCD(b,q) \\\\ =((b\\ \\% \\ GCD(b,q)*p\\ \\% \\ GCD(b,q))\\ \\% \\ GCD(b,q)+q\\ \\% \\ GCD(b,q))\\ \\% \\ GCD(b,q) \\\\ =((0*p\\ \\% \\ GCD(b,q))\\ \\% \\ GCD(b,q)+0)\\ \\% \\ GCD(b,q) \\\\ =(0\\ \\% \\ GCD(b,q))\\ \\% \\ GCD(b,q) \\\\ =0

所以 bq 的最大公约数可以整除 a ,也即 GCD(b,q)|a,所以 GCD(b,q) 能同时整除 b、qa ,所以 b

q 的最大公约数也是 ab 的公约数,也即 GCD(b,q) 也可以整除 GCD(a,b) ,也就是 GCD(b,q)|GCD(a,b)

然后对等式做下变换:

q=a-b* p

然后考虑 ab 的最大公约数 GCD(a,b) 可知:

a\\ \\% \\ GCD(a,b)=0 \\\\ b\\ \\% \\ GCD(a,b)=0

接着同样根据取模运算规则和上面的两个公式可得:

q\\ \\% \\ GCD(a,b)=(a-b*p)\\ \\% \\ GCD(a,b) \\\\ =(a\\ \\% \\ GCD(a,b)-b*p\\ \\% \\ GCD(a,b))\\ \\% \\ GCD(a,b) \\\\ =(a\\ \\% \\ GCD(a,b)-(b\\ \\% \\ GCD(a,b)*p\\ \\% \\ GCD(a,b))\\ \\% \\ GCD(a,b)) \\ \\% \\ GCD(a,b) \\\\ =(0-(0*p\\ \\% \\ GCD(a,b))\\ \\% \\ GCD(a,b))\\ \\% \\ GCD(a,b)\\\\ =0

所以同理可得 GCD(a,b) 也可以整除 GCD(b,q) ,也就是 GCD(a,b)|GCD(b,q)

所以通过上面的推导可得:

GCD(a,b)=GCD(b,q)= GCD(b, a \\ \\% \\ b), \\quad b>0

然后就可以根据这个求解最大公约数的递归式来写实现代码了:

#include <iostream>

int GCD(int m, int n) {
	int t = 1;
	while(t != 0) {
		t=m%n;
		m=n;
		n=t;
	}
	return m;
}

int GCDRecursive(int m, int n) {
	if (n == 0) return m;
	return GCDRecursive(n, m % n);
}

int main(int argc, char *argv[]) {
	if (argc != 3) {
		return 0;
	}

	int a = atoi(argv[1]);
	int b = atoi(argv[2]);

	if (a < b) {
		int t = a;
		a = b;
		b = t;
	}

	int gcd = GCD(a, b);

	printf("GCD(%d, %d)=%d\\n", a, b, gcd);

	gcd = GCDRecursive(a, b);

	printf("GCD(%d, %d)=%d\\n", a, b, gcd);

	return 0;
}

运行结果:

2.2、多项式的欧几里得算法

欧几里得算法也可以推广到多项式上,和整数最大公因数类似的多项式上也有最大公因式的概念,一样也有整除和求余的概念。所以求解两个多项式的最大公因式一样也可以应用欧几里得算法[18]

首先下面通过3个例子来说明多项式除法[16,17]是如何u操作的,这里引用资料[17]对多项式除法规则的定义:

  • 首先把被除式的第一项除以除式的第一项;
  • 然后把除式乘以上面除法得到的结果,然后写在被除式下面;
  • 两者相减得到新的多项式;
  • 然后重复上面3个步骤,重复的时候用得到的新多项式作为被除式,一直相减到多项式的最高次小于除式的最高次或者得到0就停止;

例子一、整除的情况,x^{2}-3x-10 除以 x+2

所以有 x^{2}-3x-10=(x+2)*(x-5)

例子二、带余式的情况,2x^2-5x-1 除以 x-3

所以有 2x^{2}-5x-1=(x-3)*(2x+1)+2

例子三、带缺项的情况,x^6+2x^4+6x-9x^3+3

所以有 x^6+2x^4+6x-9=(x^3+3)*(x^3+2x-3)

多元多项式的情况可以参考资料[17]

有了多项式除法的概念之后,用一个例子来说明多项式中的欧几里得算法[18],f(x)=x^3+2x^2-x-2g(x)=x^2+2x-3 的最大公因式,同样利用性质 GCD(f(x),g(x))=GCD(g(x),f(x)\\ mod \\ g(x))

首先第一步计算 f(x) 除以 g(x) 得余式 2x-2

x^3+2x^2-x-2=(x^2+2x-3)*x+(2x-2)

然后用 x^2+2x-3 去除以余式 2x-2,可以整除 :

(x^2+2x-3)=(2x-2)*(\\frac{1}{2}x+\\frac{3}{2})+0

而因为 2x-2 能被 x-1 整除,所以最大公因式是 x-1。直接对 f(x)g(x) 做因式分解也能看出最大公因式:

x^3+2x^2-x-2=(x-1)(x+1)(x+2)
x^2+2x-3=(x-1)(x+3)

但是因式分解看起来就很难用代码实现,而欧几里得算法用代码来实现就容易多了。

2.3、扩展欧几里得算法

在介绍扩展欧几里得算法之前先来看下“裴蜀等式”,下面引用wikipedia上的解释[19]:

在数论中,裴蜀等式或裴蜀定理是一个关于最大公约数(或最大公因式)的定理。说明了对任何整数 abm ,关于未知数 xy 的方程:

ax+by=m

有整数解时当且仅当 mab 的最大公约数 GCD(a,b) 的倍数,也就是要求 GCD(a,b)|m 。裴蜀等式有解时必然有无穷多个整数解,每组解 xy 都称为裴蜀数,可用扩展欧几里得算法求解。

比如,12和42的最大公约数是6,则方程 12x+42y=6 。事实上有:

(-3)\\times12+1\\times42=6
4\\times12+(-1)\\times42=6

特别来说, ax+by=1 有整数解当且仅当 ab 互素,即 GCD(a,b)=1

证明过程有兴趣的读者可以参考[19]

接着来看下如何用扩展欧几里得[13,20]算法求解裴蜀等式,简单来说扩展欧几里德算法是对欧几里德算法的扩展,它可以用来求解形如 ax+by=c(a,b,c \\in Z) 的方程的一组整数解。我们可以从欧几里德算法的等式来实现扩展欧几里得算法:

我们先来看下方程 ax+by=GCD(a,b) 的边界情况,当 b=0 的时候,方程可化为 ax=GCD(a,0) ,然后根据最大公约数的性质可知 GCD(a,0)=a ,所以可以解得 x=1,y=0

然后对于一般情况,也是应用最大公因数的性质 GCD(a,b)=GCD(b,a \\ mod \\ b) ,首先设 a'=b,b'=a \\ mod \\ b ,然后同样有方程

a'x'+b'y'=GCD(a',b')=GCD(b,a \\ mod \\ b)

联合需要求解的方程 ax+by=GCD(a,b) 可得

ax+by=GCD(a,b)=GCD(b,a \\ mod \\ b)=GCD(a',b')=a'x'+b'y'

所以

ax+by=a'x'+b'y'=bx'+(a \\ mod \\ b)y'=bx'+(a-\\lfloor \\frac{a}{b} \\rfloor b)y'

整理一下式子可得

ax+by=ay'+b(x'-\\lfloor \\frac{a}{b} \\rfloor y')

对比系数两边可得求解递归式

x=y' \\\\ y=x'-\\lfloor \\frac{a}{b} \\rfloor y'

下面看下实现代码:

#include <iostream>

int exgcd(int a, int b, int &x, int &y) {
	if (b == 0) {
		x = 1;
		y = 0;
		return a;
	}
       int gcd = exgcd(b, a % b, x, y);
       int t = x;
       x = y;
       y = t - a / b * x;
	return gcd;
}

int main(int argc, char *argv[]) {
	if (argc != 3) {
		return 0;
	}

	int a = atoi(argv[1]);
	int b = atoi(argv[2]);

	if (a < b) {
		int t = a;
		a = b;
		b = t;
	}
	int x, y;
	int gcd = exgcd(a, b, x, y);

	printf("%d * %d + %d * %d = %d\\n", a, x, b, y, gcd);

	return 0;
}

运行结果:

所以扩展欧几里得算法可以同时求出 ax+by=GCD(a,b) 方程的解和最大公因数 GCD(a,b)

2.4、多项式的扩展欧几里得算法

同样类似的扩展欧几里得算法也可以应用在求解多项式的裴蜀等式,假设现在已知有两个多项式 f(x)

g(x) 以及最大公因式 GCD(f(x),g(x)) ,求解如下方程

u(x)f(x)+v(x)g(x)=GCD(f(x),g(x))

下面举个例子说明如何用扩展欧几里得算法求解,还是用上面的例子,已知 f(x)=x^3+2x^2-x-2g(x)=x^2+2x-3 的最大公因式是 x-1

直接套用扩展欧几里得算法递归式:

u(x)=v(x)' \\\\ v(x)=u(x)'-\\lfloor \\frac{f(x)}{g(x)} \\rfloor v(x)'

代码的话可以参考[21],下面看下每一步的计算过程:

第一步, \\ \\ x^3+2x^2-x-2=(x^2+2x-3)*x+(2x-2),商是x,余数是 2x-2
第二步,\\ \\ (x^2+2x-3)=(2x-2)*(\\frac{1}{2}x+\\frac{3}{2})+0,商是(\\frac{1}{2}x+\\frac{3}{2}),余数是0 停止递归,设置 u(x)=1, \\ v(x)=0
第三步, \\ \\ 开始回溯 t(x) = u(x) = 1, \\ u(x) = v(x) =0, \\ v(x) = t(x) - (\\frac{1}{2}x+\\frac{3}{2}) * u(x) = 1
第四步, \\ \\ t(x) = u(x) = 0, \\ u(x) = v(x) = 1, \\ v(x)=t(x)-x*u(x)=-x,得到解,u(x)=1和v(x)=-x

所以可得

1*(x^3+2x^2-x-2)+(-x)(x^2+2x-3)=2x-2

两边同除以2可得

(x^3+2x^2-x-2)*\\frac{1}{2}+(-\\frac{1}{2}x)(x^2+2x-3)=x-1

2.5、乘法模逆元

模逆元[22,23]也称为模倒数。整数 a 对同余 b 的模逆元是指满足下面公式的整数 x

ax\\equiv1(mod \\ b)

整数 a 对模数 b 的模逆元存在充分必要条件是 ab 互素,也即 GCD(a,b)=1 ,所以有 ax+by=1 ,可用扩展欧几里得算法求解。

求得的 x 即为 a 关于模 b 的其中一个模逆元。事实上 x+kb \\ (k\\in Z) 都是 a 关于模 b 的模逆元,这里我们取最小的正整数解 x \\ mod \\ b \\ (x<b) ,这也很好理解,假设 x 是最小的正整数解,则有 a(x+kb) \\ mod \\ b=(ax \\ mod \\ b + kb \\ mod \\ b) \\ mod \\ b=1\\ mod \\ b

代码:

#include <iostream>

int exgcd(int a, int b, int &x, int &y) {
	if (b == 0) {
		x = 1;
		y = 0;
		return a;
	}
       int gcd = exgcd(b, a % b, x, y);
       int t = x;
       x = y;
       y = t - a / b * x;
	return gcd;
}

int reverse_unit(int a, int b) {
	int x, y;
	int gcd = exgcd(a, b, x, y);
	if (gcd != 1) {
 		printf("reverse unit does not exist.\\n");
		return -1;
	}
	return (x % b + b) % b;
}

int main(int argc, char *argv[]) {
	if (argc != 3) {
		return 0;
	}

	int a = atoi(argv[1]);
	int b = atoi(argv[2]);

	int reverse= reverse_unit(a, b);

	if (reverse != -1)
		printf("%d * %d = 1 (mod %d) \\n", a, reverse, b);

	return 0;
}

运行结果:

2.6、多项式乘法模逆元

同理也可以应用扩展欧几里得算法求解多项式模的逆元,下面直接举例进行说明。

(x^2+x-6)u(x)\\equiv1(mod \\ x^2-1) 的逆元,而因为:

GCD(x^2+x-6,x^2-1)=1

所以有解。

构造方程

(x^2+x-6)u(x)+(x^2-1)v(x)=1

下面给出扩展欧几里得算法每一步的计算过程:

第一步,x^2+x-6=(x^2-1)*1+(x-5),商为1,余数为x-5
第二步,x^2-1=(x-5)*(x+5)+24,商为x+5,余数为24
第三步, x-5=24*(\\frac{1}{24}x-\\frac{5}{24})+0,商为\\frac{1}{24}x-\\frac{5}{24},余数为0,递归停止,设u(x)=1,v(x)=0。
第四步,t(x)=u(x)=1,u(x)=v(x)=0,v(x)=t(x)-(\\frac{1}{24}x-\\frac{5}{24})*u(x)=1
第五步,t(x)=u(x)=0,u(x)=v(x)=1,v(x)=t(x)-(x+5)*u(x)=-x-5
第六步,t(x)=u(x)=1,u(x)=v(x)=-x-5,v(x)=t(x)-1*u(x)=x+6,得解

所以有

两边同除以24得

(x^2+x-6)(-\\frac{1}{24}x-\\frac{5}{24})+(x^2-1)(\\frac{1}{24}x+\\frac{1}{4})=1

验证下

(x^2+x-6)(-\\frac{1}{24}x-\\frac{5}{24})(mod \\ x^2-1) \\\\ =-\\frac{1}{24}(x^2+x-6)(x+5)(mod \\ x^2-1) \\\\ =-\\frac{1}{24}(x^3+6x^2-x-30)(mod \\ x^2-1) \\\\ =-\\frac{1}{24}*-24=1

所以 (x^2+x-6)u(x)\\equiv1(mod \\ x^2-1) 的逆元是 u(x)=-\\frac{1}{24}x-\\frac{5}{24}

2.7、中国剩余定理

有了前面知识点的铺垫,理解中国剩余定理[24,25,26]就容易多了。文章[24]对中国剩余定理的解释非常透彻,下面对中国剩余定理的解释大部分是参考这篇文章。推荐对数学感兴趣读者可以关注该专栏,都是和数学相关的内容。

这里先来看下“孙子算经”[27]里面的第二十六题,原文如下:

今有物,不知其數。三三數之,賸二;五五數之,賸三;七七數之,賸二。
問:物幾何?
答曰:二十三。

術曰:

三三數之,賸二,置一百四十;
五五數之,賸三,置六十三;
七七數之,賸二 ,置三十。
并之,得二百三十三,以二百一十減之,即得。

凡三三數之,賸一,則置七十;
五五數之,賸一,則置二十一;
七七數之,賸一,則置十五。
一百六以上,以一百五減之,即得。

用通俗的语言描述第二十六题就是:

现在有一个整数,该整数除以3余2、除以5余3、除以7余2,求该整数是多少?
答案是:23

解法:

除以3余2,加140;
除以5余3,加63;
除以7余2,加30;
求和140+63+30=233,再减去210,就得到23。

只要是除以3余1,就加70;
只要是除以5余1,就加21;
只要是除以7余1,就加15;
然后累加,如果超过了106就减去105就得到结果了。

首先把这个问题转化为一个求解同余方程组的问题,然后对这个问题的解法就称为中国剩余定理:

x\\equiv2 \\ (mod \\ 3) \\\\ x\\equiv3 \\ (mod \\ 5) \\\\ x\\equiv2 \\ (mod \\ 7) \\\\

就是要求解一个整数 x ,同时满足除3余2,除5余3和除7余2。

首先我们可以把问题分解一下,如果能找到3个整数 x1,x2,x3 ,分别满足:

x1 \\ mod \\ 3 = 2, \\ x1 \\ mod \\ 5 = 0, \\ x1 \\ mod \\ 7 = 0
x2 \\ mod \\ 3 = 0, \\ x2 \\ mod \\ 5 = 3, \\ x2 \\ mod \\ 7 = 0
x3 \\ mod \\ 3 = 0, \\ x3 \\ mod \\ 5 = 0, \\ x3 \\ mod \\ 7 = 2

那么 x=x1+x2+x3 就是解,因为根据取模运算法则:

x \\% 3 = (x1 +x2 +x3)\\%3=(x1\\%3 +x2\\%3 +x3\\%3)\\%3=2
x \\% 5 = (x1 +x2 +x3)\\%5=(x1\\%5 +x2\\%5 +x3\\%5)\\%5=3
x \\% 7 = (x1 +x2 +x3)\\%7=(x1\\%7 +x2\\%7 +x3\\%7)\\%7=2

然后接着接续分解问题,如果能找到3个整数 y1,y2,y3 ,分别满足:

y1 \\ mod \\ 3 = 1, \\ y1 \\ mod \\ 5 = 0, \\ y1 \\ mod \\ 7 = 0
y2 \\ mod \\ 3 = 0, \\ y2 \\ mod \\ 5 = 1, \\ y2 \\ mod \\ 7 = 0
y3 \\ mod \\ 3 = 0, \\ y3 \\ mod \\ 5 = 0, \\ y3 \\ mod \\ 7 = 1

那么令 x1=2\\times y1,x2=3\\times y2,x3=2\\times y3, 即可求得解 x=2\\times y1+3\\times y2+2\\times y3, 因为根据取模运算法则有:

\\ \\\\ x1 \\% 3=(2\\times y1)\\%3=(2\\%3\\times y1\\%3)\\%3=2 \\\\ x1 \\% 5=(2\\times y1)\\%5=(2\\%5\\times y1\\%5)\\%5=0 \\\\ x1 \\% 7=(2\\times y1)\\%7=(2\\%7\\times y1\\%7)\\%7=0 \\\\ x2 \\% 3=(3\\times y2)\\%3=(3\\%3\\times y2\\%3)\\%3=0 \\\\ x2 \\% 5=(3\\times y2)\\%5=(3\\%5\\times y2\\%5)\\%5=3 \\\\ x2 \\% 7=(3\\times y2)\\%7=(3\\%7\\times y2\\%7)\\%7=0 \\\\ x3 \\% 3=(2\\times y3)\\%3=(2\\%3\\times y3\\%3)\\%3=0 \\\\ x3 \\% 5=(2\\times y3)\\%5=(2\\%5\\times y3\\%5)\\%5=0 \\\\ x3 \\% 7=(2\\times y3)\\%7=(2\\%7\\times y3\\%7)\\%7=2 \\\\

然后来看下怎么求解 y1,y2,y3

首先求 y1 ,因为其同时满足被5和7整除,所以一定是5和7的公倍数,也就是5x7=35的倍数,且除3余1,也就有 35k\\equiv1 (mod \\ 3) , 就把问题转化为求解35模3的逆元的问题,用上面讲到的扩展欧几里得算法就可以求出 k
,将k
记作 \\left[ 35^{-1} \\right]_{3} (口算可得 k=2),然后就可以求得:

y1=35\\times\\left[ 35^{-1} \\right]_{3}=5\\times7\\times\\left[ (5\\times7)^{-1} \\right]_{3}=35\\times2=70 \\\\ x1=2\\times y1=140

刚好对应了原文“三三數之,賸二,置一百四十;”这一句。

接着求 y2 ,因为其同时满足被3和7整除,所以一定是3和7的公倍数,也就是3x7=21的倍数,且除5余1,也就有 21k\\equiv1 (mod \\ 5) ,同样可以求得:

y2=21\\times\\left[ 21^{-1} \\right]_{5}=3\\times7\\times\\left[ (3\\times7)^{-1} \\right]_{5}=21\\times1=21 \\\\ x2=3\\times y2=63

刚好对应了原文“五五數之,賸三,置六十三;”这一句。

最后求 y3 ,因为其同时满足被3和5整除,所以一定是3和5的公倍数,也就是3x5=15的倍数,且除7余1,也就有 15k\\equiv1 (mod \\ 7) ,同样可以求得:

y3=15\\times\\left[ 15^{-1} \\right]_{7}=3\\times5\\times\\left[ (3\\times5)^{-1} \\right]_{7}=15\\times1=15 \\\\ x3=2\\times y3=30

刚好对应原文“七七數之,賸二 ,置三十。”这一句。

然后求得:

x=2\\times(5\\times7\\times\\left[ (5\\times7)^{-1} \\right]_{3})+3\\times(3\\times7\\times\\left[ (3\\times7)^{-1} \\right]_{5})+2\\times(3\\times5\\times\\left[ (3\\times5)^{-1} \\right]_{7})
x=140+63+30=233

最后注意到,如果 x 满足除以3余2、除以5余3、除以7余2,那么 x+3\\times5\\times7 也同样满足,这个结论应用取模运算法则推导一下就知道是正确的。

因此要计算满足要求的最小的非负整数,就只需要按照上面的方法计算得到总和之后,再除以105得到的余数就是最小的非负整数。

x \\% 105 = 23x = 105\\times2+23 =210+23 所以答案就是 x-210=233-210=23 ,23满足除以3余2、除以5余3、除以7余2。

刚好对应了原文“并之,得二百三十三,以二百一十減之,即得。”

然后假设如果存在整数 X,Y 都满足“除以3余a、除以5余b、除以7余c”。则根据取模运算法则有:

\\ \\\\ (X-Y) \\ \\% \\ 3=(X\\ \\% \\ 3-Y\\ \\% \\ 3) \\ \\% \\ 3=(a-a) \\ \\% \\ 3=0 \\\\ (X-Y) \\ \\% \\ 5=(X\\ \\% \\ 5-Y\\ \\% \\ 5) \\ \\% \\ 5=(b-b) \\ \\% \\ 5=0 \\\\ (X-Y) \\ \\% \\ 7=(X\\ \\% \\ 7-Y\\ \\% \\ 7) \\ \\% \\ 7=(c-c) \\ \\% \\ 7=0

所以 X-Y 满足“除以3余0、除以5余0、除以7余0”,因此 X-Y 一定是 3,5,7的公倍数,也就是3\\times5\\times7=105 的倍数。这也就说明,在“模105同余”的意义下,之前通过分解问题、组合解答的方法所得到的 x 恰恰就是唯一解

把这个问题推广到一般情况,假设整数 m_1,m_2,m_3, ... ,m_n 两两互素,则对于任意的整数 a_1,a_2,a_3, ... ,a_n,同余方程组

x\\equiv a_{1} \\ (mod \\ m_{1}) \\\\ x\\equiv a_{2} \\ (mod \\ m_{2}) \\\\ x\\equiv a_{3} \\ (mod \\ m_{3}) \\\\ ... \\\\ x\\equiv a_{n} \\ (mod \\ m_{n}) \\\\

都存在整数解,而且若 X,Y 都满足该方程组,则必有 X\\equiv Y(mod \\ N) ,其中 N=\\prod_{i=1}^{n}m_i 。而可以下面公式求解:

\\ \\\\ x\\equiv\\sum_{i=1}^{n}{a_i\\times\\frac{N}{m_i}\\times\\left[ \\left( \\frac{N}{m_i} \\right)^{-1} \\right]_{m_i}}\\pmod N

这就是中国剩余定理,如果弄懂了上面孙子算经的题目,应该就很容易理解这个求解公式了。

简单描述下一般情况求解 x 过程:

首先分别找到 m_1\\ …\\ m_{i-1},\\ m_{i+1} \\ …\\ m_n 的公倍数 N_i ,满足除以 m_i 余 1,然后 a_i\\times N_i 即可,而求解 N_i 就相当于先求 \\frac{N}{m_i}m_i 的逆元,然后再乘以 \\frac{N}{m_i},用扩展欧几里得算法求解即可,最终把所有 a_i\\times N_i 加起来再模 N 就得到结果了。

代码:

#include <iostream>

int exgcd(int a, int b, int &x, int &y) {
	if (b == 0) {
		x = 1;
		y = 0;
		return a;
	}
       int gcd = exgcd(b, a % b, x, y);
       int t = x;
       x = y;
       y = t - a / b * x;
	return gcd;
}

int get_crt(int *a, int *m, int len) {
	int r, y;

	int N = 1;
	for (int i = 0; i < len; ++i) {
		N *= m[i];
	}
	int X = 0;

	for(int i=0; i<len; ++i) {
		int Mi = N / m[i];
		int gcd = exgcd(Mi, m[i], r, y);
		X += a[i] * Mi * r;
	}
	
	return X % N;
}

int main(int argc, char *argv[]) {
	int m[3] = {3, 5, 7};
	int a[3] = {2, 3, 2};

	int X = get_crt(a, m, 3);

	printf("crt = %d \\n", X);

	return 0;
}

运行结果:

2.8、多项式的中国剩余定理

类似的中国剩余定理同样可以应用到多项式上,下面参考[28]给出多项式版本的中国剩余定理的定义:

假设存在理数系数的多项式 m_1(x),m_2(x),m_3(x), ... ,m_n(x) 它们之间两两互素,则对于任意的有理数系数的多项式a_1(x),a_2(x),a_3(x), ... ,a_n(x) ,同余方程组

f(x)\\equiv a_{1}(x) \\ (mod \\ m_{1}(x)) \\\\ f(x)\\equiv a_{2}(x) \\ (mod \\ m_{2}(x)) \\\\ f(x)\\equiv a_{3}(x) \\ (mod \\ m_{3}(x)) \\\\ ... \\\\ f(x)\\equiv a_{n}(x) \\ (mod \\ m_{n}(x)) \\\\

都存在有理数系数的多项式解,且若 A(x),B(x) 都满足该同余方程组,则必有 A(x)\\equiv B(x)(mod \\ N(x)) ,其中N(x)=\\prod_{i=1}^{n}{m_i(x)} 。而求解方法与整数中国剩余定理类似:

f(x)\\equiv\\sum_{i=1}^{n}{a_i(x)\\times\\frac{N(x)}{m_i(x)}\\times\\left[ \\left( \\frac{N(x)}{m_i(x)} \\right)^{-1} \\right]_{m_i(x)}}(mod \\ N(x))

简单描述下一般情况求解 f(x) 过程:

首先分别找到 m_1(x)\\ …\\ m_{i-1}(x),\\ m_{i+1}(x) \\ …\\ m_n(x) 的公倍式 N_i(x) ,满足除以 m_i(x) 余 1,然后 a_i(x)\\times N_i(x) 即可,而求解 N_i(x) 就相当于先求 \\frac{N(x)}{m_i(x)}m_i(x) 的逆元,然后再乘以 \\frac{N(x)}{m_i(x)} ,应用扩展欧几里得算法求解即可,最终把所有 a_i(x)\\times N_i(x) 加起来再模 N(x) 就得到结果了。

下面举个实际例子,假设有如下同余方程组:

\\ \\\\ f(x)\\equiv x+1(mod \\ x^2+x+1) \\\\ f(x)\\equiv \\ 4 \\ (mod \\ x^2+1) \\ \\ \\ \\ \\ \\\\

先求 N(x)=(x^2+x+1)(x^2+1)=x^4+x^3+2x^2+x+1 ,然后利用扩展欧几里得算法求得逆元 \\left[ \\left( x^2+1 \\right)^{-1} \\right]_{x^2+x+1}=x+1\\left[ \\left( x^2+x+1 \\right)^{-1} \\right]_{x^2+1}=-x

然后套公式

f(x)=[(x+1)(x^2+1)(x+1)+4(x^2+x+1)(-x)](mod \\ N(x))\\\\ =(x^4-2x^3-2x^2-2x+1)(mod \\ N(x))\\\\ =-3x^3-4x^2-3x

验证结果

f(x)=-3x^3-4x^2-3x=(x^2+x+1)(-3x-1)+(x+1)\\equiv x+1(mod \\ x^2+x+1)\\\\ f(x)=-3x^3-4x^2-3x=(x^2+1)(-3x-4)+4\\equiv4(mod \\ x^2+1)

3、多项式的中国剩余定理的应用

3.1、卷积操作与中国剩余定理的联系

终于到了本文最重点的部分了,在开始看本节之前确保已经理解了前面提到的数学知识。通过前面的介绍我们已经知道了卷积操作等价于多项式乘法,下面简要描述下卷积是怎么和中国剩余定理的产生联系的,这也是我理解的Winograd这个算法的核心。

需要注意的是下面的一些结论是我根据实际例子比如F(2,3)和F(4,3)推导得到的结论不一定正确。

假设现在有两个离散序列 [h_0,h_1,...,h_{n-1}][p_0,p_1,...,p_{k-1}] 做卷积操作,首先把这两个序列转化为两个多项式 h(x)=h_0+...+h_{n-1}x^{n-1}p(x)=p_0+...+p_{k-1}x^{k-1}

h(x)p(x) 最高分别为 n-1k-1 次,然后卷积操作就可以变为多项式乘法 s(x)=h(x)p(x) ,最高次为 n+k-2

我们先有个概念就是Winograd是一个构造式的算法,是人为去构造一个计算 s(x) 的等价变换,下面介绍如何构造。

首先构造 n+k-2 个互素的多项式 m_0(x),m_1(x),...m_{n+k-3}(x) ,假设是 x,x-1,x+1,x-2,x+2... ,然后有

N(x)=m_0(x)m_1(x)....m_{n+k-3}(x)

s(x) 可以表示为 s(x)=h_{n-1}p_{k-1}N(x)+s'(x) ,因为 N(x)s(x) 同次,所以余式 s'(x) 次数小于 N(x) ,而为啥商是 h_{n-1}p_{k-1} 是因为 N(x)=x^{n+k-2}+.... ,而 s(x)=h(x)p(x)=h_{n-1}p_{k-1}x^{n+k-2}+... ,所以应用多项式除法,商就是 h_{n-1}p_{k-1} ,余式就是 s'(x)=s(x)-h_{n-1}p_{k-1}N(x) 。所以问题就转化为求余式的问题。

然后现在已知 h(x)p(x) ,所以可以求得 h(x)p(x) 除以这些互素多项式的余式

h(x)\\equiv h_0(x)(mod \\ m_0(x)) \\\\ h(x)\\equiv h_1(x)(mod \\ m_1(x)) \\\\ ... \\\\ h(x)\\equiv h_{n+k-3}(x)(mod \\ m_{n+k-3}(x)) \\\\ \\ \\\\ p(x)\\equiv p_0(x)(mod \\ m_0(x)) \\\\ p(x)\\equiv p_1(x)(mod \\ m_1(x)) \\\\ ... \\\\ p(x)\\equiv p_{n+k-3}(x)(mod \\ m_{n+k-3}(x))

接着根据取模运算法则有

s(x)\\equiv h(x)p(x)(mod \\ m_0(x))=(h(x)mod \\ m_0(x)*p(x)mod \\ m_0(x))(mod \\ m_0(x)))=h_0(x)p_0(x)(mod \\ m_0(x)) \\\\ s(x)\\equiv h(x)p(x)(mod \\ m_1(x))=(h(x)mod \\ m_1(x)*p(x)mod \\ m_1(x))(mod \\ m_1(x)))=h_1(x)p_1(x)(mod \\ m_1(x)) \\\\ ... \\\\ s(x)\\equiv h(x)p(x)(mod \\ m_{n+k-3}(x))=(h(x)mod \\ m_{n+k-3}(x)*p(x)mod \\ m_{n+k-3}(x))(mod \\ m_{n+k-3}(x)))=h_{n+k-3}(x)p_{n+k-3}(x)(mod \\ m_{n+k-3}(x))

然后因为 N(x) 可以整除 m_{i}(x) 然后有

s(x)(mod \\ m_0(x))= (h_{n-1}p_{k-1}N(x)+s'(x))(mod \\ m_0(x))=s'(x)(mod \\ m_0(x))=h_0(x)p_0(x)(mod \\ m_0(x)) \\\\ s(x)(mod \\ m_1(x))= (h_{n-1}p_{k-1}N(x)+s'(x))(mod \\ m_1(x))=s'(x)(mod \\ m_1(x))=h_1(x)p_1(x)(mod \\ m_1(x)) \\\\ ... \\\\ s(x)(mod \\ m_{n+k-3}(x))= (h_{n-1}p_{k-1}N(x)+s'(x))(mod \\ m_{n+k-3}(x))=s'(x)(mod \\ m_{n+k-3}(x))=h_{n+k-3}(x)p_{n+k-3}(x)(mod \\ m_{n+k-3}(x))

然后求余式 s'(x) 就变成求解同余方程组的问题,

s'(x)\\equiv h_0(x)p_0(x)(mod \\ m_0(x)) \\\\ s'(x)\\equiv h_1(x)p_1(x)(mod \\ m_1(x)) \\\\ ... \\\\ s'(x)\\equiv h_{n+k-3}(x)p_{n+k-3}(x)(mod \\ m_{n+k-3}(x))

就可以套用中国剩余定理去求解 s'(x)

s'(x)\\equiv\\sum_{i=0}^{n+k-3}{h_i(x)p_i(x)\\times\\frac{N(x)}{m_i(x)}\\times\\left[ \\left( \\frac{N(x)}{m_i(x)} \\right)^{-1} \\right]_{m_i(x)}}(mod \\ N(x))

所以原来的多项式乘法就化为:

s(x)=h_{n-1}p_{k-1}N(x)+s'(x) \\\\ =h_{n-1}p_{k-1}N(x) + \\sum_{i=0}^{n+k-3}{h_i(x)p_i(x)\\times\\frac{N(x)}{m_i(x)}\\times\\left[ \\left( \\frac{N(x)}{m_i(x)} \\right)^{-1} \\right]_{m_i(x)}}(mod \\ N(x))

再套用到具体情况比如2x3, 4x3卷积的时候,如果变换之后等式右边的所需的乘法次数小于 h(x)p(x) 的乘法次数就能达到加速的目的。

3.2、Winograd F(2,3)变换矩阵推导

现在来看下具体到F(2,3)的变换矩阵是如何得到的。

首先假设有两个长度是2和3的离散序列 [h_0,h_1][p_0,p_1,p_2] 和它们等价的多项式表示 h(x)=h_0+h_1xp(x)=p_0+p_1x+p_2x^2

然后相乘的结果:

s(x)=h(x)p(x)=(h_0+h_1x)(p_0+p_1x+p_2x^2)=h_0p_0+(h_0p_1+h_1p_0)x+(h_0p_2+h_1p_1)x^2+h_1p_2x^3

所以从系数的计算上可以看出总共有6次乘法和2次加法。

然后构造2+3-2=3个互素多项式, m_0(x)=x,m_1(x)=x-1,m_2(x)=x+1 ,它们的乘积 N(x)=x(x-1)(x+1)=x^3-x 。所以有

s(x)=h_1p_2N(x)+s'(x)

然后求 h(x),p(x) 除以这3个互素多项式的余数:

h_0(x)=h(x)\\ mod \\ m_0(x)=(h_0+h_1x)\\ mod \\ x=h_0 \\\\ h_1(x)=h(x)\\ mod \\ m_1(x)=(h_0+h_1x)\\ mod \\ (x-1)=h_0+h_1 \\\\ h_2(x)=h(x)\\ mod \\ m_2(x)=(h_0+h_1x)\\ mod \\ (x+1)=h_0-h_1 \\\\ \\ \\\\ p_0(x)=p(x)\\ mod \\ m_0(x)=(p_0+p_1x+p_2x^2)\\ mod \\ x=p_0 \\\\ p_1(x)=p(x)\\ mod \\ m_1(x)=(p_0+p_1x+p_2x^2)\\ mod \\ (x-1)=p_0+p_1+p_2 \\\\ p_2(x)=p(x)\\ mod \\ m_2(x)=(p_0+p_1x+p_2x^2)\\ mod \\ (x+1)=p_0-p_1+p_2

然后就可以得到关于 s'(x) 的同余方程组:

s'(x)\\equiv h_0(x)p_0(x)(mod \\ m_0(x))=h_0p_0=s'_0(x) \\\\ s'(x)\\equiv h_1(x)p_1(x)(mod \\ m_1(x)) =(h_0+h_1)(p_0+p_1+p_2)=s'_1(x)\\\\ s'(x)\\equiv h_2(x)p_2(x)(mod \\ m_2(x))=(h_0-h_1)(p_0-p_1+p_2 )=s'_2(x)

然后套用中国剩余定理,首先求逆元 \\left[ \\left( \\frac{N(x)}{m_i(x)} \\right)^{-1} \\right]_{m_i(x)} ,用扩展欧几里得算法求解

\\left[ \\left( \\frac{N(x)}{m_0(x)} \\right)^{-1} \\right]_{m_0(x)}=\\left[( x^2-1)^{-1}\\right]_{x} \\\\ \\left[ \\left( \\frac{N(x)}{m_1(x)} \\right)^{-1} \\right]_{m_1(x)}=\\left[( x^2+x)^{-1}\\right]_{x-1} \\\\ \\left[ \\left( \\frac{N(x)}{m_2(x)} \\right)^{-1} \\right]_{m_2(x)}=\\left[( x^2-x)^{-1}\\right]_{x+1}

\\left[ \\left( \\frac{N(x)}{m_0(x)} \\right)^{-1} \\right]_{m_0(x)}求解过程:

\\ \\\\ 相当于求解方程 (x^2-1)u(x)+xv(x)=1 的解 \\\\ 第一步,x^2-1=x*x-1,商是x,余式是-1 \\\\ 第二步,x=(-1)*(-x)+0,商是-x,余式是0,停止递归,设置u(x)=1,v(x)=0 \\\\ 第三步,t(x)=u(x)=1,u(x)=v(x)=0,v(x)=t(x)-(-x)u(x)=1 \\\\ 第四步,t(x)=u(x)=0,u(x)=v(x)=1,v(x)=t(x)-xu(x)=-x得解 \\\\ (x^2-1)*1+x*(-x)=-1,两边同乘-1得(x^2-1)*(-1)+x*x=1,求得逆元是-1

所以\\left[ \\left( \\frac{N(x)}{m_0(x)} \\right)^{-1} \\right]_{m_0(x)}=-1

\\left[ \\left( \\frac{N(x)}{m_1(x)} \\right)^{-1} \\right]_{m_1(x)}求解过程:

相当于求解方程 (x^2+x)u(x)+(x-1)v(x)=1 的解 \\\\ 第一步,x^2+x=(x-1)*(x+2)+2,商是x+2,余式是2 \\\\ 第二步,x-1=(2)*(\\frac{1}{2}x-\\frac{1}{2})+0,商是\\frac{1}{2}x-\\frac{1}{2},余式是0,停止递归,设置u(x)=1,v(x)=0 \\\\ 第三步,t(x)=u(x)=1,u(x)=v(x)=0,v(x)=t(x)-(\\frac{1}{2}x-\\frac{1}{2})u(x)=1 \\\\ 第四步,t(x)=u(x)=0,u(x)=v(x)=1,v(x)=t(x)-(x+2)u(x)=-x-2得解 \\\\ (x^2+x)*1+(x-1)*(-x-2)=2,两边同除以2得 (x^2+x)*(\\frac{1}{2})+(x-1)*(-\\frac{1}{2}x-1)=1,求得逆元是\\frac{1}{2}

所以 \\left[ \\left( \\frac{N(x)}{m_1(x)} \\right)^{-1} \\right]_{m_1(x)}=\\frac{1}{2}

\\left[ \\left( \\frac{N(x)}{m_2(x)} \\right)^{-1} \\right]_{m_2(x)}求解过程:

相当于求解方程 (x^2-x)u(x)+(x+1)v(x)=1 的解 \\\\ 第一步,x^2-x=(x+1)*(x-2)+2,商是x-2,余式是2 \\\\ 第二步,x+1=(2)*(\\frac{1}{2}x+\\frac{1}{2})+0,商是\\frac{1}{2}x+\\frac{1}{2},余式是0,停止递归,设置u(x)=1,v(x)=0 \\\\ 第三步,t(x)=u(x)=1,u(x)=v(x)=0,v(x)=t(x)-(\\frac{1}{2}x+\\frac{1}{2})u(x)=1 \\\\ 第四步,t(x)=u(x)=0,u(x)=v(x)=1,v(x)=t(x)-(x-2)u(x)=-x+2得解 \\\\ (x^2-x)*1+(x+1)*(-x+2)=2,两边同除以2得(x^2-x)*\\frac{1}{2}+(x+1)*(-\\frac{1}{2}x+1)=1,求得逆元是\\frac{1}{2}

所以 \\left[ \\left( \\frac{N(x)}{m_2(x)} \\right)^{-1} \\right]_{m_2(x)}=\\frac{1}{2}

然后套用中国剩余定理求解 s'(x)

s'(x)=\\sum_{i=0}^{2}h_i(x)p_i(x)\\times\\frac{N(x)}{m_i(x)}\\times\\left[ \\left( \\frac{N(x)}{m_i(x)} \\right)^{-1} \\right]_{m_i(x)}\\\\ =s'_0(x)(x^2-1)(-1)+s'_1(x)(x^2+x)\\frac{1} {2}+s'_2(x)(x^2-x)\\frac{1}{2} \\\\ =h_0p_0(x^2-1)(-1)+(h_0+h_1)(p_0+p_1+p_2)(x^2+x)\\frac{1} {2}+(h_0-h_1)(p_0-p_1+p_2 )(x^2-x)\\frac{1}{2}

所以

s(x)=h_1p_2N(x)+s'(x) \\\\ =h_1p_2(x^3-x)+s'_0(x)(x^2-1)(-1)+s'_1(x)(x^2+x)\\frac{1} {2}+s'_2(x)(x^2-x)\\frac{1}{2} \\\\ =s'_0(x)+(\\frac{s'_1(x)}{2}-\\frac{s'_2(x)}{2}-h_1p_2)x+(\\frac{s'_1(x)}{2}+\\frac{s'_2(x)}{2}-s'_0(x))x^2+h_1p_2x^3

所以用上式替代了原始下式的计算:

s(x)=h(x)p(x)=h_0p_0+(h_0p_1+h_1p_0)x+(h_0p_2+h_1p_1)x^2+h_1p_2x^3

然后我们把新的式子需要做乘法的地方写在下面:

\\ \\\\ h_0p_0=s'_0(x) \\\\ (h_0+h_1)(p_0+p_1+p_2)=s'_1(x)\\\\ (h_0-h_1)(p_0-p_1+p_2 )=s'_2(x) \\\\ h_1p_2

所以就是4次乘法和9次加法,除2的操作的开销可以在实际应用的时候把除2操作放到权值变换那里,就可以把运行时的开销去掉了。可以看到比原来的6次乘法和2次加法,少两次乘法,但是加法次数变多了。

然后看下如何提取出变换矩阵,首先对公式作一些改动,把除2操作移动到 s'_1(x),s'_2(x) 的计算里面:

s(x)=h_1p_2N(x)+s'(x) \\\\ =s'_0(x)+(s'_1(x)-s'_2(x)-h_1p_2)x+(-s'_0(x)+s'_1(x)+s'_2(x))x^2+h_1p_2x^3
h_0p_0=s'_0(x) \\\\ (h_0+h_1)(\\frac{p_0}{2}+\\frac{p_1}{2}+\\frac{p_2}{2})=s'_1(x)\\\\ (h_0-h_1)(\\frac{p_0}{2}-\\frac{p_1}{2}+\\frac{p_2}{2} )=s'_2(x) \\\\ h_1p_2

通过观察上面的式子就能抽取出序列 h=[h_0,h_1] , p=[p_0,p_1,p_2] 各自的变换矩阵 A,G ,还有最后输出变换矩阵 B

\\begin{equation*} A = \\begin{bmatrix} 1 & 0 \\\\ 1 & 1 \\\\ 1 & -1 \\\\ 0 & 1 \\end{bmatrix} \\end{equation*}
\\begin{equation*} G = \\begin{bmatrix} 1 & 0 & 0\\\\ \\frac{1}{2} & \\frac{1}{2} & \\frac{1}{2} \\\\ \\frac{1}{2} & -\\frac{1}{2} & \\frac{1}{2} \\\\ 0 & 0 & 1 \\end{bmatrix} \\end{equation*}
\\begin{equation*} B = \\begin{bmatrix} 1 & 0 & 0 & 0\\\\ 0 & 1 & -1 & -1 \\\\ -1 & 1 & 1 & 0 \\\\ 0 & 0 & 0 & 1 \\end{bmatrix} \\end{equation*}

然后就可以用以下公式计算 s(x) 的系数了:

s=B*[G *p \\ \\odot \\ A*h]

*” 表示矩阵向量乘法而 “\\odot” 表示向量点乘,下面验证下:

A*h=[h_0,h_0+h_1,h_0-h_1,h_1] \\\\ G*p=[p_0,\\frac{p_0+p_1+p_2}{2},\\frac{p_0-p_1+p_2}{2},p_2] \\\\ G *p \\ \\odot \\ A*h=[h_0p_0,(h_0+h_1)(\\frac{p_0+p_1+p_2}{2}),(h_0-h_1)(\\frac{p_0-p_1+p_2}{2}),h_1p_2]

所以 B*[G *p \\ \\odot \\ A*h] 结果为:

[h_0p_0,(h_0+h_1)(\\frac{p_0+p_1+p_2}{2})-(h_0-h_1)(\\frac{p_0-p_1+p_2}{2})-h_1p_2,-h_0p_0+(h_0+h_1)(\\frac{p_0+p_1+p_2}{2})+(h_0-h_1)(\\frac{p_0-p_1+p_2}{2}),h_1p_2]
=[h_0p_0, h_0p_1+h_1p_0,h_0p_2+h_1p_1,h_1p_2]

和直接卷积结果一致。

然后来看下这个变换是如何应用到深度学习中的卷积(Correlation)里面的,对于F(2,3)的应用,是用在1x3或者3x1卷积里面,长度是3的卷积核连续卷积两次得到两个输出,输入序列长度是4,刚好是把Winograd的变换矩阵反着来用的,为了和上面的公式对应,这里用 s,p,h 分别表示,输入,权值和输出:

\\ \\\\ h=A^{T}*[G *p \\ \\odot \\ B^{T}*s]

把矩阵 AB 做转置:

\\begin{equation*} A^T = \\begin{bmatrix} 1 & 1 & 1 & 0 \\\\ 0 & 1 & -1 & 1 \\end{bmatrix} \\end{equation*}
\\begin{equation*} B^T = \\begin{bmatrix} 1 & 0 & -1 & 0\\\\ 0 & 1 & 1 & 0 \\\\ 0 & -1 & 1 & 0 \\\\ 0 & -1 & 0 & 1 \\end{bmatrix} \\end{equation*}

验证下,假设输入 s=[s_0,s_1,s_2,s_3] ,权值 p=[p_0,p_1,p_2] 和输出 h=[h_0,h_1] ,直接做correlation的结果是:

h=[s_0p_0+s_1p_1+s_2p_2,s_1p_0+s_2p_1+s_3p_2]

然后验证下

B^T*s=[s_0-s_2,s_1+s_2,s_2-s_1,s_3-s_1] \\\\ G*p=[p_0,\\frac{p_0+p_1+p_2}{2},\\frac{p_0-p_1+p_2}{2},p_2] \\\\ G *p \\ \\odot \\ B^T*s=[(s_0-s_2)p_0,(s_1+s_2)(\\frac{p_0+p_1+p_2}{2}),(s_2-s_1)(\\frac{p_0-p_1+p_2}{2}),(s_3-s_1)p_2]

所以 h=A^{T}*[G *p \\ \\odot \\ B^{T}*s] 结果 为:

[(s_0-s_2)p_0+(s_1+s_2)(\\frac{p_0+p_1+p_2}{2})+(s_2-s_1)(\\frac{p_0-p_1+p_2}{2}),(s_1+s_2)(\\frac{p_0+p_1+p_2}{2})-(s_2-s_1)(\\frac{p_0-p_1+p_2}{2})+(s_3-s_1)p_2]
=[s_0p_0+s_1p_1+s_2p_2,s_1p_0+s_2p_1+s_3p_2]

结果与直接做Correlation一致。

其实这里有一点没想明白的地方是,卷积操作中的Winograd变换公式是如何变成用在Correlation中变换公式的,直接推导的话推不出来,感觉中间还缺了一环,但是确实结论是正确的,实际推导结果也正确。

终于写完了,真的是第一次写那么长的博客,而且公式也比较多,如果有哪里写的不对或者公式错误的地方,请读者见谅。

4、参考资料

[1] 油管--The Winograd Transformation

[2] 深度加速(一)——概述, Winograd(1)

[3] 深度加速(二)——Winograd(2)

[4] Arxiv--Fast Algorithms for Convolutional Neural Networks

[5] https://github.com/andravin/wincnn/blob/master/2464-supp.pdf

[6] 知乎--源于《孙子算经》的Cudnn

[7] 向量卷积与多项式乘法

[8] Convolution Vs Correlation

[9]The difference between convolution and cross-correlation from a signal-analysis point of view

[10] 卷积神经网络中的Winograd快速卷积算法

[11] 輾轉相除法

[12] 维基百科--质数

[13] 欧几里德算法与扩展欧几里德算法

[14] 我終於頓悟輾轉相除法求最大公約數的原理了

[15] 取模运算涉及的算法

[16] 多项式除法竖式应当如何理解?

[17] 多项式长除法

[18] 用辗转相除法求多项式的最大公因式

[19] 维基百科--裴蜀定理

[20] 数论小结

[21] 扩展欧几里得算法 有限域上多项式求逆

[22] 维基百科--模逆元

[23] 「扩展欧几里得算法」与「模逆元」详解

[24] 知乎--中国剩余定理(CRT )

[25] 扩展欧几里得算法与中国剩余定理

[26] 中国剩余定理算法详解(余数互质和不互质)

[27] 百科故事--《孙子算经》卷下

[28] 知乎--多项式也有CRT么?

[29] Winograd数学原理【卡住了>_<】

本站文章资源均来源自网络,除非特别声明,否则均不代表站方观点,并仅供查阅,不作为任何参考依据!
如有侵权请及时跟我们联系,本站将及时删除!
如遇版权问题,请查看 本站版权声明
THE END
分享
二维码
海报
详解Winograd变换矩阵生成原理
文本首发知乎:https://zhuanlan.zhihu.com/p/87516875
<<上一篇
下一篇>>