Coverage for /pythoncovmergedfiles/medio/medio/usr/local/lib/python3.11/site-packages/google/cloud/firestore_v1/base_vector_query.py: 37%

Shortcuts on this page

r m x   toggle line displays

j k   next/prev highlighted chunk

0   (zero) top of page

1   (one) first highlighted chunk

67 statements  

1# Copyright 2024 Google LLC All rights reserved. 

2# 

3# Licensed under the Apache License, Version 2.0 (the "License"); 

4# you may not use this file except in compliance with the License. 

5# You may obtain a copy of the License at 

6# 

7# http://www.apache.org/licenses/LICENSE-2.0 

8# 

9# Unless required by applicable law or agreed to in writing, software 

10# distributed under the License is distributed on an "AS IS" BASIS, 

11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 

12# See the License for the specific language governing permissions and 

13# limitations under the License. 

14 

15"""Classes for representing vector queries for the Google Cloud Firestore API. 

16""" 

17from __future__ import annotations 

18 

19import abc 

20from abc import ABC 

21from enum import Enum 

22from typing import TYPE_CHECKING, Any, Coroutine, Optional, Sequence, Tuple, Union 

23 

24from google.api_core import gapic_v1 

25from google.api_core import retry as retries 

26 

27from google.cloud.firestore_v1 import _helpers 

28from google.cloud.firestore_v1.types import query 

29from google.cloud.firestore_v1.vector import Vector 

30 

31if TYPE_CHECKING: # pragma: NO COVER 

32 from google.cloud.firestore_v1.async_stream_generator import AsyncStreamGenerator 

33 from google.cloud.firestore_v1.base_document import DocumentSnapshot 

34 from google.cloud.firestore_v1.query_profile import ExplainOptions 

35 from google.cloud.firestore_v1.query_results import QueryResultsList 

36 from google.cloud.firestore_v1.stream_generator import StreamGenerator 

37 

38 

39class DistanceMeasure(Enum): 

40 EUCLIDEAN = 1 

41 COSINE = 2 

42 DOT_PRODUCT = 3 

43 

44 

45class BaseVectorQuery(ABC): 

46 """Represents a vector query to the Firestore API.""" 

47 

48 def __init__(self, nested_query) -> None: 

49 self._nested_query = nested_query 

50 self._collection_ref = nested_query._parent 

51 self._vector_field: Optional[str] = None 

52 self._query_vector: Optional[Vector] = None 

53 self._limit: Optional[int] = None 

54 self._distance_measure: Optional[DistanceMeasure] = None 

55 self._distance_result_field: Optional[str] = None 

56 self._distance_threshold: Optional[float] = None 

57 

58 @property 

59 def _client(self): 

60 return self._collection_ref._client 

61 

62 def _to_protobuf(self) -> query.StructuredQuery: 

63 pb = query.StructuredQuery() 

64 

65 distance_measure_proto = None 

66 if self._distance_measure == DistanceMeasure.EUCLIDEAN: 

67 distance_measure_proto = ( 

68 query.StructuredQuery.FindNearest.DistanceMeasure.EUCLIDEAN 

69 ) 

70 elif self._distance_measure == DistanceMeasure.COSINE: 

71 distance_measure_proto = ( 

72 query.StructuredQuery.FindNearest.DistanceMeasure.COSINE 

73 ) 

74 elif self._distance_measure == DistanceMeasure.DOT_PRODUCT: 

75 distance_measure_proto = ( 

76 query.StructuredQuery.FindNearest.DistanceMeasure.DOT_PRODUCT 

77 ) 

78 else: 

79 raise ValueError("Invalid distance_measure") 

80 

81 # Coerce ints to floats as required by the protobuf. 

82 distance_threshold_proto = None 

83 if self._distance_threshold is not None: 

84 distance_threshold_proto = float(self._distance_threshold) 

85 

86 pb = self._nested_query._to_protobuf() 

87 pb.find_nearest = query.StructuredQuery.FindNearest( 

88 vector_field=query.StructuredQuery.FieldReference( 

89 field_path=self._vector_field 

90 ), 

91 query_vector=_helpers.encode_value(self._query_vector), 

92 distance_measure=distance_measure_proto, 

93 limit=self._limit, 

94 distance_result_field=self._distance_result_field, 

95 distance_threshold=distance_threshold_proto, 

96 ) 

97 return pb 

98 

99 def _prep_stream( 

100 self, 

101 transaction=None, 

102 retry: Union[retries.Retry, retries.AsyncRetry, object, None] = None, 

103 timeout: Optional[float] = None, 

104 explain_options: Optional[ExplainOptions] = None, 

105 ) -> Tuple[dict, str, dict]: 

106 parent_path, expected_prefix = self._collection_ref._parent_info() 

107 request = { 

108 "parent": parent_path, 

109 "structured_query": self._to_protobuf(), 

110 "transaction": _helpers.get_transaction_id(transaction), 

111 } 

112 kwargs = _helpers.make_retry_timeout_kwargs(retry, timeout) 

113 

114 if explain_options is not None: 

115 request["explain_options"] = explain_options._to_dict() 

116 

117 return request, expected_prefix, kwargs 

118 

119 @abc.abstractmethod 

120 def get( 

121 self, 

122 transaction=None, 

123 retry: retries.Retry 

124 | retries.AsyncRetry 

125 | object 

126 | None = gapic_v1.method.DEFAULT, 

127 timeout: Optional[float] = None, 

128 *, 

129 explain_options: Optional[ExplainOptions] = None, 

130 ) -> ( 

131 QueryResultsList[DocumentSnapshot] 

132 | Coroutine[Any, Any, QueryResultsList[DocumentSnapshot]] 

133 ): 

134 """Runs the vector query.""" 

135 raise NotImplementedError 

136 

137 def find_nearest( 

138 self, 

139 vector_field: str, 

140 query_vector: Union[Vector, Sequence[float]], 

141 limit: int, 

142 distance_measure: DistanceMeasure, 

143 *, 

144 distance_result_field: Optional[str] = None, 

145 distance_threshold: Optional[float] = None, 

146 ): 

147 """Finds the closest vector embeddings to the given query vector.""" 

148 if not isinstance(query_vector, Vector): 

149 self._query_vector = Vector(query_vector) 

150 else: 

151 self._query_vector = query_vector 

152 self._vector_field = vector_field 

153 self._limit = limit 

154 self._distance_measure = distance_measure 

155 self._distance_result_field = distance_result_field 

156 self._distance_threshold = distance_threshold 

157 return self 

158 

159 def stream( 

160 self, 

161 transaction=None, 

162 retry: retries.Retry 

163 | retries.AsyncRetry 

164 | object 

165 | None = gapic_v1.method.DEFAULT, 

166 timeout: Optional[float] = None, 

167 *, 

168 explain_options: Optional[ExplainOptions] = None, 

169 ) -> StreamGenerator[DocumentSnapshot] | AsyncStreamGenerator[DocumentSnapshot]: 

170 """Reads the documents in the collection that match this query.""" 

171 raise NotImplementedError