Modelの_init_vatiables内のInterceptor処理 [e334115, d07338e, 93bc07b] - pymc4のソースコード読んでみた
概要
- まずは
Model
クラスの初期化処理系のメソッドを読んでみます。_init_variables
: 今回はここを読みます
コミット
2018/06/09から2018/06/11の間のコミットです。
- tmp · pymc-devs/pymc4@e334115 · GitHub
- restructure + test point implementation · pymc-devs/pymc4@d07338e · GitHub
- fixes · pymc-devs/pymc4@93bc07b · GitHub
Model - pymc4/model/base.py
_init_variables
前回わからなかった、 interceptors
周りの処理を読んでいきます。
from pymc4.util import interceptors (中略) def _init_variables(self): info_collector = interceptors.CollectVariablesInfo() with self.graph.as_default(), ed.interception(info_collector): self._f(self.cfg) self._variables = info_collector.result
ここでの interceptors
は util
内で自前で定義している CollectVariablesInfo
を読んでみます。
CollectVariablesInfo - pymc4/util/interceptors.py
CollectVariablesInfo
クラスは Interceptor
クラスを継承しているので、まずはそちらを読みます。
VariableDescription = collections.namedtuple('VariableDescription', 'Dist,shape') class Interceptor(object): def name_scope(self): return tf.name_scope(self.__class__.__name__.lower()) def __call__(self, f, *args, **kwargs): if kwargs.get('name') is None: raise SyntaxError('Every random variable should have a name') f, args, kwargs = self.before(f, *args, **kwargs) rv = f(*args, **kwargs) return self.after(rv, *args, **kwargs) def before(self, f, *args, **kwargs): return f, args, kwargs def after(self, rv, *args, **kwargs): return rv
__call__
が定義されているので、おそらく以下のように使用できるはずです。
intercptor = Interceptor() # Interceptorのインスタンス生成 rv = intercptor(f, name="hoge") # intercptor.__call__(f, *args, **kwargs) が呼べる
keyword引数に name
が無いと SyntaxError('Every random variable should have a name')
になります。
if kwargs.get('name') is None: raise SyntaxError('Every random variable should have a name')
あとはただ before
と after
メソッドが呼ばれているだけです。ぱっと見ると、 rv
を f
関数で生成して、 after
で何かしてるようです。
( before
と after
は子クラスでoverrideされるメソッドです。)
f, args, kwargs = self.before(f, *args, **kwargs) rv = f(*args, **kwargs) return self.after(rv, *args, **kwargs) def before(self, f, *args, **kwargs): return f, args, kwargs def after(self, rv, *args, **kwargs): return rv
CollectVariablesInfo
も見てみましょう。この中では親クラスの after
メソッドがoverrideされています。
class CollectVariablesInfo(Interceptor): def __init__(self): self.result = collections.OrderedDict() def after(self, rv, *args, **kwargs): name = kwargs["name"] if name not in self.result: self.result[name] = VariableDescription(rv.distribution.__class__, rv.shape) else: raise KeyError(name, 'Duplicate name') return rv
親クラスの __call__
メソッド内で既にkeyword引数に name
が含まれていることのチェックが行われているため、ここでは単に name = kwargs["name"]
で値を取得できることが保証されています。
あとは以下の通り、 result
プロパティ(OrderedDict)に name
をkeyに、valueに rv(RandomVariable)
の情報を VariableDescription
として rv
のクラス名とshapeをセットしています。
self.result[name] = VariableDescription(rv.distribution.__class__, rv.shape)
VariableDescription
はただのnamedtuple(名前付きタプル型)です。
VariableDescription = collections.namedtuple('VariableDescription', 'Dist,shape')
次回は _init_variables
内のtensorflow及びtensorflow_probabilityに関連する処理を見ていきます。 ed.interception
がわかればなんとなく処理がわかってくるはずです。
see: probability/interceptor.py at master · tensorflow/probability · GitHub
info_collector = interceptors.CollectVariablesInfo() # info_collector() で呼べるcallable with self.graph.as_default(), ed.interception(info_collector): self._f(self.cfg)