対数確率計算用インターセプタ(実装はちょっと怪しい)[f223e4e] - pymc4のソースコード読んでみた
TL;DR
- 対数確率計算用インターセプタの実装を読んだが, なんだか実装が怪しい気がする.
コミット
2018/07/01のコミットです.
以下が変更対象ファイルです.
- pymc4/init.py
- import追加:
from .inference.sampling.sample import sample
- import追加:
- pymc4/model/base.py
- pymc4/util/interceptors.py
CollectLogProb - pymc4/util/interceptors.py
CollectLogProb
クラスの親クラスは SetState
なのでそちらから確認する.
class SetState(Interceptor): def __init__(self, state): self.state = state def before(self, f, *args, **kwargs): if kwargs['name'] in self.state: kwargs['value'] = self.state[kwargs['name']] return f, args, kwargs
以前の記事 edward2のinterception処理 を読み返すと, 次のような関数の引数が **kwargs
に対応する. 上記の SetState
では name
がkeyword引数で渡される. 例) name=“y”
.
ed.Poisson(rate=1.5, name="y")
擬似コードで説明してみる.
# keyword引数 kwargs = {"name": "y"} # すべてのRV変数の辞書 state = {"x": "foo", "y": "bar"} if kwargs['name'] in state: # kwargs['name'] == "y" kwargs['value'] = state[kwargs['name']] print(kwargs) # => {"name": "y", "value": "bar"}
次に CollectLogProb
の方を見てみる. 単にRV変数すべての対数確率を計算して足してるだけのはずです.
class CollectLogProb(SetState): def __init__(self, states): super().__init__(states) self.log_probs = [] def before(self, f, *args, **kwargs): if kwargs['name'] not in self.state: raise RuntimeError(kwargs.get('name'), 'All RV should be present in state dict') return super().before(f, *args, **kwargs) def after(self, rv, *args, **kwargs): name = kwargs.get("name") for name in self.state: value = self.state[name] if kwargs.get("name") == name: kwargs["value"] = value log_prob = tf.reduce_sum(rv.distribution.log_prob(rv.value)) self.log_probs.append(log_prob) return rv @property def result(self): return self.log_probs
しかし, after
の処理は少し怪しい気がする.
def after(self, rv, *args, **kwargs): name = kwargs.get("name") for name in self.state: value = self.state[name] if kwargs.get("name") == name: kwargs["value"] = value log_prob = tf.reduce_sum(rv.distribution.log_prob(rv.value)) self.log_probs.append(log_prob) return rv
例えば, kwargs.get("name")
の処理がおかしい気がする. name
の値を辿ってみる.
name = kwargs.get("name") # 仮にname変数の値を `test` とする. for name in self.state: # self.stateの中に `test` というkeyが存在するかチェック(ifでよいのでは?) value = self.state[name] if kwargs.get("name") == name: # name変数の値は `test` で、かつ `kwargs.get("name")` と同じ値なのでif文は不要では? kwargs["value"] = value
いろいろ気になるが, 次のように書き換えられる気がする. log_prob
の処理もif文の中に含めないと, state
に存在しないrv変数の対数確率も計算してしまう.
def after(self, rv, *args, **kwargs): name = kwargs.get("name") if name in self.state: value = self.state[name] kwargs["value"] = value # TODO: このvalueをどこで使用してるかよくわかっていない log_prob = tf.reduce_sum(rv.distribution.log_prob(rv.value)) self.log_probs.append(log_prob) return rv
TODO
- [ ] Modelクラスの
self.observed + self.unobserved = self.varaibles
のように見えるので,states = self.variables
と単にしてよいのではないかという疑問 - [ ]
CollectLogProb
のafter
処理の挙動が怪しい気がするので確認したい