target_log_prob_fn: MCMCサンプリングの実装 [66f95ca, 8aaa0ff, cded7c5] - pymc4のソースコード読んでみた
TL;DR
- 前回の
log_prob_fn
と同様に複数の未観測変数の対数確率の合計を計算している様子なので細かい点は省略
コミット
2018/06/21から2018/06/23の間のコミットです.
- add target_log_prob_fn which works with the tff mcmc sampler · pymc-devs/pymc4@66f95ca · GitHub
- modify target_log_prob_fn · pymc-devs/pymc4@8aaa0ff · GitHub
- remove log_prob_fn · pymc-devs/pymc4@cded7c5 · GitHub
target_log_prob_fn
前回の log_prob_fn
と同様に複数の未観測変数の対数確率の合計を計算しているように見えます. ただ今回の target_log_prob_fn
では値ではなく関数を返している点が異なります.
コメントを読むと Unnormalized target density as a function of unobserved states.
とあるため、未観測変数の正規化していない対数確率(正規化していないので正確には確率とは呼ばないはず)を計算しているはずです.
def target_log_prob_fn(self, *args, **kwargs): """ Unnormalized target density as a function of unobserved states. """ def log_joint_fn(*args, **kwargs): states = dict(zip(self.unobserved.keys(), args)) states.update(self.observed) log_probs = [] def interceptor(f, *args, **kwargs): name = kwargs.get("name") for name in states: value = states[name] if kwargs.get("name") == name: kwargs["value"] = value rv = f(*args, **kwargs) log_prob = tf.reduce_sum(rv.distribution.log_prob(rv.value)) log_probs.append(log_prob) return rv with ed.interception(interceptor): self._f(self._cfg) log_prob = sum(log_probs) return log_prob return log_joint_fn
TODO
- [ ]
states.update(self.observed)
で観測変数も対象にしている理由が不明 - [ ]
log_prob_fn
との使い分けが不明