Shihanmax's blog

< Back

序列标注实践

序列标注问题是一类典型的NLP问题,具体应用有:分词、词性标注、命名实体识别等。

序列标注问题的定义是:给定序列$S$,期望通过模型M得到序列中每一个token对应的标签序列$Z$,这里定义,$S$中每个词token的可能情况有$n_{word}$中,标签序列$Z$中的每一个标签tag的可能情况有$n_{tag}$种。

本文从噪声信道模型的角度对序列标注问题进行建模,并通过维特比算法优化最优路径的搜索。

定义:$S=w_1w_2w_3…w_{N}$,对应的任一种标签序列为$Z=z_1z_2z_3…z_{N}$,目标是寻找最优的标签序列$\hat{Z}$。

\[\begin{equation}\begin{aligned} \hat{Z}&=\mathop{argmax}_\limits{z}(P(Z\vert S)) \\&=\mathop{argmax}_\limits{z}(P(S\vert Z)\cdot P(Z)) \\&=\mathop{argmax}_\limits{z}(P(w_1w_2...w_N\vert z_1z_2...z_N) \cdot P(z_1z_2...z_N)) \end{aligned} \end{equation}\]

上式中乘积的第一项为Translation model(TM)、第二项为Language model(LM)。

在TM中引入独立假设,在LM中引入bi-gram假设,可得:

\[\hat{Z}=\mathop{argmax}_\limits{z}( \prod_\limits{i=1}^{N}P(w_i\vert z_i)\cdot P(z_1)\cdot \prod_\limits{j=2}^N P(z_j\vert z_{j-1}))\]

对概率进行对数化:

\[\hat{Z}=\mathop{argmax}_\limits{z}( \sum_\limits{i=1}^{N} logP(w_i\vert z_i)+ log P(z_1)+ \sum_\limits{j=2}^N logP(z_j\vert z_{j-1}))\]

相加的三项中,第一项为由标签到词的条件概率,记为$A$;第二项为标签出现在句首的概率,记为$B$,第三项为标签之间的转移概率,记为$\pi$。

在有标注的数据集上,上述三项的概率能够通过统计得到。

对于一条待标注的文本$S_u$,长度为$N$,其所有可能的标签序列有$n_{tag}^N$种,对所有情况可以计算概率$P(Z_i\vert S)$,取概率最大的标签序列即可。

实际应用中,枚举所有种可能的效率是非常低的,因此需要使用更高效的算法来对路径进行剪枝。维特比算法是一种动态规划算法,用于多个步骤且每个步骤中有多种选择的问题的最优路径选择。维特比算法通过计算所有前序步骤到当前步骤的最小代价(或最大收益),以及当前步骤做选择时的代价(收益)来进行步骤的选择,最后,通过回溯法来选择最优路径。

下面是上述过程在词性标注任务上的一种实现。

训练数据的格式为:每行:词/词性

实现:

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
import numpy as np


def build_mapper(train_data):
    """ build token mapper and tag mapper """
    tag2id = {}
    id2tag = {}
    word2id = {}
    id2word = {}

    word_count = 0
    tag_count = 0

    with open(train_data) as fr:
        for line in fr.readlines():
            sp_line = line.strip().split("/")
            if len(sp_line) != 2:
                tag = sp_line[-1]
                word = "/".join(sp_line[:-1])
            else:
                word = sp_line[0]
                tag = sp_line[1]

            if word not in word2id:
                word2id[word] = word_count
                id2word[word_count] = word

                word_count += 1

            if tag not in tag2id:
                tag2id[tag] = tag_count
                id2tag[tag_count] = tag

                tag_count += 1

    n_word = len(word2id)
    n_tag = len(tag2id)
    print("word size:{}".format(n_word))
    print("tag size:{}".format(n_tag))

    return word2id, id2word, tag2id, id2tag, n_word, n_tag


def build_params(word2id, tag2id, n_tag, n_word, train_data):
    """
    build params of HMM: theta = (pi, A, B)
    - pi is a vector marks the probe of each tag to be the first tag
      of a sentence, size: [1, n_tag]
    - A is a matrix of condition probe of words given tags,size: [n_tag, n_word]
    - B is a matrix of transition probe between tags, size: [n_tag, n_tag]"""

    pi = np.zeros(n_tag)  # probe of starting tag
    a = np.zeros((n_tag, n_word))  # condition probe p(word|tag)
    b = np.zeros((n_tag, n_tag))  # transition probe p(tag_i|tag_j)

    at_start_of_sentence = True
    last_tag_id = None

    with open(train_data) as fr:
        for line in fr.readlines():
            word, tag = line.strip().split("/")
            word_id = word2id.get(word)
            tag_id = tag2id.get(tag)

            if at_start_of_sentence:
                pi[tag_id] += 1  # starting prob
                a[tag_id, word_id] += 1  # cond. prob
                at_start_of_sentence = False

            else:
                a[tag_id, word_id] += 1  # cond. prob
                b[last_tag_id, tag_id] += 1  # trans. probe

            if word == ".":
                at_start_of_sentence = True
                last_tag_id = None
            else:
                last_tag_id = tag_id

    # done counting, normalize...
    pi /= sum(pi)

    for i in range(n_tag):
        a[i] /= sum(a[i])
        b[i] /= sum(b[i])

    return pi, a, b


def log(num):
    if num == 0:
        return np.log(1e-8)
    return np.log(num)


def viterbi(sentence, pi, a, b, word2id, id2tag):
    """
    decode with viterbi
    :param sentence: sentence to decode
    :param pi: init probe of tag
    :param a: cond probe of words given tags
    :param b: trans probe between tags
    :return:
    """
    x = [word2id[word] for word in sentence.split()]  # words of sentence

    t = len(x)
    n_tag = len(tag2id)

    dp = np.zeros((t, n_tag))
    path_record = np.array([[0 for _ in range(n_tag)] for _ in range(n_tag)])

    for j in range(n_tag):
        dp[0][j] = log(pi[j]) + log(a[j, x[0]])

    # dp
    for i in range(1, t):  # 每个单词
        for j in range(n_tag):  # 每个词性
            dp[i][j] = -1e6
            for k in range(n_tag):
                score = dp[i-1][k] + log(b[k][j]) + log(a[j, x[i]])
                if score > dp[i, j]:
                    dp[i, j] = score
                    path_record[i, j] = k

    # find best sequence
    best_tag_sequence = [0] * t

    best_tag_sequence[t-1] = np.argmax(dp[t-1])

    for i in range(t-2, -1, -1):
        best_tag_sequence[i] = path_record[i + 1, best_tag_sequence[i+1]]

    tag_sequence = [id2tag[idx] for idx in best_tag_sequence]

    return tag_sequence


train_data_path = "./resource/traindata.txt"
test_sentence = "The big question is whether the president will have the strength ."

# 1. build mapper of words and tags
word2id, id2word, tag2id, id2tag, n_word, n_tag = build_mapper(train_data_path)

# 2. \theta = (\pi, A, B), build them
pi, a, b = build_params(word2id, tag2id, n_tag, n_word, train_data_path)

# 3. find optimal path with viterbi

tag_sequence = viterbi(test_sentence, pi, a, b, word2id, id2tag)
print(tag_sequence)  # ['DT', 'JJ', 'NN', 'VBZ', 'IN', 'DT', 'NN', 'MD', 'VB', 'DT', 'NN', '.']