オーストラリアで勉強してきたデータサイエンティストの口語自由詩

主に、ベイズ・統計・データ分析・機械学習について自由に書く。

(細かい修正なのであまり重要ではない) [c27e97f, a8e3dae, 4f5382f, e06d946, ca9f334] - pymc4のソースコード読んでみた

f:id:yukinagae:20171122095115p:plain

TL;DR

  • 細かい修正が多いのであまり重要ではないです.

コミット

2018/06/26から2018/06/30の間のコミットです.

minor fixes - c27e97f

  • pymc4/inference_sampling_sample.py
    • dictのfor文の処理で不要な変数を減らしているだけです.
  • pymc4/model/base.py

add some tests - a8e3dae

add tests and fix model.configure() - 4f5382f

次のようなテストが複数追加されている.

    model = pm.Model()
    @model.define
    def simple(cfg):
        ed.Normal(0., 1., name='normal')
    
    assert len(model.variables) == 1
    assert len(model.unobserved) == 1
    assert "normal" in model.variables

fix lint errors - e06d946

lint対応なので省略: #pylint: disable=unused-argument でlintをdisableしている箇所もある.

add tests - ca9f334

テストが複数追加されている.

例えば, 次のテストは(対数)確率関数の値を近似的(approx)にチェックしている.

def test_model_log_prob_fn():
    model = pm.Model()
    @model.define
    def simple(cfg):
        mu = ed.Normal(0., 1., name="mu")
    
    log_prob_fn = model.target_log_prob_fn()
    with tf.Session():
        assert -0.91893853 == pytest.approx(log_prob_fn(0).eval(), 0.00001)

念のためにscipy.statsのdistributionモジュールで値をチェックしたみた.

from scipy.stats import norm
print(norm(0, 1).logpdf(0)) # => -0.9189385332046727

参考資料