オーストラリアで勉強してきたMLデザイナーの口語自由詩

主に、データ分析・機械学習・ベイズ・統計について自由に書く。

テストサンプルの生成 [e334115, d07338e, 93bc07b] - pymc4のソースコード読んでみた

f:id:yukinagae:20171122095115p:plain

概要

  • Model クラスのサンプル生成のメソッドを読んでみます。
    • test_point

コミット

2018/06/09から2018/06/11の間のコミットです。

Model - pymc4/model/base.py

test_point

まずはこの test_point メソッドがでどのように呼ばれるか調べます。pymc4のプロジェクト内では以下の tests/test_model.py のみで使用されています。

@pm.inline
def model(cfg):
    ed.Normal(0., 1., name='normal')

testval_random = model.test_point()
testval_mode = model.test_point(sample=False)

assert testval_mode['normal'] == 0.
assert testval_mode['normal'] != testval_random['normal']

ここでは2種類の呼び方があるみたいです。

  • test_point(): デフォルトでsample=True
  • test_point(sample=False)

事前情報として、ed.Normal(0., 1., name='normal') というのは、正規分布で平均が0かつ標準偏差が1という意味です。

test_point(sample=True)

Modelに設定されているRandomVariableのインスタンスの確率分布に沿って、1つの乱数が生成されます。

例) 正規分布で平均が0かつ標準偏差が1の確率分布から乱数を生成する

test_point(sample=False)

Modelに設定されているRandomVariableのインスタンスの確率分布に沿って、中央値が返されます。

例) 正規分布で平均が0かつ標準偏差が1の確率分布の中央値は0なので、常に0が返される


さっと test_point メソッドの実装も見てみます。

以下がそのメソッドの全容です。

def test_point(self, sample=True):
    # 1
    def not_observed(var, *args, **kwargs):
        return kwargs['name'] not in self.observed
    values_collector = interceptors.CollectVariables(filter=not_observed)
    chain = [values_collector]

    # 2
    if not sample:
        def get_mode(state, rv, *args, **kwargs):
            return rv.distribution.mode()
        chain.insert(0, interceptors.Generic(after=get_mode))

    # 3
    with self.graph.as_default(), ed.interception(interceptors.Chain(*chain)):
        self._f(self.cfg)

    # 4
    with self.session.as_default():
        returns = self.session.run(list(values_collector.result.values()))
    return dict(zip(values_collector.result.keys(), returns))

一つひとつバラバラに見ていきます。

#1

この辺りの挙動はあまりよくわかっていないので飛ばします。

    def not_observed(var, *args, **kwargs):
        return kwargs['name'] not in self.observed
    values_collector = interceptors.CollectVariables(filter=not_observed)
    chain = [values_collector]

#2

test_point(sample=False) の場合の処理です。やはり予想通り、中央値を返すようになっています。

    if not sample:
        def get_mode(state, rv, *args, **kwargs):
            return rv.distribution.mode()
        chain.insert(0, interceptors.Generic(after=get_mode))

#3

以前、 edward2のinterception処理 の記事で確認したので飛ばします。

    with self.graph.as_default(), ed.interception(interceptors.Chain(*chain)):
        self._f(self.cfg)

#4

tensorflowのsessionを実行して、実際にGraph内の演算処理を実行しています。この辺りも今後もう少し理解できてくるはずです。

    with self.session.as_default():
        returns = self.session.run(list(values_collector.result.values()))
    return dict(zip(values_collector.result.keys(), returns))

The tf.Session.run method is the main mechanism for running a tf.Operation or evaluating a tf.Tensor. You can pass one or more tf.Operation or tf.Tensor objects to tf.Session.run, and TensorFlow will execute the operations that are needed to compute the result.

see: Graphs and Sessions  |  TensorFlow


このコミットを読み続けるのも疲れたので、次回から次のコミットに進みます。


TODO

  • [ ] #1not_observed の処理の内容確認

参考資料