性别识别专题二——Fisher分类器
好吧,我又来填坑了。。。为啥呢?因为这周我的任务好像差不多了,硬件客观条件所限,要干活最快也要到周六,想象自己给自己挖了这么多坑,补天于心不安啊!!
说到坑这个东西,大家可能有所不知,在这个BLOG里,外表看上去我挖的坑不多(但也不是没有),但是殊不知我的draft里面躺了一大堆草稿,觉得这个可以写来玩玩,那个又想写来分享什么的,而且有些想法发现像个无底洞,所以就被永久雪藏了,比如之前搞的那个统计信号处理专题,出了1之后就在也没动过了。。。
填坑,是个良心活!!
闲话就说这么多吧。。。
线性判别式分析(Linear Discriminant Analysis, LDA),也叫做Fisher线性判别,也就是数模里面常见的Fisher分类器,之前讲到的那个PCA,经常和LDA一起出现,解决了无数的问题,看成模式识别界最平民化的模范COUPLE,平民化一是因为用的多,二是简单,粗暴,易懂。。
这个分类器的思想是什么呢?其实很简单,其实我是想自己做个图的,但是懒,所以就上网找了个,下图如若侵权,那我先在此道歉了。。。我错了(低头)。。。
比如说上图,我们有两个类,他们是二维的点,Fisher分类器是想怎么做呢?它想找一条直线(一般而言是一个比原数据维数更低维的空间,如面到点),然后把源数据投影到这个线上,使得数据在这条低维空间上具有很好的“可区分性”,像上图中画得最佳投影方向,如果投影的是上图中的最不利的投影方向的话,那么数据在线上就基本无法区分了,对吧。
Fisher分类器就是根据你输入的数据,算出这一条“线”来,识别的时候就是往线上投影,根据阈值来判断属于哪一类。
那么Fisher分类器是怎么算出那条线来的呢?下面是简要说明,详细复杂的计算证明大家自己找文献去。。。
为了方便说明,我们用二维例子来说,也就是上面的二维的点,我们可以用一个1×2的向量w来表示一条线,这大家没异议吧。
Fisher的思想就是它定义了一个类内散度矩阵和类间散度矩阵这两个东西,顾名思义,就是投影之后每个类是不是很“聚拢”,类与类之间是不是尽量的远,首先我们可以算出每个类的均值,假设有两个类,均值分别是上图中的μ1,μ2,然后还有一个所有样本的均值μ,然后把这三个均值也投影到w这条线上,假设投影后分别是μ1‘,μ2‘,μ‘;
然后呢,类间散度SB就是:
N1(μ1‘-μ‘)2+N2(μ2‘-μ‘)2
其中两个N就是每一类的数目的个数,然后下面给出一个结论,不给证明,就是类间散度还可以等价表示成下面的形式:
SB = wTSBw
其中SB是:
(μ1‘-μ2‘)(μ1‘-μ2‘)T
上面证明略。。。因为博客敲公式麻烦。。不是Latex。。。(摊手~)
接下来就是类内散度了,下面语言描述,注意断句,每个类的类内散度的值就等于这个类里面(每个数据投影后和类均值投影后的差的平方)的和,懂了?
然后我们要求的总的类散度就是每个类的类散度的和。
然后下面还是不加证明的给出类内散度SW的化简形式:
SW = wTSWw
其中,SW是:
∑∑(xpj‘-μi‘)(xpj‘-μi‘)T
两个求和第一重是对所有类求和,第二重是对每个类内部每个元素求和,xpj‘表示的是第p类的第j个点在w上的投影;
有了这两个定义,下面是目标,就是要让类间散度除以类内散度,这个值J达到最大!!
下面干一件比较没节操的事情,就是,这个最大值怎么求呢?我直接告诉你,至于怎么出来的,查看一下文献,很容易就明白了,为什么这么干呢?还是还是那个原因,结果形式简单明了,中间过程敲公式太复杂,如果哪一天这里支持Latex语法的话,那我就谢天谢地了。。。
让这个J达到最大的w的形式是:
(Sw)-1(μ1-μ2)
嗯,你没看错,就是那么简单,正是因为结果那么简单,所以我上面才不想说那么多废话啊。。。
投影之后的判断的阈值一般可以用(μ1+μ2)Tw/2或者(N1μ1+N2μ2)Tw/2来算。
嘛,算法描述就到这里,下面照旧,简单的实验;
也是为了简单起见,我这个实验,简单,粗暴。。。
我直接画了一条直线,然后做出两个类,如下图所示:
这些点是二维的点,投影后变成一维的一个数,这些数分别如下图所分布:
红线是阈值对应的线,可以看出,目的基本达成。
好吧,下面是代码。。。不懂得孩子再自己研究研究吧。。。
code:
data = createdata(); opp = data(data(:,3)==1,1:2); neg = data(data(:,3)==-1,1:2); plot(opp(:,1),opp(:,2),'r.'); hold on; plot(neg(:,1),neg(:,2),'b.'); mean_opp = mean(opp); mean_neg = mean(neg); diff_opp = opp-repmat(mean_opp,size(opp,1),1); S1 = diff_opp'*diff_opp; diff_neg = neg-repmat(mean_neg,size(neg,1),1); S2 = diff_neg'*diff_neg; Sw = S1+S2; w = (mean_opp-mean_neg)*Sw^-1; thred = (mean_opp+mean_neg)/2*w'; %投影后的结果 opp_new = opp*w'; neg_new = neg*w'; figure; plot(opp_new,'g.'); hold on; plot(neg_new,'b.'); plot([1 5000],[thred thred],'r'); function data = createdata() data = []; N = 10000; x = 10*rand(1,N); y = 10*rand(1,N); for i = 1 : N if(2.5+0.5*x(i)-y(i)>0) data = [data;x(i),y(i),1]; else data = [data;x(i),y(i),-1]; end end
【完】
本文内容遵从CC版权协议,转载请注明出自http://www.kylen314.com