1# -*- coding: utf-8 -*-
2#
3# Copyright 2019 Google LLC
4#
5# Licensed under the Apache License, Version 2.0 (the "License");
6# you may not use this file except in compliance with the License.
7# You may obtain a copy of the License at
8#
9# https://www.apache.org/licenses/LICENSE-2.0
10#
11# Unless required by applicable law or agreed to in writing, software
12# distributed under the License is distributed on an "AS IS" BASIS,
13# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14# See the License for the specific language governing permissions and
15# limitations under the License.
16
17"""Define resources for the BigQuery ML Models API."""
18
19from __future__ import annotations # type: ignore
20
21import copy
22import datetime
23import typing
24from typing import Any, Dict, Optional, Sequence, Union
25
26import google.cloud._helpers # type: ignore
27from google.cloud.bigquery import _helpers
28from google.cloud.bigquery import standard_sql
29from google.cloud.bigquery.encryption_configuration import EncryptionConfiguration
30
31
32class Model:
33 """Model represents a machine learning model resource.
34
35 See
36 https://cloud.google.com/bigquery/docs/reference/rest/v2/models
37
38 Args:
39 model_ref:
40 A pointer to a model. If ``model_ref`` is a string, it must
41 included a project ID, dataset ID, and model ID, each separated
42 by ``.``.
43 """
44
45 _PROPERTY_TO_API_FIELD = {
46 "expires": "expirationTime",
47 "friendly_name": "friendlyName",
48 # Even though it's not necessary for field mapping to map when the
49 # property name equals the resource name, we add these here so that we
50 # have an exhaustive list of all mutable properties.
51 "labels": "labels",
52 "description": "description",
53 "encryption_configuration": "encryptionConfiguration",
54 }
55
56 def __init__(self, model_ref: Union["ModelReference", str, None]):
57 # Use _properties on read-write properties to match the REST API
58 # semantics. The BigQuery API makes a distinction between an unset
59 # value, a null value, and a default value (0 or ""), but the protocol
60 # buffer classes do not.
61 self._properties: Dict[str, Any] = {}
62
63 if isinstance(model_ref, str):
64 model_ref = ModelReference.from_string(model_ref)
65
66 if model_ref:
67 self._properties["modelReference"] = model_ref.to_api_repr()
68
69 @property
70 def reference(self) -> Optional["ModelReference"]:
71 """A model reference pointing to this model.
72
73 Read-only.
74 """
75 resource = self._properties.get("modelReference")
76 if resource is None:
77 return None
78 else:
79 return ModelReference.from_api_repr(resource)
80
81 @property
82 def project(self) -> Optional[str]:
83 """Project bound to the model."""
84 ref = self.reference
85 return ref.project if ref is not None else None
86
87 @property
88 def dataset_id(self) -> Optional[str]:
89 """ID of dataset containing the model."""
90 ref = self.reference
91 return ref.dataset_id if ref is not None else None
92
93 @property
94 def model_id(self) -> Optional[str]:
95 """The model ID."""
96 ref = self.reference
97 return ref.model_id if ref is not None else None
98
99 @property
100 def path(self) -> Optional[str]:
101 """URL path for the model's APIs."""
102 ref = self.reference
103 return ref.path if ref is not None else None
104
105 @property
106 def location(self) -> Optional[str]:
107 """The geographic location where the model resides.
108
109 This value is inherited from the dataset.
110
111 Read-only.
112 """
113 return typing.cast(Optional[str], self._properties.get("location"))
114
115 @property
116 def etag(self) -> Optional[str]:
117 """ETag for the model resource (:data:`None` until set from the server).
118
119 Read-only.
120 """
121 return typing.cast(Optional[str], self._properties.get("etag"))
122
123 @property
124 def created(self) -> Optional[datetime.datetime]:
125 """Datetime at which the model was created (:data:`None` until set from the server).
126
127 Read-only.
128 """
129 value = typing.cast(Optional[float], self._properties.get("creationTime"))
130 if value is None:
131 return None
132 else:
133 # value will be in milliseconds.
134 return google.cloud._helpers._datetime_from_microseconds(
135 1000.0 * float(value)
136 )
137
138 @property
139 def modified(self) -> Optional[datetime.datetime]:
140 """Datetime at which the model was last modified (:data:`None` until set from the server).
141
142 Read-only.
143 """
144 value = typing.cast(Optional[float], self._properties.get("lastModifiedTime"))
145 if value is None:
146 return None
147 else:
148 # value will be in milliseconds.
149 return google.cloud._helpers._datetime_from_microseconds(
150 1000.0 * float(value)
151 )
152
153 @property
154 def model_type(self) -> str:
155 """Type of the model resource.
156
157 Read-only.
158 """
159 return typing.cast(
160 str, self._properties.get("modelType", "MODEL_TYPE_UNSPECIFIED")
161 )
162
163 @property
164 def training_runs(self) -> Sequence[Dict[str, Any]]:
165 """Information for all training runs in increasing order of start time.
166
167 Dictionaries are in REST API format. See:
168 https://cloud.google.com/bigquery/docs/reference/rest/v2/models#trainingrun
169
170 Read-only.
171 """
172 return typing.cast(
173 Sequence[Dict[str, Any]], self._properties.get("trainingRuns", [])
174 )
175
176 @property
177 def feature_columns(self) -> Sequence[standard_sql.StandardSqlField]:
178 """Input feature columns that were used to train this model.
179
180 Read-only.
181 """
182 resource: Sequence[Dict[str, Any]] = typing.cast(
183 Sequence[Dict[str, Any]], self._properties.get("featureColumns", [])
184 )
185 return [
186 standard_sql.StandardSqlField.from_api_repr(column) for column in resource
187 ]
188
189 @property
190 def transform_columns(self) -> Sequence[TransformColumn]:
191 """The input feature columns that were used to train this model.
192 The output transform columns used to train this model.
193
194 See REST API:
195 https://cloud.google.com/bigquery/docs/reference/rest/v2/models#transformcolumn
196
197 Read-only.
198 """
199 resources: Sequence[Dict[str, Any]] = typing.cast(
200 Sequence[Dict[str, Any]], self._properties.get("transformColumns", [])
201 )
202 return [TransformColumn(resource) for resource in resources]
203
204 @property
205 def label_columns(self) -> Sequence[standard_sql.StandardSqlField]:
206 """Label columns that were used to train this model.
207
208 The output of the model will have a ``predicted_`` prefix to these columns.
209
210 Read-only.
211 """
212 resource: Sequence[Dict[str, Any]] = typing.cast(
213 Sequence[Dict[str, Any]], self._properties.get("labelColumns", [])
214 )
215 return [
216 standard_sql.StandardSqlField.from_api_repr(column) for column in resource
217 ]
218
219 @property
220 def best_trial_id(self) -> Optional[int]:
221 """The best trial_id across all training runs.
222
223 .. deprecated::
224 This property is deprecated!
225
226 Read-only.
227 """
228 value = typing.cast(Optional[int], self._properties.get("bestTrialId"))
229 if value is not None:
230 value = int(value)
231 return value
232
233 @property
234 def expires(self) -> Optional[datetime.datetime]:
235 """The datetime when this model expires.
236
237 If not present, the model will persist indefinitely. Expired models will be
238 deleted and their storage reclaimed.
239 """
240 value = typing.cast(Optional[float], self._properties.get("expirationTime"))
241 if value is None:
242 return None
243 else:
244 # value will be in milliseconds.
245 return google.cloud._helpers._datetime_from_microseconds(
246 1000.0 * float(value)
247 )
248
249 @expires.setter
250 def expires(self, value: Optional[datetime.datetime]):
251 if value is None:
252 value_to_store: Optional[str] = None
253 else:
254 value_to_store = str(google.cloud._helpers._millis_from_datetime(value))
255 # TODO: Consider using typing.TypedDict when only Python 3.8+ is supported.
256 self._properties["expirationTime"] = value_to_store # type: ignore
257
258 @property
259 def description(self) -> Optional[str]:
260 """Description of the model (defaults to :data:`None`)."""
261 return typing.cast(Optional[str], self._properties.get("description"))
262
263 @description.setter
264 def description(self, value: Optional[str]):
265 # TODO: Consider using typing.TypedDict when only Python 3.8+ is supported.
266 self._properties["description"] = value # type: ignore
267
268 @property
269 def friendly_name(self) -> Optional[str]:
270 """Title of the table (defaults to :data:`None`)."""
271 return typing.cast(Optional[str], self._properties.get("friendlyName"))
272
273 @friendly_name.setter
274 def friendly_name(self, value: Optional[str]):
275 # TODO: Consider using typing.TypedDict when only Python 3.8+ is supported.
276 self._properties["friendlyName"] = value # type: ignore
277
278 @property
279 def labels(self) -> Dict[str, str]:
280 """Labels for the table.
281
282 This method always returns a dict. To change a model's labels, modify the dict,
283 then call ``Client.update_model``. To delete a label, set its value to
284 :data:`None` before updating.
285 """
286 return self._properties.setdefault("labels", {})
287
288 @labels.setter
289 def labels(self, value: Optional[Dict[str, str]]):
290 if value is None:
291 value = {}
292 self._properties["labels"] = value
293
294 @property
295 def encryption_configuration(self) -> Optional[EncryptionConfiguration]:
296 """Custom encryption configuration for the model.
297
298 Custom encryption configuration (e.g., Cloud KMS keys) or :data:`None`
299 if using default encryption.
300
301 See `protecting data with Cloud KMS keys
302 <https://cloud.google.com/bigquery/docs/customer-managed-encryption>`_
303 in the BigQuery documentation.
304 """
305 prop = self._properties.get("encryptionConfiguration")
306 if prop:
307 prop = EncryptionConfiguration.from_api_repr(prop)
308 return typing.cast(Optional[EncryptionConfiguration], prop)
309
310 @encryption_configuration.setter
311 def encryption_configuration(self, value: Optional[EncryptionConfiguration]):
312 api_repr = value.to_api_repr() if value else value
313 self._properties["encryptionConfiguration"] = api_repr
314
315 @classmethod
316 def from_api_repr(cls, resource: Dict[str, Any]) -> "Model":
317 """Factory: construct a model resource given its API representation
318
319 Args:
320 resource:
321 Model resource representation from the API
322
323 Returns:
324 Model parsed from ``resource``.
325 """
326 this = cls(None)
327 resource = copy.deepcopy(resource)
328 this._properties = resource
329 return this
330
331 def _build_resource(self, filter_fields):
332 """Generate a resource for ``update``."""
333 return _helpers._build_resource_from_properties(self, filter_fields)
334
335 def __repr__(self):
336 return f"Model(reference={self.reference!r})"
337
338 def to_api_repr(self) -> Dict[str, Any]:
339 """Construct the API resource representation of this model.
340
341 Returns:
342 Model reference represented as an API resource
343 """
344 return copy.deepcopy(self._properties)
345
346
347class ModelReference:
348 """ModelReferences are pointers to models.
349
350 See
351 https://cloud.google.com/bigquery/docs/reference/rest/v2/models#modelreference
352 """
353
354 def __init__(self):
355 self._properties = {}
356
357 @property
358 def project(self):
359 """str: Project bound to the model"""
360 return self._properties.get("projectId")
361
362 @property
363 def dataset_id(self):
364 """str: ID of dataset containing the model."""
365 return self._properties.get("datasetId")
366
367 @property
368 def model_id(self):
369 """str: The model ID."""
370 return self._properties.get("modelId")
371
372 @property
373 def path(self) -> str:
374 """URL path for the model's APIs."""
375 return f"/projects/{self.project}/datasets/{self.dataset_id}/models/{self.model_id}"
376
377 @classmethod
378 def from_api_repr(cls, resource: Dict[str, Any]) -> "ModelReference":
379 """Factory: construct a model reference given its API representation.
380
381 Args:
382 resource:
383 Model reference representation returned from the API
384
385 Returns:
386 Model reference parsed from ``resource``.
387 """
388 ref = cls()
389 ref._properties = resource
390 return ref
391
392 @classmethod
393 def from_string(
394 cls, model_id: str, default_project: Optional[str] = None
395 ) -> "ModelReference":
396 """Construct a model reference from model ID string.
397
398 Args:
399 model_id:
400 A model ID in standard SQL format. If ``default_project``
401 is not specified, this must included a project ID, dataset
402 ID, and model ID, each separated by ``.``.
403 default_project:
404 The project ID to use when ``model_id`` does not include
405 a project ID.
406
407 Returns:
408 Model reference parsed from ``model_id``.
409
410 Raises:
411 ValueError:
412 If ``model_id`` is not a fully-qualified table ID in
413 standard SQL format.
414 """
415 proj, dset, model = _helpers._parse_3_part_id(
416 model_id, default_project=default_project, property_name="model_id"
417 )
418 return cls.from_api_repr(
419 {"projectId": proj, "datasetId": dset, "modelId": model}
420 )
421
422 def to_api_repr(self) -> Dict[str, Any]:
423 """Construct the API resource representation of this model reference.
424
425 Returns:
426 Model reference represented as an API resource.
427 """
428 return copy.deepcopy(self._properties)
429
430 def _key(self):
431 """Unique key for this model.
432
433 This is used for hashing a ModelReference.
434 """
435 return self.project, self.dataset_id, self.model_id
436
437 def __eq__(self, other):
438 if not isinstance(other, ModelReference):
439 return NotImplemented
440 return self._properties == other._properties
441
442 def __ne__(self, other):
443 return not self == other
444
445 def __hash__(self):
446 return hash(self._key())
447
448 def __repr__(self):
449 return "ModelReference(project_id='{}', dataset_id='{}', model_id='{}')".format(
450 self.project, self.dataset_id, self.model_id
451 )
452
453
454class TransformColumn:
455 """TransformColumn represents a transform column feature.
456
457 See
458 https://cloud.google.com/bigquery/docs/reference/rest/v2/models#transformcolumn
459
460 Args:
461 resource:
462 A dictionary representing a transform column feature.
463 """
464
465 def __init__(self, resource: Dict[str, Any]):
466 self._properties = resource
467
468 @property
469 def name(self) -> Optional[str]:
470 """Name of the column."""
471 return self._properties.get("name")
472
473 @property
474 def type_(self) -> Optional[standard_sql.StandardSqlDataType]:
475 """Data type of the column after the transform.
476
477 Returns:
478 Optional[google.cloud.bigquery.standard_sql.StandardSqlDataType]:
479 Data type of the column.
480 """
481 type_json = self._properties.get("type")
482 if type_json is None:
483 return None
484 return standard_sql.StandardSqlDataType.from_api_repr(type_json)
485
486 @property
487 def transform_sql(self) -> Optional[str]:
488 """The SQL expression used in the column transform."""
489 return self._properties.get("transformSql")
490
491 @classmethod
492 def from_api_repr(cls, resource: Dict[str, Any]) -> "TransformColumn":
493 """Constructs a transform column feature given its API representation
494
495 Args:
496 resource:
497 Transform column feature representation from the API
498
499 Returns:
500 Transform column feature parsed from ``resource``.
501 """
502 this = cls({})
503 resource = copy.deepcopy(resource)
504 this._properties = resource
505 return this
506
507
508def _model_arg_to_model_ref(value, default_project=None):
509 """Helper to convert a string or Model to ModelReference.
510
511 This function keeps ModelReference and other kinds of objects unchanged.
512 """
513 if isinstance(value, str):
514 return ModelReference.from_string(value, default_project=default_project)
515 if isinstance(value, Model):
516 return value.reference
517 return value