設定ONNX規定之設定、版本、輸入輸出。
def train() -> (
Annotated[
PyTorch2ONNX,
PyTorch2ONNXConfig(
args=torch.randn(1, 1, 224, 224, requires_grad=True),
export_params=True, # store the trained parameter weights inside
opset_version=10, # the ONNX version to export the model to
do_constant_folding=True, # whether to execute constant folding for optimization
input_names=["input"], # the model's input names
output_names=["output"], # the model's output names
dynamic_axes={
"input": {0: "batch_size"},
"output": {0: "batch_size"},
}, # variable length axes
),
]
):
return PyTorch2ONNX(model=torch_model)
與flyteFile, FlyteDirectory相同使用方法,下載s3
@task
def onnx_predict(model_file: ONNXFile) -> JPEGImageFile:
rt_session = onnxruntime.InferenceSession(model_file.download())