オーストラリアで勉強してきたMLデザイナーの口語自由詩

主に、データ分析・機械学習・ベイズ・統計について自由に書く。

Add model.target_log_prob_fn() sampling [a703c21] - pymc4のソースコード読んでみた

f:id:yukinagae:20171122095115p:plain

TL;DR

  • ターゲットとなる unobserved(未観測なRandomVariableインスタンス) の対数確率の合計を返すメソッドを実装しています.

コミット

2018/06/18のコミットです.

以下ファイルが修正されています.

  • pymc4/model/base.py: Model クラスに次の2つのメソッドが追加
    • target_log_prob_fn(self, args, kwargs)
    • unobserved(self) *@property
  • pymc4/util/interceptors.py
    • VariableDescription クラスに rv(RandomVariableのインスタンス) のプロパティが追加されただけです. この追加プロパティは target_log_prob_fn で使用されています(後述).

target_log_prob_fn

unobserved(未観測なrv)log_probability(対数確率) の合計を返します.

※ちなみにちょっとした罠ですが, 変数名の i というのはこのコードの実装では index ではなく, dictの key なので, 単に k などとした方がよいです.

    def target_log_prob_fn(self, *args, **kwargs):
        logp = 0
        for i in self.unobserved.keys():
            print(kwargs.get(i))
            logp += self.unobserved[i].rv.distribution.log_prob(value=kwargs.get(i))
         return logp

unobserved

self.variables (RandomVariableインスタンスのリスト)から unobserved を抽出しているだけです(= observed ではないものを抽出).

ちなみに, わざわざ collections.OrderedDict で順序を保持したdictionary配列にしている理由はこの時点ではよくわかりません.

    @property
    def unobserved(self):
        unobserved = {}
        for i in self.variables:
            if self.variables[i] not in self.observed.values():
                unobserved[i] = self.variables[i]
         unobserved = collections.OrderedDict(unobserved)
        return unobserved

参考資料