Coverage for /pythoncovmergedfiles/medio/medio/usr/local/lib/python3.11/site-packages/google/cloud/firestore_v1/base_aggregation.py: 40%

Shortcuts on this page

r m x   toggle line displays

j k   next/prev highlighted chunk

0   (zero) top of page

1   (one) first highlighted chunk

104 statements  

1# Copyright 2023 Google LLC All rights reserved. 

2# 

3# Licensed under the Apache License, Version 2.0 (the "License"); 

4# you may not use this file except in compliance with the License. 

5# You may obtain a copy of the License at 

6# 

7# http://www.apache.org/licenses/LICENSE-2.0 

8# 

9# Unless required by applicable law or agreed to in writing, software 

10# distributed under the License is distributed on an "AS IS" BASIS, 

11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 

12# See the License for the specific language governing permissions and 

13# limitations under the License. 

14 

15"""Classes for representing aggregation queries for the Google Cloud Firestore API. 

16 

17A :class:`~google.cloud.firestore_v1.aggregation.AggregationQuery` can be created directly from 

18a :class:`~google.cloud.firestore_v1.collection.Collection` and that can be 

19a more common way to create an aggregation query than direct usage of the constructor. 

20""" 

21from __future__ import annotations 

22 

23import abc 

24 

25from abc import ABC 

26from typing import TYPE_CHECKING, Any, Coroutine, List, Optional, Tuple, Union 

27 

28from google.api_core import gapic_v1 

29from google.api_core import retry as retries 

30 

31from google.cloud.firestore_v1 import _helpers 

32from google.cloud.firestore_v1.field_path import FieldPath 

33from google.cloud.firestore_v1.types import ( 

34 StructuredAggregationQuery, 

35) 

36 

37# Types needed only for Type Hints 

38if TYPE_CHECKING: # pragma: NO COVER 

39 from google.cloud.firestore_v1 import transaction 

40 from google.cloud.firestore_v1.async_stream_generator import AsyncStreamGenerator 

41 from google.cloud.firestore_v1.query_profile import ExplainOptions 

42 from google.cloud.firestore_v1.query_results import QueryResultsList 

43 from google.cloud.firestore_v1.stream_generator import ( 

44 StreamGenerator, 

45 ) 

46 

47 import datetime 

48 

49 

50class AggregationResult(object): 

51 """ 

52 A class representing result from Aggregation Query 

53 :type alias: str 

54 :param alias: The alias for the aggregation. 

55 :type value: int 

56 :param value: The resulting value from the aggregation. 

57 :type read_time: 

58 :param value: The resulting read_time 

59 """ 

60 

61 def __init__(self, alias: str, value: float, read_time=None): 

62 self.alias = alias 

63 self.value = value 

64 self.read_time = read_time 

65 

66 def __repr__(self): 

67 return f"<Aggregation alias={self.alias}, value={self.value}, readtime={self.read_time}>" 

68 

69 

70class BaseAggregation(ABC): 

71 def __init__(self, alias: str | None = None): 

72 self.alias = alias 

73 

74 @abc.abstractmethod 

75 def _to_protobuf(self): 

76 """Convert this instance to the protobuf representation""" 

77 

78 

79class CountAggregation(BaseAggregation): 

80 def __init__(self, alias: str | None = None): 

81 super(CountAggregation, self).__init__(alias=alias) 

82 

83 def _to_protobuf(self): 

84 """Convert this instance to the protobuf representation""" 

85 aggregation_pb = StructuredAggregationQuery.Aggregation() 

86 if self.alias: 

87 aggregation_pb.alias = self.alias 

88 aggregation_pb.count = StructuredAggregationQuery.Aggregation.Count() 

89 return aggregation_pb 

90 

91 

92class SumAggregation(BaseAggregation): 

93 def __init__(self, field_ref: str | FieldPath, alias: str | None = None): 

94 # convert field path to string if needed 

95 field_str = ( 

96 field_ref.to_api_repr() if isinstance(field_ref, FieldPath) else field_ref 

97 ) 

98 self.field_ref: str = field_str 

99 super(SumAggregation, self).__init__(alias=alias) 

100 

101 def _to_protobuf(self): 

102 """Convert this instance to the protobuf representation""" 

103 aggregation_pb = StructuredAggregationQuery.Aggregation() 

104 if self.alias: 

105 aggregation_pb.alias = self.alias 

106 aggregation_pb.sum = StructuredAggregationQuery.Aggregation.Sum() 

107 aggregation_pb.sum.field.field_path = self.field_ref 

108 return aggregation_pb 

109 

110 

111class AvgAggregation(BaseAggregation): 

112 def __init__(self, field_ref: str | FieldPath, alias: str | None = None): 

113 # convert field path to string if needed 

114 field_str = ( 

115 field_ref.to_api_repr() if isinstance(field_ref, FieldPath) else field_ref 

116 ) 

117 self.field_ref: str = field_str 

118 super(AvgAggregation, self).__init__(alias=alias) 

119 

120 def _to_protobuf(self): 

121 """Convert this instance to the protobuf representation""" 

122 aggregation_pb = StructuredAggregationQuery.Aggregation() 

123 if self.alias: 

124 aggregation_pb.alias = self.alias 

125 aggregation_pb.avg = StructuredAggregationQuery.Aggregation.Avg() 

126 aggregation_pb.avg.field.field_path = self.field_ref 

127 return aggregation_pb 

128 

129 

130def _query_response_to_result( 

131 response_pb, 

132) -> List[AggregationResult]: 

133 results = [ 

134 AggregationResult( 

135 alias=key, 

136 value=response_pb.result.aggregate_fields[key].integer_value 

137 or response_pb.result.aggregate_fields[key].double_value, 

138 read_time=response_pb.read_time, 

139 ) 

140 for key in response_pb.result.aggregate_fields.pb.keys() 

141 ] 

142 

143 return results 

144 

145 

146class BaseAggregationQuery(ABC): 

147 """Represents an aggregation query to the Firestore API.""" 

148 

149 def __init__(self, nested_query, alias: str | None = None) -> None: 

150 self._nested_query = nested_query 

151 self._alias = alias 

152 self._collection_ref = nested_query._parent 

153 self._aggregations: List[BaseAggregation] = [] 

154 

155 @property 

156 def _client(self): 

157 return self._collection_ref._client 

158 

159 def count(self, alias: str | None = None): 

160 """ 

161 Adds a count over the nested query 

162 """ 

163 count_aggregation = CountAggregation(alias=alias) 

164 self._aggregations.append(count_aggregation) 

165 return self 

166 

167 def sum(self, field_ref: str | FieldPath, alias: str | None = None): 

168 """ 

169 Adds a sum over the nested query 

170 """ 

171 sum_aggregation = SumAggregation(field_ref, alias=alias) 

172 self._aggregations.append(sum_aggregation) 

173 return self 

174 

175 def avg(self, field_ref: str | FieldPath, alias: str | None = None): 

176 """ 

177 Adds an avg over the nested query 

178 """ 

179 avg_aggregation = AvgAggregation(field_ref, alias=alias) 

180 self._aggregations.append(avg_aggregation) 

181 return self 

182 

183 def add_aggregation(self, aggregation: BaseAggregation) -> None: 

184 """ 

185 Adds an aggregation operation to the nested query 

186 

187 :type aggregation: :class:`google.cloud.firestore_v1.aggregation.BaseAggregation` 

188 :param aggregation: An aggregation operation, e.g. a CountAggregation 

189 """ 

190 self._aggregations.append(aggregation) 

191 

192 def add_aggregations(self, aggregations: List[BaseAggregation]) -> None: 

193 """ 

194 Adds a list of aggregations to the nested query 

195 

196 :type aggregations: list 

197 :param aggregations: a list of aggregation operations 

198 """ 

199 self._aggregations.extend(aggregations) 

200 

201 def _to_protobuf(self) -> StructuredAggregationQuery: 

202 pb = StructuredAggregationQuery() 

203 pb.structured_query = self._nested_query._to_protobuf() 

204 

205 for aggregation in self._aggregations: 

206 aggregation_pb = aggregation._to_protobuf() 

207 pb.aggregations.append(aggregation_pb) 

208 return pb 

209 

210 def _prep_stream( 

211 self, 

212 transaction=None, 

213 retry: Union[retries.Retry, retries.AsyncRetry, None, object] = None, 

214 timeout: float | None = None, 

215 explain_options: Optional[ExplainOptions] = None, 

216 read_time: Optional[datetime.datetime] = None, 

217 ) -> Tuple[dict, dict]: 

218 parent_path, expected_prefix = self._collection_ref._parent_info() 

219 request = { 

220 "parent": parent_path, 

221 "structured_aggregation_query": self._to_protobuf(), 

222 "transaction": _helpers.get_transaction_id(transaction), 

223 } 

224 if explain_options: 

225 request["explain_options"] = explain_options._to_dict() 

226 if read_time is not None: 

227 request["read_time"] = read_time 

228 kwargs = _helpers.make_retry_timeout_kwargs(retry, timeout) 

229 

230 return request, kwargs 

231 

232 @abc.abstractmethod 

233 def get( 

234 self, 

235 transaction=None, 

236 retry: Union[ 

237 retries.Retry, retries.AsyncRetry, None, object 

238 ] = gapic_v1.method.DEFAULT, 

239 timeout: float | None = None, 

240 *, 

241 explain_options: Optional[ExplainOptions] = None, 

242 read_time: Optional[datetime.datetime] = None, 

243 ) -> ( 

244 QueryResultsList[AggregationResult] 

245 | Coroutine[Any, Any, List[List[AggregationResult]]] 

246 ): 

247 """Runs the aggregation query. 

248 

249 This sends a ``RunAggregationQuery`` RPC and returns a list of 

250 aggregation results in the stream of ``RunAggregationQueryResponse`` 

251 messages. 

252 

253 Args: 

254 transaction 

255 (Optional[:class:`~google.cloud.firestore_v1.transaction.Transaction`]): 

256 An existing transaction that this query will run in. 

257 If a ``transaction`` is used and it already has write operations 

258 added, this method cannot be used (i.e. read-after-write is not 

259 allowed). 

260 retry (google.api_core.retry.Retry): Designation of what errors, if any, 

261 should be retried. Defaults to a system-specified policy. 

262 timeout (float): The timeout for this request. Defaults to a 

263 system-specified value. 

264 explain_options 

265 (Optional[:class:`~google.cloud.firestore_v1.query_profile.ExplainOptions`]): 

266 Options to enable query profiling for this query. When set, 

267 explain_metrics will be available on the returned generator. 

268 read_time (Optional[datetime.datetime]): If set, reads documents as they were at the given 

269 time. This must be a timestamp within the past one hour, or if Point-in-Time Recovery 

270 is enabled, can additionally be a whole minute timestamp within the past 7 days. If no 

271 timezone is specified in the :class:`datetime.datetime` object, it is assumed to be UTC. 

272 

273 Returns: 

274 (QueryResultsList[List[AggregationResult]] | Coroutine[Any, Any, List[List[AggregationResult]]]): 

275 The aggregation query results. 

276 """ 

277 

278 @abc.abstractmethod 

279 def stream( 

280 self, 

281 transaction: Optional[transaction.Transaction] = None, 

282 retry: retries.Retry 

283 | retries.AsyncRetry 

284 | object 

285 | None = gapic_v1.method.DEFAULT, 

286 timeout: Optional[float] = None, 

287 *, 

288 explain_options: Optional[ExplainOptions] = None, 

289 read_time: Optional[datetime.datetime] = None, 

290 ) -> ( 

291 StreamGenerator[List[AggregationResult]] 

292 | AsyncStreamGenerator[List[AggregationResult]] 

293 ): 

294 """Runs the aggregation query. 

295 

296 This sends a``RunAggregationQuery`` RPC and returns a generator in the stream of ``RunAggregationQueryResponse`` messages. 

297 

298 Args: 

299 transaction 

300 (Optional[:class:`~google.cloud.firestore_v1.transaction.Transaction`]): 

301 An existing transaction that this query will run in. 

302 retry (Optional[google.api_core.retry.Retry]): Designation of what 

303 errors, if any, should be retried. Defaults to a 

304 system-specified policy. 

305 timeout (Optinal[float]): The timeout for this request. Defaults 

306 to a system-specified value. 

307 explain_options 

308 (Optional[:class:`~google.cloud.firestore_v1.query_profile.ExplainOptions`]): 

309 Options to enable query profiling for this query. When set, 

310 explain_metrics will be available on the returned generator. 

311 read_time (Optional[datetime.datetime]): If set, reads documents as they were at the given 

312 time. This must be a timestamp within the past one hour, or if Point-in-Time Recovery 

313 is enabled, can additionally be a whole minute timestamp within the past 7 days. If no 

314 timezone is specified in the :class:`datetime.datetime` object, it is assumed to be UTC. 

315 

316 Returns: 

317 StreamGenerator[List[AggregationResult]] | AsyncStreamGenerator[List[AggregationResult]]: 

318 A generator of the query results. 

319 """