tensorflow.saved_model.save
SavedModel 형식 의 함수를 사용하여 모델을 저장하면 나중에이 모델에 사용 된 Tensorflow Ops를 검색 할 수 있습니다. 모델을 복원 할 수 있으므로 이러한 작업은 그래프에 저장되며 추측은 saved_model.pb
파일에 있습니다. 이 protobuf (전체 모델이 아님)를로드하면 protobuf의 라이브러리 부분에이 목록이 나열되지만 지금은 실험적인 기능으로 문서화 및 태그 지정되지 않았습니다. Tensorflow 1.x에서 생성 된 모델에는이 부분이 없습니다.
그렇다면 저장된 모델 형식의 모델에서 사용 된 작업 ( MatchingFiles
또는 유사 WriteFile
) 목록을 검색하는 빠르고 안정적인 방법은 무엇 입니까?
지금처럼 전체를 얼릴 수 있습니다 tensorflowjs-converter
. 또한 지원되는 작업을 확인합니다. LSTM이 모델에있는 경우 현재 작동하지 않습니다 ( 여기 참조) . 작전이 확실히 있기 때문에 더 좋은 방법이 있습니까?
예제 모델 :
class FileReader(tf.Module):
@tf.function(input_signature=[tf.TensorSpec(name='filename', shape=[None], dtype=tf.string)])
def read_disk(self, file_name):
input_scalar = tf.reshape(file_name, [])
output = tf.io.read_file(input_scalar)
return tf.stack([output], name='content')
file_reader = FileReader()
tf.saved_model.save(file_reader, 'file_reader')
이 경우 최소한 다음을 포함하여 모든 Ops를 출력 할 것으로 예상됩니다.
ReadFile
여기에 설명 된대로- …
답변
경우 saved_model.pb
A는 SavedModel
protobuf 메시지가, 당신은 거기에서 직접 작업을 얻을. 다음과 같이 모델을 생성한다고 가정 해 봅시다.
import tensorflow as tf
class FileReader(tf.Module):
@tf.function(input_signature=[tf.TensorSpec(name='filename', shape=[None], dtype=tf.string)])
def read_disk(self, file_name):
input_scalar = tf.reshape(file_name, [])
output = tf.io.read_file(input_scalar)
return tf.stack([output], name='content')
file_reader = FileReader()
tf.saved_model.save(file_reader, 'tmp')
이제 다음과 같이 해당 모델에서 사용되는 작업을 찾을 수 있습니다.
from tensorflow.core.protobuf.saved_model_pb2 import SavedModel
saved_model = SavedModel()
with open('tmp/saved_model.pb', 'rb') as f:
saved_model.ParseFromString(f.read())
model_op_names = set()
# Iterate over every metagraph in case there is more than one
for meta_graph in saved_model.meta_graphs:
# Add operations in the graph definition
model_op_names.update(node.op for node in meta_graph.graph_def.node)
# Go through the functions in the graph definition
for func in meta_graph.graph_def.library.function:
# Add operations in each function
model_op_names.update(node.op for node in func.node_def)
# Convert to list, sorted if you want
model_op_names = sorted(model_op_names)
print(*model_op_names, sep='\n')
# Const
# Identity
# MergeV2Checkpoints
# NoOp
# Pack
# PartitionedCall
# Placeholder
# ReadFile
# Reshape
# RestoreV2
# SaveV2
# ShardedFilename
# StatefulPartitionedCall
# StringJoin