SumTo関数の実装
先ほど、ブロードキャストで形状を正しく変更する関数をBroadcastTo を説明しました。その過程で、BroadcastTo関数のBackwardにおいて新たな関数、SumTo が必要になりました。ここではこの関数について解説します。
ではこのSumTo関数 のFowwardとBackwardを見てみましょう。すると、もうお気づきかもしれませんが、実は、BroadcastToとSumTo は ForwardとBackwardにおいて表裏一体の関係 なのです。
SumTo関数のForwardとBackward
Forward
graph LR
arrx["$$x:\begin{pmatrix}2 & 2 & 2 \\\ 2 & 2 & 2\end{pmatrix}$$"] --> f[SumTo]
f --> arry["$$y:\begin{pmatrix}12 \end{pmatrix}$$"]
style arrx stroke-width:0px
style arry stroke-width:0px
Backward
graph RL
arrgy["$$gy:\begin{pmatrix}1\end{pmatrix}$$"] --> reshape["BroadcastTo ( = SumTo')"]
reshape --> arrgx["$$gx:\begin{pmatrix}1 & 1 & 1 \\\ 1 & 1 & 1\end{pmatrix}$$"]
style arrgx stroke-width:0px
style arrgy stroke-width:0px
もう一つの場合
Forward
graph LR
arrx["$$x:\begin{pmatrix}3 & 3 & 3 \\\ 3 & 3 & 3\end{pmatrix}$$"] --> f[SumTo]
f --> arry["$$y:\begin{pmatrix}6 & 6 & 6\end{pmatrix}$$"]
style arrx stroke-width:0px
style arry stroke-width:0px
Backward
graph RL
arrgy["$$gy:\begin{pmatrix}1 & 1 &1 \end{pmatrix}$$"] --> reshape["BroadcastTo ( = SumTo')"]
reshape --> arrgx["$$gx:\begin{pmatrix}1 & 1 & 1 \\\ 1 & 1 & 1\end{pmatrix}$$"]
style arrgx stroke-width:0px
style arrgy stroke-width:0px
これらのグラフと前のBroadcastToの時のグラフを比べて見ると、 input をブロードキャストで拡張したのを、SumTo で戻し、逆にSumTo で和を取ったものを、BroadcastTo で拡張して戻すというお互いに依存しあう関係になっています。なので、これら二つの関数を実装する際は同時に実装する必要があります。
struct SumTo {
inputs: Vec<RcVariable>,
output: Option<Weak<RefCell<Variable>>>,
shape: IxDyn,
generation: i32,
id: usize,
}
impl Function for SumTo {
fn call(&mut self) -> RcVariable {
let inputs = &self.inputs;
if inputs.len() != 1 {
panic!("SumToは一変数関数です。inputsの個数が一つではありません。")
}
let output = self.forward(inputs);
if get_grad_status() == true {
//inputのgenerationで一番大きい値をFuncitonのgenerationとする
self.generation = inputs.iter().map(|input| input.generation()).max().unwrap();
// outputを弱参照(downgrade)で覚える
self.output = Some(output.downgrade());
let self_f: Rc<RefCell<dyn Function>> = Rc::new(RefCell::new(self.clone()));
//outputsに自分をcreatorとして覚えさせる
output.0.borrow_mut().set_creator(self_f.clone());
}
output
}
fn forward(&self, xs: &[RcVariable]) -> RcVariable {
let x = &xs[0];
let y_shape = self.shape.clone();
let y_data = array_sum_to(&x.data().view(), y_shape);
y_data.rv()
}
fn backward(&self, gy: &RcVariable) -> Vec<RcVariable> {
let x = &self.inputs[0];
let x_shape = IxDyn(x.data().shape());
let gx = broadcast_to(gy, x_shape);
let gxs = vec![gx];
gxs
}
fn get_inputs(&self) -> &[RcVariable] {
&self.inputs
}
fn get_output(&self) -> RcVariable {
let output;
output = self
.output
.as_ref()
.unwrap()
.upgrade()
.as_ref()
.unwrap()
.clone();
RcVariable(output)
}
fn get_generation(&self) -> i32 {
self.generation
}
fn get_id(&self) -> usize {
self.id
}
}
impl SumTo {
fn new(inputs: &[RcVariable], shape: IxDyn) -> Rc<RefCell<Self>> {
Rc::new(RefCell::new(Self {
inputs: inputs.to_vec(),
output: None,
shape: shape,
generation: 0,
id: id_generator(),
}))
}
}
fn array_sum_to(x_array: &ArrayViewD<f32>, shape: IxDyn) -> ArrayD<f32> {
let x_shape = x_array.shape();
let mut axes_to_sum = HashSet::new();
// 合計する軸を特定する
for i in 0..x_shape.len() {
if i >= shape.ndim() || x_shape[i] != shape[i] {
axes_to_sum.insert(i);
}
}
let mut y = x_array.to_owned();
// HashSetの要素をVecに収集し、ソートして逆順にイテレーションする
let mut sorted_axes: Vec<_> = axes_to_sum.into_iter().collect();
sorted_axes.sort_unstable();
// 特定した軸を合計する
for &axis in sorted_axes.iter().rev() {
y = y.sum_axis(Axis(axis)).insert_axis(Axis(axis));
}
y
}
fn sum_to_f(xs: &[RcVariable], shape: IxDyn) -> RcVariable {
SumTo::new(xs, shape).borrow_mut().call()
}
pub fn sum_to(x: &RcVariable, shape: IxDyn) -> RcVariable {
let y;
let x_shape = IxDyn(x.data().shape());
if x_shape == shape { // 形状が変化しないならそのまま流す
y = x.clone();
} else {
y = sum_to_f(&[x.clone()], shape);
}
y
}
基本的にはBroadcastTo関数と同様に、引数として求める形状を渡します。しかし、違う点はBroadcastTo を実際に処理するメソッド、broadcast() がArray型 には装備されていたのに対して、sumto というメソッドは現状実装されていません。つまり、自分でその処理を書かなければなりません。その処理が array_sum_to() です。この処理については詳しくは解説しませんが、簡単な説明として、inputの行列のある軸に対して和をとる際、求める形状の軸に対してどこが違うのか、すなわちどの軸を対象とすべきかという軸を求め、それを配列として保持し、その軸においてsum_axis() メソッドで和をとります。また、sum_to() 関数では、形状が変化しない場合は、そのまま流すようにします。
BroadcastToも実装したら、 SumTo関数関数を二つの場合でテストしてみましょう。
TODO: コード検証必要
#[test]
fn sum_to_test() {
use crate::core_new::ArrayDToRcVariable;
let x = array![[3.0, 3.0, 3.0],[ 3.0, 3.0,3.0]].rv();
let mut y = sum_to(&x,Dim(IxDyn(&[1, 3])));
println!("y = {}", y.data());
y.backward(false);
println!("x_grad = {:?}", x.grad().unwrap().data());
}