オーストラリアで勉強してきたMLデザイナーの口語自由詩

主に、データ分析・機械学習・ベイズ・統計について自由に書く。

pymc4のソースコード読んでみた - Remove treedict dependency [10ea1aa]

f:id:yukinagae:20171122095115p:plain

TL;DR

model nesting が不要になったので、 treedict が削除されたみたいです。

コミット

2018/05/31のコミットです。

以前は Model クラスで treedict クラスが使用されていましたが、その処理が削除されています。その理由や流れを関連issueを読みながらまとめてみます。

10ea1aa - Remove treedict dependency

以下のissueで言及されています。

see: Model context manager, primitive default sampling, random variable class by sharanry · Pull Request #1 · pymc-devs/pymc4 · GitHub

The reason treedict is used in pymc3 is model nesting. Once we were thinking about building blocks that can contan variables. Model definition would come in init like in regular context manager

以前のバージョンpymc3で treedict が使用されていたのは model nesting が理由でした。

@ferrine We never used model nesting and I think we should drop it.

もう model nesting は使わない から treedict はなくていいよ。

実際にコードで動作の違いを見てみます。実際にjupyte notebookで実行して比較してみました。

see: pymc4_code_reading/summary-10ea1aa.ipynb at master · yukinagae/pymc4_code_reading · GitHub

  • treedictクラスに依存している場合、ModelがネストするとparentにもRandomVariableが伝搬する
    • model1(parent): [rv1, rv2, rv3]
    • model2(child): [rv2, rv3]
with WithTreeModel(name="model1") as model1:
    rv1 = WithTreeRandomVariable(name="rv1")
    with WithTreeModel(name="model2") as model2:
        rv2 = WithTreeRandomVariable(name="rv2")
        rv3 = WithTreeRandomVariable(name="rv3")
        print("model1: {}".format([v for v in model1.named_vars]))
        print("model2: {}".format([v for v in model2.named_vars]))

# 出力:
# => model1: ['rv1', 'rv2', 'rv3']
# => model2: ['rv2', 'rv3']
  • treedictクラスに依存していない場合、ModelがネストしてもparentにRandomVariableが伝搬しない
    • model1(parent): [rv1]
    • model2(child): [rv2, rv3]
with NoTreeModel(name="model1") as model1:
    rv1 = NoTreeRandomVariable(name="rv1")
    with NoTreeModel(name="model2") as model2:
        rv2 = NoTreeRandomVariable(name="rv2")
        rv3 = NoTreeRandomVariable(name="rv3")
        print("model1: {}".format([v for v in model1.named_vars]))
        print("model2: {}".format([v for v in model2.named_vars]))

# 出力:
# => model1: ['rv1']
# => model2: ['rv2', 'rv3']

参考資料