[python] TensorFlow 저장 / 파일에서 그래프로드

지금까지 수집 한 내용에서 TensorFlow 그래프를 파일에 덤핑 한 다음 다른 프로그램에로드하는 방법에는 여러 가지가 있지만 작동 방식에 대한 명확한 예제 / 정보를 찾을 수 없었습니다. 내가 이미 알고있는 것은 이것이다 :

  1. a를 사용하여 모델의 변수를 체크 포인트 파일 (.ckpt)에 저장 tf.train.Saver()하고 나중에 복원 ( source )
  2. 모델을 .pb 파일에 저장하고 tf.train.write_graph()tf.import_graph_def()( 소스 ) 를 사용하여 다시로드합니다.
  3. .pb 파일에서 모델을로드하고 다시 학습 한 다음 Bazel을 사용하여 새 .pb 파일에 덤프합니다 ( 소스 ).
  4. 그래프를 고정하여 그래프와 가중치를 함께 저장합니다 ( 소스 ).
  5. 사용 as_graph_def()모델을 저장 및 무게 / 변수 (상수로 매핑 소스 )

그러나 이러한 다른 방법에 대한 몇 가지 질문을 해결할 수 없었습니다.

  1. 체크 포인트 파일과 관련하여 모델의 훈련 된 가중치 만 저장합니까? 체크 포인트 파일을 새 프로그램에로드하여 모델을 실행하는 데 사용할 수 있습니까, 아니면 단순히 특정 시간 / 단계에서 모델의 가중치를 저장하는 방법으로 사용됩니까?
  2. 와 관련 tf.train.write_graph()하여 가중치 / 변수도 저장됩니까?
  3. Bazel과 관련하여 재교육을 위해 .pb 파일로만 저장 /로드 할 수 있습니까? 그래프를 .pb로 덤프하는 간단한 Bazel 명령이 있습니까?
  4. 고정과 관련하여 고정 된 그래프는 tf.import_graph_def()? 를 사용하여로드 할 수 있습니다 .
  5. TensorFlow 용 Android 데모는 .pb 파일에서 Google의 Inception 모델로로드됩니다. 내 자신의 .pb 파일을 대체하려면 어떻게해야합니까? 네이티브 코드 / 메서드를 변경해야합니까?
  6. 일반적으로이 모든 방법의 차이점은 정확히 무엇입니까? 또는 더 광범위하게 as_graph_def()/.ckpt/.pb 의 차이점은 무엇입니까?

요컨대, 제가 찾고있는 것은 그래프 (다양한 연산 등)와 가중치 / 변수를 파일에 저장하는 방법입니다. 그러면 그래프와 가중치를 다른 프로그램에로드하는 데 사용할 수 있습니다. , 사용을 위해 (반드시 계속 / 재교육하는 것은 아님).

이 주제에 대한 문서는 그다지 간단하지 않으므로 답변 / 정보를 보내 주시면 감사하겠습니다.



답변

TensorFlow에서 모델을 저장하는 문제에 접근하는 방법에는 여러 가지가 있으며, 이로 인해 약간 혼란 스러울 수 있습니다. 각 하위 질문을 차례로 수행 :

  1. 체크 포인트 파일 (예 : 객체 를 호출 saver.save()하여 생성됨 tf.train.Saver)에는 가중치와 동일한 프로그램에 정의 된 다른 변수 만 포함됩니다. 다른 프로그램에서 사용하려면 관련 그래프 구조를 다시 만들어야합니다 (예 : 코드를 실행하여 다시 빌드하거나를 호출하여 tf.import_graph_def()). 그러면 TensorFlow에 해당 가중치로 수행 할 작업을 알려줍니다. 또한를 호출 하면 그래프와 체크 포인트의 가중치를 해당 그래프와 연결하는 방법에 대한 세부 정보가 포함 saver.save()된 파일이 생성 MetaGraphDef됩니다. 자세한 내용 은 튜토리얼 을 참조하십시오.

  2. tf.train.write_graph()그래프 구조 만 작성합니다. 가중치가 아닙니다.

  3. Bazel은 TensorFlow 그래프를 읽거나 쓰는 것과 관련이 없습니다. (아마도 귀하의 질문을 오해하고 있습니다. 의견을 통해 자유롭게 질문하십시오.)

  4. 고정 된 그래프는 tf.import_graph_def(). 이 경우 가중치는 (일반적으로) 그래프에 포함되므로 별도의 체크 포인트를로드 할 필요가 없습니다.

  5. 주요 변경 사항은 모델에 공급되는 텐서의 이름과 모델에서 가져온 텐서의 이름을 업데이트하는 것입니다. TensorFlow Android 데모에서 이는 에 전달 되는 inputNameoutputName문자열에 해당합니다 TensorFlowClassifier.initializeTensorFlow().

  6. GraphDef일반적으로 교육 과정을 변경하지 않는 프로그램 구조입니다. 체크 포인트는 일반적으로 교육 프로세스의 모든 단계에서 변경되는 교육 프로세스 상태의 스냅 샷입니다. 결과적으로 TensorFlow는 이러한 유형의 데이터에 대해 서로 다른 저장 형식을 사용하고 저수준 API는 데이터를 저장하고로드하는 다양한 방법을 제공합니다. 같은과 같은 높은 수준의 도서관, MetaGraphDef도서관, Kerasskflow 이러한 메커니즘에 빌드 저장하고 전체 모델을 복원하는 편리한 방법을 제공합니다.


답변

다음 코드를 시도해 볼 수 있습니다.

with tf.gfile.FastGFile('model/frozen_inference_graph.pb', "rb") as f:
    graph_def = tf.GraphDef()
    graph_def.ParseFromString(f.read())
    g_in = tf.import_graph_def(graph_def, name="")
sess = tf.Session(graph=g_in)


답변