Deep Karmaning

技術系の話から日常のことまで色々と書きます

Pytorchで分散表現の学習手法であるskipgram、skipgram with negative samplingの実装

1.概要

分散表現(distributed representation)の学習手法である、skipgramとskipgram with negative samplingをPytorchを使って実装したので、その紹介をしたいと思います。

2.理論

まずは 理論的な側面を簡単に紹介します。

2.1.分散表現(distributed representation)とは

分散表現は埋め込み(embedding)表現とも呼ばれたりするのですが、離散的な値を数値的な値で表現したものを指します。

言語における単語などは代表的だと思いますが、単語を単純にone hot encodingで離散的な値のまま使うと、かなり次元が高くなってしまいます。

一方で分散表現を学習することで、低次元かつ連続的な値として扱うことが可能です。

また以下の図で表現されているように、連続的な値のため類似度を計算することも容易になります。

f:id:rf00:20190316163934p:plain
出展:坪井ら, 2017, 深層学習による自然言語処理

2.2.skipgramとは

skipgramはニューラルネットを活用した分散表現の学習手法の一つになります。

もともと自然言語処理の分野で使われる手法で、文章内のターゲットの単語の周辺に出てくる単語を予測することで分散表現を得ようとします。

モデルとしては以下の図のような形になっていて、このモデルのw(t)がターゲットの単語でw(t±n)が周辺に出現する単語です。nの最大値、最小値はウィンドウサイズによって変わってきます。w(t±n)をコンテキストと呼んだりもするようです*1

f:id:rf00:20190316165423p:plain
出展:Mikolov et.al., 2013, 「Distributed Representations of Words and Phrases and their Compositionality」

目的関数は以下のようになります。


\frac{1}{T} \displaystyle{\Sigma_{t=1}^{T}} \displaystyle{\Sigma_{-c \leq j \leq c, j \neq 0}} \log p(w_{t+j} | w_t)

ここでcはウィンドウサイズ、w_t+jがコンテキスト、w_tがターゲット。

p(w_{t+j}|w_t)には以下のようなsoftmaxを使うのが定番です。


p(w_O|w_I) = \frac{ \exp( \nu'_{w_O} \top \nu_{w_I}) }{\Sigma_{w=1}^{W} \exp(\nu'_{w} \top \nu_{w_I}) }

この式の右辺の分母に注目すると、ボキャブラリサイズ分(W)合計をとっているのがわかるのですが、このせいで単純に学習しようとするとボキャブラリー数が増えるほどに学習が大変になっていくことになります。

そこで出てくるのが次に紹介するskipgram with negative samplingです。

2.3.skipgram with negative samplingとは

skipgram with negative samplingはskipgramの計算効率を高めたニューラルネットを活用した分散表現の学習手法です。

ターゲット単語から周辺に出てくるコンテキスト単語を学習するという、基本的なモデルのアーキテクチャは変わりません。

しかし、softmaxの代わりにsigmoidを利用し、負例をボキャブラリーでの出現頻度に基づいてサンプリングしたものを使うという点で以下のようになってきます。


\log \sigma(\nu'_{w_o} \top \nu_{w_I}) + \Sigma_{i=1}^{k} \mathbb{E}_{w_i} \sim P_n (w) \left[ \log \sigma(-\nu'_{w_i} \top \nu_{w_I})\right]

kがnegative sampling数になるため、通常のskipgramに比べて計算量が減ることがわかると思います。

3.実装

今回はpytorchにて実装してみました*2

またコードは念の為ライブラリぽく使えるように整理しています。

本記事では特に肝となるバッチ作成部分と学習部分だけを中心に見ていきますので、全体に興味がある方はコードをgithubに用意していますので見てみてください。

github.com

3.1.バッチ生成部分の実装

個人的にskipgram、skipgram with negative samplingを理解する第一歩としてはバッチ生成部分を理解することが肝だと思っています*3

バッチ生成部分を理解できれば、どのように学習をしているのかを大体理解できると思います。

ちなみにここでいうバッチというのは、コンテキストとターゲットの組み合わせを指していて、コンテキストとターゲットの組み合わせを生成しているのがバッチ生成部分です。

今回はskipgram.pyの中でSkipgramクラスの関数として以下のように実装しました。

def generate_batch(self, corpus, window_size, batch_size):
        row_idx = self.row_idx
        col_idx = self.col_idx
        context = collections.deque()
        target  = collections.deque()
        i = 0

        while i < batch_size:
            data = corpus.data[row_idx]
            target_ = data[col_idx]
            sentence_length = len(data) 

            if col_idx == 0: #first word
                start_idx = col_idx + 1
                start_idx = 0 if  start_idx < 0 else start_idx
                end_idx = col_idx + 1 + window_size
                end_idx = end_idx if  end_idx < (sentence_length )  else sentence_length
                for t in range(start_idx, end_idx):
                    if t > sentence_length - 1:break
                    context.append(data[t])
                    target.append(target_)
                    i += 1
            elif col_idx == len(data): #last word
                start_idx = col_idx - window_size
                start_idx = 0 if  start_idx < 0 else start_idx
                end_idx = col_idx + 1
                end_idx = end_idx if  end_idx < (sentence_length)  else sentence_length 
                for t in range(start_idx, end_idx):
                    if t > sentence_length - 1:break
                    context.append(data[t])
                    target.append(target_)
                    i += 1
            else:#mid word
                start_idx = col_idx - window_size
                start_idx = 0 if  start_idx < 0 else start_idx
                end_idx = col_idx + 1 + window_size
                end_idx = end_idx if  end_idx < (sentence_length )  else sentence_length 
                for t in range(start_idx, end_idx):
                    if t > sentence_length - 1:break
                    if t == col_idx:continue
                    context.append(data[t])
                    target.append(target_)
                    i += 1

            col_idx = (col_idx + 1)
            if col_idx == len(data):
                col_idx  = 0
                row_idx = row_idx + 1
            
            if row_idx == len(corpus.data):
                self.row_idx = 0
                self.col_idx = 0
                self.batch_end = 1
                break
            else:
                self.row_idx = row_idx
                self.col_idx = col_idx
        
        return  np.vstack((np.array(target), np.array(context))).T

やっていることは指定されたウィンドウサイズ内で前後の単語をとってくるという処理です。ただし、文頭と文末の場合だけは前か後ろかしかないため処理を変えています。

またバッチサイズを優先していて、バッチサイズを超えてもウィンドウサイズ内の単語をすべて取っていなければ終わるまで取るようにしています。

ここはこのやり方が正しいのか、あるいはウィンドウサイズのどこまで単語を取ったか記憶させて、次のバッチ生成で引き継いで開始するべきかはわかっていませんが、ひとまずこのままにしています。

3.2.skipgramとskipgram with negative samplingの実装

モデルの順伝播部分は以下のように実装しました。

def forward(self, batch, corpus = None):
        if self.sgns == 0:
            y_true = Variable(torch.from_numpy(np.array([batch[1]])).long())

            x1 = torch.LongTensor([[batch[0]]])
            x2 = torch.LongTensor([range(self.embedding_dim)])
            u_emb = self.u_embeddings(x1)
            v_emb = self.v_embeddings(x2)
            z = torch.matmul(u_emb, v_emb).view(-1) #view reshape

            log_softmax = F.log_softmax(z, dim = 0)
            loss = F.nll_loss(log_softmax.view(1,-1), y_true)
        else:        
            target = torch.LongTensor([[batch[0]]])
            context = torch.LongTensor([[batch[1]]])

            ns = self.negative_sampling(corpus)
            ns = torch.LongTensor([[ns]])

            if torch.cuda.is_available():
                target = Variable(target).cuda()
                context = Variable(context).cuda()
                ns = Variable(ns).cuda()

            #positive
            x1 = self.u_embeddings(target)
            x2 = self.v_embeddings(context)

            score = torch.sum(torch.mul(x1, x2))#inner product
            log_target = F.logsigmoid(score).squeeze()

            #negative
            x3 = self.v_embeddings(ns)
            neg_score = -1 * torch.sum(torch.mul(x1, x3), dim = 2)
            log_neg_sample = F.logsigmoid(neg_score).squeeze()

            loss = -1 * (log_target + log_neg_sample.sum())
    
        return loss

sgnsがフラグになっていて、1を指定するとnegative samplingを利用するようになっています。したがって12行目までがskipgram、14行目以降がskipgram with negative samplingになります。

通常のskipgramはsoftmaxを使っていて、skipgram with negative samplingはsigmoidを使っていることがわかると思います。

negative samplingの方は以下のように実装しています。

def negative_sampling(self, corpus):
        sampled = np.random.choice(corpus.negaive_sample_table_w, p = corpus.negaive_sample_table_p, size = self.negative_samples)
        negative_samples = np.array([corpus.dictionary[w] for w in sampled])
        return negative_samples    

データを読み込んだ際にあらかじめ単語ごとに頻度に基づいた確率に変換したテーブルを用意していて、そこから確率に基づいてサンプリングする実装にしています。

forward関数の17行目、18行目で毎バッチごとに実行しているのがわかると思います。

3.3.学習の実行

exampleフォルダにサンプルコードを用意したので実行してみます。初回のみデータをダウンロードするので時間がかかるので注意してください。

データは以下のTomas Mikolov氏*4のサイトからダウンロードしています。

www.fit.vutbr.cz

skipgramの方のサンプルコードはexample_skipgram.pyとなり、以下のようにしました。

import sys
sys.path.append('..')

import corpus as cp
import distributed_representation as dr

import utility

#data download
dl = utility.data_loader()
dl.dataload()

corpus = cp.Corpus(data = 'data/simple-examples/data/ptb.train.txt', mode = "l", 
                max_vocabulary_size = 5000, max_line = 10, 
               minimum_freq = 5)

window_size = 1
embedding_dims = 100
batch_size = 128

import time
start = time.time()

dr_sgns = dr.DistributedRepresentation(corpus, embedding_dims, window_size, batch_size, mode_type = 1, 
                                sgns = 0, trace = True)
dr_sgns.train(num_epochs = 11, learning_rate = 0.01)

process_time = time.time() - start
print(process_time)

13~15行目でデータを読み込んでいて、全部のファイルを使うと学習に時間がかかるため、max_line引数で最初の10行に絞っています。

24~26行目で学習の実行をしています。sgns=0にしているためnegative samplingを利用しないことにしています。

以下のように実行すると、

cd gdp/example/
python example_skipgram.py

以下のような結果が返ってくれば学習成功です。

Loss at epo 0: 37.58647918701172
Loss at epo 10: 17.378202438354492
3.9009947776794434

skipgram with negative samplingの方はコードがexample_sgns.pyで、

import sys
sys.path.append('..')

import corpus as cp
import distributed_representation as dr

import utility

#data download
dl = utility.data_loader()
dl.dataload()

corpus = cp.Corpus(data = 'data/simple-examples/data/ptb.train.txt', mode = "l", 
                max_vocabulary_size = 5000, max_line = 10, 
               minimum_freq = 5)

window_size = 1
embedding_dims = 100
batch_size = 128

import time
start = time.time()

dr_sgns = dr.DistributedRepresentation(corpus, embedding_dims, window_size, batch_size, mode_type = 1, 
                                sgns = 1, negative_samples = 5, trace = True)
dr_sgns.train(num_epochs = 11, learning_rate = 0.01)

process_time = time.time() - start
print(process_time)

と実装しています。25行目だけ引数の指定を変えています。

実行は、

cd gdp/example
python example_sgns.py

としてあげればよく、

Loss at epo 0: 1353.0279541015625
Loss at epo 10: 354.0135192871094
5.175678730010986

という結果が返ってくるはずです。

これを見るとskipgram with negative samplingのほうが計算効率が良いモデルのはずなのに遅くなっているのですが、これは使っているデータが少ないことに起因していると思っています。ボキャブラリー数が少ない状態だとネガティブサンプリングよりもsoftmaxで計算するほうが効率的だということだと思っています。

この点を確認するために、example_compare_time_skipgram.pyを用意しました。

こちらはデータのすべてを使って1epoch回すのにかかる時間を比較しているのですが、 比較したところ5~6倍程度skipgram with negative samplingが早かったので、うまく実装できているのかなと思います。興味ある方は回してみてください。

ちなみに実行環境によるとは思いますが終わるまで2時間くらいかかりますので注意してください。

4.まとめ

今回は分散表現手法である、

  • skipgram
  • skipgram with negative sampling

の理論と実装を紹介しました。

私自身はこれまでブラックボックス的に使ってきた分散表現なのですが、改めて調査し実装することで理解がだいぶ深まった気がしています。

また今後はCBOW等も実装していきたいと考えています。

それでは間違い等ありましたら、ご指摘お願いいたします。

参考文献

  • Mikolov, Tomas, Ilya Sutskever, Kai Chen, Greg S Corrado, and Jeff Dean. 2013. “Distributed Representations of Words and Phrases and Their Compositionality,”
  • 坪井祐太, 海野裕也と鈴木潤. 2017. 深層学習による自然言語処理. 講談社.

*1:逆向きの形で学習する手法にCBOWというものもあります

*2:実装してみたもののpytorchでの実装だと学習が終わるのに時間がかかるため、実務上はgensim等のライブラリを使うことをおすすめしておきます

*3:CBOWも同様

*4:skipgram等に関して論文を書いている人