Task
- 弄懂概念
- 初探原理推导
- 初探ctc预测
- 初探代码实现
Reference
- CTC——下雨天和RNN更配哦
- 详解CTC
- CTC Algorithm Explained Part 1:Training the Network(CTC算法详解之训练篇)
- CTC Algorithm Explained Part 2:Decoding the Network(CTC算法详解之解码篇)
注:个人认为若只需理解,看第一篇即可;理解全面,看第二篇;再深入则看三和四。
概念
在语音识别和OCR场景中,给定输入序列 以及对应的标签数据 ,例如语音识别中的音频文件和文本文件。我们需要找到 到 的一个映射。
CTC提供了解决方案,对于一个给定的输入序列 ,CTC给出所有可能的 的输出分布。根据这个分布,我们可以输出最可能的结果或者给出某个输出的概率。
CTC的对齐策略:引入了空白字符 ,对齐涉及去除重复字母和去除 两部分
CTC的时间片的输出和输出序列的映射如图:
前文提及输出最可能的结果或者给出某个输出的概率
,
这里再提:输出序列和最终的label之间存在多对一的映射关系
损失函数概要搬运:
给定输入序列 ,我们希望最大化 的后验概率 , 应该是可导的,这样我们能执行梯度下降算法;
测试:给定一个训练好的模型和输入序列 ,我们希望输出概率最高的 :
当然,在测试时,我们希望 能够尽快的被搜索到。
关键词: 极大似然估计, 信息熵
对应标签 ,其关于输入 的后验概率可以表示为所有映射为 的路径之和,我们的目标就是最大化 关于 的后验概率 。假设每个时间片的输出是相互独立的,则路径的后验概率是每个时间片概率的累积,公式及其详细含义如图
路径之和示意:
本人概率论菜,详细损失原理推论请见Fefrence
的三
其中问题定义,CTC Loss定义和CTC Loss计算是理解入门很值得一看的~
放一个最终的模型结构图:
解决计算Loss时任务量巨大问题:
用了动态规划的思想来对查找路径进行剪枝,确保路径的唯一
示意图如下:
路径制定具有规则,示意图:
如何求解这些路径的概率总和?采用动态规划求解
注:分3情况yo~
得到Loss后自然是对其求导,这里搬运大神的求导核心部分,详细请移步大神博客:
完整的训练过程
概述
当我们训练好一个RNN模型时,给定一个输入序列 ,我们需要找到最可能的输出,也就是求解
求解最可能的输出有两种方案:
Greedy Search
方法:每个时间片均取该时间片概率最高的节点作为输出
缺点:忽略了一个输出可能对应多个对齐方式
beam search
搬运:Beam Search是寻找全局最优值和Greedy Search在查找时间和模型精度的一个折中。一个简单的beam search在每个时间片计算所有可能假设的概率,并从中选出最高的几个作为一组。然后再从这组假设的基础上产生概率最高的几个作为一组假设,依次进行,直到达到最后一个时间片,下图是beam search的宽度为3的搜索过程,红线为选中的假设。
再深入请见大神博客
TODO
1
2
3
41.求导部分再看懂,有空推一遍
2.BPTT反向传播理论学习
3.预测部分学习
4.代码学习补充:代码学习
1
2
3
4
5
6
7
8
9
10
11x = Dense(n_class, init='he_normal', activation='softmax')(x)
base_model = Model(input=input_tensor, output=x)
labels = Input(name='the_labels', shape=[n_len], dtype='float32')
input_length = Input(name='input_length', shape=[1], dtype='int64')
label_length = Input(name='label_length', shape=[1], dtype='int64')
loss_out = Lambda(ctc_lambda_func, output_shape=(1,),
name='ctc')([x, labels, input_length, label_length])
model = Model(input=[input_tensor, labels, input_length, label_length], output=[loss_out])
model.compile(loss={'ctc': lambda y_true, y_pred: y_pred}, optimizer='adadelta')也就是,写法里x是[batch_size * n_class], labels是[batch_size, 1]。
训练时写法见此,预测时候K.ctc_decode
1 | characters2 = characters + ' ' |
那为什么Desnet+CTC可以成功吗?难道模型对一张图片能输出不同的预测结果?
原因在于CTC,CTC带编辑距离算法,效果类似RNN~