更新时间:2023-12-02 21:12:46
终于解决了.这包含一些特定于 S3 的代码和 S3 实例调用(!
命令),但您应该几乎可以将其切出以运行它.
Finally solved it. This contains some S3 specific code and S3 instance calls (the !
commands) but you should pretty much be able to slice that out to run this.
#!python3
"""
Assumes we've defined:
- A directory for our working files to live in, CONTAINER_DIR
- an arbitrary integer VERSION_INT
- We have established local and S3 paths for our model and their labels as variables, particularly `modelLabel` and `modelPath`
"""
# Create a versioned path for the models to live in
# See https://***.com/a/54014480/1877527
exportDir = os.path.join(CONTAINER_DIR, VERSION_INT)
if os.path.exists(exportDir):
shutil.rmtree(exportDir)
os.mkdir(exportDir)
import tensorflow as tf
def load_graph(model_file, returnElements= None):
"""
Code from v1.6.0 of Tensorflow's label_image.py example
"""
graph = tf.Graph()
graph_def = tf.GraphDef()
with open(model_file, "rb") as f:
graph_def.ParseFromString(f.read())
returns = None
with graph.as_default():
returns = tf.import_graph_def(graph_def, return_elements= returnElements)
if returnElements is None:
return graph
return graph, returns
# Add the serving metagraph tag
# We need the inputLayerName; in Inception we're feeding the resized tensor
# corresponding to resized_input_tensor_name
# May be able to get away with auto-determining this if not using Inception,
# but for Inception this is the 11th layer
inputLayerName = "Mul:0"
# Load the graph
if inputLayerName is None:
graph = load_graph(modelPath)
inputTensor = None
else:
graph, returns = load_graph(modelPath, returnElements= [inputLayerName])
inputTensor = returns[0]
with tf.Session(graph= graph) as sess:
# Read the layers
try:
from tensorflow.compat.v1.saved_model import simple_save
except (ModuleNotFoundError, ImportError):
from tensorflow.saved_model import simple_save
with graph.as_default():
layers = [n.name for n in graph.as_graph_def().node]
outName = layers.pop() + ":0"
if inputLayerName is None:
inputLayerName = layers.pop(0) + ":0"
print("Checking outlayer", outName)
outLayer = tf.get_default_graph().get_tensor_by_name(outName)
if inputTensor is None:
print("Checking inlayer", inputLayerName)
inputTensor = tf.get_default_graph().get_tensor_by_name(inputLayerName)
inputs = {
inputLayerName: inputTensor
}
outputs = {
outName: outLayer
}
simple_save(sess, exportDir, inputs, outputs)
print("Built a SavedModel")
# Put the model label into the artifact dir
modelLabelDest = os.path.join(exportDir, "saved_model.txt")
!cp {modelLabel} {modelLabelDest}
# Prep for serving
import datetime as dt
modelArtifact = f"livemodel_{dt.datetime.now().timestamp()}.tar.gz"
# Copy the version directory here to package
!cp -R {exportDir} ./
# gziptar it
!tar -czvf {modelArtifact} {VERSION_INT}
# Shove it back to S3 for serving
!aws s3 cp {modelArtifact} {bucketPath}
shutil.rmtree(VERSION_INT) # Cleanup
shutil.rmtree(exportDir) # Cleanup
此模型随后可部署为 Sagemaker 端点(以及任何其他 Tensorflow 服务环境)
This model is then deployable as a Sagemaker endpoint (and any other Tensorflow serving environment)