Module geoengine.ml
Util functions for machine learning
Expand source code
'''
Util functions for machine learning
'''
from pathlib import Path
import tempfile
from dataclasses import dataclass
from onnx import TypeProto, TensorProto, ModelProto
from onnx.helper import tensor_dtype_to_string
from geoengine_openapi_client.models import MlModelMetadata, MlModel, RasterDataType
import geoengine_openapi_client
from geoengine.auth import get_session
from geoengine.datasets import UploadId
from geoengine.error import InputException
@dataclass
class MlModelConfig:
'''Configuration for an ml model'''
name: str
metadata: MlModelMetadata
file_name: str = "model.onnx"
display_name: str = "My Ml Model"
description: str = "My Ml Model Description"
def register_ml_model(onnx_model: ModelProto,
model_config: MlModelConfig,
upload_timeout: int = 3600,
register_timeout: int = 60):
'''Uploads an onnx file and registers it as an ml model'''
validate_model_config(
onnx_model,
input_type=model_config.metadata.input_type,
output_type=model_config.metadata.output_type,
num_input_bands=model_config.metadata.num_input_bands,
)
session = get_session()
with geoengine_openapi_client.ApiClient(session.configuration) as api_client:
with tempfile.TemporaryDirectory() as temp_dir:
file_name = Path(temp_dir) / model_config.file_name
with open(file_name, 'wb') as file:
file.write(onnx_model.SerializeToString())
uploads_api = geoengine_openapi_client.UploadsApi(api_client)
response = uploads_api.upload_handler([str(file_name)],
_request_timeout=upload_timeout)
upload_id = UploadId.from_response(response)
ml_api = geoengine_openapi_client.MLApi(api_client)
model = MlModel(name=model_config.name, upload=str(upload_id), metadata=model_config.metadata,
display_name=model_config.display_name, description=model_config.description)
ml_api.add_ml_model(model, _request_timeout=register_timeout)
def validate_model_config(onnx_model: ModelProto, *,
input_type: RasterDataType,
output_type: RasterDataType,
num_input_bands: int):
'''Validates the model config. Raises an exception if the model config is invalid'''
def check_data_type(data_type: TypeProto, expected_type: RasterDataType, prefix: 'str'):
if not data_type.tensor_type:
raise InputException('Only tensor input types are supported')
elem_type = data_type.tensor_type.elem_type
if elem_type != RASTER_TYPE_TO_ONNX_TYPE[expected_type]:
elem_type_str = tensor_dtype_to_string(elem_type)
raise InputException(f'Model {prefix} type `{elem_type_str}` does not match the '
f'expected type `{expected_type}`')
for domain in onnx_model.opset_import:
if domain.domain != '':
continue
if domain.version != 9:
raise InputException('Only ONNX models with opset version 9 are supported')
model_inputs = onnx_model.graph.input
model_outputs = onnx_model.graph.output
if len(model_inputs) != 1:
raise InputException('Models with multiple inputs are not supported')
check_data_type(model_inputs[0].type, input_type, 'input')
dims = model_inputs[0].type.tensor_type.shape.dim
if len(dims) != 2:
raise InputException('Only 2D input tensors are supported')
if not dims[1].dim_value:
raise InputException('Dimension 1 of the input tensor must have a length')
if dims[1].dim_value != num_input_bands:
raise InputException(f'Model input has {dims[1].dim_value} bands, but {num_input_bands} bands are expected')
if len(model_outputs) < 1:
raise InputException('Models with no outputs are not supported')
check_data_type(model_outputs[0].type, output_type, 'output')
RASTER_TYPE_TO_ONNX_TYPE = {
RasterDataType.F32: TensorProto.FLOAT,
RasterDataType.F64: TensorProto.DOUBLE,
RasterDataType.U8: TensorProto.UINT8,
RasterDataType.U16: TensorProto.UINT16,
RasterDataType.U32: TensorProto.UINT32,
RasterDataType.U64: TensorProto.UINT64,
RasterDataType.I8: TensorProto.INT8,
RasterDataType.I16: TensorProto.INT16,
RasterDataType.I32: TensorProto.INT32,
RasterDataType.I64: TensorProto.INT64,
}
Functions
def register_ml_model(onnx_model: onnx.onnx_ml_pb2.ModelProto, model_config: MlModelConfig, upload_timeout: int = 3600, register_timeout: int = 60)
-
Uploads an onnx file and registers it as an ml model
Expand source code
def register_ml_model(onnx_model: ModelProto, model_config: MlModelConfig, upload_timeout: int = 3600, register_timeout: int = 60): '''Uploads an onnx file and registers it as an ml model''' validate_model_config( onnx_model, input_type=model_config.metadata.input_type, output_type=model_config.metadata.output_type, num_input_bands=model_config.metadata.num_input_bands, ) session = get_session() with geoengine_openapi_client.ApiClient(session.configuration) as api_client: with tempfile.TemporaryDirectory() as temp_dir: file_name = Path(temp_dir) / model_config.file_name with open(file_name, 'wb') as file: file.write(onnx_model.SerializeToString()) uploads_api = geoengine_openapi_client.UploadsApi(api_client) response = uploads_api.upload_handler([str(file_name)], _request_timeout=upload_timeout) upload_id = UploadId.from_response(response) ml_api = geoengine_openapi_client.MLApi(api_client) model = MlModel(name=model_config.name, upload=str(upload_id), metadata=model_config.metadata, display_name=model_config.display_name, description=model_config.description) ml_api.add_ml_model(model, _request_timeout=register_timeout)
def validate_model_config(onnx_model: onnx.onnx_ml_pb2.ModelProto, *, input_type: geoengine_openapi_client.models.raster_data_type.RasterDataType, output_type: geoengine_openapi_client.models.raster_data_type.RasterDataType, num_input_bands: int)
-
Validates the model config. Raises an exception if the model config is invalid
Expand source code
def validate_model_config(onnx_model: ModelProto, *, input_type: RasterDataType, output_type: RasterDataType, num_input_bands: int): '''Validates the model config. Raises an exception if the model config is invalid''' def check_data_type(data_type: TypeProto, expected_type: RasterDataType, prefix: 'str'): if not data_type.tensor_type: raise InputException('Only tensor input types are supported') elem_type = data_type.tensor_type.elem_type if elem_type != RASTER_TYPE_TO_ONNX_TYPE[expected_type]: elem_type_str = tensor_dtype_to_string(elem_type) raise InputException(f'Model {prefix} type `{elem_type_str}` does not match the ' f'expected type `{expected_type}`') for domain in onnx_model.opset_import: if domain.domain != '': continue if domain.version != 9: raise InputException('Only ONNX models with opset version 9 are supported') model_inputs = onnx_model.graph.input model_outputs = onnx_model.graph.output if len(model_inputs) != 1: raise InputException('Models with multiple inputs are not supported') check_data_type(model_inputs[0].type, input_type, 'input') dims = model_inputs[0].type.tensor_type.shape.dim if len(dims) != 2: raise InputException('Only 2D input tensors are supported') if not dims[1].dim_value: raise InputException('Dimension 1 of the input tensor must have a length') if dims[1].dim_value != num_input_bands: raise InputException(f'Model input has {dims[1].dim_value} bands, but {num_input_bands} bands are expected') if len(model_outputs) < 1: raise InputException('Models with no outputs are not supported') check_data_type(model_outputs[0].type, output_type, 'output')
Classes
class MlModelConfig (name: str, metadata: geoengine_openapi_client.models.ml_model_metadata.MlModelMetadata, file_name: str = 'model.onnx', display_name: str = 'My Ml Model', description: str = 'My Ml Model Description')
-
Configuration for an ml model
Expand source code
@dataclass class MlModelConfig: '''Configuration for an ml model''' name: str metadata: MlModelMetadata file_name: str = "model.onnx" display_name: str = "My Ml Model" description: str = "My Ml Model Description"
Class variables
var description : str
var display_name : str
var file_name : str
var metadata : geoengine_openapi_client.models.ml_model_metadata.MlModelMetadata
var name : str