オーストラリアで勉強してきたMLデザイナーの口語自由詩

主に、データ分析・機械学習・ベイズ・統計について自由に書く。

対数確率関数の計算をインターセプター処理 [f223e4e] - pymc4のソースコード読んでみた

f:id:yukinagae:20171122095115p:plain

TL;DR

  • 変数に対する処理は基本的にインターセプターで対応する方針のため, 対数確率も同様に対応.

コミット

2018/07/01のコミットです.

以下が変更対象ファイルです.

  • pymc4/init.py
    • import追加: from .inference.sampling.sample import sample
  • pymc4/model/base.py
  • pymc4/util/interceptors.py

target_log_prob_fn - pymc4/model/base.py

states 変数に unobservedobserved のkey:valueを格納し, interceptors.CollectLogProbstatesインターセプト処理をしている. この CollectLogProb 内で対数確率の計算を行っている. その結果がリストで返るため, 最後に sum で合計を計算している.

    def target_log_prob_fn(self, *args, **kwargs):  # pylint: disable=unused-argument
        """
        Pass the states of the RVs as args in alphabetical order of the RVs.
        Compatible as `target_log_prob_fn` for tfp samplers.
        """

        def log_joint_fn(*args, **kwargs):  # pylint: disable=unused-argument
            states = dict(zip(self.unobserved.keys(), args))
            states.update(self.observed)
            log_probs = []
            interceptor = interceptors.CollectLogProb(states)
            with ed.interception(interceptor):
                self._f(self._cfg)

            log_prob = sum(interceptor.log_probs)
            return log_prob
        return log_joint_fn

CollectLogProb - pymc4/util/interceptors.py

  • 次回読みます.

参考資料