Tensorflow入门——改进RNN预测牛奶产量

in #cn-stem5 years ago (edited)

photo of milk bottle lot

image source from unsplash by Mehrshad Rajabi

上一篇文章我们用Keras搭建GRU神经网络,通过对前13年牛奶产量的学习,成功预测了地最后1年牛奶的产量。

该模型是多对一的输入/输出结构,也就意味着12个月的数据输入,只能输出1个月的数据。有没有可能改进模型,让输出输入的数量一致,以提高预测效率呢?这篇文章我们就来改进GRU模型,实现多对多的结构。

同样的,为了方便与读者交流,所有的代码都放在了这里:

Repository:

https://github.com/zht007/tensorflow-practice

1. 数据预处理

数据的导入,训练集测试集分离以及归一化与之前一致,就不赘述了。需要改变的是GRU输入输出Shape。

  • 设计一个连续的数据窗口,窗口中包含24个月的数据。
  • 由于是多对多的结构,前12个月数据X为输入的Feature,后12个月数据为label与神经网络的输出做对比。
  • 数据窗口按月平移,这样一共可以产生133个组数据。

采用相同的帮助函数,仅仅改变future_monthes的数量

def build_train_data(data, past_monthes = 12, future_monthes = 12):
  X_train, Y_train = [],[]
  
  for i in range(data.shape[0] - past_monthes - future_monthes):
    X_train.append(np.array(data[i:i + past_monthes]))
    Y_train.append(np.array(data[i + past_monthes:i + past_monthes + future_monthes]))
    
  return np.array(X_train).reshape([-1,12]), np.array(Y_train).reshape([-1,12])
     

调用帮助函数获得输入和输出

x, y = build_train_data(train_scaled)

2. GRU神经网络

2.1 None-Stateful结构

多对多结构的GRU与多对一的GRU结构没有太大的变化,唯一的区别是最后的Dense层需要用 layers.TimeDistributed()连接,以便将所有时间序列上的输出都传给Dense层,而不仅仅是最后一位。

model_layers = [
    layers.Reshape((SEQLEN,1),input_shape=(SEQLEN,)),
    layers.GRU(RNN_CELLSIZE, return_sequences=True),
    layers.GRU(RNN_CELLSIZE, return_sequences=True),
    layers.TimeDistributed(layers.Dense(1)),
    layers.Flatten()
    
]
model = Sequential(model_layers)
model.summary()

部分代码参考 github with lisence Apache-2.0

GRU的结构如下图所示

image-20190416121948606

训练1000个epoch后的loss变化如图:

image-20190416122309696

我们发现loss在下降的过程中噪音非常大,而且这反复的变化似乎是成规律的。这是由于我们在训练的过程中,每个Batch是相对独立的,其训练之后产生的状态(State)并没有传到下一个Batch中。要解决这个问题,我们需要在GRU中开启Stateful。

2.2 Stateful结构

开启Stateful之后,我们必须在第一个输入层指定batch_size,为了方便后面的预测,这里的batch_size 设定为1。记得GRU的每一层都需要开启Stateful。

RNN_CELLSIZE = 10
SEQLEN = 12
BATCHSIZE = 1

model_layers = [
    layers.Reshape((SEQLEN,1),input_shape=(SEQLEN,),batch_size = BATCHSIZE),
    layers.GRU(RNN_CELLSIZE, return_sequences=True, stateful=True),
    layers.GRU(RNN_CELLSIZE, return_sequences=True, stateful=True),
    layers.TimeDistributed(layers.Dense(1)),
    layers.Flatten()
    
]
model = Sequential(model_layers)
model.summary()

部分代码参考 github with lisence Apache-2.0

如下图所示,这里我们仅仅训练了100个epoch,loss的下降就非常平滑了。

image-20190416123012594

3. 模型预测

同样的,我们用模型预测最后一年12个月牛奶的产量,这里我们将三个模型的预测结果做了对比。

image-20190416124233619

可以看到多对多的模型,尤其是开启了Stateful之后的预测更加地准确。

image-20190416124158893

用第一年的数据生成后面13年的数据如上图所示,可以发现,多对多的模型同样更具优势。


参考资料

[1]https://codelabs.developers.google.com/codelabs/cloud-tensorflow-mnist/#0

[2]https://github.com/GoogleCloudPlatform/tensorflow-without-a-phd.git

[3]https://www.tensorflow.org/api_docs/

[4]https://datamarket.com/data/set/22ox/monthly-milk-production-pounds-per-cow-jan-62-dec-75#!ds=22ox&display=line


相关文章

Tensorflow入门——RNN预测牛奶产量

AI学习笔记——循环神经网络(RNN)的基本概念

Tensorflow入门——单层神经网络识别MNIST手写数字

Tensorflow入门——多层神经网络MNIST手写数字识别

AI学习笔记——Tensorflow中的Optimizer

Tensorflow入门——分类问题cross_entropy的选择

AI学习笔记——Tensorflow入门

Tensorflow入门——Keras简介和上手


同步到我的简书

https://www.jianshu.com/u/bd506afc6fc1

Sort:  

修改建议:

  1. 参考资料部分的序号存在错误。

谢谢指出,已修复



This post has been voted on by the SteemSTEM curation team and voting trail. It is elligible for support from @curie.

If you appreciate the work we are doing, then consider supporting our witness stem.witness. Additional witness support to the curie witness would be appreciated as well.

For additional information please join us on the SteemSTEM discord and to get to know the rest of the community!

Please consider setting @steemstem as a beneficiary to your post to get a stronger support.

Please consider using the steemstem.io app to get a stronger support.

Congratulations @hongtao! You have completed the following achievement on the Steem blockchain and have been rewarded with new badge(s) :

You received more than 6000 upvotes. Your next target is to reach 7000 upvotes.

You can view your badges on your Steem Board and compare to others on the Steem Ranking
If you no longer want to receive notifications, reply to this comment with the word STOP

You can upvote this notification to help all Steem users. Learn how here!