accuracy関数
ここではモデルがどれくらい精度が高いかを評価する accuracy関数 を実装していきます。これはただモデルの性能を測るだけの関数なのでモデルの学習には使用しません。つまり、バックプロパゲーション対象外の微分をしない関数ということです。
モデルの精度を測る関数はいくつか存在しますが、この accuracy関数 は一番オーソドックスな関数で、処理としてはただ単に正解数をデータ数で割り正解率を求めるというシンプルな計算ですが、この正解数を求める処理を考える必要があります。
そもそもモデルが正しく答えを出せたとはどういうことでしょうか。まずはそこから考えます。モデルの学習とはいかに誤差を小さくするか、そして、もっと言えば正解ラベルに近い出力データを出せるかということです。そのような中で、私たちはモデルの予測値が答えにどれほど近ければ正解したとするかを定義しなければなりません。先ほど softmax関数 の中で確率に変換すると説明しましたが、つまりモデルは予測値をあくまで確率として出力するのです。例えば、クラス数が3なら、0の確率は0.1、1は0.3、2は0.6という感じです。なので、私たちは、モデルが正解の値を一番高い確率で出力できたら正解したとします。この場合、モデルは確率の値として一番高い2だと予測したとらえ、それが正解ラベルと一致しているかを確かめればよいのです。
モデルの正解の定義を考えましたが、次はそれを実装していきます。重要なことは正解ラベルの値はインデックスを表しているということです。
pub fn accuracy(y: ArrayView2<f32>, t: ArrayView2<f32>) -> f32 {
if y.shape() != t.shape() {
panic!("交差エントロピー誤差でのxとtの形状が異なります。tがone-hotベクトルでない可能性があります。")
}
let data_size = y.shape()[0] as f32;
let num_class = t.shape()[1];
let argmax_vec: Vec<u32> = y
.outer_iter()
.map(|row: ArrayBase<ViewRepr<&f32>, Dim<[usize; 1]>>| row.argmax().unwrap() as u32)
.collect();
let max_index = Array::from_vec(argmax_vec);
let one_hot_y = arr1d_to_one_hot(max_index.view(), num_class);
assert_eq!(one_hot_y.shape(), t.shape());
let acc_matrix = &one_hot_y * &t;
let accuracy = acc_matrix.sum() / data_size; // 正解数 / データ数
accuracy
}
//二つの行列をかけることで正解数がわかる。
// 例
// one_hot_y = [[0.0,0.0,1.0],
// [0.0,1.0,0.0],
// [1.0,0.0,0.0],]
//
// t = [[0.0,0.0,1.0],
// [1.0,0.0,0.0],
// [1.0,0.0,0.0],]
//
//one_hot_y* t = [[0.0,0.0,1.0], 2行目は1.0の位置が違うのでかけてすべてゼロ。
// [0.0,0.0,0.0],
// [1.0,0.0,0.0],] 行列を要素ごとにかけることで、正しい答えだった場合のみ1.0になり、その他はすべて0.0になる。
// この行列の合計値sum関数で正解数がわかる。この場合、sumの値は2なので正解数は2。
コメントでも説明していますが、予測値で一番値が高いラベルのところを1にし、他を0にするというone_hotベクトル化をします。これにより、正解ラベルの行列と掛けることで正解したところだけ1が残ります。あとは行列でSum を取れば、正解した個数がわかります。あとはそれをデータ数(正確にはバッチ数)で割れば、そのデータ数(バッチ数)での正解率が求まります。
ここで一つ気になる点として行列の次元を静的で扱っていることが挙げられます。この理由は主に二つあります。一つ目は、最大値を取るインデックスを求める、numpyでいうargmax() のような処理(軸を指定した最大値のインデックスを求める関数はrustのndarrayには実装されていないので、自分で処理を書きます)を行う際、静的な次元で行う必要があるからです。二つ目は、基本的に正解ラベルの行列の形状は2次元だからです。というのも、正解のデータは0,1,2,・・というように行列ではなく、数値で表せるからです。しかし、今後、高次元の正解ラベル(行列の正解データはもはやラベルとは呼べないかもしれませんが)を扱う新たなアルゴリズムでが登場した場合は対応できませんので、より柔軟な関数が求めらます。
また、この関数はバックプロパゲーション対象外なので、Function構造体 や RcVariable といったものを考える必要はなく、ただ単にndarrayの行列計算に集中してよいというわけです。
コードで登場した arr1d_to_one_hot() は一次元の行列から2次元のone_hotベクトルの行列を生成します。損失関数の one_hotベクトル の\(T\)から\(T’\)の変換説明を見るとわかりやすいと思います。この arr1d_to_one_hot() については 補足 で説明します。
TODO:arr1d_to_one_hot()補足
では実際に正解率を求められるかテストします。
fn main() {
let a: Array2<f32> = array![
[1.0f32, 2.0, 3.0], // ○
[6.0, 4.0, 5.0], // ○
[1.0, 2.0, 10.0], // ×
[12.0, 15.0, 5.0] // ○
];
let b = array![2, 0, 1, 1];
let b = to_one_hot(b.view(), 3);
let acc = accuracy(a.view(), b.view());
println!("{}", acc);
}