您当前的位置: 首页 >  tensorflow

Better Bench

暂无认证

  • 4浏览

    0关注

    695博文

    0收益

  • 0浏览

    0点赞

    0打赏

    0留言

私信
关注
热门博文

【Tensorflow 2】Keras API+Estimator的使用

Better Bench 发布时间:2021-03-10 18:00:28 ,浏览量:4

1 原因

提高GPU利用率

2 Example

参考官网的介绍通过 Keras 模型创建 Estimator

# 通过keras API 构建模型
model  = build_model()
# 产生训练集sample 和label
x,y = generator_data(data_size,SNRdb)
# 用Dataset封装,加快训练
dataset_xy=tf.data.Dataset.from_tensor_slices((x,y)).shuffle(5000).batch(batchs).prefetch(tf.data.experimental.AUTOTUNE).repeat() 
# 临时文件
import tempfile
model_dir = tempfile.mkdtemp()
# 用Estimator进行训练
keras_estimator = tf.keras.estimator.model_to_estimator(keras_model=model,model_dir=model_dir)

# 预测
valid_data =...#Dataset格式的验证集
eval_result = keras_estimator.evaluate(input_fn=valid_data, steps=10)
print('Eval result: {}'.format(eval_result))
关注
打赏
1665674626
查看更多评论
立即登录/注册

微信扫码登录

0.5099s