オーストラリアで勉強してきたMLデザイナーの口語自由詩

主に、データ分析・機械学習・ベイズ・統計について自由に書く。

Modelの_init_vatiables内のInterceptor処理 [e334115, d07338e, 93bc07b] - pymc4のソースコード読んでみた

f:id:yukinagae:20171122095115p:plain

概要

  • まずは Model クラスの初期化処理系のメソッドを読んでみます。
    • _init_variables: 今回はここを読みます

コミット

2018/06/09から2018/06/11の間のコミットです。

Model - pymc4/model/base.py

_init_variables

前回わからなかった、 interceptors 周りの処理を読んでいきます。

from pymc4.util import interceptors

(中略)

    def _init_variables(self):
        info_collector = interceptors.CollectVariablesInfo()
        with self.graph.as_default(), ed.interception(info_collector):
            self._f(self.cfg)
        self._variables = info_collector.result

ここでの interceptorsutil 内で自前で定義している CollectVariablesInfo を読んでみます。

CollectVariablesInfo - pymc4/util/interceptors.py

CollectVariablesInfo クラスは Interceptor クラスを継承しているので、まずはそちらを読みます。

VariableDescription = collections.namedtuple('VariableDescription', 'Dist,shape')

class Interceptor(object):
    def name_scope(self):
        return tf.name_scope(self.__class__.__name__.lower())

    def __call__(self, f, *args, **kwargs):
        if kwargs.get('name') is None:
            raise SyntaxError('Every random variable should have a name')
        f, args, kwargs = self.before(f, *args, **kwargs)
        rv = f(*args, **kwargs)
        return self.after(rv, *args, **kwargs)

    def before(self, f, *args, **kwargs):
        return f, args, kwargs

    def after(self, rv, *args, **kwargs):
        return rv

__call__ が定義されているので、おそらく以下のように使用できるはずです。

intercptor = Interceptor() # Interceptorのインスタンス生成
rv = intercptor(f, name="hoge") # intercptor.__call__(f, *args, **kwargs) が呼べる

keyword引数に name が無いと SyntaxError('Every random variable should have a name') になります。

if kwargs.get('name') is None:
    raise SyntaxError('Every random variable should have a name')

あとはただ beforeafter メソッドが呼ばれているだけです。ぱっと見ると、 rvf 関数で生成して、 after で何かしてるようです。

beforeafter は子クラスでoverrideされるメソッドです。)

    f, args, kwargs = self.before(f, *args, **kwargs)
    rv = f(*args, **kwargs)
    return self.after(rv, *args, **kwargs)

    def before(self, f, *args, **kwargs):
        return f, args, kwargs

    def after(self, rv, *args, **kwargs):
        return rv

CollectVariablesInfo も見てみましょう。この中では親クラスの after メソッドがoverrideされています。

class CollectVariablesInfo(Interceptor):
    def __init__(self):
        self.result = collections.OrderedDict()

    def after(self, rv, *args, **kwargs):
        name = kwargs["name"]
        if name not in self.result:
            self.result[name] = VariableDescription(rv.distribution.__class__, rv.shape)
        else:
            raise KeyError(name, 'Duplicate name')
        return rv

親クラスの __call__ メソッド内で既にkeyword引数に name が含まれていることのチェックが行われているため、ここでは単に name = kwargs["name"] で値を取得できることが保証されています。

あとは以下の通り、 result プロパティ(OrderedDict)に name をkeyに、valuerv(RandomVariable) の情報を VariableDescription として rv のクラス名とshapeをセットしています。

self.result[name] = VariableDescription(rv.distribution.__class__, rv.shape)

VariableDescription はただのnamedtuple(名前付きタプル型)です。

VariableDescription = collections.namedtuple('VariableDescription', 'Dist,shape')

次回は _init_variables 内のtensorflow及びtensorflow_probabilityに関連する処理を見ていきます。 ed.interception がわかればなんとなく処理がわかってくるはずです。

see: probability/interceptor.py at master · tensorflow/probability · GitHub

info_collector = interceptors.CollectVariablesInfo() # info_collector() で呼べるcallable
with self.graph.as_default(), ed.interception(info_collector):
    self._f(self.cfg)

参考資料