Coverage for /pythoncovmergedfiles/medio/medio/usr/local/lib/python3.11/site-packages/onnx/defs/__init__.py: 61%

Shortcuts on this page

r m x   toggle line displays

j k   next/prev highlighted chunk

0   (zero) top of page

1   (one) first highlighted chunk

61 statements  

1# Copyright (c) ONNX Project Contributors 

2# 

3# SPDX-License-Identifier: Apache-2.0 

4from __future__ import annotations 

5 

6__all__ = [ 

7 "C", 

8 "ONNX_DOMAIN", 

9 "ONNX_ML_DOMAIN", 

10 "AI_ONNX_PREVIEW_DOMAIN", 

11 "AI_ONNX_PREVIEW_TRAINING_DOMAIN", 

12 "has", 

13 "register_schema", 

14 "deregister_schema", 

15 "get_schema", 

16 "get_all_schemas", 

17 "get_all_schemas_with_history", 

18 "onnx_opset_version", 

19 "get_function_ops", 

20 "OpSchema", 

21 "SchemaError", 

22] 

23 

24import onnx 

25import onnx.onnx_cpp2py_export.defs as C # noqa: N812 

26 

27ONNX_DOMAIN = "" 

28ONNX_ML_DOMAIN = "ai.onnx.ml" 

29AI_ONNX_PREVIEW_DOMAIN = "ai.onnx.preview" 

30AI_ONNX_PREVIEW_TRAINING_DOMAIN = "ai.onnx.preview.training" 

31 

32 

33has = C.has_schema 

34get_schema = C.get_schema 

35get_all_schemas = C.get_all_schemas 

36get_all_schemas_with_history = C.get_all_schemas_with_history 

37deregister_schema = C.deregister_schema 

38 

39 

40def onnx_opset_version() -> int: 

41 """Return current opset for domain `ai.onnx`.""" 

42 return C.schema_version_map()[ONNX_DOMAIN][1] 

43 

44 

45def onnx_ml_opset_version() -> int: 

46 """Return current opset for domain `ai.onnx.ml`.""" 

47 return C.schema_version_map()[ONNX_ML_DOMAIN][1] 

48 

49 

50@property # type: ignore[misc] 

51def _function_proto(self): 

52 func_proto = onnx.FunctionProto() 

53 func_proto.ParseFromString(self._function_body) 

54 return func_proto 

55 

56 

57OpSchema = C.OpSchema 

58OpSchema.function_body = _function_proto # type: ignore[method-assign] 

59 

60 

61@property # type: ignore[misc] 

62def _non_deterministic(self): 

63 """Check if the operator is non-deterministic.""" 

64 return self.node_determinism != OpSchema.NodeDeterminism.Deterministic 

65 

66 

67OpSchema.non_deterministic = _non_deterministic # type: ignore[attr-defined] 

68 

69 

70@property # type: ignore[misc] 

71def _attribute_default_value(self): 

72 attr = onnx.AttributeProto() 

73 attr.ParseFromString(self._default_value) 

74 return attr 

75 

76 

77OpSchema.Attribute.default_value = _attribute_default_value # type: ignore[method-assign] 

78 

79 

80def _op_schema_repr(self) -> str: 

81 return f"""\ 

82OpSchema( 

83 name={self.name!r}, 

84 domain={self.domain!r}, 

85 since_version={self.since_version!r}, 

86 doc={self.doc!r}, 

87 type_constraints={self.type_constraints!r}, 

88 inputs={self.inputs!r}, 

89 outputs={self.outputs!r}, 

90 attributes={self.attributes!r} 

91)""" 

92 

93 

94OpSchema.__repr__ = _op_schema_repr # type: ignore[method-assign] 

95 

96 

97def _op_schema_formal_parameter_repr(self) -> str: 

98 return ( 

99 f"OpSchema.FormalParameter(name={self.name!r}, type_str={self.type_str!r}, " 

100 f"description={self.description!r}, param_option={self.option!r}, " 

101 f"is_homogeneous={self.is_homogeneous!r}, min_arity={self.min_arity!r}, " 

102 f"differentiation_category={self.differentiation_category!r})" 

103 ) 

104 

105 

106OpSchema.FormalParameter.__repr__ = _op_schema_formal_parameter_repr # type: ignore[method-assign] 

107 

108 

109def _op_schema_type_constraint_param_repr(self) -> str: 

110 return ( 

111 f"OpSchema.TypeConstraintParam(type_param_str={self.type_param_str!r}, " 

112 f"allowed_type_strs={self.allowed_type_strs!r}, description={self.description!r})" 

113 ) 

114 

115 

116OpSchema.TypeConstraintParam.__repr__ = _op_schema_type_constraint_param_repr # type: ignore[method-assign] 

117 

118 

119def _op_schema_attribute_repr(self) -> str: 

120 return ( 

121 f"OpSchema.Attribute(name={self.name!r}, type={self.type!r}, description={self.description!r}, " 

122 f"default_value={self.default_value!r}, required={self.required!r})" 

123 ) 

124 

125 

126OpSchema.Attribute.__repr__ = _op_schema_attribute_repr # type: ignore[method-assign] 

127 

128 

129def get_function_ops() -> list[OpSchema]: 

130 """Return operators defined as functions.""" 

131 schemas = C.get_all_schemas() 

132 return [ 

133 schema 

134 for schema in schemas 

135 if schema.has_function or schema.has_context_dependent_function # type: ignore[attr-defined] 

136 ] 

137 

138 

139SchemaError = C.SchemaError 

140 

141 

142def register_schema(schema: OpSchema) -> None: 

143 """Register a user provided OpSchema. 

144 

145 The function extends available operator set versions for the provided domain if necessary. 

146 

147 Args: 

148 schema: The OpSchema to register. 

149 """ 

150 version_map = C.schema_version_map() 

151 domain = schema.domain 

152 version = schema.since_version 

153 min_version, max_version = version_map.get(domain, (version, version)) 

154 if domain not in version_map or not (min_version <= version <= max_version): 

155 min_version = min(min_version, version) 

156 max_version = max(max_version, version) 

157 C.set_domain_to_version(schema.domain, min_version, max_version) 

158 C.register_schema(schema)