Coverage for /pythoncovmergedfiles/medio/medio/usr/local/lib/python3.11/site-packages/google/cloud/firestore_v1/base_aggregation.py: 43%

Shortcuts on this page

r m x   toggle line displays

j k   next/prev highlighted chunk

0   (zero) top of page

1   (one) first highlighted chunk

127 statements  

1# Copyright 2023 Google LLC All rights reserved. 

2# 

3# Licensed under the Apache License, Version 2.0 (the "License"); 

4# you may not use this file except in compliance with the License. 

5# You may obtain a copy of the License at 

6# 

7# http://www.apache.org/licenses/LICENSE-2.0 

8# 

9# Unless required by applicable law or agreed to in writing, software 

10# distributed under the License is distributed on an "AS IS" BASIS, 

11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 

12# See the License for the specific language governing permissions and 

13# limitations under the License. 

14 

15"""Classes for representing aggregation queries for the Google Cloud Firestore API. 

16 

17A :class:`~google.cloud.firestore_v1.aggregation.AggregationQuery` can be created directly from 

18a :class:`~google.cloud.firestore_v1.collection.Collection` and that can be 

19a more common way to create an aggregation query than direct usage of the constructor. 

20""" 

21from __future__ import annotations 

22 

23import abc 

24import itertools 

25 

26from abc import ABC 

27from typing import TYPE_CHECKING, Any, Coroutine, List, Optional, Tuple, Union, Iterable 

28 

29from google.api_core import gapic_v1 

30from google.api_core import retry as retries 

31 

32from google.cloud.firestore_v1 import _helpers 

33from google.cloud.firestore_v1.field_path import FieldPath 

34from google.cloud.firestore_v1.types import ( 

35 StructuredAggregationQuery, 

36) 

37from google.cloud.firestore_v1.pipeline_expressions import AggregateFunction 

38from google.cloud.firestore_v1.pipeline_expressions import Count 

39from google.cloud.firestore_v1.pipeline_expressions import AliasedExpression 

40from google.cloud.firestore_v1.pipeline_expressions import Field 

41 

42# Types needed only for Type Hints 

43if TYPE_CHECKING: # pragma: NO COVER 

44 from google.cloud.firestore_v1 import transaction 

45 from google.cloud.firestore_v1.async_stream_generator import AsyncStreamGenerator 

46 from google.cloud.firestore_v1.query_profile import ExplainOptions 

47 from google.cloud.firestore_v1.query_results import QueryResultsList 

48 from google.cloud.firestore_v1.stream_generator import ( 

49 StreamGenerator, 

50 ) 

51 from google.cloud.firestore_v1.pipeline_source import PipelineSource 

52 

53 import datetime 

54 

55 

56class AggregationResult(object): 

57 """ 

58 A class representing result from Aggregation Query 

59 :type alias: str 

60 :param alias: The alias for the aggregation. 

61 :type value: int 

62 :param value: The resulting value from the aggregation. 

63 :type read_time: 

64 :param value: The resulting read_time 

65 """ 

66 

67 def __init__(self, alias: str, value: float, read_time=None): 

68 self.alias = alias 

69 self.value = value 

70 self.read_time = read_time 

71 

72 def __repr__(self): 

73 return f"<Aggregation alias={self.alias}, value={self.value}, readtime={self.read_time}>" 

74 

75 def _to_dict(self): 

76 return {self.alias: self.value} 

77 

78 

79class BaseAggregation(ABC): 

80 def __init__(self, alias: str | None = None): 

81 self.alias = alias 

82 

83 @abc.abstractmethod 

84 def _to_protobuf(self): 

85 """Convert this instance to the protobuf representation""" 

86 

87 @abc.abstractmethod 

88 def _to_pipeline_expr( 

89 self, autoindexer: Iterable[int] 

90 ) -> AliasedExpression[AggregateFunction]: 

91 """ 

92 Convert this instance to a pipeline expression for use with pipeline.aggregate() 

93 

94 Args: 

95 autoindexer: If an alias isn't supplied, one should be created with the format "field_n" 

96 The autoindexer is an iterable that provides the `n` value to use for each expression 

97 """ 

98 

99 def _pipeline_alias(self, autoindexer): 

100 """ 

101 Helper to build the alias for the pipeline expression 

102 """ 

103 if self.alias is not None: 

104 return self.alias 

105 else: 

106 return f"field_{next(autoindexer)}" 

107 

108 

109class CountAggregation(BaseAggregation): 

110 def __init__(self, alias: str | None = None): 

111 super(CountAggregation, self).__init__(alias=alias) 

112 

113 def _to_protobuf(self): 

114 """Convert this instance to the protobuf representation""" 

115 aggregation_pb = StructuredAggregationQuery.Aggregation() 

116 if self.alias: 

117 aggregation_pb.alias = self.alias 

118 aggregation_pb.count = StructuredAggregationQuery.Aggregation.Count() 

119 return aggregation_pb 

120 

121 def _to_pipeline_expr(self, autoindexer: Iterable[int]): 

122 return Count().as_(self._pipeline_alias(autoindexer)) 

123 

124 

125class SumAggregation(BaseAggregation): 

126 def __init__(self, field_ref: str | FieldPath, alias: str | None = None): 

127 # convert field path to string if needed 

128 field_str = ( 

129 field_ref.to_api_repr() if isinstance(field_ref, FieldPath) else field_ref 

130 ) 

131 self.field_ref: str = field_str 

132 super(SumAggregation, self).__init__(alias=alias) 

133 

134 def _to_protobuf(self): 

135 """Convert this instance to the protobuf representation""" 

136 aggregation_pb = StructuredAggregationQuery.Aggregation() 

137 if self.alias: 

138 aggregation_pb.alias = self.alias 

139 aggregation_pb.sum = StructuredAggregationQuery.Aggregation.Sum() 

140 aggregation_pb.sum.field.field_path = self.field_ref 

141 return aggregation_pb 

142 

143 def _to_pipeline_expr(self, autoindexer: Iterable[int]): 

144 return Field.of(self.field_ref).sum().as_(self._pipeline_alias(autoindexer)) 

145 

146 

147class AvgAggregation(BaseAggregation): 

148 def __init__(self, field_ref: str | FieldPath, alias: str | None = None): 

149 # convert field path to string if needed 

150 field_str = ( 

151 field_ref.to_api_repr() if isinstance(field_ref, FieldPath) else field_ref 

152 ) 

153 self.field_ref: str = field_str 

154 super(AvgAggregation, self).__init__(alias=alias) 

155 

156 def _to_protobuf(self): 

157 """Convert this instance to the protobuf representation""" 

158 aggregation_pb = StructuredAggregationQuery.Aggregation() 

159 if self.alias: 

160 aggregation_pb.alias = self.alias 

161 aggregation_pb.avg = StructuredAggregationQuery.Aggregation.Avg() 

162 aggregation_pb.avg.field.field_path = self.field_ref 

163 return aggregation_pb 

164 

165 def _to_pipeline_expr(self, autoindexer: Iterable[int]): 

166 return Field.of(self.field_ref).average().as_(self._pipeline_alias(autoindexer)) 

167 

168 

169def _query_response_to_result( 

170 response_pb, 

171) -> List[AggregationResult]: 

172 results = [ 

173 AggregationResult( 

174 alias=key, 

175 value=response_pb.result.aggregate_fields[key].integer_value 

176 or response_pb.result.aggregate_fields[key].double_value, 

177 read_time=response_pb.read_time, 

178 ) 

179 for key in response_pb.result.aggregate_fields.pb.keys() 

180 ] 

181 

182 return results 

183 

184 

185class BaseAggregationQuery(ABC): 

186 """Represents an aggregation query to the Firestore API.""" 

187 

188 def __init__(self, nested_query, alias: str | None = None) -> None: 

189 self._nested_query = nested_query 

190 self._alias = alias 

191 self._collection_ref = nested_query._parent 

192 self._aggregations: List[BaseAggregation] = [] 

193 

194 @property 

195 def _client(self): 

196 return self._collection_ref._client 

197 

198 def count(self, alias: str | None = None): 

199 """ 

200 Adds a count over the nested query 

201 """ 

202 count_aggregation = CountAggregation(alias=alias) 

203 self._aggregations.append(count_aggregation) 

204 return self 

205 

206 def sum(self, field_ref: str | FieldPath, alias: str | None = None): 

207 """ 

208 Adds a sum over the nested query 

209 """ 

210 sum_aggregation = SumAggregation(field_ref, alias=alias) 

211 self._aggregations.append(sum_aggregation) 

212 return self 

213 

214 def avg(self, field_ref: str | FieldPath, alias: str | None = None): 

215 """ 

216 Adds an avg over the nested query 

217 """ 

218 avg_aggregation = AvgAggregation(field_ref, alias=alias) 

219 self._aggregations.append(avg_aggregation) 

220 return self 

221 

222 def add_aggregation(self, aggregation: BaseAggregation) -> None: 

223 """ 

224 Adds an aggregation operation to the nested query 

225 

226 :type aggregation: :class:`google.cloud.firestore_v1.aggregation.BaseAggregation` 

227 :param aggregation: An aggregation operation, e.g. a CountAggregation 

228 """ 

229 self._aggregations.append(aggregation) 

230 

231 def add_aggregations(self, aggregations: List[BaseAggregation]) -> None: 

232 """ 

233 Adds a list of aggregations to the nested query 

234 

235 :type aggregations: list 

236 :param aggregations: a list of aggregation operations 

237 """ 

238 self._aggregations.extend(aggregations) 

239 

240 def _to_protobuf(self) -> StructuredAggregationQuery: 

241 pb = StructuredAggregationQuery() 

242 pb.structured_query = self._nested_query._to_protobuf() 

243 

244 for aggregation in self._aggregations: 

245 aggregation_pb = aggregation._to_protobuf() 

246 pb.aggregations.append(aggregation_pb) 

247 return pb 

248 

249 def _prep_stream( 

250 self, 

251 transaction=None, 

252 retry: Union[retries.Retry, retries.AsyncRetry, None, object] = None, 

253 timeout: float | None = None, 

254 explain_options: Optional[ExplainOptions] = None, 

255 read_time: Optional[datetime.datetime] = None, 

256 ) -> Tuple[dict, dict]: 

257 parent_path, expected_prefix = self._collection_ref._parent_info() 

258 request = { 

259 "parent": parent_path, 

260 "structured_aggregation_query": self._to_protobuf(), 

261 "transaction": _helpers.get_transaction_id(transaction), 

262 } 

263 if explain_options: 

264 request["explain_options"] = explain_options._to_dict() 

265 if read_time is not None: 

266 request["read_time"] = read_time 

267 kwargs = _helpers.make_retry_timeout_kwargs(retry, timeout) 

268 

269 return request, kwargs 

270 

271 @abc.abstractmethod 

272 def get( 

273 self, 

274 transaction=None, 

275 retry: Union[ 

276 retries.Retry, retries.AsyncRetry, None, object 

277 ] = gapic_v1.method.DEFAULT, 

278 timeout: float | None = None, 

279 *, 

280 explain_options: Optional[ExplainOptions] = None, 

281 read_time: Optional[datetime.datetime] = None, 

282 ) -> ( 

283 QueryResultsList[AggregationResult] 

284 | Coroutine[Any, Any, List[List[AggregationResult]]] 

285 ): 

286 """Runs the aggregation query. 

287 

288 This sends a ``RunAggregationQuery`` RPC and returns a list of 

289 aggregation results in the stream of ``RunAggregationQueryResponse`` 

290 messages. 

291 

292 Args: 

293 transaction 

294 (Optional[:class:`~google.cloud.firestore_v1.transaction.Transaction`]): 

295 An existing transaction that this query will run in. 

296 If a ``transaction`` is used and it already has write operations 

297 added, this method cannot be used (i.e. read-after-write is not 

298 allowed). 

299 retry (google.api_core.retry.Retry): Designation of what errors, if any, 

300 should be retried. Defaults to a system-specified policy. 

301 timeout (float): The timeout for this request. Defaults to a 

302 system-specified value. 

303 explain_options 

304 (Optional[:class:`~google.cloud.firestore_v1.query_profile.ExplainOptions`]): 

305 Options to enable query profiling for this query. When set, 

306 explain_metrics will be available on the returned generator. 

307 read_time (Optional[datetime.datetime]): If set, reads documents as they were at the given 

308 time. This must be a timestamp within the past one hour, or if Point-in-Time Recovery 

309 is enabled, can additionally be a whole minute timestamp within the past 7 days. If no 

310 timezone is specified in the :class:`datetime.datetime` object, it is assumed to be UTC. 

311 

312 Returns: 

313 (QueryResultsList[List[AggregationResult]] | Coroutine[Any, Any, List[List[AggregationResult]]]): 

314 The aggregation query results. 

315 """ 

316 

317 @abc.abstractmethod 

318 def stream( 

319 self, 

320 transaction: Optional[transaction.Transaction] = None, 

321 retry: retries.Retry 

322 | retries.AsyncRetry 

323 | object 

324 | None = gapic_v1.method.DEFAULT, 

325 timeout: Optional[float] = None, 

326 *, 

327 explain_options: Optional[ExplainOptions] = None, 

328 read_time: Optional[datetime.datetime] = None, 

329 ) -> ( 

330 StreamGenerator[List[AggregationResult]] 

331 | AsyncStreamGenerator[List[AggregationResult]] 

332 ): 

333 """Runs the aggregation query. 

334 

335 This sends a``RunAggregationQuery`` RPC and returns a generator in the stream of ``RunAggregationQueryResponse`` messages. 

336 

337 Args: 

338 transaction 

339 (Optional[:class:`~google.cloud.firestore_v1.transaction.Transaction`]): 

340 An existing transaction that this query will run in. 

341 retry (Optional[google.api_core.retry.Retry]): Designation of what 

342 errors, if any, should be retried. Defaults to a 

343 system-specified policy. 

344 timeout (Optinal[float]): The timeout for this request. Defaults 

345 to a system-specified value. 

346 explain_options 

347 (Optional[:class:`~google.cloud.firestore_v1.query_profile.ExplainOptions`]): 

348 Options to enable query profiling for this query. When set, 

349 explain_metrics will be available on the returned generator. 

350 read_time (Optional[datetime.datetime]): If set, reads documents as they were at the given 

351 time. This must be a timestamp within the past one hour, or if Point-in-Time Recovery 

352 is enabled, can additionally be a whole minute timestamp within the past 7 days. If no 

353 timezone is specified in the :class:`datetime.datetime` object, it is assumed to be UTC. 

354 

355 Returns: 

356 StreamGenerator[List[AggregationResult]] | AsyncStreamGenerator[List[AggregationResult]]: 

357 A generator of the query results. 

358 """ 

359 

360 def _build_pipeline(self, source: "PipelineSource"): 

361 """ 

362 Convert this query into a Pipeline 

363 

364 Args: 

365 source: the PipelineSource to build the pipeline off of 

366 Returns: 

367 a Pipeline representing the query 

368 """ 

369 # use autoindexer to keep track of which field number to use for un-aliased fields 

370 autoindexer = itertools.count(start=1) 

371 exprs = [a._to_pipeline_expr(autoindexer) for a in self._aggregations] 

372 return self._nested_query._build_pipeline(source).aggregate(*exprs)