対数確率関数の計算をインターセプター処理 [f223e4e] - pymc4のソースコード読んでみた
TL;DR
- 変数に対する処理は基本的にインターセプターで対応する方針のため, 対数確率も同様に対応.
コミット
2018/07/01のコミットです.
以下が変更対象ファイルです.
- pymc4/init.py
- import追加:
from .inference.sampling.sample import sample
- import追加:
- pymc4/model/base.py
- pymc4/util/interceptors.py
target_log_prob_fn - pymc4/model/base.py
states
変数に unobserved
と observed
のkey:valueを格納し, interceptors.CollectLogProb
で states
のインターセプト処理をしている. この CollectLogProb
内で対数確率の計算を行っている. その結果がリストで返るため, 最後に sum
で合計を計算している.
def target_log_prob_fn(self, *args, **kwargs): # pylint: disable=unused-argument """ Pass the states of the RVs as args in alphabetical order of the RVs. Compatible as `target_log_prob_fn` for tfp samplers. """ def log_joint_fn(*args, **kwargs): # pylint: disable=unused-argument states = dict(zip(self.unobserved.keys(), args)) states.update(self.observed) log_probs = [] interceptor = interceptors.CollectLogProb(states) with ed.interception(interceptor): self._f(self._cfg) log_prob = sum(interceptor.log_probs) return log_prob return log_joint_fn
CollectLogProb - pymc4/util/interceptors.py
- 次回読みます.