更新时间:2023-12-02 19:41:22
也许作者不再需要答案,但我能够使用 TensorFlow 2.1 保存和加载 DNNClassifier
maybe the author doesn't need the answer anymore but I was able to save and load a DNNClassifier using TensorFlow 2.1
# training.py
from pathlib import Path
import tensorflow as tf
....
# Creating the estimator
estimator = tf.estimator.DNNClassifier(
model_dir = <model_dir>,
hidden_units = [1000, 500],
feature_columns = feature_columns, # this is a list defined earlier
n_classes = 2,
optimizer = 'adam')
feature_spec = tf.feature_column.make_parse_example_spec(feature_columns)
export_input_fn = tf.estimator.export.build_parsing_serving_input_receiver_fn(feature_spec)
servable_model_path = Path(estimator.export_saved_model(<model_dir>, export_input_fn).decode('utf8'))
print(f'Model saved at {servable_model_path}')
对于加载,您找到了正确的方法,您只需要检索predict_fn
For loading, you found the correct method, you just need to retrieve the predict_fn
# testing.py
import tensorflow as tf
import pandas as pd
def predict_input_fn(test_df):
'''Convert your dataframe using tf.train.Example() and tf.train.Features()'''
examples = []
....
return tf.constant(examples)
test_df = pd.read_csv('test.csv', ...)
# Loading the estimator
predict_fn = tf.saved_model.load(<model_dir>).signatures['predict']
# Predict
predictions = predict_fn(examples=predict_input_fn(test_df))
希望这也能帮助其他人(:
Hope that this can help other people too (: