オーストラリアで勉強してきたデータサイエンティストの口語自由詩

主に、ベイズ・統計・データ分析・機械学習について自由に書く。

tensorflowのグラフ構造 [e334115, d07338e, 93bc07b] - pymc4のソースコード読んでみた

f:id:yukinagae:20180927082722g:plain

概要

  • まずは Model クラスの初期化処理系のメソッドを読んでみます。
    • _init_variables: 今回はここの self.graph.as_default() の処理を読みます

コミット

2018/06/09から2018/06/11の間のコミットです。

Model - pymc4/model/base.py

_init_variables

今回は _init_variables 内の self.graph.as_default() の処理を見ていきます。

info_collector = interceptors.CollectVariablesInfo() # info_collector() で呼べるcallable
with self.graph.as_default(), ed.interception(info_collector):
    self._f(self.cfg)

self._variables = info_collector.result

self.graphModel クラス内で以下のように定義されているので、実際には self.session.graph == tf.session.graphと同等です。

@property
def graph(self):
    return self.session.graph

また、今回の読む対象に絞って [余計な情報を省略] + [変数を読み替える]、を行うと以下のコードと同等になります。

with tf.Session.graph().as_default():
    ed.Normal(0., 1., name='normal') <= ここはed.XXX()

上記の XX 部分は tensorflow_probabilitydistribution クラスをedward2の RandomVariable クラスでwrapしたものです。

see: probability/generated_random_variables.py at master · tensorflow/probability · GitHub

結局これらを理解するには tensorflow のGraphの動作を理解する方が早そうです。

f:id:yukinagae:20180927082722g:plain

see: Graphs and Sessions  |  TensorFlow

実際に as_default() の使用例を見てみると、 以下のように with スコープで実行された tf.XXX の処理の結果がそのスコープのGraphに追加されていくようです。

g_1 = tf.Graph()
with g_1.as_default():
  # Operations created in this scope will be added to `g_1`.
  c = tf.constant("Node in g_1")

  # Sessions created in this scope will run operations from `g_1`.
  sess_1 = tf.Session()

g_2 = tf.Graph()
with g_2.as_default():
  # Operations created in this scope will be added to `g_2`.
  d = tf.constant("Node in g_2")

# Alternatively, you can pass a graph when constructing a <a href="../api_docs/python/tf/Session"><code>tf.Session</code></a>:
# `sess_2` will run operations from `g_2`.
sess_2 = tf.Session(graph=g_2)

assert c.graph is g_1
assert sess_1.graph is g_1

assert d.graph is g_2
assert sess_2.graph is g_2

つまり以下のコードを簡略的に理解すると、 g1(graph) -> c(constant) という紐付きをGraphとして構造化することができるということです。

with g_1.as_default():
  # Operations created in this scope will be added to `g_1`.
  c = tf.constant("Node in g_1")

同様に以下のコードも、tf.Session.graph() -> ed.Normal() と紐付けることができます。

with tf.Session.graph().as_default():
    ed.Normal(0., 1., name='normal') <= ここはed.XXX()

参考資料