二、模型架构
3.cifar10_train.py:在CPU或GPU上训练CIFAR-10 模型
# -*- coding: utf-8 -*-
# Copyright 2015 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the “License”);
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an “AS IS” BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
“””A binary to train CIFAR-10 using a single GPU.
Accuracy:
cifar10_train.py achieves ~86% accuracy after 100K steps (256 epochs of
data) as judged by cifar10_eval.py.
Speed: With batch_size 128.
System | Step Time (sec/batch) | Accuracy
——————————————————————
1 Tesla K20m | 0.35-0.60 | ~86% at 60K steps (5 hours)
1 Tesla K40m | 0.25-0.35 | ~86% at 100K steps (4 hours)
Usage:
Please see the tutorial and website for how to download the CIFAR-10
data set, compile the program and train the model.
http://tensorflow.org/tutorials/deep_cnn/
“””
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
from datetime import datetime
import os.path
import time
import numpy as np
from six.moves import xrange # pylint: disable=redefined-builtin
import tensorflow as tf
#from tensorflow.models.image.cifar10 import cifar10
import cifar10
FLAGS = tf.app.flags.FLAGS
tf.app.flags.DEFINE_string(‘train_dir’, ‘/tmp/cifar10_train’,
“””Directory where to write event logs “””
“””and checkpoint.”””)
tf.app.flags.DEFINE_integer(‘max_steps’, 10000,
“””Number of batches to run.”””)
tf.app.flags.DEFINE_boolean(‘log_device_placement’, False,
“””Whether to log device placement.”””)
def train():
“””训练 CIFAR-10 for a number of steps.”””
with tf.Graph().as_default():
global_step = tf.Variable(0, trainable=False)
# 获得CIFAR-10的图片核标签
images, labels = cifar10.distorted_inputs()
# 建立一个图表,从推理模型计算逻辑预测
logits = cifar10.inference(images)
# 计算损失.
loss = cifar10.loss(logits, labels)
# 构建一个用一批示例数据训练模型并更新模型参数的图.
train_op = cifar10.train(loss, global_step)
# 创建一个saver.
saver = tf.train.Saver(tf.all_variables())
# Build the summary operation based on the TF collection of Summaries.
# 基于TF summaries的集合,建造summary操作
#summary_op = tf.merge_all_summaries()
summary_op = tf.summary.merge_all()
# 构建初始化操作
init = tf.initialize_all_variables()
# 开始运行计算图的操作
sess = tf.Session(config=tf.ConfigProto(
log_device_placement=FLAGS.log_device_placement))
sess.run(init)
# 开始队列的运行.
tf.train.start_queue_runners(sess=sess)
#summary_writer = tf.train.SummaryWriter(FLAGS.train_dir, sess.graph)
summary_writer = tf.summary.FileWriter(FLAGS.train_dir,sess.graph)
for step in xrange(FLAGS.max_steps):
start_time = time.time()
_, loss_value = sess.run([train_op, loss])
duration = time.time() – start_time
assert not np.isnan(loss_value), ‘Model diverged with loss = NaN’
if step % 10 == 0:
num_examples_per_step = FLAGS.batch_size
examples_per_sec = num_examples_per_step / duration
sec_per_batch = float(duration)
format_str = (‘%s: step %d, loss = %.2f (%.1f examples/sec; %.3f ‘
‘sec/batch)’)
print (format_str % (datetime.now(), step, loss_value,
examples_per_sec, sec_per_batch))
if step % 100 == 0:
summary_str = sess.run(summary_op)
summary_writer.add_summary(summary_str, step)
# Save the model checkpoint periodically.
if step % 1000 == 0 or (step + 1) == FLAGS.max_steps:
checkpoint_path = os.path.join(FLAGS.train_dir, ‘model.ckpt’)
saver.save(sess, checkpoint_path, global_step=step)
def main(argv=None): # pylint: disable=unused-argument
cifar10.maybe_download_and_extract()
if tf.gfile.Exists(FLAGS.train_dir):
tf.gfile.DeleteRecursively(FLAGS.train_dir)
tf.gfile.MakeDirs(FLAGS.train_dir)
train()
if __name__ == ‘__main__’:
tf.app.run()
4.cifar10_multi_gpu_train.py:在多GPU上训练CIFAR-10 模型.
5.cifar10_eval.py :评估一个CIFAR模型的预测性能
版权声明:本文内容由互联网用户自发贡献,该文观点仅代表作者本人。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如发现本站有涉嫌侵权/违法违规的内容, 请发送邮件至 举报,一经查实,本站将立刻删除。
文章由极客之音整理,本文链接:https://www.bmabk.com/index.php/post/195575.html