オーストラリアで勉強してきたデータサイエンティストの口語自由詩

主に、ベイズ・統計・データ分析・機械学習について自由に書く。

target_log_prob_fn: MCMCサンプリングの実装 [66f95ca, 8aaa0ff, cded7c5] - pymc4のソースコード読んでみた

f:id:yukinagae:20171122095115p:plain

TL;DR

  • 前回の log_prob_fn と同様に複数の未観測変数の対数確率の合計を計算している様子なので細かい点は省略

コミット

2018/06/21から2018/06/23の間のコミットです.

target_log_prob_fn

前回の log_prob_fn と同様に複数の未観測変数の対数確率の合計を計算しているように見えます. ただ今回の target_log_prob_fn では値ではなく関数を返している点が異なります.

コメントを読むと Unnormalized target density as a function of unobserved states. とあるため、未観測変数の正規化していない対数確率(正規化していないので正確には確率とは呼ばないはず)を計算しているはずです.

def target_log_prob_fn(self, *args, **kwargs):
    """
    Unnormalized target density as a function of unobserved states.
    """

    def log_joint_fn(*args, **kwargs):
        states = dict(zip(self.unobserved.keys(), args))
        states.update(self.observed)
        log_probs = []

        def interceptor(f, *args, **kwargs):
            name = kwargs.get("name")
            for name in states:
                value = states[name]
                if kwargs.get("name") == name:
                    kwargs["value"] = value
            rv = f(*args, **kwargs)
            log_prob = tf.reduce_sum(rv.distribution.log_prob(rv.value))
            log_probs.append(log_prob)
            return rv

        with ed.interception(interceptor):
            self._f(self._cfg)
        log_prob = sum(log_probs)
        return log_prob
    
    return log_joint_fn

TODO

  • [ ] states.update(self.observed) で観測変数も対象にしている理由が不明
  • [ ] log_prob_fn との使い分けが不明

参考資料