1.遇到关于tf.get_variable()的问题
今天用tensorflow写一个模型,过程中遇到一个很坑的问题,只运行下面这行代码会报错:
W = tf.get_variable('W', (3, 1), initializer=tf.constant_initializer())
# 报错:
ValueError: Variable W already exists, disallowed. Did you mean to set reuse=True in VarScope? Originally defined at:
当时懵了,我只写了一行代码,上面没有建立名为W
的变量,这里共享变量W
却报错,说已经存在了,蛇皮玩意???
查了很多资料也没解决,后来意识到,是这串代码运行了多次,W
变量已经在内存中了,相当于写了两行W = tf.get_variable('W', (3, 1), initializer=tf.constant_initializer())
代码,我们知道,同名的共享变量不能创建两次,因为他不像Variable()
能够自动处理命名重复的问题。
2. tf.get_variable()和tf.Variable()的区别
先看一段代码:
import tensorflow as tf
w_1 = tf.Variable(3,name="w_1")
w_2 = tf.Variable(1,name="w_1")
print(w_1.name)
print(w_2.name)
#输出
#w_1:0
#w_1_1:0
import tensorflow as tf
w_1 = tf.get_variable(name="w_1",initializer=1)
w_2 = tf.get_variable(name="w_1",initializer=2)
#错误信息
#ValueError: Variable w_1 already exists, disallowed. Did
#you mean to set reuse=True in VarScope?
import tensorflow as tf
with tf.variable_scope("scope1"):
w1 = tf.get_variable("w1", shape=[])
w2 = tf.Variable(0.0, name="w2")
with tf.variable_scope("scope1", reuse=True):
w1_p = tf.get_variable("w1", shape=[])
w2_p = tf.Variable(1.0, name="w2")
print(w1 is w1_p, w2 is w2_p)
#输出
#True False
区别:
- 使用tf.Variable时,如果检测到命名冲突,系统会自己处理。使用tf.get_variable()时,系统不会处理冲突,而会报错。
- tf.Variable()每次都在创建新的对象,与name没有关系。而tf.get_variable()对于已经创建的同样name的变量对象,就直接把那个变量对象返回(类似于:共享变量),tf.get_variable() 会检查当前命名空间下是否存在同样name的变量,可以方便共享变量。
- tf.get_variable():对于在上下文管理器中已经生成一个v的变量,若想通过tf.get_variable函数获取其变量,则可以通过reuse参数的设定为True来获取。
- 还有一点,tf.get_variable()必须写name,否则报错
版权声明:本文内容由互联网用户自发贡献,该文观点仅代表作者本人。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如发现本站有涉嫌侵权/违法违规的内容, 请发送邮件至 举报,一经查实,本站将立刻删除。
文章由极客之音整理,本文链接:https://www.bmabk.com/index.php/post/84744.html