1# Copyright 2025 Google LLC
2#
3# Licensed under the Apache License, Version 2.0 (the "License");
4# you may not use this file except in compliance with the License.
5# You may obtain a copy of the License at
6#
7# http://www.apache.org/licenses/LICENSE-2.0
8#
9# Unless required by applicable law or agreed to in writing, software
10# distributed under the License is distributed on an "AS IS" BASIS,
11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12# See the License for the specific language governing permissions and
13# limitations under the License.
14"""
15.. warning::
16 **Preview API**: Firestore Pipelines is currently in preview and is
17 subject to potential breaking changes in future releases.
18"""
19
20from __future__ import annotations
21from typing import (
22 Any,
23 AsyncIterable,
24 AsyncIterator,
25 Iterable,
26 Iterator,
27 List,
28 Generic,
29 MutableMapping,
30 Type,
31 TypeVar,
32 TYPE_CHECKING,
33)
34from google.cloud.firestore_v1 import _helpers
35from google.cloud.firestore_v1.field_path import get_nested_value
36from google.cloud.firestore_v1.field_path import FieldPath
37from google.cloud.firestore_v1.query_profile import ExplainStats
38from google.cloud.firestore_v1.query_profile import QueryExplainError
39from google.cloud.firestore_v1.types.firestore import ExecutePipelineRequest
40from google.cloud.firestore_v1.types.document import Value
41
42if TYPE_CHECKING: # pragma: NO COVER
43 import datetime
44 from google.cloud.firestore_v1.async_client import AsyncClient
45 from google.cloud.firestore_v1.client import Client
46 from google.cloud.firestore_v1.base_client import BaseClient
47 from google.cloud.firestore_v1.async_transaction import AsyncTransaction
48 from google.cloud.firestore_v1.transaction import Transaction
49 from google.cloud.firestore_v1.base_document import BaseDocumentReference
50 from google.protobuf.timestamp_pb2 import Timestamp
51 from google.cloud.firestore_v1.types.firestore import ExecutePipelineResponse
52 from google.cloud.firestore_v1.types.document import Value as ValueProto
53 from google.cloud.firestore_v1.vector import Vector
54 from google.cloud.firestore_v1.async_pipeline import AsyncPipeline
55 from google.cloud.firestore_v1.base_pipeline import _BasePipeline
56 from google.cloud.firestore_v1.pipeline import Pipeline
57 from google.cloud.firestore_v1.pipeline_expressions import Constant
58 from google.cloud.firestore_v1.query_profile import PipelineExplainOptions
59
60
61class PipelineResult:
62 """
63 Contains data read from a Firestore Pipeline. The data can be extracted with
64 the `data()` or `get()` methods.
65
66 If the PipelineResult represents a non-document result `ref` may be `None`.
67 """
68
69 def __init__(
70 self,
71 client: BaseClient,
72 fields_pb: MutableMapping[str, ValueProto],
73 ref: BaseDocumentReference | None = None,
74 execution_time: Timestamp | None = None,
75 create_time: Timestamp | None = None,
76 update_time: Timestamp | None = None,
77 ):
78 """
79 PipelineResult should be returned from `pipeline.execute()`, not constructed manually.
80
81 Args:
82 client: The Firestore client instance.
83 fields_pb: A map of field names to their protobuf Value representations.
84 ref: The DocumentReference or AsyncDocumentReference if this result corresponds to a document.
85 execution_time: The time at which the pipeline execution producing this result occurred.
86 create_time: The creation time of the document, if applicable.
87 update_time: The last update time of the document, if applicable.
88 """
89 self._client = client
90 self._fields_pb = fields_pb
91 self._ref = ref
92 self._execution_time = execution_time
93 self._create_time = create_time
94 self._update_time = update_time
95
96 def __repr__(self):
97 return f"{type(self).__name__}(data={self.data()})"
98
99 @property
100 def ref(self) -> BaseDocumentReference | None:
101 """
102 The `BaseDocumentReference` if this result represents a document, else `None`.
103 """
104 return self._ref
105
106 @property
107 def id(self) -> str | None:
108 """The ID of the document if this result represents a document, else `None`."""
109 return self._ref.id if self._ref else None
110
111 @property
112 def create_time(self) -> Timestamp | None:
113 """The creation time of the document. `None` if not applicable."""
114 return self._create_time
115
116 @property
117 def update_time(self) -> Timestamp | None:
118 """The last update time of the document. `None` if not applicable."""
119 return self._update_time
120
121 @property
122 def execution_time(self) -> Timestamp:
123 """
124 The time at which the pipeline producing this result was executed.
125
126 Raise:
127 ValueError: if not set
128 """
129 if self._execution_time is None:
130 raise ValueError("'execution_time' is expected to exist, but it is None.")
131 return self._execution_time
132
133 def __eq__(self, other: object) -> bool:
134 """
135 Compares this `PipelineResult` to another object for equality.
136
137 Two `PipelineResult` instances are considered equal if their document
138 references (if any) are equal and their underlying field data
139 (protobuf representation) is identical.
140 """
141 if not isinstance(other, PipelineResult):
142 return NotImplemented
143 return (self._ref == other._ref) and (self._fields_pb == other._fields_pb)
144
145 def data(self) -> dict | "Vector" | None:
146 """
147 Retrieves all fields in the result.
148
149 Returns:
150 The data in dictionary format, or `None` if the document doesn't exist.
151 """
152 if self._fields_pb is None:
153 return None
154
155 return _helpers.decode_dict(self._fields_pb, self._client)
156
157 def get(self, field_path: str | FieldPath) -> Any:
158 """
159 Retrieves the field specified by `field_path`.
160
161 Args:
162 field_path: The field path (e.g. 'foo' or 'foo.bar') to a specific field.
163
164 Returns:
165 The data at the specified field location, decoded to Python types.
166 """
167 str_path = (
168 field_path if isinstance(field_path, str) else field_path.to_api_repr()
169 )
170 value = get_nested_value(str_path, self._fields_pb)
171 return _helpers.decode_value(value, self._client)
172
173
174T = TypeVar("T", bound=PipelineResult)
175
176
177class _PipelineResultContainer(Generic[T]):
178 """Base class to hold shared attributes for PipelineSnapshot and PipelineStream"""
179
180 def __init__(
181 self,
182 return_type: Type[T],
183 pipeline: Pipeline | AsyncPipeline,
184 transaction: Transaction | AsyncTransaction | None,
185 read_time: datetime.datetime | None,
186 explain_options: PipelineExplainOptions | None,
187 additional_options: dict[str, Constant | Value],
188 ):
189 # public
190 self.transaction = transaction
191 self.pipeline: _BasePipeline = pipeline
192 self.execution_time: Timestamp | None = None
193 # private
194 self._client: Client | AsyncClient = pipeline._client
195 self._started: bool = False
196 self._read_time = read_time
197 self._explain_stats: ExplainStats | None = None
198 self._explain_options: PipelineExplainOptions | None = explain_options
199 self._return_type = return_type
200 self._additonal_options = {
201 k: v if isinstance(v, Value) else v._to_pb()
202 for k, v in additional_options.items()
203 }
204
205 @property
206 def explain_stats(self) -> ExplainStats:
207 if self._explain_stats is not None:
208 return self._explain_stats
209 elif self._explain_options is None:
210 raise QueryExplainError("explain_options not set on query.")
211 elif not self._started:
212 raise QueryExplainError(
213 "explain_stats not available until query is complete"
214 )
215 else:
216 raise QueryExplainError("explain_stats not found")
217
218 def _build_request(self) -> ExecutePipelineRequest:
219 """
220 shared logic for creating an ExecutePipelineRequest
221 """
222 database_name = (
223 f"projects/{self._client.project}/databases/{self._client._database}"
224 )
225 transaction_id = (
226 _helpers.get_transaction_id(self.transaction, read_operation=False)
227 if self.transaction is not None
228 else None
229 )
230 options = {}
231 if self._explain_options:
232 options["explain_options"] = self._explain_options._to_value()
233 if self._additonal_options:
234 options.update(self._additonal_options)
235 request = ExecutePipelineRequest(
236 database=database_name,
237 transaction=transaction_id,
238 structured_pipeline=self.pipeline._to_pb(**options),
239 read_time=self._read_time,
240 )
241 return request
242
243 def _process_response(self, response: ExecutePipelineResponse) -> Iterable[T]:
244 """Shared logic for processing an individual response from a stream"""
245 if response.explain_stats:
246 self._explain_stats = ExplainStats(response.explain_stats)
247 execution_time = response._pb.execution_time
248 if execution_time and not self.execution_time:
249 self.execution_time = execution_time
250 for doc in response.results:
251 ref = self._client.document(doc.name) if doc.name else None
252 yield self._return_type(
253 self._client,
254 doc.fields,
255 ref,
256 execution_time,
257 doc._pb.create_time if doc.create_time else None,
258 doc._pb.update_time if doc.update_time else None,
259 )
260
261
262class PipelineSnapshot(_PipelineResultContainer[T], List[T]):
263 """
264 A list type that holds the result of a pipeline.execute() operation, along with related metadata
265 """
266
267 def __init__(self, results_list: List[T], source: _PipelineResultContainer[T]):
268 self.__dict__.update(source.__dict__.copy())
269 list.__init__(self, results_list)
270 # snapshots are always complete
271 self._started = True
272
273
274class PipelineStream(_PipelineResultContainer[T], Iterable[T]):
275 """
276 An iterable stream representing the result of a pipeline.stream() operation, along with related metadata
277 """
278
279 def __iter__(self) -> Iterator[T]:
280 if self._started:
281 raise RuntimeError(f"{self.__class__.__name__} can only be iterated once")
282 self._started = True
283 request = self._build_request()
284 stream = self._client._firestore_api.execute_pipeline(request)
285 for response in stream:
286 yield from self._process_response(response)
287
288
289class AsyncPipelineStream(_PipelineResultContainer[T], AsyncIterable[T]):
290 """
291 An iterable stream representing the result of an async pipeline.stream() operation, along with related metadata
292 """
293
294 async def __aiter__(self) -> AsyncIterator[T]:
295 if self._started:
296 raise RuntimeError(f"{self.__class__.__name__} can only be iterated once")
297 self._started = True
298 request = self._build_request()
299 stream = await self._client._firestore_api.execute_pipeline(request)
300 async for response in stream:
301 for result in self._process_response(response):
302 yield result