オーストラリアで勉強してきたMLデザイナーの口語自由詩

主に、データ分析・機械学習・ベイズ・統計について自由に書く。

対数確率計算用インターセプタ(実装はちょっと怪しい)[f223e4e] - pymc4のソースコード読んでみた

f:id:yukinagae:20171122095115p:plain

TL;DR

  • 対数確率計算用インターセプタの実装を読んだが, なんだか実装が怪しい気がする.

コミット

2018/07/01のコミットです.

以下が変更対象ファイルです.

  • pymc4/init.py
    • import追加: from .inference.sampling.sample import sample
  • pymc4/model/base.py
  • pymc4/util/interceptors.py

CollectLogProb - pymc4/util/interceptors.py

CollectLogProb クラスの親クラスは SetState なのでそちらから確認する.

class SetState(Interceptor):
    def __init__(self, state):
        self.state = state

    def before(self, f, *args, **kwargs):
        if kwargs['name'] in self.state:
            kwargs['value'] = self.state[kwargs['name']]
        return f, args, kwargs

以前の記事 edward2のinterception処理 を読み返すと, 次のような関数の引数が **kwargs に対応する. 上記の SetState では name がkeyword引数で渡される. 例) name=“y”.

ed.Poisson(rate=1.5, name="y")

擬似コードで説明してみる.

# keyword引数
kwargs = {"name": "y"}

# すべてのRV変数の辞書
state = {"x": "foo", "y": "bar"}

if kwargs['name'] in state: # kwargs['name'] == "y"
    kwargs['value'] = state[kwargs['name']]

print(kwargs) # => {"name": "y", "value": "bar"}

次に CollectLogProb の方を見てみる. 単にRV変数すべての対数確率を計算して足してるだけのはずです.

class CollectLogProb(SetState):
    def __init__(self, states):
        super().__init__(states)
        self.log_probs = []

    def before(self, f, *args, **kwargs):
        if kwargs['name'] not in self.state:
            raise RuntimeError(kwargs.get('name'), 'All RV should be present in state dict')
        return super().before(f, *args, **kwargs)

    def after(self, rv, *args, **kwargs):
        name = kwargs.get("name")
        for name in self.state:
            value = self.state[name]
            if kwargs.get("name") == name:
                kwargs["value"] = value
        log_prob = tf.reduce_sum(rv.distribution.log_prob(rv.value))
        self.log_probs.append(log_prob)
        return rv

    @property
    def result(self):
        return self.log_probs

しかし, after の処理は少し怪しい気がする.

    def after(self, rv, *args, **kwargs):
        name = kwargs.get("name")
        for name in self.state:
            value = self.state[name]
            if kwargs.get("name") == name:
                kwargs["value"] = value
        log_prob = tf.reduce_sum(rv.distribution.log_prob(rv.value))
        self.log_probs.append(log_prob)
        return rv

例えば, kwargs.get("name") の処理がおかしい気がする. name の値を辿ってみる.

name = kwargs.get("name") # 仮にname変数の値を `test` とする.
for name in self.state: # self.stateの中に `test` というkeyが存在するかチェック(ifでよいのでは?)
    value = self.state[name]
    if kwargs.get("name") == name: # name変数の値は `test` で、かつ `kwargs.get("name")` と同じ値なのでif文は不要では?
        kwargs["value"] = value

いろいろ気になるが, 次のように書き換えられる気がする. log_prob の処理もif文の中に含めないと, state に存在しないrv変数の対数確率も計算してしまう.

    def after(self, rv, *args, **kwargs):
        name = kwargs.get("name")
        if name in self.state:
            value = self.state[name]
            kwargs["value"] = value # TODO: このvalueをどこで使用してるかよくわかっていない
            log_prob = tf.reduce_sum(rv.distribution.log_prob(rv.value))
            self.log_probs.append(log_prob)
        return rv

TODO

  • [ ] Modelクラスの self.observed + self.unobserved = self.varaibles のように見えるので, states = self.variables と単にしてよいのではないかという疑問
  • [ ] CollectLogProbafter 処理の挙動が怪しい気がするので確認したい

参考資料

対数確率関数の計算をインターセプター処理 [f223e4e] - pymc4のソースコード読んでみた

f:id:yukinagae:20171122095115p:plain

TL;DR

  • 変数に対する処理は基本的にインターセプターで対応する方針のため, 対数確率も同様に対応.

コミット

2018/07/01のコミットです.

以下が変更対象ファイルです.

  • pymc4/init.py
    • import追加: from .inference.sampling.sample import sample
  • pymc4/model/base.py
  • pymc4/util/interceptors.py

target_log_prob_fn - pymc4/model/base.py

states 変数に unobservedobserved のkey:valueを格納し, interceptors.CollectLogProbstatesインターセプト処理をしている. この CollectLogProb 内で対数確率の計算を行っている. その結果がリストで返るため, 最後に sum で合計を計算している.

    def target_log_prob_fn(self, *args, **kwargs):  # pylint: disable=unused-argument
        """
        Pass the states of the RVs as args in alphabetical order of the RVs.
        Compatible as `target_log_prob_fn` for tfp samplers.
        """

        def log_joint_fn(*args, **kwargs):  # pylint: disable=unused-argument
            states = dict(zip(self.unobserved.keys(), args))
            states.update(self.observed)
            log_probs = []
            interceptor = interceptors.CollectLogProb(states)
            with ed.interception(interceptor):
                self._f(self._cfg)

            log_prob = sum(interceptor.log_probs)
            return log_prob
        return log_joint_fn

CollectLogProb - pymc4/util/interceptors.py

  • 次回読みます.

参考資料

pycodestyle追加ともろもろ[a7cef9b, bd381b1, 89edc5c, 9f46878] - pymc4のソースコード読んでみた

f:id:yukinagae:20171122095115p:plain

TL;DR

  • pycodestyle でコードチェックの追加
  • 不要なテストの削除

コミット

2018/07/01のコミットです.

以下が各コミットの概要です.

  • a7cef9b
  • bd381b1
    • pycodestyle でコードスタイルをチェックした対応
  • 89edc5c
  • 9f46878
    • pycodestyle をrequirements-dev.txtに追加

pycodestyle - 9f46878

pycodestyleとは?

pycodestyle: GitHub - PyCQA/pycodestyle: Simple Python style checker in one Python file

一言で言うと, pep8がpycodestyleという名前に変わった だけらしいです.

see: pep8 が pycodestyle に変わった話

以下のようにコードチェックをできるようです.

例)複数importする際には, , 区切りは不可

$ pycodestyle --show-source --show-pep8 testsuite/E40.py
testsuite/E40.py:2:10: E401 multiple imports on one line
import os, sys
         ^
    Imports should usually be on separate lines.

    Okay: import os\nimport sys
    E401: import sys, os

参考資料

(細かい修正なのであまり重要ではない) [c27e97f, a8e3dae, 4f5382f, e06d946, ca9f334] - pymc4のソースコード読んでみた

f:id:yukinagae:20171122095115p:plain

TL;DR

  • 細かい修正が多いのであまり重要ではないです.

コミット

2018/06/26から2018/06/30の間のコミットです.

minor fixes - c27e97f

  • pymc4/inference_sampling_sample.py
    • dictのfor文の処理で不要な変数を減らしているだけです.
  • pymc4/model/base.py

add some tests - a8e3dae

add tests and fix model.configure() - 4f5382f

次のようなテストが複数追加されている.

    model = pm.Model()
    @model.define
    def simple(cfg):
        ed.Normal(0., 1., name='normal')
    
    assert len(model.variables) == 1
    assert len(model.unobserved) == 1
    assert "normal" in model.variables

fix lint errors - e06d946

lint対応なので省略: #pylint: disable=unused-argument でlintをdisableしている箇所もある.

add tests - ca9f334

テストが複数追加されている.

例えば, 次のテストは(対数)確率関数の値を近似的(approx)にチェックしている.

def test_model_log_prob_fn():
    model = pm.Model()
    @model.define
    def simple(cfg):
        mu = ed.Normal(0., 1., name="mu")
    
    log_prob_fn = model.target_log_prob_fn()
    with tf.Session():
        assert -0.91893853 == pytest.approx(log_prob_fn(0).eval(), 0.00001)

念のためにscipy.statsのdistributionモジュールで値をチェックしたみた.

from scipy.stats import norm
print(norm(0, 1).logpdf(0)) # => -0.9189385332046727

参考資料

target_log_prob_fn: MCMCサンプリングの実装 [66f95ca, 8aaa0ff, cded7c5] - pymc4のソースコード読んでみた

f:id:yukinagae:20171122095115p:plain

TL;DR

  • 前回の log_prob_fn と同様に複数の未観測変数の対数確率の合計を計算している様子なので細かい点は省略

コミット

2018/06/21から2018/06/23の間のコミットです.

target_log_prob_fn

前回の log_prob_fn と同様に複数の未観測変数の対数確率の合計を計算しているように見えます. ただ今回の target_log_prob_fn では値ではなく関数を返している点が異なります.

コメントを読むと Unnormalized target density as a function of unobserved states. とあるため、未観測変数の正規化していない対数確率(正規化していないので正確には確率とは呼ばないはず)を計算しているはずです.

def target_log_prob_fn(self, *args, **kwargs):
    """
    Unnormalized target density as a function of unobserved states.
    """

    def log_joint_fn(*args, **kwargs):
        states = dict(zip(self.unobserved.keys(), args))
        states.update(self.observed)
        log_probs = []

        def interceptor(f, *args, **kwargs):
            name = kwargs.get("name")
            for name in states:
                value = states[name]
                if kwargs.get("name") == name:
                    kwargs["value"] = value
            rv = f(*args, **kwargs)
            log_prob = tf.reduce_sum(rv.distribution.log_prob(rv.value))
            log_probs.append(log_prob)
            return rv

        with ed.interception(interceptor):
            self._f(self._cfg)
        log_prob = sum(log_probs)
        return log_prob
    
    return log_joint_fn

TODO

  • [ ] states.update(self.observed) で観測変数も対象にしている理由が不明
  • [ ] log_prob_fn との使い分けが不明

参考資料

log_prob_fn: MCMCサンプリングの実装 [66f95ca, 8aaa0ff, cded7c5] - pymc4のソースコード読んでみた

f:id:yukinagae:20171122095115p:plain

TL;DR

  • 複数の未観測変数の対数確率の合計を計算しています. これがMCMCサンプリングのコアの部分です.

コミット

2018/06/21から2018/06/23の間のコミットです.

以下ファイルが修正されています.

  • pymc4/model/base.py: 対数確率を計算するメソッドを2つ実装
    • log_prob_fn <= 今回はここを読みます.
    • target_log_prob_fn <= 次回
  • pymc4/inference_sampling_sample.py
    • iteritems() の代わりに keys() を使うことで未使用の変数を減らしているだけです.

Model - model/base.py

log_prob_fn

複数の未観測変数の対数確率の合計を計算しています.

注意) x 引数は関数内で使用されていないので, 後のコミットで削除されています.

def log_prob_fn(self, x,  *args, **kwargs):
    logp = 0
    for i in self.unobserved.keys():
        logp += self.unobserved[i].rv.distribution.log_prob(value=kwargs.get(i))

    return logp

前提条件として, 次のように掛け算(x)は値のlogを取ることで足し算(+)に変換できます.

  • 変数1の確率関数 x 変数2の確率関数 x 変数3の確率関数 x … x 変数nの確率関数 ↓
  • 変数1の対数確率関数 + 変数2の対数確率関数 + 変数3の対数確率関数 + … x 変数nの対数確率関数

参考資料

PyMC4のInstallation failsというissueに対応するPR送った

TL;DR

次のissueにあるように, 現状だと依存性の解決の部分でfailしてインストールできないのでとりあえずforkして dependency-resolution というbranchで修正してみた.

see: Installation fails · Issue #23 · pymc-devs/pymc4 · GitHub

修正点

これを

tf-nightly==1.9.0.dev20180607
tfp-nightly==0.3.0.dev20180725
tb-nightly==1.9.0a20180613

こうするだけの簡単なお仕事.

tf-nightly
tfp-nightly
tb-nightly

とりあえず本体にPR送った(`・ω・´)

Specified the latest versions of tf/tfp/tfb-nightly by yukinagae · Pull Request #24 · pymc-devs/pymc4 · GitHub

おまけ

pip コマンドよくわかってないオプションがあったので軽く調べた.

pip install --user git+https://github.com/pymc-devs/pymc4.git#egg=pymc4
  • —user
    • ユーザのローカルディレクトリ( ~/.local/ など)に対象のパッケージをインストールする
  • egg=[プロジェクト名を明示的に指定]

また, 次のようにbranch名も指定できる * @[branch名]

pip install --user git+https://github.com/yukinagae/pymc4@dependecny-resolution.git#egg=pymc4

参考資料