Keyboard shortcuts

Press or to navigate between chapters

Press S or / to search in the book

Press ? to show this help

Press Esc to hide this help

Reshape関数の実装

graph LR
 arrx["$$x:\begin{pmatrix}x_0 & x_1 & x_2 \\\ x_3 & x_4 & x_5\end{pmatrix}$$"] --> reshape[Reshape]
 reshape --> arry["$$y:\begin{pmatrix}x_0 & x_1 & x_2 & x_3 & x_4 & x_5\end{pmatrix}$$"]
 
 style arrx stroke-width:0px
 style arry stroke-width:0px
graph RL
 arrgy["$$gy:\begin{pmatrix}gy_0 & gy_1 & gy_2 & gy_3 & gy_4 & gy_5\end{pmatrix}$$"] --> reshape["Reshape'"]
 reshape --> arrgx["$$gx:\begin{pmatrix}gy_0 & gy_1 & gy_2 \\\ gy_3 & gy_4 & gy_5\end{pmatrix}$$"] 

 style arrgx stroke-width:0px
 style arrgy stroke-width:0px

reshape関数は行列の形状だけを変更する関数です。つまり数値計算を行わず、値をそのまま流すだけなので、微分の値は1であり、上流からきた微分の値をそのまま流します。このような微分が1の関数は今後も行列計算では出てきます。ここで重要なことは、返す微分の行列はinputの行列と同じ形状でなければならないことです。 つまり、xとgx、yとgyの形状はそれぞれ一致していなければなりません。 なので、backwardの場合、形状を戻して渡す必要があるのです。

struct Reshape {
    inputs: Vec<RcVariable>,
    output: Option<Weak<RefCell<Variable>>>,
    shape: IxDyn,
    generation: i32,
    id: usize,
}

impl Function for Reshape {
    fn call(&mut self) -> RcVariable {
        let inputs = &self.inputs;
        if inputs.len() != 1 {
            panic!("Reshapeは一変数関数です。inputsの個数が一つではありません。")
        }

        let output = self.forward(inputs);

        if get_grad_status() == true {
            //inputのgenerationで一番大きい値をFuncitonのgenerationとする
            self.generation = inputs.iter().map(|input| input.generation()).max().unwrap();

            //  outputを弱参照(downgrade)で覚える
            self.output = Some(output.downgrade());

            let self_f: Rc<RefCell<dyn Function>> = Rc::new(RefCell::new(self.clone()));

            //outputsに自分をcreatorとして覚えさせる
            output.0.borrow_mut().set_creator(self_f.clone());
        }

        output
    }

    fn forward(&self, xs: &[RcVariable]) -> RcVariable {
        let x = &xs[0];
        let y_shape = self.shape.clone();
        let y_data = x.data().to_shape(y_shape).unwrap().to_owned();

        y_data.rv()
    }

    fn backward(&self, gy: &RcVariable) -> Vec<RcVariable> {
        let x = &self.inputs[0];
        let x_shape = IxDyn(x.data().shape());
        let gx = reshape(gy, x_shape);
        let gxs = vec![gx];

        gxs
    }

    fn get_inputs(&self) -> &[RcVariable] {
        &self.inputs
    }

    fn get_output(&self) -> RcVariable {
        let output;
        output = self
            .output
            .as_ref()
            .unwrap()
            .upgrade()
            .as_ref()
            .unwrap()
            .clone();

        RcVariable(output)
    }

    fn get_generation(&self) -> i32 {
        self.generation
    }
    fn get_id(&self) -> usize {
        self.id
    }
}
impl Reshape {
    fn new(inputs: &[RcVariable], shape: IxDyn) -> Rc<RefCell<Self>> {
        Rc::new(RefCell::new(Self {
            inputs: inputs.to_vec(),
            output: None,
            shape: shape,
            generation: 0,
            id: id_generator(),
        }))
    }
}

fn reshape_f(xs: &[RcVariable], shape: IxDyn) -> RcVariable {
    Reshape::new(xs, shape).borrow_mut().call()
}

pub fn reshape(x: &RcVariable, shape: IxDyn) -> RcVariable {
    let y = reshape_f(&[x.clone()], shape);
    y
}

backward で形状を戻すために、inputの形状を覚えておく必要がありますが、inputのデータは保持しており、そこから呼び出せばよいので特に考えません。

変形させたい形状を引数として受け取るために、呼び出す関数の引数に shape を設定します。このshapeの型は IxDyn で動的な形状を指します。また、それを保持するために、shapeフィールド を持ちます。フィールドとして保持し、forwardでinputのArrayのデータの形状を変形します。backwardでは上流からきた微分の値である gy をinputの形状に変形させるために、xを呼び出して形状を求め、それによって変形させます。backward内の reshape 関数は今まさに実装している Reshape構造体 です。

では実装した Reshape 関数をテストしてみましょう。微分の値や形状などに着目してください。

#[test]
    fn reshape_test() {
        use crate::core_new::ArrayDToRcVariable;

        let a = array![[1.0, 2.0, 3.0], [4.0, 5.0, 6.0]].rv();
        let b = array![[1.0, 2.0, 3.0], [4.0, 5.0, 6.0]].rv();

        let mut y0 = reshape(&a, Dim(IxDyn(&[1, 6])));

        let mut y1 = b.reshape(Dim(IxDyn(&[1, 6])));

        println!("y0 = {}", y0.data()); //[[1,2,3,4,5,6]] shape(1,6)
        println!("y1 = {}", y1.data()); //[[1,2,3,4,5,6]] shape(1,6)

        y0.backward(false);
        y1.backward(false);
        println!("a_grad = {:?}", a.grad().unwrap().data()); // [[1.0,1.0,1.0],[1.0,1.0,1.0]] shape(2,3)
        println!("b_grad = {:?}", b.grad().unwrap().data()); // [[1.0,1.0,1.0],[1.0,1.0,1.0]] shape(2,3)
    }

すると、上の図のようにaとa_gradの形状が一致しているのがわかります。