gpt-2 如何将检查点图冻结为.pb格式?

k97glaaz  于 9个月前  发布在  其他
关注(0)|答案(2)|浏览(214)

试图冻结GPT 2微调模型,但无法弄清楚输出节点名称将是什么。使用此代码作为参考,我把它放在一起:-

  1. import fire
  2. import json
  3. import os
  4. import numpy as np
  5. import tensorflow as tf
  6. import model, sample, encoder
  7. seed=None
  8. length=40
  9. temperature=1
  10. top_k=0
  11. hparams = model.default_hparams()
  12. with open('models/345M/hparams.json') as f:
  13. hparams.override_from_dict(json.load(f))
  14. with tf.Session(graph=tf.Graph()) as sess:
  15. context = tf.placeholder(tf.int32, [1, None])
  16. np.random.seed(seed)
  17. tf.set_random_seed(seed)
  18. output = sample.sample_sequence(
  19. hparams=hparams, length=length,
  20. context=context,
  21. batch_size=1,
  22. temperature=temperature, top_k=top_k
  23. )
  24. saver = tf.train.Saver()
  25. ckpt = tf.train.latest_checkpoint(os.path.join('models', '345M'))
  26. saver.restore(sess, ckpt)
  27. print([n.name for n in tf.get_default_graph().as_graph_def().node])
  28. # Freeze the graph
  29. frozen_graph_def = tf.graph_util.convert_variables_to_constants(sess,sess.graph_def,[output.name])
  30. # Save the frozen graph
  31. with open('output_graph.pb', 'wb') as f:
  32. f.write(frozen_graph_def.SerializeToString())

但我得到
AssertionError:sample_sequence/while/Exit_3:0不在图形中
那么我应该在freeze_graph中把什么作为参数3输出节点名称呢?

yeotifhr

yeotifhr1#

output.name给你一个Tensor名称(' sample_sequence/while/Exit_3:0')而不是节点名称。我猜你应该把[ 'sample_sequence/while/Exit_3']作为tf.graph_util.convert_variables_to_constants中的参数3

csga3l58

csga3l582#

亲爱的@ChintanTrivedi,你正确冻结了GPT 2模型吗?

相关问题