博客
关于我
强烈建议你试试无所不能的chatGPT,快点击我
神经网络5:循环神经网络1
阅读量:6720 次
发布时间:2019-06-25

本文共 2100 字,大约阅读时间需要 7 分钟。

import tensorflow as tffrom tensorflow.examples.tutorials.mnist import input_datamnist = input_data.read_data_sets("MNIST_data", one_hot=True)learn_rate = 0.001train_iters = 100000batch_size = 128n_inputs = 28n_steps = 28n_hidden_units = 128n_classes = 10in_x = tf.placeholder(tf.float32, [None, n_steps, n_inputs])in_label = tf.placeholder(tf.float32, [None, n_classes])weights = {    'in': tf.Variable(tf.random_normal([n_inputs, n_hidden_units])),    'out': tf.Variable(tf.random_normal([n_hidden_units, n_classes]))}biases = {    'in': tf.Variable(tf.constant(value=0.1, shape=[n_hidden_units, ])),    'out': tf.Variable(tf.constant(value=0.1, shape=[n_classes, ]))}def RNN(inputs, weights, biases):    inputs = tf.reshape(inputs, [-1, n_inputs])    inputs_in = tf.matmul(inputs, weights['in']) + biases['in']    inputs_in = tf.reshape(inputs_in, [-1, n_steps, n_hidden_units])    # cell    lstm_cell = tf.nn.rnn_cell.BasicLSTMCell(num_units=n_hidden_units, forget_bias=1.0, state_is_tuple=True)    _init_state = lstm_cell.zero_state(batch_size=batch_size, dtype=tf.float32)    outputs, states = tf.nn.dynamic_rnn(lstm_cell, inputs_in, initial_state=_init_state, time_major=False)    results = tf.matmul(states[1], weights['out']) + biases['out']    return resultspred = RNN(in_x, weights, biases)cost = tf.reduce_mean(tf.nn.softmax_cross_entropy_with_logits(logits=pred, labels=in_label))train = tf.train.AdamOptimizer(learn_rate).minimize(cost)correct_pred = tf.equal(tf.argmax(pred, 1), tf.argmax(in_label, 1))accuracy = tf.reduce_mean(tf.cast(correct_pred, tf.float32))init = tf.global_variables_initializer()with tf.Session() as sess:    sess.run(init)    step = 0    while step * batch_size < train_iters:        batch_x, batch_label = mnist.train.next_batch(batch_size)        batch_x = batch_x.reshape([batch_size, n_steps, n_inputs])        sess.run([train], feed_dict={in_x: batch_x, in_label: batch_label})        if step % 20 == 0:            print(sess.run(accuracy, feed_dict={in_x: batch_x, in_label: batch_label}))        step += 1

 

转载于:https://www.cnblogs.com/infoo/p/9509017.html

你可能感兴趣的文章
Quartz在Spring中如何动态配置时间
查看>>
css实现正方形
查看>>
高性能Socket服务器编程-01
查看>>
gentoo系统安装(详细)
查看>>
Spring Cloud(二)Consul 服务治理实现
查看>>
mysql备份还原(视图、存储过程)
查看>>
快速配置oralce11g安装环境脚本
查看>>
int.Parse
查看>>
光纤跳线
查看>>
day02:管道符、shell及环境变量
查看>>
php设计模式——适配器模式
查看>>
C#文件、文件夹操作
查看>>
MySQL编译安装加入service
查看>>
以rsync进行同步镜像备份
查看>>
热烈祝贺VMware View4.5荣获“2010年度最佳产品”大奖
查看>>
ORACLE 11G 中表空间传输 TransportableTablespace
查看>>
自动化1
查看>>
Jenkins 2.32.3参数化构建maven项目
查看>>
使用Oracle存储过程批量生成测试数据
查看>>
正则表达式 - ×××
查看>>