オーストラリアで勉強してきたMLデザイナーの口語自由詩

主に、データ分析・機械学習・ベイズ・統計について自由に書く。

Modelクラスの初期化処理 [e334115, d07338e, 93bc07b] - pymc4のソースコード読んでみた

f:id:yukinagae:20171122095115p:plain

TL;DR

  • まずは Model クラスの初期化処理系のメソッドを読んでみます。
    • __init__: 初期化処理いろいろ
    • define: self._f を設定して変数初期化しているようですが、今のところテスト用のヘルパー関数に見えます
    • configure: 設定を上書きして変数初期化
    • _init_variables: 次回読みます

コミット

2018/06/09から2018/06/11の間のコミットです。

Model - pymc4/model/base.py

init

初期化処理内では、以下プロパティを設定しているみたいです。

  • _cfg: 設定情報
  • name: モデルの名称(任意)
  • _f: None を設定してるので、後で使用されると推測
  • _variables: None を設定してるので、後で使用されると推測
  • _observed: dict() で初期化してるので、 key(RandomVariableの名前)value(RandomVariableのインスタンス) が入ると推測
  • session: tensorflowの session を初期化もしくは引数から設定しているだけです。
class Model(object):
    def __init__(self, name=None, graph=None, session=None, **config):
        self._cfg = Config(**config)
        self.name = name
        self._f = None
        self._variables = None
        self._observed = dict()
        if session is None:
            session = tf.Session(graph=graph)
        self.session = session

configure

_cfg を上書きして _init_variables を呼んでいるだけです。

    def configure(self, **override):
        self._cfg.update(**override)
        self._init_variables()
        return self

define

引数の f を設定して、 _init_variables を呼んでいるだけです。

    def define(self, f):
        self._f = f
        self._init_variables()
        return f

pymc4/model/base.py の最下段で以下のように define が使われています。

@biwrap.biwrap
def inline(f, **kwargs):
    model = Model(**kwargs)
    model.define(f)
    return model

inlinetests/test_model.py 内の以下にある通り、まずはテスト実行時のヘルパー関数のように使用されています。

def test_testvalue():
    @pm.inline
    def model(cfg):
        ed.Normal(0., 1., name='normal')

つまり、define 内の self._f には def model(cfg) の関数が設定されることがわかります(以下は擬似コードです)

self._f = def model(cfg):
                ed.Normal(0., 1., name='normal')

_init_variables

この辺りは interceptors 周りの動きが分からないと理解できなそうなので、次回読もうと思います。

    def _init_variables(self):
        info_collector = interceptors.CollectVariablesInfo()
        with self.graph.as_default(), ed.interception(info_collector):
            self._f(self.cfg)
        self._variables = info_collector.result

参考資料