テストサンプルの生成 [e334115, d07338e, 93bc07b] - pymc4のソースコード読んでみた
概要
Model
クラスのサンプル生成のメソッドを読んでみます。test_point
コミット
2018/06/09から2018/06/11の間のコミットです。
- tmp · pymc-devs/pymc4@e334115 · GitHub
- restructure + test point implementation · pymc-devs/pymc4@d07338e · GitHub
- fixes · pymc-devs/pymc4@93bc07b · GitHub
Model - pymc4/model/base.py
test_point
まずはこの test_point
メソッドがでどのように呼ばれるか調べます。pymc4のプロジェクト内では以下の tests/test_model.py
のみで使用されています。
@pm.inline def model(cfg): ed.Normal(0., 1., name='normal') testval_random = model.test_point() testval_mode = model.test_point(sample=False) assert testval_mode['normal'] == 0. assert testval_mode['normal'] != testval_random['normal']
ここでは2種類の呼び方があるみたいです。
- test_point(): デフォルトでsample=True
- test_point(sample=False)
事前情報として、ed.Normal(0., 1., name='normal')
というのは、正規分布で平均が0かつ標準偏差が1という意味です。
test_point(sample=True)
Modelに設定されているRandomVariableのインスタンスの確率分布に沿って、1つの乱数が生成されます。
例) 正規分布で平均が0かつ標準偏差が1の確率分布から乱数を生成する
test_point(sample=False)
Modelに設定されているRandomVariableのインスタンスの確率分布に沿って、中央値が返されます。
例) 正規分布で平均が0かつ標準偏差が1の確率分布の中央値は0なので、常に0が返される
さっと test_point
メソッドの実装も見てみます。
以下がそのメソッドの全容です。
def test_point(self, sample=True): # 1 def not_observed(var, *args, **kwargs): return kwargs['name'] not in self.observed values_collector = interceptors.CollectVariables(filter=not_observed) chain = [values_collector] # 2 if not sample: def get_mode(state, rv, *args, **kwargs): return rv.distribution.mode() chain.insert(0, interceptors.Generic(after=get_mode)) # 3 with self.graph.as_default(), ed.interception(interceptors.Chain(*chain)): self._f(self.cfg) # 4 with self.session.as_default(): returns = self.session.run(list(values_collector.result.values())) return dict(zip(values_collector.result.keys(), returns))
一つひとつバラバラに見ていきます。
#1
この辺りの挙動はあまりよくわかっていないので飛ばします。
def not_observed(var, *args, **kwargs): return kwargs['name'] not in self.observed values_collector = interceptors.CollectVariables(filter=not_observed) chain = [values_collector]
#2
test_point(sample=False)
の場合の処理です。やはり予想通り、中央値を返すようになっています。
if not sample: def get_mode(state, rv, *args, **kwargs): return rv.distribution.mode() chain.insert(0, interceptors.Generic(after=get_mode))
#3
以前、 edward2のinterception処理 の記事で確認したので飛ばします。
with self.graph.as_default(), ed.interception(interceptors.Chain(*chain)):
self._f(self.cfg)
#4
tensorflowのsessionを実行して、実際にGraph内の演算処理を実行しています。この辺りも今後もう少し理解できてくるはずです。
with self.session.as_default(): returns = self.session.run(list(values_collector.result.values())) return dict(zip(values_collector.result.keys(), returns))
The tf.Session.run method is the main mechanism for running a tf.Operation or evaluating a tf.Tensor. You can pass one or more tf.Operation or tf.Tensor objects to tf.Session.run, and TensorFlow will execute the operations that are needed to compute the result.
see: Graphs and Sessions | TensorFlow
このコミットを読み続けるのも疲れたので、次回から次のコミットに進みます。
TODO
- [ ] #1 の
not_observed
の処理の内容確認