博客
关于我
强烈建议你试试无所不能的chatGPT,快点击我
Using Tensorflow SavedModel Format to Save and Do Predictions
阅读量:6907 次
发布时间:2019-06-27

本文共 2719 字,大约阅读时间需要 9 分钟。

We are now trying to deploy our Deep Learning model onto Google Cloud. It is required to use Google Function to trigger the Deep Learning predictions. However, when pre-trained models are stored on cloud, it is impossible to get the exact directory path and restore the tensorflow session like what we did on local machine.

 

So we turn to use SavedModel, which is quite like a 'Prediction Mode' of tensorflow. According to official turotial: a SavedModel contains a complete TensorFlow program, including weights and computation. It does not require the original model building code to run, which makes it useful for sharing or deploying.

 

The Definition of our graph, just here to show the input and output tensors:

'''RNN Model Definition'''tf.reset_default_graph()''''''#define inputstf_x = tf.placeholder(tf.float32, [None, window_size,1],name='x')tf_y = tf.placeholder(tf.int32, [None, 2],name='y')cells = [tf.keras.layers.LSTMCell(units=n) for n in num_units]stacked_rnn_cell = tf.keras.layers.StackedRNNCells(cells)outputs, (h_c, h_n) = tf.nn.dynamic_rnn(        stacked_rnn_cell,                   # cell you have chosen        tf_x,                      # input        initial_state=None,         # the initial hidden state        dtype=tf.float32,           # must given if set initial_state = None        time_major=False,           # False: (batch, time step, input); True: (time step, batch, input))l1 = tf.layers.dense(outputs[:, -1, :],32,activation=tf.nn.relu,name='l1')l2 = tf.layers.dense(l1,8,activation=tf.nn.relu,name='l6')pred = tf.layers.dense(l2,2,activation=tf.nn.relu,name='pred')with tf.name_scope('loss'):    cross_entropy =  tf.nn.softmax_cross_entropy_with_logits_v2(labels=tf_y, logits=pred)     loss = tf.reduce_mean(cross_entropy)    tf.summary.scalar("loss",tensor=loss)train_op = tf.train.AdamOptimizer(LR).minimize(loss)accuracy = tf.reduce_mean(tf.cast(tf.equal(tf.argmax(tf_y, axis=1), tf.argmax(pred, axis=1)), tf.float32))init_op = tf.group(tf.global_variables_initializer(), tf.local_variables_initializer()) saver = tf.train.Saver()

 

Train and Save the model, we use simple_save:

sess = tf.Session()sess.run(init_op)for i in range(0,n):    sess.run(train_op,{tf_x:batch_X , tf_y:batch_y})    ...   tf.saved_model.simple_save(sess, 'simple_save/model', \                           inputs={"x": tf_x},outputs={"pred": pred})sess.close()

 

Restore and Predict:

with tf.Session(graph=tf.Graph()) as sess:    tf.saved_model.loader.load(sess, ["serve"], 'simple_save_test/model')    batch = sess.run('pred/Relu:0',feed_dict={'x:0':dataX.reshape([-1,24,1])})     print(batch)

 

Reference:

 medium post:

The official tutorial of Tensorflow:

转载于:https://www.cnblogs.com/rhyswang/p/10971237.html

你可能感兴趣的文章
方法:C#在WinForm中如何将Image存入SQL2000和取出显示
查看>>
码农翻身
查看>>
在windows下运行docker的问题【Error getting IP address: ***】
查看>>
python基础一 day16 匿名函数
查看>>
参考_Android中,如何新建一个界面,并且实现从当前界面切换到到刚才新建的(另外一个)界面...
查看>>
Linux常用命令大全
查看>>
Jenkins卸载方法(Windows/Linux/MacOS)
查看>>
《过节》——北岛
查看>>
并发、并行、同步、异步、多线程的区别?
查看>>
JavaScript的写类方式(5)——转
查看>>
Java并发编程笔记—摘抄—基础知识
查看>>
simple-spring-memcached统一缓存的使用实例
查看>>
Codeforces 600E - Lomsat gelral(树上启发式合并)
查看>>
[Hnoi2013]消毒
查看>>
[HNOI2015]开店
查看>>
容斥与反演
查看>>
GitHub 配置指南
查看>>
swift swift学习笔记--函数和闭包
查看>>
Java 面向对象,封装,继承
查看>>
ISO语言代码
查看>>