古い記事
ランダムジャンプ
新しい記事
かなり昔に作った Naive Bayes(単純ベイズ)による文書分類のサンプルプログラムを整理したので公開しておきます。Perl で書かれています。Pure Perl。

Naive Bayes についての詳細は下記のサイトをどうぞ。
- 単純ベイズ - 機械学習の「朱鷺の杜Wiki」
- Wikipedia:単純ベイズ分類器

さて、Naive Bayes で分類するときには下記の式を用いるわけです。


Pについての定義は下記:


実装を簡単にするために log をとって足し算にしています(argmax ですが実際はマイナスかけて argmin で実装)。


学習データ


分類したいカテゴリごとに1行。
各行はカテゴリのラベル(LABEL)とそのカテゴリの文書に現れる単語(WORD)の頻度(TF)。
各行のフォーマットは:
^CAT\tWORD:TF,WORD:TF,WORD:TF,...$
(ref. [2010-02-13-4])

t.c2w(学習データ):TSV+CSV
政治	選挙:6,外交:3,防衛:2,予算:1,オスプレイ:2,独立:1,環境:1,医療:1,エネルギー:2
スポーツ	野球:9,移籍:2,速報:1,日本代表:8,サッカー:9,スタジアム:3,エネルギー:1
エンタメ	選挙:1,移籍:1,速報:1,不倫:4,結婚:6,離婚:7,ものまね:4,破綻:2,独立:1
科学	感染症:2,環境:2,医療:4,エネルギー:5,遺伝子:3,惑星:3,ニュートリノ:1

転置ファイル


推定時に使いやすいように転置したデータを作ります。
変換プログラム mkdat4nb.pl で
カテゴリごとの単語頻度合計や各単語のカテゴリ別出現頻度が入った
ファイル「*.nb」を作ります。
手抜き実装のため、行頭がスペースのものがカテゴリの情報、そうでないものが単語別の情報となっています。

mkdat4nb.pl
#!/usr/bin/perl
use strict;
use warnings;

my %cnt_wc;
my %cnt_c;
my $cnt_w;

while (<>) {
    next if not /^(.+?)\t(.+?)$/;
    my ($cat, $ws) = ($1, $2);
    my @words = split(",", $ws);
    foreach (@words) {
        next if not /^(.+):([^:]+)$/;
        my ($w, $tf) = ($1, $2);
        $cnt_wc{$w}{$cat} += $tf;
        $cnt_c{$cat} += $tf;
        $cnt_w += $tf;
    }
}

foreach my $cat (sort keys %cnt_c){
    print " $cat\t$cnt_c{$cat}\t$cnt_w\n";
}

foreach my $w (sort keys %cnt_wc){
    print "$w\t".join(",", map {"$_:$cnt_wc{$w}{$_}"}
                      sort {$cnt_wc{$w}{$b} <=> $cnt_wc{$w}{$a}}
                      keys %{$cnt_wc{$w}})."\n";
}

実行例:
% ./mkdat4nb.pl t.c2w > t.nb

t.nb(転置ファイル):TSV
 エンタメ	27	99
 スポーツ	33	99
 政治	19	99
 科学	20	99
ものまね	エンタメ:4
エネルギー	科学:5,政治:2,スポーツ:1
オスプレイ	政治:2
サッカー	スポーツ:9
スタジアム	スポーツ:3
ニュートリノ	科学:1
不倫	エンタメ:4
予算	政治:1
医療	科学:4,政治:1
外交	政治:3
惑星	科学:3
感染症	科学:2
日本代表	スポーツ:8
独立	エンタメ:1,政治:1
環境	科学:2,政治:1
破綻	エンタメ:2
移籍	スポーツ:2,エンタメ:1
結婚	エンタメ:6
速報	エンタメ:1,スポーツ:1
選挙	政治:6,エンタメ:1
遺伝子	科学:3
野球	スポーツ:9
防衛	政治:2
離婚	エンタメ:7

入力データ


判定単位は行。
単語がスペースで区切られているだけのフォーマット。
同じ行に含まれている単語集合にたいして、カテゴリ分類を行います。

t.txt(入力データ):
選挙 移籍 離婚 破綻
環境 医療 エネルギー 速報 遺伝子
防衛 オスプレイ 外交 破綻
速報 サッカー 日本代表 移籍

分類


分類プログラム nb.pl 転置ファイルと入力データを受け取って、
結果を出力します。
スコアの高い順にカテゴリが出てきます。
スコアは前述の log とって足し算してマイナスかけたもの。
なお、各カテゴリでの未出現単語の頻度は 0.1 ($min_freq) にしてあります。
適当です。

nb.pl
#!/usr/bin/perl
use strict;
use warnings;

my $dat_fn = shift;

my $min_freq = 0.1;

my %cnt_wc;
my %cnt_c;
my $cnt_w;

open(my $fh, "<", $dat_fn) or die;
while (<$fh>) {
    if (/^\s(.+?)\t(\d+)\t(\d+)$/) {
        $cnt_c{$1} = $2;
        $cnt_w = $3;
    } elsif (/^(.+?)\t(.+?)$/) {
        my ($w, $cs) = ($1, $2);
        my @cats = split(",", $cs);
        foreach (@cats) {
            next if not /^(.+):(\d+)$/;
            my ($c, $f) = ($1, $2);
            $cnt_wc{$w}{$c} = $f;
        }
    }
}
close($fh);

while (<>) {
    chomp;
    my @ws = split(/ +/, $_);
    print "> @ws\n";
    my %val;
    my %c2w;
    foreach my $w (@ws) {
        next if not defined $cnt_wc{$w};
        foreach my $c (keys %cnt_c) {
            my $wc = $cnt_wc{$w}{$c} || $min_freq;
            $val{$c} +=  -1 * log($wc / $cnt_c{$c});
            $c2w{$c}{$w} = $wc if $wc >= 1;
        }
    }
    foreach my $c (keys %val) {
        $val{$c} += -1 * log($cnt_c{$c} / $cnt_w);
    }
    foreach my $c (sort {$val{$a} <=> $val{$b}} keys %val) {
        next if not defined $c2w{$c};
        my $v = int($val{$c} * 1000)/1000;
        print "$c\t$v\t".join(",", map {"$_:$c2w{$c}{$_}"} 
                              sort {$c2w{$c}{$b} <=> $c2w{$c}{$a}}
                              keys %{$c2w{$c}})."\n";
    }
    print "\n";
}

実行結果:
% ./nb.pl t.nb t.txt
> 選挙 移籍 離婚 破綻
エンタメ	11.843	離婚:7,破綻:2,選挙:1,移籍:1
政治	18.544	選挙:6
スポーツ	21.299	移籍:2

> 環境 医療 エネルギー 速報 遺伝子
科学	14.093	エネルギー:5,医療:4,遺伝子:3,環境:2
政治	20.284	エネルギー:2,環境:1,医療:1
スポーツ	25.488	エネルギー:1,速報:1
エンタメ	26.988	速報:1

> 防衛 オスプレイ 外交 破綻
政治	13.246	外交:3,オスプレイ:2,防衛:2
エンタメ	20.697	破綻:2

> 速報 サッカー 日本代表 移籍
スポーツ	10.114	サッカー:9,日本代表:8,移籍:2,速報:1
エンタメ	19.087	速報:1,移籍:1


「選挙 移籍 離婚 破綻」はエンタメカテゴリっぽい、「速報 サッカー 日本代表 移籍」はスポーツカテゴリっぽい、などという想定通りの結果が出ています。