Keyboard shortcuts

Press or to navigate between chapters

Press S or / to search in the book

Press ? to show this help

Press Esc to hide this help

Log関数の実装

ではクロスエントロピー誤差に必要な関数である Log関数 を実装していきます。この関数は以前の Function構造体 なので同じように実装します。ただし、オプションが多いので、微分の式など処理が多少複雑なので、説明します。


まず数学の分野では明示しなければLog関数の底は\(e\)というのが暗黙の了解なので、sum関数 の時ど同様、引数がNone の時を底\(e\)とします。そして底が他の値の場合はSome(base) として渡すようにします。もちろんですが、base の型はもちろん普通の数値と同じ f32 です。

また、この底の値で微分のふるまいが異なります。正確に言えば、統一した方法で計算するは可能ですが、分けた方が、数学的な理解としても、パフォーマンス的にしても良いので(\(e\)以外の底の値を使用することはまれなので)、分けます。

では実装する前にlogの微分を考えてます。今回は底で場合分けして考えます。まず\(y = \log x\)、もしくは\(y = \ln x\)の時の微分は

$$\frac{dy}{dx} = \frac{1}{x}$$

となります。次に底が指定された場合、つまり\(y = \log_a x\)の時の微分は

$$\frac{dy}{dx} = \frac{1}{x\cdot \ln a}$$

となります。ではこれらをもとに Log関数 を実装していきます。

struct Log {
    inputs: Vec<RcVariable>,
    output: Option<Weak<RefCell<Variable>>>,
    base: Option<f32>,
    generation: i32,
    id: usize,
}

impl Function for Log {
    fn call(&mut self) -> RcVariable {
        let inputs = &self.inputs;
        if inputs.len() != 1 {
            panic!("Logは一変数関数です。inputsの個数が一つではありません。")
        }

        let output = self.forward(inputs);

        if get_grad_status() == true {
            //inputのgenerationで一番大きい値をFuncitonのgenerationとする
            self.generation = inputs.iter().map(|input| input.generation()).max().unwrap();

            //  outputを弱参照(downgrade)で覚える
            self.output = Some(output.downgrade());

            let self_f: Rc<RefCell<dyn Function>> = Rc::new(RefCell::new(self.clone()));

            //outputsに自分をcreatorとして覚えさせる
            output.0.borrow_mut().set_creator(self_f.clone());
        }

        output
    }

    fn forward(&self, xs: &[RcVariable]) -> RcVariable {
        let base = self.base;
        let x = &xs[0];
        let y_data;

        //baseがeか他の値かで場合分け(eの場合、baseはNone)
        if let Some(base_data) = base {
            y_data = x.data().mapv(|x| x.log(base_data));
        } else {
            y_data = x.data().mapv(|x| x.ln());
        }
        y_data.rv()
    }

    fn backward(&self, gy: &RcVariable) -> Vec<RcVariable> {
        let x = &self.inputs[0];
        let gx;

        let base = self.base;

        //baseがeか他の値かで場合分け(eの場合、baseはNone)
        if let Some(base_data) = base {
            gx = 1.0.rv() / (x.clone() * base_data.ln().rv()) * gy.clone();
        } else {
            gx = (1.0.rv() / x.clone()) * gy.clone();
        }
        let gxs = vec![gx];
        gxs
    }

    fn get_inputs(&self) -> &[RcVariable] {
        &self.inputs
    }

    fn get_output(&self) -> RcVariable {
        let output;
        output = self
            .output
            .as_ref()
            .unwrap()
            .upgrade()
            .as_ref()
            .unwrap()
            .clone();

        RcVariable(output)
    }

    fn get_generation(&self) -> i32 {
        self.generation
    }
    fn get_id(&self) -> usize {
        self.id
    }
}
impl Log {
    fn new(inputs: &[RcVariable], base: Option<f32>) -> Rc<RefCell<Self>> {
        Rc::new(RefCell::new(Self {
            inputs: inputs.to_vec(),
            output: None,
            base: base,
            generation: 0,
            id: id_generator(),
        }))
    }
}

pub fn log(x: &RcVariable, base: Option<f32>) -> RcVariable {
    let y = log_f(&[x.clone()], base);
    y
}

fn log_f(xs: &[RcVariable], base: Option<f32>) -> RcVariable {
    Log::new(xs, base).borrow_mut().call()
}

底の指定を Sum関数 と同じようにOption型 で渡し、forward,backward で場合分けして処理します。log関数の計算に慣れていればそれほど難しくはないでしょう。

では底で場合分けしてテストします。

#[test]
    fn log_test() {
        use crate::core_new::ArrayDToRcVariable;

        let a = array![3.0, 3.0, 3.0].rv();
        let b = array![3.0, 3.0, 3.0].rv();

        let mut y0 = log(&a, None); //底がe
        let mut y1 = log(&b, Some(2.0)); //底が2.0

        println!("y0 = {}", y0.data()); // 1.098...
        println!("y1 = {}", y1.data()); // 1.584...

        y0.backward(false);
        y1.backward(false);

        println!("a_grad = {:?}", a.grad().unwrap().data()); // 0.3333...
        println!("b_grad = {:?}", b.grad().unwrap().data()); // 0.4808...
    }