[python] Tensorflow에서 그래프의 모든 Tensor 이름을 가져옵니다.

저는 Tensorflow및로 신경망을 만들고 있습니다 skflow. 어떤 이유로 그래서 내가 사용하고, 주어진 입력에 대한 몇 가지 내부 텐서의 값을 얻으려면 myClassifier.get_layer_value(input, "tensorName"), myClassifierskflow.estimators.TensorFlowEstimator.

그러나 이름을 알더라도 텐서 이름의 올바른 구문을 찾기가 어렵 기 때문에 (그리고 연산과 텐서 사이에 혼란스러워집니다) 그래프를 플로팅하고 이름을 찾기 위해 텐서 보드를 사용하고 있습니다.

텐서 보드를 사용하지 않고 그래프의 모든 텐서를 열거하는 방법이 있습니까?



답변

넌 할 수있어

[n.name for n in tf.get_default_graph().as_graph_def().node]

또한 IPython 노트북에서 프로토 타이핑하는 경우 노트북에 직접 그래프를 표시 할 수 있습니다 show_graph. Alexander ‘s Deep Dream 노트북 의 기능을 참조하십시오.


답변

get_operations 를 사용하여 Yaroslav의 답변보다 약간 더 빠르게 수행하는 방법이 있습니다 . 다음은 간단한 예입니다.

import tensorflow as tf

a = tf.constant(1.3, name='const_a')
b = tf.Variable(3.1, name='variable_b')
c = tf.add(a, b, name='addition')
d = tf.multiply(c, a, name='multiply')

for op in tf.get_default_graph().get_operations():
    print(str(op.name))


답변

대답을 요약하려고합니다.

모든 노드 를 가져 오려면 (유형 tensorflow.core.framework.node_def_pb2.NodeDef) :

all_nodes = [n for n in tf.get_default_graph().as_graph_def().node]

모든 작업 을 가져 오려면 (유형 tensorflow.python.framework.ops.Operation) :

all_ops = tf.get_default_graph().get_operations()

모든 변수 를 가져 오려면 (유형 tensorflow.python.ops.resource_variable_ops.ResourceVariable) :

all_vars = tf.global_variables()

모든 텐서 를 얻으려면 (유형 tensorflow.python.framework.ops.Tensor) :

all_tensors = [tensor for op in tf.get_default_graph().get_operations() for tensor in op.values()]


답변

tf.all_variables() 원하는 정보를 얻을 수 있습니다.

또한 오늘 TensorFlow Learn에서 만든 이 커밋get_variable_names모든 변수 이름을 쉽게 검색하는 데 사용할 수있는 추정기 의 기능을 제공합니다 .


답변

나는 이것도 할 것이라고 생각한다.

print(tf.contrib.graph_editor.get_tensors(tf.get_default_graph()))

그러나 Salvado와 Yaroslav의 답변과 비교할 때 어느 것이 더 나은지 모르겠습니다.


답변

허용되는 대답은 이름이있는 문자열 목록 만 제공합니다. 텐서에 (거의) 직접 액세스 할 수있는 다른 접근 방식을 선호합니다.

graph = tf.get_default_graph()
list_of_tuples = [op.values() for op in graph.get_operations()]

list_of_tuples이제 튜플 내에있는 모든 텐서를 포함합니다. 텐서를 직접 가져 오도록 조정할 수도 있습니다.

graph = tf.get_default_graph()
list_of_tuples = [op.values()[0] for op in graph.get_operations()]


답변

OP가 작업 / 노드 목록 대신 텐서 목록을 요청했기 때문에 코드는 약간 달라야합니다.

graph = tf.get_default_graph()
tensors_per_node = [node.values() for node in graph.get_operations()]
tensor_names = [tensor.name for tensors in tensors_per_node for tensor in tensors]