オーストラリアで勉強してきたデータサイエンティストの口語自由詩

主に、ベイズ・統計・データ分析・機械学習について自由に書く。

対数確率計算用インターセプタ(実装はちょっと怪しい)[f223e4e] - pymc4のソースコード読んでみた

f:id:yukinagae:20171122095115p:plain

TL;DR

  • 対数確率計算用インターセプタの実装を読んだが, なんだか実装が怪しい気がする.

コミット

2018/07/01のコミットです.

以下が変更対象ファイルです.

  • pymc4/init.py
    • import追加: from .inference.sampling.sample import sample
  • pymc4/model/base.py
  • pymc4/util/interceptors.py

CollectLogProb - pymc4/util/interceptors.py

CollectLogProb クラスの親クラスは SetState なのでそちらから確認する.

class SetState(Interceptor):
    def __init__(self, state):
        self.state = state

    def before(self, f, *args, **kwargs):
        if kwargs['name'] in self.state:
            kwargs['value'] = self.state[kwargs['name']]
        return f, args, kwargs

以前の記事 edward2のinterception処理 を読み返すと, 次のような関数の引数が **kwargs に対応する. 上記の SetState では name がkeyword引数で渡される. 例) name=“y”.

ed.Poisson(rate=1.5, name="y")

擬似コードで説明してみる.

# keyword引数
kwargs = {"name": "y"}

# すべてのRV変数の辞書
state = {"x": "foo", "y": "bar"}

if kwargs['name'] in state: # kwargs['name'] == "y"
    kwargs['value'] = state[kwargs['name']]

print(kwargs) # => {"name": "y", "value": "bar"}

次に CollectLogProb の方を見てみる. 単にRV変数すべての対数確率を計算して足してるだけのはずです.

class CollectLogProb(SetState):
    def __init__(self, states):
        super().__init__(states)
        self.log_probs = []

    def before(self, f, *args, **kwargs):
        if kwargs['name'] not in self.state:
            raise RuntimeError(kwargs.get('name'), 'All RV should be present in state dict')
        return super().before(f, *args, **kwargs)

    def after(self, rv, *args, **kwargs):
        name = kwargs.get("name")
        for name in self.state:
            value = self.state[name]
            if kwargs.get("name") == name:
                kwargs["value"] = value
        log_prob = tf.reduce_sum(rv.distribution.log_prob(rv.value))
        self.log_probs.append(log_prob)
        return rv

    @property
    def result(self):
        return self.log_probs

しかし, after の処理は少し怪しい気がする.

    def after(self, rv, *args, **kwargs):
        name = kwargs.get("name")
        for name in self.state:
            value = self.state[name]
            if kwargs.get("name") == name:
                kwargs["value"] = value
        log_prob = tf.reduce_sum(rv.distribution.log_prob(rv.value))
        self.log_probs.append(log_prob)
        return rv

例えば, kwargs.get("name") の処理がおかしい気がする. name の値を辿ってみる.

name = kwargs.get("name") # 仮にname変数の値を `test` とする.
for name in self.state: # self.stateの中に `test` というkeyが存在するかチェック(ifでよいのでは?)
    value = self.state[name]
    if kwargs.get("name") == name: # name変数の値は `test` で、かつ `kwargs.get("name")` と同じ値なのでif文は不要では?
        kwargs["value"] = value

いろいろ気になるが, 次のように書き換えられる気がする. log_prob の処理もif文の中に含めないと, state に存在しないrv変数の対数確率も計算してしまう.

    def after(self, rv, *args, **kwargs):
        name = kwargs.get("name")
        if name in self.state:
            value = self.state[name]
            kwargs["value"] = value # TODO: このvalueをどこで使用してるかよくわかっていない
            log_prob = tf.reduce_sum(rv.distribution.log_prob(rv.value))
            self.log_probs.append(log_prob)
        return rv

TODO

  • [ ] Modelクラスの self.observed + self.unobserved = self.varaibles のように見えるので, states = self.variables と単にしてよいのではないかという疑問
  • [ ] CollectLogProbafter 処理の挙動が怪しい気がするので確認したい

参考資料