対数確率計算用インターセプタ(実装はちょっと怪しい)[f223e4e] - pymc4のソースコード読んでみた
TL;DR
- 対数確率計算用インターセプタの実装を読んだが, なんだか実装が怪しい気がする.
コミット
2018/07/01のコミットです.
以下が変更対象ファイルです.
- pymc4/init.py
- import追加:
from .inference.sampling.sample import sample
- import追加:
- pymc4/model/base.py
- pymc4/util/interceptors.py
CollectLogProb - pymc4/util/interceptors.py
CollectLogProb
クラスの親クラスは SetState
なのでそちらから確認する.
class SetState(Interceptor): def __init__(self, state): self.state = state def before(self, f, *args, **kwargs): if kwargs['name'] in self.state: kwargs['value'] = self.state[kwargs['name']] return f, args, kwargs
以前の記事 edward2のinterception処理 を読み返すと, 次のような関数の引数が **kwargs
に対応する. 上記の SetState
では name
がkeyword引数で渡される. 例) name=“y”
.
ed.Poisson(rate=1.5, name="y")
擬似コードで説明してみる.
# keyword引数 kwargs = {"name": "y"} # すべてのRV変数の辞書 state = {"x": "foo", "y": "bar"} if kwargs['name'] in state: # kwargs['name'] == "y" kwargs['value'] = state[kwargs['name']] print(kwargs) # => {"name": "y", "value": "bar"}
次に CollectLogProb
の方を見てみる. 単にRV変数すべての対数確率を計算して足してるだけのはずです.
class CollectLogProb(SetState): def __init__(self, states): super().__init__(states) self.log_probs = [] def before(self, f, *args, **kwargs): if kwargs['name'] not in self.state: raise RuntimeError(kwargs.get('name'), 'All RV should be present in state dict') return super().before(f, *args, **kwargs) def after(self, rv, *args, **kwargs): name = kwargs.get("name") for name in self.state: value = self.state[name] if kwargs.get("name") == name: kwargs["value"] = value log_prob = tf.reduce_sum(rv.distribution.log_prob(rv.value)) self.log_probs.append(log_prob) return rv @property def result(self): return self.log_probs
しかし, after
の処理は少し怪しい気がする.
def after(self, rv, *args, **kwargs): name = kwargs.get("name") for name in self.state: value = self.state[name] if kwargs.get("name") == name: kwargs["value"] = value log_prob = tf.reduce_sum(rv.distribution.log_prob(rv.value)) self.log_probs.append(log_prob) return rv
例えば, kwargs.get("name")
の処理がおかしい気がする. name
の値を辿ってみる.
name = kwargs.get("name") # 仮にname変数の値を `test` とする. for name in self.state: # self.stateの中に `test` というkeyが存在するかチェック(ifでよいのでは?) value = self.state[name] if kwargs.get("name") == name: # name変数の値は `test` で、かつ `kwargs.get("name")` と同じ値なのでif文は不要では? kwargs["value"] = value
いろいろ気になるが, 次のように書き換えられる気がする. log_prob
の処理もif文の中に含めないと, state
に存在しないrv変数の対数確率も計算してしまう.
def after(self, rv, *args, **kwargs): name = kwargs.get("name") if name in self.state: value = self.state[name] kwargs["value"] = value # TODO: このvalueをどこで使用してるかよくわかっていない log_prob = tf.reduce_sum(rv.distribution.log_prob(rv.value)) self.log_probs.append(log_prob) return rv
TODO
- [ ] Modelクラスの
self.observed + self.unobserved = self.varaibles
のように見えるので,states = self.variables
と単にしてよいのではないかという疑問 - [ ]
CollectLogProb
のafter
処理の挙動が怪しい気がするので確認したい
参考資料
対数確率関数の計算をインターセプター処理 [f223e4e] - pymc4のソースコード読んでみた
TL;DR
- 変数に対する処理は基本的にインターセプターで対応する方針のため, 対数確率も同様に対応.
コミット
2018/07/01のコミットです.
以下が変更対象ファイルです.
- pymc4/init.py
- import追加:
from .inference.sampling.sample import sample
- import追加:
- pymc4/model/base.py
- pymc4/util/interceptors.py
target_log_prob_fn - pymc4/model/base.py
states
変数に unobserved
と observed
のkey:valueを格納し, interceptors.CollectLogProb
で states
のインターセプト処理をしている. この CollectLogProb
内で対数確率の計算を行っている. その結果がリストで返るため, 最後に sum
で合計を計算している.
def target_log_prob_fn(self, *args, **kwargs): # pylint: disable=unused-argument """ Pass the states of the RVs as args in alphabetical order of the RVs. Compatible as `target_log_prob_fn` for tfp samplers. """ def log_joint_fn(*args, **kwargs): # pylint: disable=unused-argument states = dict(zip(self.unobserved.keys(), args)) states.update(self.observed) log_probs = [] interceptor = interceptors.CollectLogProb(states) with ed.interception(interceptor): self._f(self._cfg) log_prob = sum(interceptor.log_probs) return log_prob return log_joint_fn
CollectLogProb - pymc4/util/interceptors.py
- 次回読みます.
参考資料
pycodestyle追加ともろもろ[a7cef9b, bd381b1, 89edc5c, 9f46878] - pymc4のソースコード読んでみた
TL;DR
pycodestyle
でコードチェックの追加- 不要なテストの削除
コミット
2018/07/01のコミットです.
- remove some tests · pymc-devs/pymc4@a7cef9b · GitHub
- fix pycodestyle errors · pymc-devs/pymc4@bd381b1 · GitHub
- remove test_interceptors · pymc-devs/pymc4@89edc5c · GitHub
- add pycodestyle to dev requirements · pymc-devs/pymc4@9f46878 · GitHub
以下が各コミットの概要です.
- a7cef9b
- クラスのインスタンス化のテストなど, やり過ぎなテストを削除
- bd381b1
pycodestyle
でコードスタイルをチェックした対応
- 89edc5c
a7cef9b
と同様にクラスのインスタンス化のテストを削除
- 9f46878
pycodestyle
をrequirements-dev.txtに追加
pycodestyle - 9f46878
pycodestyleとは?
pycodestyle: GitHub - PyCQA/pycodestyle: Simple Python style checker in one Python file
一言で言うと, pep8がpycodestyleという名前に変わった
だけらしいです.
see: pep8 が pycodestyle に変わった話
以下のようにコードチェックをできるようです.
例)複数importする際には, ,
区切りは不可
$ pycodestyle --show-source --show-pep8 testsuite/E40.py testsuite/E40.py:2:10: E401 multiple imports on one line import os, sys ^ Imports should usually be on separate lines. Okay: import os\nimport sys E401: import sys, os
参考資料
(細かい修正なのであまり重要ではない) [c27e97f, a8e3dae, 4f5382f, e06d946, ca9f334] - pymc4のソースコード読んでみた
TL;DR
- 細かい修正が多いのであまり重要ではないです.
コミット
2018/06/26から2018/06/30の間のコミットです.
- minor fixes · pymc-devs/pymc4@c27e97f · GitHub
- add some tests · pymc-devs/pymc4@a8e3dae · GitHub
- add tests and fix model.configure() · pymc-devs/pymc4@4f5382f · GitHub
- fix lint errors · pymc-devs/pymc4@e06d946 · GitHub
- add tests · pymc-devs/pymc4@ca9f334 · GitHub
minor fixes - c27e97f
- pymc4/inference_sampling_sample.py
- dictのfor文の処理で不要な変数を減らしているだけです.
- pymc4/model/base.py
- target_log_prob_fnメソッド: この辺りは変更は多いがやっていることは変わらないため省略
- unobservedプロパティ: 順序付きdictの初期化の処理を修正
unobserved = {}
とdictを初期化するとpython3.5以下では順序無しのdictになる点を修正- see: Get mcmc sampling to work by sharanry · Pull Request #9 · pymc-devs/pymc4 · GitHub
add some tests - a8e3dae
- 実際にはテストの追加ではなくリファクタリングをしているだけなので省略
add tests and fix model.configure() - 4f5382f
次のようなテストが複数追加されている.
- Modelクラスのインスタンス化
@model.define
でed.Normal(0., 1., name='normal')
を生成したModelインスタンスにセット- assertでModelインスタンスのプロパティをチェック
model = pm.Model() @model.define def simple(cfg): ed.Normal(0., 1., name='normal') assert len(model.variables) == 1 assert len(model.unobserved) == 1 assert "normal" in model.variables
fix lint errors - e06d946
lint対応なので省略: #pylint: disable=unused-argument
でlintをdisableしている箇所もある.
add tests - ca9f334
テストが複数追加されている.
例えば, 次のテストは(対数)確率関数の値を近似的(approx)にチェックしている.
def test_model_log_prob_fn(): model = pm.Model() @model.define def simple(cfg): mu = ed.Normal(0., 1., name="mu") log_prob_fn = model.target_log_prob_fn() with tf.Session(): assert -0.91893853 == pytest.approx(log_prob_fn(0).eval(), 0.00001)
念のためにscipy.statsのdistributionモジュールで値をチェックしたみた.
from scipy.stats import norm print(norm(0, 1).logpdf(0)) # => -0.9189385332046727
参考資料
target_log_prob_fn: MCMCサンプリングの実装 [66f95ca, 8aaa0ff, cded7c5] - pymc4のソースコード読んでみた
TL;DR
- 前回の
log_prob_fn
と同様に複数の未観測変数の対数確率の合計を計算している様子なので細かい点は省略
コミット
2018/06/21から2018/06/23の間のコミットです.
- add target_log_prob_fn which works with the tff mcmc sampler · pymc-devs/pymc4@66f95ca · GitHub
- modify target_log_prob_fn · pymc-devs/pymc4@8aaa0ff · GitHub
- remove log_prob_fn · pymc-devs/pymc4@cded7c5 · GitHub
target_log_prob_fn
前回の log_prob_fn
と同様に複数の未観測変数の対数確率の合計を計算しているように見えます. ただ今回の target_log_prob_fn
では値ではなく関数を返している点が異なります.
コメントを読むと Unnormalized target density as a function of unobserved states.
とあるため、未観測変数の正規化していない対数確率(正規化していないので正確には確率とは呼ばないはず)を計算しているはずです.
def target_log_prob_fn(self, *args, **kwargs): """ Unnormalized target density as a function of unobserved states. """ def log_joint_fn(*args, **kwargs): states = dict(zip(self.unobserved.keys(), args)) states.update(self.observed) log_probs = [] def interceptor(f, *args, **kwargs): name = kwargs.get("name") for name in states: value = states[name] if kwargs.get("name") == name: kwargs["value"] = value rv = f(*args, **kwargs) log_prob = tf.reduce_sum(rv.distribution.log_prob(rv.value)) log_probs.append(log_prob) return rv with ed.interception(interceptor): self._f(self._cfg) log_prob = sum(log_probs) return log_prob return log_joint_fn
TODO
- [ ]
states.update(self.observed)
で観測変数も対象にしている理由が不明 - [ ]
log_prob_fn
との使い分けが不明
参考資料
log_prob_fn: MCMCサンプリングの実装 [66f95ca, 8aaa0ff, cded7c5] - pymc4のソースコード読んでみた
TL;DR
- 複数の未観測変数の対数確率の合計を計算しています. これがMCMCサンプリングのコアの部分です.
コミット
2018/06/21から2018/06/23の間のコミットです.
- add target_log_prob_fn which works with the tff mcmc sampler · pymc-devs/pymc4@66f95ca · GitHub
- modify target_log_prob_fn · pymc-devs/pymc4@8aaa0ff · GitHub
- remove log_prob_fn · pymc-devs/pymc4@cded7c5 · GitHub
以下ファイルが修正されています.
- pymc4/model/base.py: 対数確率を計算するメソッドを2つ実装
- log_prob_fn <= 今回はここを読みます.
- target_log_prob_fn <= 次回
- pymc4/inference_sampling_sample.py
iteritems()
の代わりにkeys()
を使うことで未使用の変数を減らしているだけです.
Model - model/base.py
log_prob_fn
複数の未観測変数の対数確率の合計を計算しています.
注意) x
引数は関数内で使用されていないので, 後のコミットで削除されています.
def log_prob_fn(self, x, *args, **kwargs): logp = 0 for i in self.unobserved.keys(): logp += self.unobserved[i].rv.distribution.log_prob(value=kwargs.get(i)) return logp
前提条件として, 次のように掛け算(x)は値のlogを取ることで足し算(+)に変換できます.
- 変数1の確率関数 x 変数2の確率関数 x 変数3の確率関数 x … x 変数nの確率関数 ↓
- 変数1の対数確率関数 + 変数2の対数確率関数 + 変数3の対数確率関数 + … x 変数nの対数確率関数
参考資料
PyMC4のInstallation failsというissueに対応するPR送った
TL;DR
次のissueにあるように, 現状だと依存性の解決の部分でfailしてインストールできないのでとりあえずforkして dependency-resolution
というbranchで修正してみた.
see: Installation fails · Issue #23 · pymc-devs/pymc4 · GitHub
修正点
これを
tf-nightly==1.9.0.dev20180607 tfp-nightly==0.3.0.dev20180725 tb-nightly==1.9.0a20180613
こうするだけの簡単なお仕事.
tf-nightly tfp-nightly tb-nightly
とりあえず本体にPR送った(`・ω・´)
おまけ
pip
コマンドよくわかってないオプションがあったので軽く調べた.
pip install --user git+https://github.com/pymc-devs/pymc4.git#egg=pymc4
- —user
- ユーザのローカルディレクトリ(
~/.local/
など)に対象のパッケージをインストールする
- ユーザのローカルディレクトリ(
- egg=[プロジェクト名を明示的に指定]
また, 次のようにbranch名も指定できる * @[branch名]
pip install --user git+https://github.com/yukinagae/pymc4@dependecny-resolution.git#egg=pymc4