Coverage for /pythoncovmergedfiles/medio/medio/usr/local/lib/python3.11/site-packages/onnx/defs/__init__.py: 61%
Shortcuts on this page
r m x toggle line displays
j k next/prev highlighted chunk
0 (zero) top of page
1 (one) first highlighted chunk
Shortcuts on this page
r m x toggle line displays
j k next/prev highlighted chunk
0 (zero) top of page
1 (one) first highlighted chunk
1# Copyright (c) ONNX Project Contributors
2#
3# SPDX-License-Identifier: Apache-2.0
4from __future__ import annotations
6__all__ = [
7 "C",
8 "ONNX_DOMAIN",
9 "ONNX_ML_DOMAIN",
10 "AI_ONNX_PREVIEW_DOMAIN",
11 "AI_ONNX_PREVIEW_TRAINING_DOMAIN",
12 "has",
13 "register_schema",
14 "deregister_schema",
15 "get_schema",
16 "get_all_schemas",
17 "get_all_schemas_with_history",
18 "onnx_opset_version",
19 "get_function_ops",
20 "OpSchema",
21 "SchemaError",
22]
24import onnx
25import onnx.onnx_cpp2py_export.defs as C # noqa: N812
27ONNX_DOMAIN = ""
28ONNX_ML_DOMAIN = "ai.onnx.ml"
29AI_ONNX_PREVIEW_DOMAIN = "ai.onnx.preview"
30AI_ONNX_PREVIEW_TRAINING_DOMAIN = "ai.onnx.preview.training"
33has = C.has_schema
34get_schema = C.get_schema
35get_all_schemas = C.get_all_schemas
36get_all_schemas_with_history = C.get_all_schemas_with_history
37deregister_schema = C.deregister_schema
40def onnx_opset_version() -> int:
41 """Return current opset for domain `ai.onnx`."""
42 return C.schema_version_map()[ONNX_DOMAIN][1]
45def onnx_ml_opset_version() -> int:
46 """Return current opset for domain `ai.onnx.ml`."""
47 return C.schema_version_map()[ONNX_ML_DOMAIN][1]
50@property # type: ignore[misc]
51def _function_proto(self):
52 func_proto = onnx.FunctionProto()
53 func_proto.ParseFromString(self._function_body)
54 return func_proto
57OpSchema = C.OpSchema
58OpSchema.function_body = _function_proto # type: ignore[method-assign]
61@property # type: ignore[misc]
62def _non_deterministic(self):
63 """Check if the operator is non-deterministic."""
64 return self.node_determinism != OpSchema.NodeDeterminism.Deterministic
67OpSchema.non_deterministic = _non_deterministic # type: ignore[attr-defined]
70@property # type: ignore[misc]
71def _attribute_default_value(self):
72 attr = onnx.AttributeProto()
73 attr.ParseFromString(self._default_value)
74 return attr
77OpSchema.Attribute.default_value = _attribute_default_value # type: ignore[method-assign]
80def _op_schema_repr(self) -> str:
81 return f"""\
82OpSchema(
83 name={self.name!r},
84 domain={self.domain!r},
85 since_version={self.since_version!r},
86 doc={self.doc!r},
87 type_constraints={self.type_constraints!r},
88 inputs={self.inputs!r},
89 outputs={self.outputs!r},
90 attributes={self.attributes!r}
91)"""
94OpSchema.__repr__ = _op_schema_repr # type: ignore[method-assign]
97def _op_schema_formal_parameter_repr(self) -> str:
98 return (
99 f"OpSchema.FormalParameter(name={self.name!r}, type_str={self.type_str!r}, "
100 f"description={self.description!r}, param_option={self.option!r}, "
101 f"is_homogeneous={self.is_homogeneous!r}, min_arity={self.min_arity!r}, "
102 f"differentiation_category={self.differentiation_category!r})"
103 )
106OpSchema.FormalParameter.__repr__ = _op_schema_formal_parameter_repr # type: ignore[method-assign]
109def _op_schema_type_constraint_param_repr(self) -> str:
110 return (
111 f"OpSchema.TypeConstraintParam(type_param_str={self.type_param_str!r}, "
112 f"allowed_type_strs={self.allowed_type_strs!r}, description={self.description!r})"
113 )
116OpSchema.TypeConstraintParam.__repr__ = _op_schema_type_constraint_param_repr # type: ignore[method-assign]
119def _op_schema_attribute_repr(self) -> str:
120 return (
121 f"OpSchema.Attribute(name={self.name!r}, type={self.type!r}, description={self.description!r}, "
122 f"default_value={self.default_value!r}, required={self.required!r})"
123 )
126OpSchema.Attribute.__repr__ = _op_schema_attribute_repr # type: ignore[method-assign]
129def get_function_ops() -> list[OpSchema]:
130 """Return operators defined as functions."""
131 schemas = C.get_all_schemas()
132 return [
133 schema
134 for schema in schemas
135 if schema.has_function or schema.has_context_dependent_function # type: ignore[attr-defined]
136 ]
139SchemaError = C.SchemaError
142def register_schema(schema: OpSchema) -> None:
143 """Register a user provided OpSchema.
145 The function extends available operator set versions for the provided domain if necessary.
147 Args:
148 schema: The OpSchema to register.
149 """
150 version_map = C.schema_version_map()
151 domain = schema.domain
152 version = schema.since_version
153 min_version, max_version = version_map.get(domain, (version, version))
154 if domain not in version_map or not (min_version <= version <= max_version):
155 min_version = min(min_version, version)
156 max_version = max(max_version, version)
157 C.set_domain_to_version(schema.domain, min_version, max_version)
158 C.register_schema(schema)