Coverage for /pythoncovmergedfiles/medio/medio/usr/local/lib/python3.8/site-packages/tensorflow/python/client/session.py: 21%

624 statements  

« prev     ^ index     » next       coverage.py v7.4.0, created at 2024-01-03 07:57 +0000

1# Copyright 2015 The TensorFlow Authors. All Rights Reserved. 

2# 

3# Licensed under the Apache License, Version 2.0 (the "License"); 

4# you may not use this file except in compliance with the License. 

5# You may obtain a copy of the License at 

6# 

7# http://www.apache.org/licenses/LICENSE-2.0 

8# 

9# Unless required by applicable law or agreed to in writing, software 

10# distributed under the License is distributed on an "AS IS" BASIS, 

11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 

12# See the License for the specific language governing permissions and 

13# limitations under the License. 

14# ============================================================================== 

15"""A client interface for TensorFlow.""" 

16 

17import collections 

18import functools 

19import re 

20import threading 

21import warnings 

22 

23import numpy as np 

24import wrapt 

25 

26from tensorflow.core.protobuf import config_pb2 

27from tensorflow.core.protobuf import rewriter_config_pb2 

28from tensorflow.python.client import pywrap_tf_session as tf_session 

29from tensorflow.python.eager import context 

30from tensorflow.python.eager import monitoring 

31from tensorflow.python.framework import device 

32from tensorflow.python.framework import error_interpolation 

33from tensorflow.python.framework import errors 

34from tensorflow.python.framework import indexed_slices 

35from tensorflow.python.framework import ops 

36from tensorflow.python.framework import sparse_tensor 

37from tensorflow.python.framework import stack 

38from tensorflow.python.ops import session_ops 

39from tensorflow.python.platform import tf_logging as logging 

40from tensorflow.python.training.experimental import mixed_precision_global_state 

41from tensorflow.python.util import compat 

42from tensorflow.python.util import nest 

43from tensorflow.python.util.compat import collections_abc 

44from tensorflow.python.util.tf_export import tf_export 

45 

46_python_session_create_counter = monitoring.Counter( 

47 '/tensorflow/api/python/session_create_counter', 

48 'Counter for number of sessions created in Python.') 

49 

50 

51class SessionInterface(object): 

52 """Base class for implementations of TensorFlow client sessions.""" 

53 

54 @property 

55 def graph(self): 

56 """The underlying TensorFlow graph, to be used in building Operations.""" 

57 raise NotImplementedError('graph') 

58 

59 @property 

60 def sess_str(self): 

61 """The TensorFlow process to which this session will connect.""" 

62 raise NotImplementedError('sess_str') 

63 

64 def run(self, fetches, feed_dict=None, options=None, run_metadata=None): 

65 """Runs operations in the session. See `BaseSession.run()` for details.""" 

66 raise NotImplementedError('run') 

67 

68 def partial_run_setup(self, fetches, feeds=None): 

69 """Sets up the feeds and fetches for partial runs in the session.""" 

70 raise NotImplementedError('partial_run_setup') 

71 

72 def partial_run(self, handle, fetches, feed_dict=None): 

73 """Continues the execution with additional feeds and fetches.""" 

74 raise NotImplementedError('partial_run') 

75 

76 

77def _get_indexed_slices_value_from_fetches(fetched_vals): 

78 return indexed_slices.IndexedSlicesValue( 

79 fetched_vals[0], fetched_vals[1], 

80 fetched_vals[2] if len(fetched_vals) == 3 else None) 

81 

82 

83def _get_feeds_for_indexed_slices(feed, feed_val): 

84 return list( 

85 zip([feed.values, feed.indices] if feed.dense_shape is None else 

86 [feed.values, feed.indices, feed.dense_shape], feed_val)) 

87 

88 

89# List of extensions supported to convert run arguments into actual fetches and 

90# feeds. 

91# 

92# Each element in the list is a tuple of (Type, fetch_fn, feed_fn1, feed_fn2), 

93# where the function signatures are: 

94# fetch_fn : Type -> (list of Tensors, 

95# lambda: list of fetched np.ndarray -> TypeVal) 

96# feed_fn1 : Type, TypeVal -> list of (Tensor, value) 

97# feed_fn2 : Type -> list of Tensors 

98# 

99# `fetch_fn` describes how to expand fetch into its 

100# component Tensors and how to contract the fetched results back into 

101# a single return value. 

102# 

103# Each feed function describes how to unpack a single fed value and map it to 

104# feeds of one or more tensors and their corresponding values: `feed_fn1` is 

105# used to feed a run, `feed_fn2` to set up a partial run. 

106# 

107# TODO(touts): We could reimplement these as specialized _FeedMapper 

108# implementations after we refactor the feed handling code to use them. 

109# 

110# Eventually, this registration could be opened up to support custom Tensor 

111# expansions. 

112# pylint: disable=g-long-lambda 

113_REGISTERED_EXPANSIONS = [ 

114 # SparseTensors are fetched as SparseTensorValues. They can be fed 

115 # SparseTensorValues or normal tuples. 

116 (sparse_tensor.SparseTensor, lambda fetch: ([ 

117 fetch.indices, fetch.values, fetch.dense_shape 

118 ], lambda fetched_vals: sparse_tensor.SparseTensorValue(*fetched_vals)), 

119 lambda feed, feed_val: list( 

120 zip([feed.indices, feed.values, feed.dense_shape], feed_val)), 

121 lambda feed: [feed.indices, feed.values, feed.dense_shape]), 

122 # IndexedSlices are fetched as IndexedSlicesValues. They can be fed 

123 # IndexedSlicesValues or normal tuples. 

124 (indexed_slices.IndexedSlices, 

125 lambda fetch: ([fetch.values, fetch.indices] if fetch.dense_shape is None 

126 else [fetch.values, fetch.indices, fetch.dense_shape 

127 ], _get_indexed_slices_value_from_fetches), 

128 _get_feeds_for_indexed_slices, 

129 lambda feed: [feed.values, feed.indices] if feed.dense_shape is None else 

130 [feed.values, feed.indices, feed.dense_shape]), 

131 # The default catches all other types and performs no expansions. 

132 (object, lambda fetch: ([fetch], lambda fetched_vals: fetched_vals[0]), 

133 lambda feed, feed_val: [(feed, feed_val)], lambda feed: [feed]) 

134] 

135 

136# pylint: enable=g-long-lambda 

137 

138 

139def _convert_to_numpy_obj(numpy_dtype, obj): 

140 """Explicitly convert obj based on numpy type except for string type.""" 

141 return numpy_dtype(obj) if numpy_dtype is not object else str(obj) 

142 

143 

144def register_session_run_conversion_functions( 

145 tensor_type, 

146 fetch_function, 

147 feed_function=None, 

148 feed_function_for_partial_run=None): 

149 """Register fetch and feed conversion functions for `tf.Session.run()`. 

150 

151 This function registers a triple of conversion functions for fetching and/or 

152 feeding values of user-defined types in a call to tf.Session.run(). 

153 

154 An example 

155 

156 ```python 

157 class SquaredTensor(object): 

158 def __init__(self, tensor): 

159 self.sq = tf.square(tensor) 

160 #you can define conversion functions as follows: 

161 fetch_function = lambda squared_tensor:([squared_tensor.sq], 

162 lambda val: val[0]) 

163 feed_function = lambda feed, feed_val: [(feed.sq, feed_val)] 

164 feed_function_for_partial_run = lambda feed: [feed.sq] 

165 #then after invoking this register function, you can use as follows: 

166 session.run(squared_tensor1, 

167 feed_dict = {squared_tensor2 : some_numpy_array}) 

168 ``` 

169 

170 Args: 

171 tensor_type: The type for which you want to register a conversion function. 

172 fetch_function: A callable that takes an object of type `tensor_type` and 

173 returns a tuple, where the first element is a list of `tf.Tensor` objects, 

174 and the second element is a callable that takes a list of ndarrays and 

175 returns an object of some value type that corresponds to `tensor_type`. 

176 fetch_function describes how to expand fetch into its component Tensors 

177 and how to contract the fetched results back into a single return value. 

178 feed_function: A callable that takes feed_key and feed_value as input, and 

179 returns a list of tuples (feed_tensor, feed_val), feed_key must have type 

180 `tensor_type`, and feed_tensor must have type `tf.Tensor`. Each feed 

181 function describes how to unpack a single fed value and map it to feeds of 

182 one or more tensors and their corresponding values. 

183 feed_function_for_partial_run: A callable for specifying tensor values to 

184 feed when setting up a partial run, which takes a `tensor_type` type 

185 object as input, and returns a list of Tensors. 

186 

187 Raises: 

188 ValueError: If `tensor_type` has already been registered. 

189 """ 

190 for conversion_function in _REGISTERED_EXPANSIONS: 

191 if issubclass(conversion_function[0], tensor_type): 

192 raise ValueError(f'{tensor_type} has already been registered so ignore ' 

193 'it.') 

194 

195 _REGISTERED_EXPANSIONS.insert(0, (tensor_type, fetch_function, feed_function, 

196 feed_function_for_partial_run)) 

197 

198 

199def _is_attrs_instance(obj): 

200 """Returns True if the given obj is an instance of attrs-decorated class.""" 

201 return getattr(obj.__class__, '__attrs_attrs__', None) is not None 

202 

203 

204def _get_attrs_values(obj): 

205 """Returns the list of values from an attrs instance.""" 

206 attrs = getattr(obj.__class__, '__attrs_attrs__') 

207 return [getattr(obj, a.name) for a in attrs] 

208 

209 

210class _FetchMapper(object): 

211 """Definition of the interface provided by fetch mappers. 

212 

213 Fetch mappers are utility classes used by the _FetchHandler to handle 

214 arbitrary structures for the `fetch` argument to `Session.run()`. 

215 

216 The `fetch` argument can be of various shapes: single tensor or op, list of 

217 fetches, tuple of fetches, namedtuple of fetches, or dict of fetches. The 

218 structures can be arbitrarily nested. 

219 

220 The low level run() API only wants a list of tensor or op names. The various 

221 `_FetchMapper` subclasses below take care of handling the different shapes: 

222 uniquifying the fetches, and constructing results with the original shape. 

223 """ 

224 

225 def unique_fetches(self): 

226 """Return the list of unique tensors or ops needed by this fetch mapper. 

227 

228 Returns: 

229 A list of tensors or ops. 

230 """ 

231 raise NotImplementedError( 

232 'unique_fetches must be implemented by subclasses') 

233 

234 def build_results(self, values): 

235 """Build results that match the original shape of the fetch. 

236 

237 Args: 

238 values: List of values returned by run(). The values correspond exactly to 

239 the list tensors or ops returned by unique_fetches(). 

240 

241 Returns: 

242 A struct of the same shape as the original fetch object handled by 

243 this fetch mapper. In the returned struct, the original fetches are 

244 replaced by their fetched values. 

245 """ 

246 raise NotImplementedError('build_results must be implemented by subclasses') 

247 

248 @staticmethod 

249 def for_fetch(fetch): 

250 """Creates fetch mapper that handles the structure of `fetch`. 

251 

252 The default graph must be the one from which we want to fetch values when 

253 this function is called. 

254 

255 Args: 

256 fetch: An arbitrary fetch structure: singleton, list, tuple, namedtuple, 

257 or dict. 

258 

259 Returns: 

260 An instance of a subclass of `_FetchMapper` that handles the shape. 

261 """ 

262 if fetch is None: 

263 raise TypeError(f'Argument `fetch` = {fetch} has invalid type ' 

264 f'"{type(fetch).__name__}". Cannot be None') 

265 elif isinstance(fetch, (list, tuple)): 

266 # NOTE(touts): This is also the code path for namedtuples. 

267 return _ListFetchMapper(fetch) 

268 elif isinstance(fetch, collections_abc.Mapping): 

269 return _DictFetchMapper(fetch) 

270 elif _is_attrs_instance(fetch): 

271 return _AttrsFetchMapper(fetch) 

272 else: 

273 # Look for a handler in the registered expansions. 

274 for tensor_type, fetch_fn, _, _ in _REGISTERED_EXPANSIONS: 

275 if isinstance(fetch, tensor_type): 

276 fetches, contraction_fn = fetch_fn(fetch) 

277 return _ElementFetchMapper(fetches, contraction_fn) 

278 # Did not find anything. 

279 raise TypeError(f'Argument `fetch` = {fetch} has invalid type ' 

280 f'"{type(fetch).__name__}"') 

281 

282 

283class _ElementFetchMapper(_FetchMapper): 

284 """Fetch mapper for singleton tensors and ops.""" 

285 

286 def __init__(self, fetches, contraction_fn): 

287 """Creates an _ElementFetchMapper. 

288 

289 This is the fetch mapper used for leaves in the fetch struct. Because of 

290 the expansions mechanism, a leaf can actually fetch more than one tensor. 

291 

292 Also note that the fetches here can be just strings (tensor or op names) or 

293 any other object that the graph knows how to convert to a tensor, such as a 

294 Variable. So we have to run each fetch through `as_graph_element()` to get 

295 the corresponding tensor or op. 

296 

297 Args: 

298 fetches: List of objects, as returned by a fetch_fn defined in 

299 _REGISTERED_EXPANSIONS. 

300 contraction_fn: Callable as returned by a fetch_fn. 

301 """ 

302 self._unique_fetches = [] 

303 for fetch in fetches: 

304 try: 

305 self._unique_fetches.append(ops.get_default_graph().as_graph_element( 

306 fetch, allow_tensor=True, allow_operation=True)) 

307 except TypeError as e: 

308 raise TypeError(f'Argument `fetch` = {fetch} has invalid type ' 

309 f'"{type(fetch).__name__}" must be a string or Tensor. ' 

310 f'({str(e)})') 

311 except ValueError as e: 

312 raise ValueError(f'Argument `fetch` = {fetch} cannot be interpreted as ' 

313 f'a Tensor. ({str(e)})') 

314 except KeyError as e: 

315 raise ValueError(f'Argument `fetch` = {fetch} cannot be interpreted as ' 

316 f'a Tensor. ({str(e)})') 

317 self._contraction_fn = contraction_fn 

318 

319 def unique_fetches(self): 

320 return self._unique_fetches 

321 

322 def build_results(self, values): 

323 if not values: 

324 # 'Operation' case 

325 return None 

326 else: 

327 return self._contraction_fn(values) 

328 

329 

330def _uniquify_fetches(fetch_mappers): 

331 """Uniquifies fetches from a list of fetch_mappers. 

332 

333 This is a utility function used by _ListFetchMapper and _DictFetchMapper. It 

334 gathers all the unique fetches from a list of mappers and builds a list 

335 containing all of them but without duplicates (unique_fetches). 

336 

337 It also returns a 2-D list of integers (values_indices) indicating at which 

338 index in unique_fetches the fetches of the mappers are located. 

339 

340 This list is as follows: 

341 values_indices[mapper_index][mapper_fetch_index] = unique_fetches_index 

342 

343 Args: 

344 fetch_mappers: list of fetch mappers. 

345 

346 Returns: 

347 A list of fetches. 

348 A 2-D list of integers. 

349 """ 

350 unique_fetches = [] 

351 value_indices = [] 

352 seen_fetches = {} 

353 for m in fetch_mappers: 

354 m_value_indices = [] 

355 for f in m.unique_fetches(): 

356 j = seen_fetches.get(id(f)) 

357 if j is None: 

358 j = len(seen_fetches) 

359 seen_fetches[id(f)] = j 

360 unique_fetches.append(f) 

361 m_value_indices.append(j) 

362 value_indices.append(m_value_indices) 

363 return unique_fetches, value_indices 

364 

365 

366class _ListFetchMapper(_FetchMapper): 

367 """Fetch mapper for lists, tuples, and namedtuples.""" 

368 

369 def __init__(self, fetches): 

370 """Creates a _ListFetchMapper. 

371 

372 Args: 

373 fetches: List, tuple, or namedtuple of fetches. 

374 """ 

375 if isinstance(fetches, wrapt.ObjectProxy): 

376 self._fetch_type = type(fetches.__wrapped__) 

377 else: 

378 self._fetch_type = type(fetches) 

379 self._mappers = [_FetchMapper.for_fetch(fetch) for fetch in fetches] 

380 self._unique_fetches, self._value_indices = _uniquify_fetches(self._mappers) 

381 

382 def unique_fetches(self): 

383 return self._unique_fetches 

384 

385 def build_results(self, values): 

386 # Create the list of results for each mapper. 

387 results = [] 

388 for m, vi in zip(self._mappers, self._value_indices): 

389 results.append(m.build_results([values[j] for j in vi])) 

390 # Return a value of the original type of the fetches. 

391 if issubclass(self._fetch_type, list): 

392 return results 

393 elif self._fetch_type == tuple: 

394 return tuple(results) 

395 else: 

396 # This is the code path for namedtuple. 

397 return self._fetch_type(*results) 

398 

399 

400class _DictFetchMapper(_FetchMapper): 

401 """Fetch mapper for dicts.""" 

402 

403 def __init__(self, fetches): 

404 """Creates a _DictFetchMapper. 

405 

406 Args: 

407 fetches: Dict of fetches. 

408 """ 

409 self._fetch_type = type(fetches) 

410 if isinstance(fetches, collections.defaultdict): 

411 self._type_ctor = functools.partial(collections.defaultdict, 

412 fetches.default_factory) 

413 else: 

414 self._type_ctor = self._fetch_type 

415 

416 self._keys = fetches.keys() 

417 self._mappers = [ 

418 _FetchMapper.for_fetch(fetch) for fetch in fetches.values() 

419 ] 

420 self._unique_fetches, self._value_indices = _uniquify_fetches(self._mappers) 

421 

422 def unique_fetches(self): 

423 return self._unique_fetches 

424 

425 def build_results(self, values): 

426 

427 def _generator(): 

428 for k, m, vi in zip(self._keys, self._mappers, self._value_indices): 

429 yield k, m.build_results([values[j] for j in vi]) 

430 

431 return self._type_ctor(_generator()) 

432 

433 

434class _AttrsFetchMapper(_FetchMapper): 

435 """Fetch mapper for attrs decorated classes.""" 

436 

437 def __init__(self, fetches): 

438 """Creates a _AttrsFetchMapper. 

439 

440 Args: 

441 fetches: An instance of an attrs decorated class. 

442 """ 

443 values = _get_attrs_values(fetches) 

444 self._fetch_type = type(fetches) 

445 self._mappers = [_FetchMapper.for_fetch(fetch) for fetch in values] 

446 self._unique_fetches, self._value_indices = _uniquify_fetches(self._mappers) 

447 

448 def unique_fetches(self): 

449 return self._unique_fetches 

450 

451 def build_results(self, values): 

452 results = [] 

453 for m, vi in zip(self._mappers, self._value_indices): 

454 results.append(m.build_results([values[j] for j in vi])) 

455 return self._fetch_type(*results) 

456 

457 

458class _FetchHandler(object): 

459 """Handler for structured fetches. 

460 

461 Given a graph, a user-provided structure for fetches, and a feed dict, this 

462 class takes care of generating a list of tensor names to fetch and op names 

463 to run for a low level `run()` call. 

464 

465 Given the results of the low level run call, this class can also rebuild a 

466 result structure matching the user-provided structure for fetches, but 

467 containing the corresponding results. 

468 """ 

469 

470 # TODO(touts): Make this class also take care of destructuring the feed 

471 # dict instead of doing it in the callers. 

472 

473 def __init__(self, graph, fetches, feeds, feed_handles=None): 

474 """Creates a fetch handler. 

475 

476 Args: 

477 graph: Graph of the fetches. Used to check for fetchability and to 

478 convert all fetches to tensors or ops as needed. 

479 fetches: An arbitrary fetch structure: singleton, list, tuple, namedtuple, 

480 or dict. 

481 feeds: A feed dict where keys are Tensors. 

482 feed_handles: A dict from feed Tensors to TensorHandle objects used as 

483 direct feeds. 

484 """ 

485 with graph.as_default(): 

486 self._fetch_mapper = _FetchMapper.for_fetch(fetches) 

487 self._fetches = [] 

488 self._targets = [] 

489 self._feeds = feeds 

490 self._feed_handles = feed_handles or {} 

491 self._ops = [] 

492 self._fetch_handles = {} 

493 for fetch in self._fetch_mapper.unique_fetches(): 

494 if isinstance(fetch, ops.Operation): 

495 self._assert_fetchable(graph, fetch) 

496 self._targets.append(fetch) 

497 self._ops.append(True) 

498 else: 

499 self._assert_fetchable(graph, fetch.op) 

500 self._fetches.append(fetch) 

501 self._ops.append(False) 

502 # Remember the fetch if it is for a tensor handle. 

503 if (isinstance(fetch, ops.Tensor) and 

504 (fetch.op.type == 'GetSessionHandle' or 

505 fetch.op.type == 'GetSessionHandleV2')): 

506 self._fetch_handles[fetch.ref()] = fetch.op.inputs[0].dtype 

507 self._final_fetches = [x for x in self._fetches if x.ref() not in feeds] 

508 

509 def _assert_fetchable(self, graph, op): 

510 if not graph.is_fetchable(op): 

511 raise errors.InaccessibleTensorError( 

512 f'Operation {op.name} has been marked as not fetchable. Typically ' 

513 'this happens when it is defined in another function or code block. ' 

514 'Use return values, explicit Python locals or TensorFlow collections ' 

515 'to access it.') 

516 

517 def fetches(self): 

518 """Return the unique names of tensors to fetch. 

519 

520 Returns: 

521 A list of strings. 

522 """ 

523 return self._final_fetches 

524 

525 def targets(self): 

526 """Return the unique names of ops to run. 

527 

528 Returns: 

529 A list of strings. 

530 """ 

531 return self._targets 

532 

533 def build_results(self, session, tensor_values): 

534 """Build results matching the original fetch shape. 

535 

536 `tensor_values` must be a list of the same length as 

537 the one returned by `fetches()`, and holding the requested 

538 fetch values. 

539 

540 This method builds a struct with the same shape as the original `fetches` 

541 passed to the constructor, in which the fetches are replaced by their 

542 fetched value. 

543 

544 Args: 

545 session: The enclosing session. Used for tensor handles. 

546 tensor_values: List of values matching the list returned by fetches(). 

547 

548 Returns: 

549 A structure of the same shape as the original `fetches` argument but 

550 containing tensors or None (for fetched ops). 

551 """ 

552 full_values = [] 

553 assert len(self._final_fetches) == len(tensor_values) 

554 i = 0 

555 j = 0 

556 for is_op in self._ops: 

557 if is_op: 

558 full_values.append(None) 

559 else: 

560 # If the fetch was in the feeds, use the fed value, otherwise 

561 # use the returned value. 

562 if self._fetches[i].ref() in self._feed_handles: 

563 # A fetch had a corresponding direct TensorHandle feed. Call eval() 

564 # to obtain the Tensor value from the TensorHandle. 

565 value = self._feed_handles[self._fetches[i].ref()].eval() 

566 else: 

567 value = self._feeds.get(self._fetches[i].ref()) 

568 if value is None: 

569 value = tensor_values[j] 

570 j += 1 

571 dtype = self._fetch_handles.get(self._fetches[i].ref()) 

572 if dtype: 

573 full_values.append(session_ops.TensorHandle(value, dtype, session)) 

574 else: 

575 full_values.append(value) 

576 i += 1 

577 assert j == len(tensor_values) 

578 return self._fetch_mapper.build_results(full_values) 

579 

580 

581def _name_list(tensor_list): 

582 """Utility function for transitioning to the new session API. 

583 

584 Args: 

585 tensor_list: a list of `Tensor`s. 

586 

587 Returns: 

588 A list of each `Tensor`s name (as byte arrays). 

589 """ 

590 return [compat.as_bytes(t.name) for t in tensor_list] 

591 

592 

593class _DeviceAttributes(object): 

594 """Struct-like object describing a device's attributes. 

595 

596 Each device has 3 key properties: 

597 - name: the fully-qualified TensorFlow path to the device. For 

598 example: /job:worker/replica:0/task:3/device:CPU:0 

599 - device_type: the type of the device (e.g. CPU, GPU, TPU, etc.) 

600 - memory_limit_bytes: the maximum amount of memory available on the device 

601 (in bytes). 

602 """ 

603 

604 def __init__(self, name, device_type, memory_limit_bytes, incarnation): 

605 self._name = device.canonical_name(name) 

606 self._device_type = device_type 

607 self._memory_limit_bytes = memory_limit_bytes 

608 self._incarnation = incarnation 

609 

610 @property 

611 def name(self): 

612 return self._name 

613 

614 @property 

615 def device_type(self): 

616 return self._device_type 

617 

618 @property 

619 def memory_limit_bytes(self): 

620 return self._memory_limit_bytes 

621 

622 @property 

623 def incarnation(self): 

624 return self._incarnation 

625 

626 def __repr__(self): 

627 return '_DeviceAttributes(%s, %s, %d, %d)' % ( 

628 self.name, 

629 self.device_type, 

630 self.memory_limit_bytes, 

631 self.incarnation, 

632 ) 

633 

634 

635class BaseSession(SessionInterface): 

636 """A class for interacting with a TensorFlow computation. 

637 

638 The BaseSession enables incremental graph building with inline 

639 execution of Operations and evaluation of Tensors. 

640 """ 

641 

642 def __init__(self, target='', graph=None, config=None): 

643 """Constructs a new TensorFlow session. 

644 

645 Args: 

646 target: (Optional) The TensorFlow execution engine to connect to. 

647 graph: (Optional) The graph to be used. If this argument is None, the 

648 default graph will be used. 

649 config: (Optional) ConfigProto proto used to configure the session. If no 

650 config is specified, the global default will be used. The global default 

651 can be configured via the tf.config APIs. 

652 

653 Raises: 

654 tf.errors.OpError: Or one of its subclasses if an error occurs while 

655 creating the TensorFlow session. 

656 TypeError: If one of the arguments has the wrong type. 

657 """ 

658 _python_session_create_counter.get_cell().increase_by(1) 

659 if graph is None: 

660 self._graph = ops.get_default_graph() 

661 else: 

662 if not isinstance(graph, ops.Graph): 

663 raise TypeError('Argument `graph` must be a tf.Graph, but got ' 

664 f'"{type(graph).__name__}"') 

665 self._graph = graph 

666 

667 self._closed = False 

668 

669 if target is not None: 

670 try: 

671 self._target = compat.as_bytes(target) 

672 except TypeError: 

673 if isinstance(target, config_pb2.ConfigProto): 

674 raise TypeError('Argument `target` must be a string, but got ' 

675 f'"{type(target).__name__}". Did you do ' 

676 '"Session(config)" instead of ' 

677 '"Session(config=config)"?') 

678 raise TypeError('Argument `target` must be a string, but got ' 

679 f'"{type(target).__name__}"') 

680 else: 

681 self._target = None 

682 

683 self._delete_lock = threading.Lock() 

684 self._dead_handles = [] 

685 

686 if config is None: 

687 config = context.context().config 

688 

689 if not isinstance(config, config_pb2.ConfigProto): 

690 raise TypeError('Argument `config` must be a tf.ConfigProto, but got ' 

691 f'"{type(config).__name__}"') 

692 

693 if (mixed_precision_global_state.is_mixed_precision_graph_rewrite_enabled() 

694 and config.graph_options.rewrite_options.auto_mixed_precision != 

695 rewriter_config_pb2.RewriterConfig.OFF): 

696 new_config = config_pb2.ConfigProto() 

697 new_config.CopyFrom(config) 

698 new_config.graph_options.rewrite_options.auto_mixed_precision = ( 

699 rewriter_config_pb2.RewriterConfig.ON) 

700 config = new_config 

701 elif (config.graph_options.rewrite_options.auto_mixed_precision != 

702 rewriter_config_pb2.RewriterConfig.ON): 

703 mixed_precision_global_state.set_non_mixed_precision_session_created(True) 

704 

705 self._config = config 

706 self._add_shapes = config.graph_options.infer_shapes 

707 

708 self._session = None 

709 opts = tf_session.TF_NewSessionOptions(target=self._target, config=config) 

710 try: 

711 # pylint: disable=protected-access 

712 with self._graph._c_graph.get() as c_graph: 

713 self._session = tf_session.TF_NewSessionRef(c_graph, opts) 

714 # pylint: enable=protected-access 

715 finally: 

716 tf_session.TF_DeleteSessionOptions(opts) 

717 

718 def list_devices(self): 

719 """Lists available devices in this session. 

720 

721 ```python 

722 devices = sess.list_devices() 

723 for d in devices: 

724 print(d.name) 

725 ``` 

726 

727 Where: 

728 Each element in the list has the following properties 

729 name: A string with the full name of the device. ex: 

730 `/job:worker/replica:0/task:3/device:CPU:0` 

731 device_type: The type of the device (e.g. `CPU`, `GPU`, `TPU`.) 

732 memory_limit: The maximum amount of memory available on the device. 

733 Note: depending on the device, it is possible the usable memory could 

734 be substantially less. 

735 

736 Raises: 

737 tf.errors.OpError: If it encounters an error (e.g. session is in an 

738 invalid state, or network errors occur). 

739 

740 Returns: 

741 A list of devices in the session. 

742 """ 

743 raw_device_list = tf_session.TF_SessionListDevices(self._session) 

744 device_list = [] 

745 size = tf_session.TF_DeviceListCount(raw_device_list) 

746 for i in range(size): 

747 name = tf_session.TF_DeviceListName(raw_device_list, i) 

748 device_type = tf_session.TF_DeviceListType(raw_device_list, i) 

749 memory = tf_session.TF_DeviceListMemoryBytes(raw_device_list, i) 

750 incarnation = tf_session.TF_DeviceListIncarnation(raw_device_list, i) 

751 device_list.append( 

752 _DeviceAttributes(name, device_type, memory, incarnation)) 

753 tf_session.TF_DeleteDeviceList(raw_device_list) 

754 return device_list 

755 

756 def close(self): 

757 """Closes this session. 

758 

759 Calling this method frees all resources associated with the session. 

760 

761 Raises: 

762 tf.errors.OpError: Or one of its subclasses if an error occurs while 

763 closing the TensorFlow session. 

764 """ 

765 if self._session and not self._closed: 

766 self._closed = True 

767 tf_session.TF_CloseSession(self._session) 

768 

769 def __del__(self): 

770 # cleanly ignore all exceptions 

771 try: 

772 self.close() 

773 except Exception: # pylint: disable=broad-except 

774 pass 

775 if self._session is not None: 

776 try: 

777 tf_session.TF_DeleteSession(self._session) 

778 except (AttributeError, TypeError): 

779 # At shutdown, `c_api_util`, `tf_session`, or 

780 # `tf_session.TF_DeleteSession` may have been garbage collected, causing 

781 # the above method calls to fail. In this case, silently leak since the 

782 # program is about to terminate anyway. 

783 pass 

784 self._session = None 

785 

786 @property 

787 def graph(self): 

788 """The graph that was launched in this session.""" 

789 return self._graph 

790 

791 @property 

792 def graph_def(self): 

793 """A serializable version of the underlying TensorFlow graph. 

794 

795 Returns: 

796 A graph_pb2.GraphDef proto containing nodes for all of the Operations in 

797 the underlying TensorFlow graph. 

798 """ 

799 return self._graph.as_graph_def(add_shapes=self._add_shapes) 

800 

801 @property 

802 def sess_str(self): 

803 return self._target 

804 

805 def as_default(self): 

806 """Returns a context manager that makes this object the default session. 

807 

808 Use with the `with` keyword to specify that calls to 

809 `tf.Operation.run` or `tf.Tensor.eval` should be executed in 

810 this session. 

811 

812 ```python 

813 c = tf.constant(..) 

814 sess = tf.compat.v1.Session() 

815 

816 with sess.as_default(): 

817 assert tf.compat.v1.get_default_session() is sess 

818 print(c.eval()) 

819 ``` 

820 

821 To get the current default session, use `tf.compat.v1.get_default_session`. 

822 

823 *N.B.* The `as_default` context manager *does not* close the 

824 session when you exit the context, and you must close the session 

825 explicitly. 

826 

827 ```python 

828 c = tf.constant(...) 

829 sess = tf.compat.v1.Session() 

830 with sess.as_default(): 

831 print(c.eval()) 

832 # ... 

833 with sess.as_default(): 

834 print(c.eval()) 

835 

836 sess.close() 

837 ``` 

838 

839 Alternatively, you can use `with tf.compat.v1.Session():` to create a 

840 session that is automatically closed on exiting the context, 

841 including when an uncaught exception is raised. 

842 

843 *N.B.* The default session is a property of the current thread. If you 

844 create a new thread, and wish to use the default session in that 

845 thread, you must explicitly add a `with sess.as_default():` in that 

846 thread's function. 

847 

848 *N.B.* Entering a `with sess.as_default():` block does not affect 

849 the current default graph. If you are using multiple graphs, and 

850 `sess.graph` is different from the value of 

851 `tf.compat.v1.get_default_graph`, you must explicitly enter a 

852 `with sess.graph.as_default():` block to make `sess.graph` the default 

853 graph. 

854 

855 Returns: 

856 A context manager using this session as the default session. 

857 """ 

858 return stack.default_session(self) 

859 

860 def run(self, fetches, feed_dict=None, options=None, run_metadata=None): 

861 """Runs operations and evaluates tensors in `fetches`. 

862 

863 This method runs one "step" of TensorFlow computation, by 

864 running the necessary graph fragment to execute every `Operation` 

865 and evaluate every `Tensor` in `fetches`, substituting the values in 

866 `feed_dict` for the corresponding input values. 

867 

868 The `fetches` argument may be a single graph element, or an arbitrarily 

869 nested list, tuple, namedtuple, dict, or OrderedDict containing graph 

870 elements at its leaves. A graph element can be one of the following types: 

871 

872 * A `tf.Operation`. 

873 The corresponding fetched value will be `None`. 

874 * A `tf.Tensor`. 

875 The corresponding fetched value will be a numpy ndarray containing the 

876 value of that tensor. 

877 * A `tf.sparse.SparseTensor`. 

878 The corresponding fetched value will be a 

879 `tf.compat.v1.SparseTensorValue` 

880 containing the value of that sparse tensor. 

881 * A `get_tensor_handle` op. The corresponding fetched value will be a 

882 numpy ndarray containing the handle of that tensor. 

883 * A `string` which is the name of a tensor or operation in the graph. 

884 

885 The value returned by `run()` has the same shape as the `fetches` argument, 

886 where the leaves are replaced by the corresponding values returned by 

887 TensorFlow. 

888 

889 Example: 

890 

891 ```python 

892 a = tf.constant([10, 20]) 

893 b = tf.constant([1.0, 2.0]) 

894 # 'fetches' can be a singleton 

895 v = session.run(a) 

896 # v is the numpy array [10, 20] 

897 # 'fetches' can be a list. 

898 v = session.run([a, b]) 

899 # v is a Python list with 2 numpy arrays: the 1-D array [10, 20] and the 

900 # 1-D array [1.0, 2.0] 

901 # 'fetches' can be arbitrary lists, tuples, namedtuple, dicts: 

902 MyData = collections.namedtuple('MyData', ['a', 'b']) 

903 v = session.run({'k1': MyData(a, b), 'k2': [b, a]}) 

904 # v is a dict with 

905 # v['k1'] is a MyData namedtuple with 'a' (the numpy array [10, 20]) and 

906 # 'b' (the numpy array [1.0, 2.0]) 

907 # v['k2'] is a list with the numpy array [1.0, 2.0] and the numpy array 

908 # [10, 20]. 

909 ``` 

910 

911 The optional `feed_dict` argument allows the caller to override 

912 the value of tensors in the graph. Each key in `feed_dict` can be 

913 one of the following types: 

914 

915 * If the key is a `tf.Tensor`, the 

916 value may be a Python scalar, string, list, or numpy ndarray 

917 that can be converted to the same `dtype` as that 

918 tensor. Additionally, if the key is a 

919 `tf.compat.v1.placeholder`, the shape of 

920 the value will be checked for compatibility with the placeholder. 

921 * If the key is a 

922 `tf.sparse.SparseTensor`, 

923 the value should be a 

924 `tf.compat.v1.SparseTensorValue`. 

925 * If the key is a nested tuple of `Tensor`s or `SparseTensor`s, the value 

926 should be a nested tuple with the same structure that maps to their 

927 corresponding values as above. 

928 

929 Each value in `feed_dict` must be convertible to a numpy array of the dtype 

930 of the corresponding key. 

931 

932 The optional `options` argument expects a [`RunOptions`] proto. The options 

933 allow controlling the behavior of this particular step (e.g. turning tracing 

934 on). 

935 

936 The optional `run_metadata` argument expects a [`RunMetadata`] proto. When 

937 appropriate, the non-Tensor output of this step will be collected there. For 

938 example, when users turn on tracing in `options`, the profiled info will be 

939 collected into this argument and passed back. 

940 

941 Args: 

942 fetches: A single graph element, a list of graph elements, or a dictionary 

943 whose values are graph elements or lists of graph elements (described 

944 above). 

945 feed_dict: A dictionary that maps graph elements to values (described 

946 above). 

947 options: A [`RunOptions`] protocol buffer 

948 run_metadata: A [`RunMetadata`] protocol buffer 

949 

950 Returns: 

951 Either a single value if `fetches` is a single graph element, or 

952 a list of values if `fetches` is a list, or a dictionary with the 

953 same keys as `fetches` if that is a dictionary (described above). 

954 Order in which `fetches` operations are evaluated inside the call 

955 is undefined. 

956 

957 Raises: 

958 RuntimeError: If this `Session` is in an invalid state (e.g. has been 

959 closed). 

960 TypeError: If `fetches` or `feed_dict` keys are of an inappropriate type. 

961 ValueError: If `fetches` or `feed_dict` keys are invalid or refer to a 

962 `Tensor` that doesn't exist. 

963 """ 

964 options_ptr = tf_session.TF_NewBufferFromString( 

965 compat.as_bytes(options.SerializeToString())) if options else None 

966 run_metadata_ptr = tf_session.TF_NewBuffer() if run_metadata else None 

967 

968 try: 

969 result = self._run(None, fetches, feed_dict, options_ptr, 

970 run_metadata_ptr) 

971 if run_metadata: 

972 proto_data = tf_session.TF_GetBuffer(run_metadata_ptr) 

973 run_metadata.ParseFromString(compat.as_bytes(proto_data)) 

974 finally: 

975 if run_metadata_ptr: 

976 tf_session.TF_DeleteBuffer(run_metadata_ptr) 

977 if options: 

978 tf_session.TF_DeleteBuffer(options_ptr) 

979 return result 

980 

981 def partial_run(self, handle, fetches, feed_dict=None): 

982 """Continues the execution with more feeds and fetches. 

983 

984 This is EXPERIMENTAL and subject to change. 

985 

986 To use partial execution, a user first calls `partial_run_setup()` and 

987 then a sequence of `partial_run()`. `partial_run_setup` specifies the 

988 list of feeds and fetches that will be used in the subsequent 

989 `partial_run` calls. 

990 

991 The optional `feed_dict` argument allows the caller to override 

992 the value of tensors in the graph. See run() for more information. 

993 

994 Below is a simple example: 

995 

996 ```python 

997 a = array_ops.placeholder(dtypes.float32, shape=[]) 

998 b = array_ops.placeholder(dtypes.float32, shape=[]) 

999 c = array_ops.placeholder(dtypes.float32, shape=[]) 

1000 r1 = math_ops.add(a, b) 

1001 r2 = math_ops.multiply(r1, c) 

1002 

1003 h = sess.partial_run_setup([r1, r2], [a, b, c]) 

1004 res = sess.partial_run(h, r1, feed_dict={a: 1, b: 2}) 

1005 res = sess.partial_run(h, r2, feed_dict={c: res}) 

1006 ``` 

1007 

1008 Args: 

1009 handle: A handle for a sequence of partial runs. 

1010 fetches: A single graph element, a list of graph elements, or a dictionary 

1011 whose values are graph elements or lists of graph elements (see 

1012 documentation for `run`). 

1013 feed_dict: A dictionary that maps graph elements to values (described 

1014 above). 

1015 

1016 Returns: 

1017 Either a single value if `fetches` is a single graph element, or 

1018 a list of values if `fetches` is a list, or a dictionary with the 

1019 same keys as `fetches` if that is a dictionary 

1020 (see documentation for `run`). 

1021 

1022 Raises: 

1023 tf.errors.OpError: Or one of its subclasses on error. 

1024 """ 

1025 # TODO(touts): Support feeding and fetching the same tensor. 

1026 return self._run(handle, fetches, feed_dict, None, None) 

1027 

1028 def partial_run_setup(self, fetches, feeds=None): 

1029 """Sets up a graph with feeds and fetches for partial run. 

1030 

1031 This is EXPERIMENTAL and subject to change. 

1032 

1033 Note that contrary to `run`, `feeds` only specifies the graph elements. 

1034 The tensors will be supplied by the subsequent `partial_run` calls. 

1035 

1036 Args: 

1037 fetches: A single graph element, or a list of graph elements. 

1038 feeds: A single graph element, or a list of graph elements. 

1039 

1040 Returns: 

1041 A handle for partial run. 

1042 

1043 Raises: 

1044 RuntimeError: If this `Session` is in an invalid state (e.g. has been 

1045 closed). 

1046 TypeError: If `fetches` or `feed_dict` keys are of an inappropriate type. 

1047 tf.errors.OpError: Or one of its subclasses if a TensorFlow error happens. 

1048 """ 

1049 

1050 def _feed_fn(feed): 

1051 for tensor_type, _, _, feed_fn in _REGISTERED_EXPANSIONS: 

1052 if isinstance(feed, tensor_type): 

1053 return feed_fn(feed) 

1054 raise TypeError(f'Feed argument {feed} has invalid type ' 

1055 f'"{type(feed).__name__}"') 

1056 

1057 # Check session. 

1058 if self._closed: 

1059 raise RuntimeError('Attempted to use a closed Session.') 

1060 if self.graph.version == 0: 

1061 raise RuntimeError('The Session graph is empty. Add operations to the ' 

1062 'graph before calling run().') 

1063 

1064 if feeds is None: 

1065 feeds = [] 

1066 # Create request. 

1067 feed_list = [] 

1068 

1069 # Validate and process feed_list. 

1070 is_list_feed = isinstance(feeds, (list, tuple)) 

1071 if not is_list_feed: 

1072 feeds = [feeds] 

1073 for feed in feeds: 

1074 for subfeed in _feed_fn(feed): 

1075 try: 

1076 subfeed_t = self.graph.as_graph_element( 

1077 subfeed, allow_tensor=True, allow_operation=False) 

1078 # pylint: disable=protected-access 

1079 feed_list.append(subfeed_t._as_tf_output()) 

1080 # pylint: enable=protected-access 

1081 except Exception as e: 

1082 e.message = ('Cannot interpret argument `feed` key as Tensor: ' 

1083 f'{e.message}') 

1084 e.args = (e.message,) 

1085 raise e 

1086 

1087 # Validate and process fetches. 

1088 # TODO(touts): Support feeding and fetching the same tensor. 

1089 fetch_handler = _FetchHandler(self._graph, fetches, {}) 

1090 

1091 # Set up a graph with feeds and fetches for partial run. 

1092 def _setup_fn(session, feed_list, fetch_list, target_list): 

1093 self._extend_graph() 

1094 return tf_session.TF_SessionPRunSetup_wrapper(session, feed_list, 

1095 fetch_list, target_list) 

1096 

1097 # pylint: disable=protected-access 

1098 final_fetches = [t._as_tf_output() for t in fetch_handler.fetches()] 

1099 final_targets = [op._c_op for op in fetch_handler.targets()] 

1100 # pylint: enable=protected-access 

1101 

1102 return self._do_call(_setup_fn, self._session, feed_list, final_fetches, 

1103 final_targets) 

1104 

1105 def _run(self, handle, fetches, feed_dict, options, run_metadata): 

1106 """Perform either run or partial_run, depending the presence of `handle`.""" 

1107 

1108 def _feed_fn(feed, feed_val): 

1109 for tensor_type, _, feed_fn, _ in _REGISTERED_EXPANSIONS: 

1110 if isinstance(feed, tensor_type): 

1111 return feed_fn(feed, feed_val) 

1112 raise TypeError(f'{feed} in argument `feed_dict` has invalid type ' 

1113 f'"{type(feed).__name__}"') 

1114 

1115 # Check session. 

1116 if self._closed: 

1117 raise RuntimeError('Attempted to use a closed Session.') 

1118 if self.graph.version == 0: 

1119 raise RuntimeError('The Session graph is empty. Add operations to the ' 

1120 'graph before calling run().') 

1121 

1122 # Create request. 

1123 feed_dict_tensor = {} 

1124 feed_map = {} 

1125 

1126 # Validate and process feed_dict. 

1127 feed_handles = {} 

1128 if feed_dict: 

1129 feed_dict = nest.flatten_dict_items(feed_dict) 

1130 for feed, feed_val in feed_dict.items(): 

1131 for subfeed, subfeed_val in _feed_fn(feed, feed_val): 

1132 try: 

1133 subfeed_t = self.graph.as_graph_element( 

1134 subfeed, allow_tensor=True, allow_operation=False) 

1135 except Exception as e: 

1136 raise TypeError( 

1137 f'Cannot interpret feed_dict key as Tensor: {e.args[0]}') 

1138 

1139 if isinstance(subfeed_val, ops.Tensor): 

1140 raise TypeError( 

1141 'The value of a feed cannot be a tf.Tensor object. Acceptable ' 

1142 'feed values include Python scalars, strings, lists, numpy ' 

1143 'ndarrays, or TensorHandles. For reference, the tensor object ' 

1144 f'was {str(feed_val)} which was passed to the argument ' 

1145 f'`feed_dict` with key {str(feed)}.') 

1146 

1147 subfeed_dtype = subfeed_t.dtype.as_numpy_dtype 

1148 if isinstance(subfeed_val, int) and _convert_to_numpy_obj( 

1149 subfeed_dtype, subfeed_val) != subfeed_val: 

1150 raise TypeError( 

1151 f'Type of feed value {str(subfeed_val)} with type ' + 

1152 f'{str(type(subfeed_val))} is not compatible with Tensor type ' 

1153 f'{str(subfeed_dtype)}. Try explicitly setting the type of the ' 

1154 'feed tensor to a larger type (e.g. int64).') 

1155 

1156 is_tensor_handle_feed = isinstance(subfeed_val, 

1157 session_ops.TensorHandle) 

1158 if is_tensor_handle_feed: 

1159 np_val = subfeed_val.to_numpy_array() 

1160 feed_handles[subfeed_t.ref()] = subfeed_val 

1161 else: 

1162 np_val = np.asarray(subfeed_val, dtype=subfeed_dtype) 

1163 

1164 if (not is_tensor_handle_feed and 

1165 not subfeed_t.get_shape().is_compatible_with(np_val.shape)): 

1166 raise ValueError( 

1167 f'Cannot feed value of shape {str(np_val.shape)} for Tensor ' 

1168 f'{subfeed_t.name}, which has shape ' 

1169 f'{str(subfeed_t.get_shape())}') 

1170 if not self.graph.is_feedable(subfeed_t): 

1171 raise ValueError(f'Tensor {subfeed_t.name} may not be fed.') 

1172 

1173 feed_dict_tensor[subfeed_t.ref()] = np_val 

1174 feed_map[compat.as_bytes(subfeed_t.name)] = (subfeed_t, subfeed_val) 

1175 

1176 # Create a fetch handler to take care of the structure of fetches. 

1177 fetch_handler = _FetchHandler( 

1178 self._graph, fetches, feed_dict_tensor, feed_handles=feed_handles) 

1179 

1180 # Run request and get response. 

1181 # We need to keep the returned movers alive for the following _do_run(). 

1182 # These movers are no longer needed when _do_run() completes, and 

1183 # are deleted when `movers` goes out of scope when this _run() ends. 

1184 # TODO(yuanbyu, keveman): Revisit whether we should just treat feeding 

1185 # of a handle from a different device as an error. 

1186 _ = self._update_with_movers(feed_dict_tensor, feed_map) 

1187 final_fetches = fetch_handler.fetches() 

1188 final_targets = fetch_handler.targets() 

1189 # We only want to really perform the run if fetches or targets are provided, 

1190 # or if the call is a partial run that specifies feeds. 

1191 if final_fetches or final_targets or (handle and feed_dict_tensor): 

1192 results = self._do_run(handle, final_targets, final_fetches, 

1193 feed_dict_tensor, options, run_metadata) 

1194 else: 

1195 results = [] 

1196 return fetch_handler.build_results(self, results) 

1197 

1198 def make_callable(self, fetches, feed_list=None, accept_options=False): 

1199 """Returns a Python callable that runs a particular step. 

1200 

1201 The returned callable will take `len(feed_list)` arguments whose types 

1202 must be compatible feed values for the respective elements of `feed_list`. 

1203 For example, if element `i` of `feed_list` is a `tf.Tensor`, the `i`th 

1204 argument to the returned callable must be a numpy ndarray (or something 

1205 convertible to an ndarray) with matching element type and shape. See 

1206 `tf.Session.run` for details of the allowable feed key and value types. 

1207 

1208 The returned callable will have the same return type as 

1209 `tf.Session.run(fetches, ...)`. For example, if `fetches` is a `tf.Tensor`, 

1210 the callable will return a numpy ndarray; if `fetches` is a `tf.Operation`, 

1211 it will return `None`. 

1212 

1213 Args: 

1214 fetches: A value or list of values to fetch. See `tf.Session.run` for 

1215 details of the allowable fetch types. 

1216 feed_list: (Optional.) A list of `feed_dict` keys. See `tf.Session.run` 

1217 for details of the allowable feed key types. 

1218 accept_options: (Optional.) If `True`, the returned `Callable` will be 

1219 able to accept `tf.compat.v1.RunOptions` and `tf.compat.v1.RunMetadata` 

1220 as optional keyword arguments `options` and `run_metadata`, 

1221 respectively, with the same syntax and semantics as `tf.Session.run`, 

1222 which is useful for certain use cases (profiling and debugging) but will 

1223 result in measurable slowdown of the `Callable`'s 

1224 performance. Default: `False`. 

1225 

1226 Returns: 

1227 A function that when called will execute the step defined by 

1228 `feed_list` and `fetches` in this session. 

1229 

1230 Raises: 

1231 TypeError: If `fetches` or `feed_list` cannot be interpreted 

1232 as arguments to `tf.Session.run`. 

1233 """ 

1234 if feed_list is not None: 

1235 if not isinstance(feed_list, (list, tuple)): 

1236 raise TypeError('Argument `feed_list` must be a list or tuple. ' 

1237 f'Received: feed_list={feed_list}') 

1238 # Delegate any non-empty feed lists to the existing `run()` logic. 

1239 # TODO(mrry): Refactor the feed handling logic from 

1240 # `Session._run()` so that we can convert the feeds to a list of 

1241 # strings here. 

1242 def _generic_run(*feed_args, **kwargs): 

1243 feed_dict = { 

1244 feed: feed_val for feed, feed_val in zip(feed_list, feed_args) 

1245 } 

1246 return self.run(fetches, feed_dict=feed_dict, **kwargs) 

1247 

1248 return _generic_run 

1249 

1250 # Ensure any changes to the graph are reflected in the runtime. 

1251 # Note that we don't need to do this on subsequent calls to the 

1252 # returned object, because the arguments to `fetches` must already be 

1253 # in the graph. 

1254 self._extend_graph() 

1255 

1256 # Create a fetch handler to take care of the structure of fetches. 

1257 fetch_handler = _FetchHandler(self._graph, fetches, {}) 

1258 # pylint: disable=protected-access 

1259 fetch_list = [t._as_tf_output() for t in fetch_handler.fetches()] 

1260 target_list = [op._c_op for op in fetch_handler.targets()] 

1261 

1262 # pylint: enable=protected-access 

1263 

1264 def _callable_template_with_options_and_metadata(fetch_list, 

1265 target_list, 

1266 fetch_handler, 

1267 options=None, 

1268 run_metadata=None): 

1269 """Template callable that accepts RunOptions and RunMetadata.""" 

1270 options_ptr = tf_session.TF_NewBufferFromString( 

1271 compat.as_bytes(options.SerializeToString())) if options else None 

1272 run_metadata_ptr = tf_session.TF_NewBuffer() if run_metadata else None 

1273 try: 

1274 results = self._call_tf_sessionrun(options_ptr, {}, fetch_list, 

1275 target_list, run_metadata_ptr) 

1276 if fetch_handler: 

1277 results = fetch_handler.build_results(self, results) 

1278 else: 

1279 results = results[0] if results else None 

1280 if run_metadata: 

1281 proto_data = tf_session.TF_GetBuffer(run_metadata_ptr) 

1282 run_metadata.ParseFromString(compat.as_bytes(proto_data)) 

1283 finally: 

1284 if run_metadata_ptr: 

1285 tf_session.TF_DeleteBuffer(run_metadata_ptr) 

1286 if options: 

1287 tf_session.TF_DeleteBuffer(options_ptr) 

1288 return results 

1289 

1290 if accept_options: 

1291 return functools.partial(_callable_template_with_options_and_metadata, 

1292 fetch_list, target_list, fetch_handler) 

1293 elif isinstance(fetches, ops.Operation): 

1294 # Special case for fetching a single operation, because the 

1295 # function will have no return value. 

1296 assert not fetch_list 

1297 assert len(target_list) == 1 

1298 

1299 def _single_operation_run(): 

1300 self._call_tf_sessionrun(None, {}, [], target_list, None) 

1301 

1302 return _single_operation_run 

1303 elif isinstance(fetches, ops.Tensor): 

1304 # Special case for fetching a single tensor, because the 

1305 # function can return the result of `TF_Run()` directly. 

1306 assert len(fetch_list) == 1 

1307 assert not target_list 

1308 

1309 def _single_tensor_run(): 

1310 results = self._call_tf_sessionrun(None, {}, fetch_list, [], None) 

1311 return results[0] 

1312 

1313 return _single_tensor_run 

1314 else: 

1315 # In all other cases, we must use `fetch_handler` to build the 

1316 # results for us. 

1317 def _fetch_handler_run(): 

1318 results = self._call_tf_sessionrun(None, {}, fetch_list, target_list, 

1319 None) 

1320 return fetch_handler.build_results(self, results) 

1321 

1322 return _fetch_handler_run 

1323 

1324 # Captures the name of a node in an error status. The regex below matches 

1325 # both the old and the new formats: 

1326 # Old format: [[Node: <node_name> = ...]] 

1327 # New format: [[{{node <node_name>}} = ...]] 

1328 _NODEDEF_NAME_RE = re.compile( 

1329 r'\[\[(Node: )?(\{\{node )?([^\} ]*)(\}\})?\s*=*') 

1330 

1331 def _do_run(self, handle, target_list, fetch_list, feed_dict, options, 

1332 run_metadata): 

1333 """Runs a step based on the given fetches and feeds. 

1334 

1335 Args: 

1336 handle: a handle for partial_run. None if this is just a call to run(). 

1337 target_list: A list of operations to be run, but not fetched. 

1338 fetch_list: A list of tensors to be fetched. 

1339 feed_dict: A dictionary that maps tensors to numpy ndarrays. 

1340 options: A (pointer to a) [`RunOptions`] protocol buffer, or None 

1341 run_metadata: A (pointer to a) [`RunMetadata`] protocol buffer, or None 

1342 

1343 Returns: 

1344 A list of numpy ndarrays, corresponding to the elements of 

1345 `fetch_list`. If the ith element of `fetch_list` contains the 

1346 name of an operation, the first Tensor output of that operation 

1347 will be returned for that element. 

1348 

1349 Raises: 

1350 tf.errors.OpError: Or one of its subclasses on error. 

1351 """ 

1352 # pylint: disable=protected-access 

1353 feeds = dict((t.deref()._as_tf_output(), v) for t, v in feed_dict.items()) 

1354 fetches = [t._as_tf_output() for t in fetch_list] 

1355 targets = [op._c_op for op in target_list] 

1356 

1357 # pylint: enable=protected-access 

1358 

1359 def _run_fn(feed_dict, fetch_list, target_list, options, run_metadata): 

1360 # Ensure any changes to the graph are reflected in the runtime. 

1361 self._extend_graph() 

1362 return self._call_tf_sessionrun(options, feed_dict, fetch_list, 

1363 target_list, run_metadata) 

1364 

1365 def _prun_fn(handle, feed_dict, fetch_list): 

1366 if target_list: 

1367 raise RuntimeError('partial_run() requires empty `target_list`. ' 

1368 f'Received: target_list={target_list} (non-empty)') 

1369 return self._call_tf_sessionprun(handle, feed_dict, fetch_list) 

1370 

1371 if handle is None: 

1372 return self._do_call(_run_fn, feeds, fetches, targets, options, 

1373 run_metadata) 

1374 else: 

1375 return self._do_call(_prun_fn, handle, feeds, fetches) 

1376 

1377 def _do_call(self, fn, *args): 

1378 try: 

1379 return fn(*args) 

1380 except errors.OpError as e: 

1381 message = compat.as_text(e.message) 

1382 m = BaseSession._NODEDEF_NAME_RE.search(message) 

1383 node_def = None 

1384 op = None 

1385 if m is not None: 

1386 node_name = m.group(3) 

1387 try: 

1388 op = self._graph.get_operation_by_name(node_name) 

1389 node_def = op.node_def 

1390 except KeyError: 

1391 pass 

1392 message = error_interpolation.interpolate(message, self._graph) 

1393 if 'only supports NHWC tensor format' in message: 

1394 message += ('\nA possible workaround: Try disabling Grappler optimizer' 

1395 '\nby modifying the config for creating the session eg.' 

1396 '\nsession_config.graph_options.rewrite_options.' 

1397 'disable_meta_optimizer = True') 

1398 raise type(e)(node_def, op, message) # pylint: disable=no-value-for-parameter 

1399 

1400 def _extend_graph(self): 

1401 with self._graph._session_run_lock(): # pylint: disable=protected-access 

1402 tf_session.ExtendSession(self._session) 

1403 

1404 # The threshold to run garbage collection to delete dead tensors. 

1405 _DEAD_HANDLES_THRESHOLD = 10 

1406 

1407 def _register_dead_handle(self, handle): 

1408 # Register a dead handle in the session. Delete the dead tensors when 

1409 # the number of dead tensors exceeds certain threshold. 

1410 tensors_to_delete = None 

1411 with self._delete_lock: 

1412 self._dead_handles.append(handle) 

1413 if len(self._dead_handles) == BaseSession._DEAD_HANDLES_THRESHOLD: 

1414 tensors_to_delete = self._dead_handles 

1415 self._dead_handles = [] 

1416 # Delete the dead tensors. 

1417 if tensors_to_delete: 

1418 feeds = {} 

1419 fetches = [] 

1420 for deleter_key, tensor_handle in enumerate(tensors_to_delete): 

1421 holder, deleter = session_ops._get_handle_deleter( 

1422 self.graph, deleter_key, tensor_handle) 

1423 feeds[holder] = tensor_handle 

1424 fetches.append(deleter) 

1425 self.run(fetches, feed_dict=feeds) 

1426 

1427 def _update_with_movers(self, feed_dict, feed_map): 

1428 # If a tensor handle that is fed to a device incompatible placeholder, 

1429 # we move the tensor to the right device, generate a new tensor handle, 

1430 # and update `feed_dict` to use the new handle. 

1431 handle_movers = [] 

1432 for feed_name, val in feed_map.items(): 

1433 mover = session_ops._get_handle_mover(self.graph, *val) 

1434 if mover: 

1435 handle_movers.append((feed_name, val[1], mover)) 

1436 # Transfer a tensor to the right device if needed. 

1437 if not handle_movers: 

1438 return [] 

1439 else: 

1440 feeds = {} 

1441 fetches = [] 

1442 for _, handle, mover in handle_movers: 

1443 feeds[mover[0]] = handle 

1444 fetches.append(mover[1]) 

1445 handles = self.run(fetches, feed_dict=feeds) 

1446 for handle_mover, handle in zip(handle_movers, handles): 

1447 np_val = np.array(handle.handle, dtype=np.object_) 

1448 feed_name = handle_mover[0] 

1449 feed_tensor = feed_map[feed_name][0] 

1450 feed_dict[feed_tensor.ref()] = np_val 

1451 return handles 

1452 

1453 def _call_tf_sessionrun(self, options, feed_dict, fetch_list, target_list, 

1454 run_metadata): 

1455 return tf_session.TF_SessionRun_wrapper(self._session, options, feed_dict, 

1456 fetch_list, target_list, 

1457 run_metadata) 

1458 

1459 def _call_tf_sessionprun(self, handle, feed_dict, fetch_list): 

1460 return tf_session.TF_SessionPRun_wrapper(self._session, handle, feed_dict, 

1461 fetch_list) 

1462 

1463 # pylint: disable=protected-access 

1464 class _Callable(object): 

1465 """Experimental wrapper for the C++ `Session::MakeCallable()` API.""" 

1466 

1467 def __init__(self, session, callable_options): 

1468 self._session = session 

1469 self._handle = None 

1470 options_ptr = tf_session.TF_NewBufferFromString( 

1471 compat.as_bytes(callable_options.SerializeToString())) 

1472 try: 

1473 self._handle = tf_session.TF_SessionMakeCallable( 

1474 session._session, options_ptr) 

1475 finally: 

1476 tf_session.TF_DeleteBuffer(options_ptr) 

1477 

1478 def __call__(self, *args, **kwargs): 

1479 run_metadata = kwargs.get('run_metadata', None) 

1480 try: 

1481 run_metadata_ptr = tf_session.TF_NewBuffer() if run_metadata else None 

1482 ret = tf_session.TF_SessionRunCallable(self._session._session, 

1483 self._handle, args, 

1484 run_metadata_ptr) 

1485 if run_metadata: 

1486 proto_data = tf_session.TF_GetBuffer(run_metadata_ptr) 

1487 run_metadata.ParseFromString(compat.as_bytes(proto_data)) 

1488 finally: 

1489 if run_metadata_ptr: 

1490 tf_session.TF_DeleteBuffer(run_metadata_ptr) 

1491 return ret 

1492 

1493 def __del__(self): 

1494 # NOTE(mrry): It is possible that `self._session.__del__()` could be 

1495 # called before this destructor, in which case `self._session._session` 

1496 # will be `None`. 

1497 if (self._handle is not None and self._session._session is not None and 

1498 not self._session._closed): 

1499 tf_session.TF_SessionReleaseCallable(self._session._session, 

1500 self._handle) 

1501 

1502 # pylint: enable=protected-access 

1503 

1504 def _make_callable_from_options(self, callable_options): 

1505 """Returns a handle to a "callable" with the given options. 

1506 

1507 Args: 

1508 callable_options: A `CallableOptions` protocol buffer message describing 

1509 the computation that will be performed by the callable. 

1510 

1511 Returns: 

1512 A handle to the new callable. 

1513 """ 

1514 self._extend_graph() 

1515 return BaseSession._Callable(self, callable_options) 

1516 

1517 

1518@tf_export(v1=['Session']) 

1519class Session(BaseSession): 

1520 """A class for running TensorFlow operations. 

1521 

1522 A `Session` object encapsulates the environment in which `Operation` 

1523 objects are executed, and `Tensor` objects are evaluated. For 

1524 example: 

1525 

1526 ```python 

1527 tf.compat.v1.disable_eager_execution() # need to disable eager in TF2.x 

1528 # Build a graph. 

1529 a = tf.constant(5.0) 

1530 b = tf.constant(6.0) 

1531 c = a * b 

1532 

1533 # Launch the graph in a session. 

1534 sess = tf.compat.v1.Session() 

1535 

1536 # Evaluate the tensor `c`. 

1537 print(sess.run(c)) # prints 30.0 

1538 ``` 

1539 

1540 A session may own resources, such as 

1541 `tf.Variable`, `tf.queue.QueueBase`, 

1542 and `tf.compat.v1.ReaderBase`. It is important to release 

1543 these resources when they are no longer required. To do this, either 

1544 invoke the `tf.Session.close` method on the session, or use 

1545 the session as a context manager. The following two examples are 

1546 equivalent: 

1547 

1548 ```python 

1549 # Using the `close()` method. 

1550 sess = tf.compat.v1.Session() 

1551 sess.run(...) 

1552 sess.close() 

1553 

1554 # Using the context manager. 

1555 with tf.compat.v1.Session() as sess: 

1556 sess.run(...) 

1557 ``` 

1558 

1559 The 

1560 [`ConfigProto`](https://www.tensorflow.org/code/tensorflow/core/protobuf/config.proto) 

1561 protocol buffer exposes various configuration options for a 

1562 session. For example, to create a session that uses soft constraints 

1563 for device placement, and log the resulting placement decisions, 

1564 create a session as follows: 

1565 

1566 ```python 

1567 # Launch the graph in a session that allows soft device placement and 

1568 # logs the placement decisions. 

1569 sess = tf.compat.v1.Session(config=tf.compat.v1.ConfigProto( 

1570 allow_soft_placement=True, 

1571 log_device_placement=True)) 

1572 ``` 

1573 

1574 @compatibility(TF2) 

1575 `Session` does not work with either eager execution or `tf.function`, and you 

1576 should not invoke it directly. To migrate code that uses sessions to TF2, 

1577 rewrite the code without it. See the 

1578 [migration 

1579 guide](https://www.tensorflow.org/guide/migrate#1_replace_v1sessionrun_calls) 

1580 on replacing `Session.run` calls. 

1581 @end_compatibility 

1582 """ 

1583 

1584 def __init__(self, target='', graph=None, config=None): 

1585 """Creates a new TensorFlow session. 

1586 

1587 If no `graph` argument is specified when constructing the session, 

1588 the default graph will be launched in the session. If you are 

1589 using more than one graph (created with `tf.Graph()`) in the same 

1590 process, you will have to use different sessions for each graph, 

1591 but each graph can be used in multiple sessions. In this case, it 

1592 is often clearer to pass the graph to be launched explicitly to 

1593 the session constructor. 

1594 

1595 Args: 

1596 target: (Optional.) The execution engine to connect to. Defaults to using 

1597 an in-process engine. See 

1598 [Distributed TensorFlow](https://tensorflow.org/deploy/distributed) for 

1599 more examples. 

1600 graph: (Optional.) The `Graph` to be launched (described above). 

1601 config: (Optional.) A 

1602 [`ConfigProto`](https://www.tensorflow.org/code/tensorflow/core/protobuf/config.proto) 

1603 protocol buffer with configuration options for the session. 

1604 """ 

1605 super(Session, self).__init__(target, graph, config=config) 

1606 # NOTE(mrry): Create these on first `__enter__` to avoid a reference cycle. 

1607 self._default_graph_context_manager = None 

1608 self._default_session_context_manager = None 

1609 

1610 def __enter__(self): 

1611 if self._default_graph_context_manager is None: 

1612 self._default_graph_context_manager = self.graph.as_default() 

1613 else: 

1614 raise RuntimeError('Session context managers are not re-entrant. ' 

1615 'Use `Session.as_default()` if you want to enter ' 

1616 'a session multiple times.') 

1617 if self._default_session_context_manager is None: 

1618 self._default_session_context_manager = self.as_default() 

1619 self._default_graph_context_manager.__enter__() 

1620 return self._default_session_context_manager.__enter__() 

1621 

1622 def __exit__(self, exec_type, exec_value, exec_tb): 

1623 if exec_type is errors.OpError: 

1624 logging.error('Session closing due to OpError: %s', (exec_value,)) 

1625 try: 

1626 self._default_session_context_manager.__exit__(exec_type, exec_value, 

1627 exec_tb) 

1628 except RuntimeError as error: 

1629 if error == exec_value: 

1630 # NOTE(skyewm): for some reason, in Python3, 

1631 # _default_session_context_manager.__exit__ will re-raise the "not 

1632 # re-entrant" exception raised in __enter__ above (note that if we're 

1633 # here, we're in the outer session context manager, since __exit__ is 

1634 # not called when __enter__ raises an exception). We still want to 

1635 # continue cleaning up this context manager before the exception is 

1636 # further propagated, so we ignore it here (note that it'll continue 

1637 # being propagated after this method completes). 

1638 pass 

1639 else: 

1640 raise 

1641 self._default_graph_context_manager.__exit__(exec_type, exec_value, exec_tb) 

1642 

1643 self._default_session_context_manager = None 

1644 self._default_graph_context_manager = None 

1645 

1646 # If we are closing due to an exception, set a time limit on our Close() to 

1647 # avoid blocking forever. 

1648 # TODO(b/120204635) remove this when deadlock is fixed. 

1649 if exec_type: 

1650 close_thread = threading.Thread( 

1651 name='SessionCloseThread', target=self.close) 

1652 close_thread.daemon = True 

1653 close_thread.start() 

1654 close_thread.join(30.0) 

1655 if close_thread.is_alive(): 

1656 logging.error( 

1657 'Session failed to close after 30 seconds. Continuing after this ' 

1658 'point may leave your program in an undefined state.') 

1659 else: 

1660 self.close() 

1661 

1662 @staticmethod 

1663 def reset(target, containers=None, config=None): 

1664 """Resets resource containers on `target`, and close all connected sessions. 

1665 

1666 A resource container is distributed across all workers in the 

1667 same cluster as `target`. When a resource container on `target` 

1668 is reset, resources associated with that container will be cleared. 

1669 In particular, all Variables in the container will become undefined: 

1670 they lose their values and shapes. 

1671 

1672 NOTE: 

1673 (i) reset() is currently only implemented for distributed sessions. 

1674 (ii) Any sessions on the master named by `target` will be closed. 

1675 

1676 If no resource containers are provided, all containers are reset. 

1677 

1678 Args: 

1679 target: The execution engine to connect to. 

1680 containers: A list of resource container name strings, or `None` if all of 

1681 all the containers are to be reset. 

1682 config: (Optional.) Protocol buffer with configuration options. 

1683 

1684 Raises: 

1685 tf.errors.OpError: Or one of its subclasses if an error occurs while 

1686 resetting containers. 

1687 """ 

1688 if target is not None: 

1689 target = compat.as_bytes(target) 

1690 if containers is not None: 

1691 containers = [compat.as_bytes(c) for c in containers] 

1692 else: 

1693 containers = [] 

1694 tf_session.TF_Reset(target, containers, config) 

1695 

1696 

1697@tf_export(v1=['InteractiveSession']) 

1698class InteractiveSession(BaseSession): 

1699 """A TensorFlow `Session` for use in interactive contexts, such as a shell. 

1700 

1701 The only difference with a regular `Session` is that an `InteractiveSession` 

1702 installs itself as the default session on construction. 

1703 The methods `tf.Tensor.eval` 

1704 and `tf.Operation.run` 

1705 will use that session to run ops. 

1706 

1707 This is convenient in interactive shells and [IPython 

1708 notebooks](http://ipython.org), as it avoids having to pass an explicit 

1709 `Session` object to run ops. 

1710 

1711 For example: 

1712 

1713 ```python 

1714 sess = tf.compat.v1.InteractiveSession() 

1715 a = tf.constant(5.0) 

1716 b = tf.constant(6.0) 

1717 c = a * b 

1718 # We can just use 'c.eval()' without passing 'sess' 

1719 print(c.eval()) 

1720 sess.close() 

1721 ``` 

1722 

1723 Note that a regular session installs itself as the default session when it 

1724 is created in a `with` statement. The common usage in non-interactive 

1725 programs is to follow that pattern: 

1726 

1727 ```python 

1728 a = tf.constant(5.0) 

1729 b = tf.constant(6.0) 

1730 c = a * b 

1731 with tf.compat.v1.Session(): 

1732 # We can also use 'c.eval()' here. 

1733 print(c.eval()) 

1734 ``` 

1735 """ 

1736 

1737 _count_lock = threading.Lock() 

1738 _active_session_count = 0 # GUARDED_BY(_count_lock) 

1739 

1740 def __init__(self, target='', graph=None, config=None): 

1741 """Creates a new interactive TensorFlow session. 

1742 

1743 If no `graph` argument is specified when constructing the session, 

1744 the default graph will be launched in the session. If you are 

1745 using more than one graph (created with `tf.Graph()`) in the same 

1746 process, you will have to use different sessions for each graph, 

1747 but each graph can be used in multiple sessions. In this case, it 

1748 is often clearer to pass the graph to be launched explicitly to 

1749 the session constructor. 

1750 

1751 Args: 

1752 target: (Optional.) The execution engine to connect to. Defaults to using 

1753 an in-process engine. 

1754 graph: (Optional.) The `Graph` to be launched (described above). 

1755 config: (Optional) `ConfigProto` proto used to configure the session. 

1756 """ 

1757 if not config: 

1758 # If config is not provided, choose some reasonable defaults for 

1759 # interactive use: 

1760 # 

1761 # - Grow GPU memory as needed at the cost of fragmentation. 

1762 gpu_options = config_pb2.GPUOptions(allow_growth=True) 

1763 config = config_pb2.ConfigProto(gpu_options=gpu_options) 

1764 # Interactive sessions always place pruned graphs. 

1765 config.graph_options.place_pruned_graph = True 

1766 

1767 super(InteractiveSession, self).__init__(target, graph, config) 

1768 with InteractiveSession._count_lock: 

1769 if InteractiveSession._active_session_count > 0: 

1770 warnings.warn('An interactive session is already active. This can ' 

1771 'cause out-of-memory errors in some cases. You must ' 

1772 'explicitly call `InteractiveSession.close()` to release ' 

1773 'resources held by the other session(s).') 

1774 InteractiveSession._active_session_count += 1 

1775 # NOTE(mrry): We do not use `Session._closed` here because it has unhelpful 

1776 # semantics (in particular, it is not set to true if `Session.close()` is 

1777 # called on a session that has not been "opened" by running a step) and we 

1778 # cannot change those semantics without breaking existing code. 

1779 self._explicitly_closed = False 

1780 

1781 self._default_session = self.as_default() 

1782 self._default_session.enforce_nesting = False 

1783 self._default_session.__enter__() 

1784 self._explicit_graph = graph 

1785 if self._explicit_graph is not None: 

1786 self._default_graph = graph.as_default() 

1787 self._default_graph.enforce_nesting = False 

1788 self._default_graph.__enter__() 

1789 

1790 def close(self): 

1791 """Closes an `InteractiveSession`.""" 

1792 super(InteractiveSession, self).close() 

1793 with InteractiveSession._count_lock: 

1794 if not self._explicitly_closed: 

1795 InteractiveSession._active_session_count -= 1 

1796 self._explicitly_closed = True 

1797 else: 

1798 return 

1799 if self._explicit_graph is not None: 

1800 self._default_graph.__exit__(None, None, None) 

1801 self._default_graph = None 

1802 self._default_session.__exit__(None, None, None) 

1803 self._default_session = None