tensorflowのグラフ構造 [e334115, d07338e, 93bc07b] - pymc4のソースコード読んでみた
概要
- まずは
Model
クラスの初期化処理系のメソッドを読んでみます。_init_variables
: 今回はここのself.graph.as_default()
の処理を読みます
コミット
2018/06/09から2018/06/11の間のコミットです。
- tmp · pymc-devs/pymc4@e334115 · GitHub
- restructure + test point implementation · pymc-devs/pymc4@d07338e · GitHub
- fixes · pymc-devs/pymc4@93bc07b · GitHub
Model - pymc4/model/base.py
_init_variables
今回は _init_variables
内の self.graph.as_default()
の処理を見ていきます。
info_collector = interceptors.CollectVariablesInfo() # info_collector() で呼べるcallable with self.graph.as_default(), ed.interception(info_collector): self._f(self.cfg) self._variables = info_collector.result
self.graph
は Model
クラス内で以下のように定義されているので、実際には self.session.graph
== tf.session.graph
と同等です。
@property def graph(self): return self.session.graph
また、今回の読む対象に絞って [余計な情報を省略] + [変数を読み替える]、を行うと以下のコードと同等になります。
with tf.Session.graph().as_default(): ed.Normal(0., 1., name='normal') <= ここはed.XXX()
上記の XX
部分は tensorflow_probability
の distribution
クラスをedward2の RandomVariable
クラスでwrapしたものです。
see: probability/generated_random_variables.py at master · tensorflow/probability · GitHub
結局これらを理解するには tensorflow
のGraphの動作を理解する方が早そうです。
see: Graphs and Sessions | TensorFlow
実際に as_default()
の使用例を見てみると、 以下のように with
スコープで実行された tf.XXX
の処理の結果がそのスコープのGraphに追加されていくようです。
g_1 = tf.Graph() with g_1.as_default(): # Operations created in this scope will be added to `g_1`. c = tf.constant("Node in g_1") # Sessions created in this scope will run operations from `g_1`. sess_1 = tf.Session() g_2 = tf.Graph() with g_2.as_default(): # Operations created in this scope will be added to `g_2`. d = tf.constant("Node in g_2") # Alternatively, you can pass a graph when constructing a <a href="../api_docs/python/tf/Session"><code>tf.Session</code></a>: # `sess_2` will run operations from `g_2`. sess_2 = tf.Session(graph=g_2) assert c.graph is g_1 assert sess_1.graph is g_1 assert d.graph is g_2 assert sess_2.graph is g_2
つまり以下のコードを簡略的に理解すると、 g1(graph)
-> c(constant)
という紐付きをGraphとして構造化することができるということです。
with g_1.as_default(): # Operations created in this scope will be added to `g_1`. c = tf.constant("Node in g_1")
同様に以下のコードも、tf.Session.graph()
-> ed.Normal()
と紐付けることができます。
with tf.Session.graph().as_default(): ed.Normal(0., 1., name='normal') <= ここはed.XXX()