Add model.target_log_prob_fn() sampling [a703c21] - pymc4のソースコード読んでみた
TL;DR
- ターゲットとなる
unobserved(未観測なRandomVariableインスタンス)
の対数確率の合計を返すメソッドを実装しています.
コミット
2018/06/18のコミットです.
以下ファイルが修正されています.
- pymc4/model/base.py:
Model
クラスに次の2つのメソッドが追加- target_log_prob_fn(self, args, kwargs)
- unobserved(self) *@property
- pymc4/util/interceptors.py
VariableDescription
クラスにrv(RandomVariableのインスタンス)
のプロパティが追加されただけです. この追加プロパティはtarget_log_prob_fn
で使用されています(後述).
target_log_prob_fn
unobserved(未観測なrv)
の log_probability(対数確率)
の合計を返します.
※ちなみにちょっとした罠ですが, 変数名の i
というのはこのコードの実装では index
ではなく, dictの key
なので, 単に k
などとした方がよいです.
def target_log_prob_fn(self, *args, **kwargs): logp = 0 for i in self.unobserved.keys(): print(kwargs.get(i)) logp += self.unobserved[i].rv.distribution.log_prob(value=kwargs.get(i)) return logp
unobserved
self.variables
(RandomVariableインスタンスのリスト)から unobserved
を抽出しているだけです(= observed
ではないものを抽出).
ちなみに, わざわざ collections.OrderedDict
で順序を保持したdictionary配列にしている理由はこの時点ではよくわかりません.
@property def unobserved(self): unobserved = {} for i in self.variables: if self.variables[i] not in self.observed.values(): unobserved[i] = self.variables[i] unobserved = collections.OrderedDict(unobserved) return unobserved
参考資料
- Get mcmc sampling to work by sharanry · Pull Request #9 · pymc-devs/pymc4 · GitHub
- 確率質量関数 - Wikipedia: probability mass function(PMF)- 離散値用の確率関数
- 確率密度関数 - Wikipedia: probability density function(PDF)- 連続値用の確率変数