Coverage for /pythoncovmergedfiles/medio/medio/usr/local/lib/python3.8/site-packages/tensorflow/python/distribute/experimental/rpc/rpc_ops.py: 32%

143 statements  

« prev     ^ index     » next       coverage.py v7.4.0, created at 2024-01-03 07:57 +0000

1# Copyright 2021 The TensorFlow Authors. All Rights Reserved. 

2# 

3# Licensed under the Apache License, Version 2.0 (the "License"); 

4# you may not use this file except in compliance with the License. 

5# You may obtain a copy of the License at 

6# 

7# http://www.apache.org/licenses/LICENSE-2.0 

8# 

9# Unless required by applicable law or agreed to in writing, software 

10# distributed under the License is distributed on an "AS IS" BASIS, 

11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 

12# See the License for the specific language governing permissions and 

13# limitations under the License. 

14# ============================================================================== 

15"""Module to expose RPC APIs in tensorflow.""" 

16 

17from typing import Optional, Sequence, Union 

18 

19import tensorflow.distribute.experimental.rpc.kernels.gen_rpc_ops as gen_rpc_ops 

20from tensorflow.distribute.experimental.rpc.proto import tf_rpc_service_pb2 as rpc_pb2 

21from tensorflow.python.data.util import structure 

22from tensorflow.python.eager import context 

23from tensorflow.python.eager import def_function 

24from tensorflow.python.eager import function as tf_function 

25from tensorflow.python.framework import constant_op 

26from tensorflow.python.framework import dtypes 

27from tensorflow.python.framework import errors 

28from tensorflow.python.framework import type_spec 

29from tensorflow.python.ops import math_ops 

30from tensorflow.python.ops import resource_variable_ops 

31from tensorflow.python.saved_model import nested_structure_coder 

32from tensorflow.python.types import core as core_tf_types 

33from tensorflow.python.util import nest 

34from tensorflow.python.util.tf_export import tf_export 

35 

36 

37def get_output_specs_from_function(func: tf_function.ConcreteFunction): 

38 output_specs = nest.map_structure(type_spec.type_spec_from_value, 

39 func.structured_outputs) 

40 output_specs_proto = nested_structure_coder.encode_structure(output_specs) 

41 return output_specs_proto.SerializeToString() 

42 

43 

44def get_input_specs_from_function(func: tf_function.ConcreteFunction): 

45 arg_specs, _ = func.structured_input_signature 

46 arg_specs_proto = nested_structure_coder.encode_structure(arg_specs) 

47 return arg_specs_proto.SerializeToString() 

48 

49 

50@tf_export("distribute.experimental.rpc.Server", v1=[]) 

51class Server(object): 

52 """A Server base class for accepting RPCs for registered tf.functions. 

53 

54 Functions can be registered on the server and are exposed via RPCs. 

55 """ 

56 

57 @staticmethod 

58 def create(rpc_layer, address): 

59 """Create TF RPC server at given address. 

60 

61 Args: 

62 rpc_layer: Communication layer between client and server. Only "grpc" rpc 

63 layer is supported at the moment. 

64 address: Address where RPC server is hosted. 

65 

66 Returns: 

67 An instance of `tf.distribute.experimental.rpc.Server` class. 

68 

69 Raises: 

70 A ValueError if rpc_layer other than "grpc" is used. Only GRPC 

71 is supported at the moment. 

72 

73 Example usage: 

74 

75 >>> import portpicker 

76 >>> @tf.function(input_signature=[ 

77 ... tf.TensorSpec([], tf.int32), 

78 ... tf.TensorSpec([], tf.int32)]) 

79 ... def remote_fn(a, b): 

80 ... return tf.add(a, b) 

81 

82 >>> port = portpicker.pick_unused_port() 

83 >>> address = "localhost:{}".format(port) 

84 >>> server = tf.distribute.experimental.rpc.Server.create("grpc", address) 

85 >>> server.register("addition", remote_fn) 

86 >>> server.start() 

87 

88 """ 

89 if rpc_layer != "grpc": 

90 raise ValueError("Only GRPC backend is supported at the moment.") 

91 return GrpcServer(address=address) 

92 

93 def register(self, method_name: str, 

94 func: Union[def_function.Function, 

95 tf_function.ConcreteFunction]): 

96 """Method for registering tf.function on server. 

97 

98 Registered methods can be invoked remotely from clients. 

99 

100 Args: 

101 method_name: Name of the tf.function. Clients use this method_name to make 

102 RPCs. 

103 func: A `tf.function` or ConcreteFunction to register. 

104 """ 

105 raise NotImplementedError("Please use create_server method to create a" 

106 "concrete subclass of Server.") 

107 

108 def start(self): 

109 """Starts the RPC server on provided address. 

110 

111 Server listens for new requests from client, once it is started. 

112 """ 

113 raise NotImplementedError("Please use create_server method to create a" 

114 "concrete subclass of Server.") 

115 

116 

117@tf_export("distribute.experimental.rpc.Client", v1=[]) 

118class Client(object): 

119 """Client class for invoking RPCs to the server.""" 

120 

121 @staticmethod 

122 def create(rpc_layer, address, name="", timeout_in_ms=0): 

123 """Create TF RPC client to connect to the given address. 

124 

125 Args: 

126 rpc_layer: Communication layer between client and server. Only "grpc" rpc 

127 layer is supported at the moment. 

128 address: Address of the server to connect the RPC client to. 

129 name: Name of the RPC Client. You can create multiple clients connecting 

130 to same server and distinguish them using different names. 

131 timeout_in_ms: The default timeout to use for outgoing RPCs from client. 0 

132 indicates no timeout. Exceeding timeout during RPC will raise 

133 DeadlineExceeded error. 

134 

135 Returns: 

136 An instance of `tf.distribute.experimental.rpc.Client` with the following 

137 dynamically added methods for eagerly created clients: 

138 * `Registered methods` e.g. multiply(**args): 

139 If Client is created when executing eagerly, client will request the 

140 list of registered methods from server during client creation. 

141 The convenience methods for RPCs will be dynamically added to the 

142 created Client instance. 

143 

144 For example, when a server has method "multiply" registered, the 

145 client object created in eager mode will have 'multiply' method 

146 available. Users can use client.multiply(..) to make RPC, instead of 

147 client.call("multiply", ...) 

148 

149 Both "call" and "multiply" methods are non-blocking i.e. they return 

150 a StatusOrResult object which should be used to wait for getting 

151 value or error. 

152 

153 Along with the above, blocking versions of the registered 

154 methods are also dynamically added to client instance. 

155 e.g. multiply_blocking(**args). These methods block till the RPC is 

156 finished and return response for successful RPC. Otherwise raise 

157 exception. 

158 

159 These methods are not available when Client is created inside a 

160 tf.function. 

161 

162 Raises: 

163 A ValueError if rpc_layer other than "grpc" is used. Only GRPC 

164 is supported at the moment. 

165 A DeadlineExceeded exception in eager mode if timeout exceeds while 

166 creating and listing client methods. 

167 

168 Example usage: 

169 >>> # Have server already started. 

170 >>> import portpicker 

171 >>> @tf.function(input_signature=[ 

172 ... tf.TensorSpec([], tf.int32), 

173 ... tf.TensorSpec([], tf.int32)]) 

174 ... def remote_fn(a, b): 

175 ... return tf.add(a, b) 

176 

177 >>> port = portpicker.pick_unused_port() 

178 >>> address = "localhost:{}".format(port) 

179 >>> server = tf.distribute.experimental.rpc.Server.create("grpc", address) 

180 >>> server.register("addition", remote_fn) 

181 >>> server.start() 

182 

183 >>> # Start client 

184 >>> client = tf.distribute.experimental.rpc.Client.create("grpc", 

185 ... address=address, name="test_client") 

186 

187 >>> a = tf.constant(2, dtype=tf.int32) 

188 >>> b = tf.constant(3, dtype=tf.int32) 

189 

190 >>> result = client.call( 

191 ... args=[a, b], 

192 ... method_name="addition", 

193 ... output_specs=tf.TensorSpec((), tf.int32)) 

194 

195 >>> if result.is_ok(): 

196 ... result.get_value() 

197 

198 >>> result = client.addition(a, b) 

199 

200 >>> if result.is_ok(): 

201 ... result.get_value() 

202 

203 >>> value = client.addition_blocking(a, b) 

204 """ 

205 if rpc_layer != "grpc": 

206 raise ValueError("Only GRPC backend is supported at the moment.") 

207 if context.executing_eagerly(): 

208 list_registered_methods = True 

209 else: 

210 list_registered_methods = False 

211 return GrpcClient( 

212 address=address, 

213 name=name, 

214 list_registered_methods=list_registered_methods, 

215 timeout_in_ms=timeout_in_ms) 

216 

217 def call(self, 

218 method_name: str, 

219 args: Optional[Sequence[core_tf_types.Tensor]] = None, 

220 output_specs=None, 

221 timeout_in_ms=0): 

222 """Method for making RPC calls to remote server. 

223 

224 This invokes RPC to the server, executing the registered method_name 

225 remotely. 

226 Args: 

227 method_name: Remote registered method to invoke 

228 args: List of arguments for the registered method. 

229 output_specs: Output specs for the output from method. 

230 For example, if tf.function is: @tf.function(input_signature=[ 

231 tf.TensorSpec([], tf.int32), tf.TensorSpec([], tf.int32) ]) 

232 def multiply_fn(a, b): return tf.math.multiply(a, b) 

233 output_spec is: tf.TensorSpec((), tf.int32) If you have access to TF 

234 Function, the output specs can be generated 

235 from tf.function by calling: output_specs = 

236 tf.nest.map_structure(tf.type_spec_from_value, 

237 tf_function.get_concrete_function().structured_outputs If output_specs 

238 are not provided, flattened list of tensors will be returned in 

239 response. 

240 timeout_in_ms: Timeout for this call. If 0, default client timeout will be 

241 used. 

242 

243 Returns: 

244 An instance of `StatusOrResult` class with the following available 

245 methods. 

246 * `is_ok()`: 

247 Returns True of RPC was successful. 

248 * `get_error()`: 

249 Returns TF error_code and error message for the RPC. 

250 * `get_value()`: 

251 Returns the returned value from remote TF function execution 

252 when RPC is successful. 

253 

254 Calling any of the above methods will block till RPC is completed and 

255 result is available. 

256 """ 

257 raise NotImplementedError("Must be implemented in inherited classes.") 

258 

259 

260class GrpcServer(Server): 

261 """GrpcServer object encapsulates a resource with GRPC server. 

262 

263 Functions can be registered locally and are exposed via RPCs. 

264 Example: 

265 ``` 

266 server = rpc_ops.GrpcServer("host:port") 

267 @tf.function 

268 def add(a, b): 

269 return a + b 

270 

271 server.register("add", add) 

272 server.start() 

273 ``` 

274 """ 

275 

276 def __init__(self, address: str): 

277 self._server_handle = gen_rpc_ops.rpc_server(address) 

278 if context.executing_eagerly(): 

279 self._handle_deleter = resource_variable_ops.EagerResourceDeleter( 

280 handle=self._server_handle, handle_device=self._server_handle.device) 

281 else: 

282 raise NotImplementedError("Please create the server outside tf.function.") 

283 

284 def register(self, method_name: str, 

285 func: Union[def_function.Function, 

286 tf_function.ConcreteFunction]): 

287 """Method for registering functions.""" 

288 

289 if isinstance(func, def_function.Function): 

290 if func._function_spec.arg_names: # pylint: disable=protected-access 

291 if func.input_signature is None: 

292 raise ValueError("Input signature not specified for the function.") 

293 concrete_fn = func.get_concrete_function() 

294 gen_rpc_ops.rpc_server_register( 

295 self._server_handle, 

296 method_name=method_name, 

297 captured_inputs=concrete_fn.captured_inputs, 

298 input_specs=get_input_specs_from_function(concrete_fn), 

299 output_specs=get_output_specs_from_function(concrete_fn), 

300 f=concrete_fn) 

301 elif isinstance(func, tf_function.ConcreteFunction): 

302 gen_rpc_ops.rpc_server_register( 

303 self._server_handle, 

304 method_name=method_name, 

305 captured_inputs=func.captured_inputs, 

306 input_specs=get_input_specs_from_function(func), 

307 output_specs=get_output_specs_from_function(func), 

308 f=func) 

309 else: 

310 # Python functions 

311 # TODO(b/186762191): Add an implementation to support python functions. 

312 raise ValueError("Only TF functions are supported with Register method") 

313 

314 def start(self): 

315 """Starts GRPC server.""" 

316 gen_rpc_ops.rpc_server_start(self._server_handle) 

317 

318 

319class GrpcClient(Client): 

320 """Client wrapper to connect to remote RPC server using GRPC. 

321 

322 If Client is created with (list_registered_methods=True): 

323 1. Input and output specs for the methods till this point will be fetched from 

324 Server. 

325 2. convenience methods are added to invoke registered methods directly from 

326 client. 

327 For example: 

328 For call a server method `add` 

329 client.add(a, b) or client.add_async(a, b) can be used instead of 

330 client.call(args=[a,b], output_specs=[..]) 

331 

332 Prerequiste for using list_registered_methods=True: 

333 1. Server should be already started with the registered methods. 

334 2. Client must be created in Eager mode. 

335 """ 

336 

337 def __init__(self, 

338 address: str, 

339 name: str = "", 

340 list_registered_methods=False, 

341 timeout_in_ms=0): 

342 self._client_handle, methods = gen_rpc_ops.rpc_client( 

343 shared_name=name, 

344 server_address=address, 

345 list_registered_methods=list_registered_methods, 

346 timeout_in_ms=timeout_in_ms) 

347 if context.executing_eagerly(): 

348 self._handle_deleter = resource_variable_ops.EagerResourceDeleter( 

349 handle=self._client_handle, handle_device=self._client_handle.device) 

350 else: 

351 raise NotImplementedError( 

352 "Client creation is supported only in eager mode.") 

353 self._server_address = address 

354 self._method_registry = {} 

355 for method in methods.numpy(): 

356 m = rpc_pb2.RegisteredMethod() 

357 m.ParseFromString(method) 

358 output_specs = nested_structure_coder.decode_proto(m.output_specs) 

359 input_specs = nested_structure_coder.decode_proto(m.input_specs) 

360 self._method_registry[m.method] = output_specs 

361 # TODO(ishark): Perhaps doc string can also be taken as input during 

362 # function registration. 

363 doc_string = "RPC Call for " + m.method + " method to server " + address 

364 self._add_method(m.method, output_specs, input_specs, self._client_handle, 

365 doc_string) 

366 

367 def _add_method(self, method_name, output_specs, input_specs, client_handle, 

368 doc_string): 

369 """Method to add RPC methods to the client object.""" 

370 

371 def validate_and_get_flat_inputs(*args): 

372 if args is None: 

373 args = [] 

374 if input_specs: 

375 nest.assert_same_structure(args, input_specs) 

376 flat_inputs = nest.flatten(args) 

377 return flat_inputs 

378 

379 def call_wrapper(*args, timeout_in_ms=0): 

380 status_or, deleter = gen_rpc_ops.rpc_call( 

381 client_handle, 

382 args=validate_and_get_flat_inputs(*args), 

383 method_name=method_name, 

384 timeout_in_ms=timeout_in_ms) 

385 return StatusOrResult(status_or, deleter, output_specs) 

386 

387 def call_blocking_wrapper(*args, timeout_in_ms=0): 

388 status_or, deleter = gen_rpc_ops.rpc_call( 

389 client_handle, 

390 args=validate_and_get_flat_inputs(*args), 

391 method_name=method_name, 

392 timeout_in_ms=timeout_in_ms) 

393 status_or = StatusOrResult(status_or, deleter, output_specs) 

394 if status_or.is_ok(): 

395 return status_or.get_value() 

396 else: 

397 error_code, error_msg = status_or.get_error() 

398 raise errors.exception_type_from_error_code(error_code.numpy())( 

399 None, None, error_msg.numpy()) 

400 

401 setattr(self, method_name, call_wrapper) 

402 call_wrapper.__doc__ = doc_string 

403 

404 blocking_method_name = method_name + "_blocking" 

405 setattr(self, blocking_method_name, call_blocking_wrapper) 

406 call_blocking_wrapper.__doc__ = doc_string 

407 

408 def call(self, 

409 method_name: str, 

410 args: Optional[Sequence[core_tf_types.Tensor]] = None, 

411 output_specs=None, 

412 timeout_in_ms=0): 

413 """Method to invoke remote registered functions on the connected server. 

414 

415 Server should be started before making an RPC Call. 

416 

417 Args: 

418 method_name: Registered method to invoke on Server. 

419 args: Input arguments for the method. 

420 output_specs: Output specs for the output from method. 

421 timeout_in_ms: Timeout for this call. If 0, default client timeout will be 

422 used. 

423 

424 Returns: 

425 StatusOrResult object. This function issues the RPC call to server, it 

426 does not block for the duration of RPC. Please call is_ok, get_error or 

427 get_value methods on the returned object to blocked till RPC finishes. 

428 """ 

429 if args is None: 

430 args = [] 

431 status_or, deleter = gen_rpc_ops.rpc_call( 

432 self._client_handle, 

433 args=nest.flatten(args), 

434 method_name=method_name, 

435 timeout_in_ms=timeout_in_ms) 

436 return StatusOrResult(status_or, deleter, output_specs) 

437 

438 

439class StatusOrResult(object): 

440 """Class representing result and status from RPC Call.""" 

441 

442 def __init__(self, status_or, deleter, output_specs=None): 

443 self._status_or = status_or 

444 self._output_specs = output_specs 

445 self._deleter = deleter 

446 self._error_code, self._error_message = None, None 

447 

448 def _check_status(self): 

449 if self._error_code is None: 

450 self._error_code, self._error_message = gen_rpc_ops.rpc_check_status( 

451 self._status_or) 

452 

453 def __del__(self): 

454 # Make sure the resource is deleted in the same mode as it was created in. 

455 if context.executing_eagerly(): 

456 with context.eager_mode(): 

457 gen_rpc_ops.delete_rpc_future_resource( 

458 handle=self._status_or, deleter=self._deleter) 

459 else: 

460 with context.graph_mode(): 

461 gen_rpc_ops.delete_rpc_future_resource( 

462 handle=self._status_or, deleter=self._deleter) 

463 

464 def is_ok(self): 

465 """Returns True if RPC is successful, otherwise returns False. 

466 

467 This call will block for RPC result. 

468 """ 

469 self._check_status() 

470 return math_ops.equal(self._error_code, 

471 constant_op.constant(0, dtype=dtypes.int64)) 

472 

473 def get_error(self): 

474 """Returns (TF Error Code, Error Message) from RPC Response. 

475 

476 This call will block for RPC result. 

477 """ 

478 self._check_status() 

479 return self._error_code, self._error_message 

480 

481 def get_value(self): 

482 """Returns the returned response value from RPC Call when RPC is successful. 

483 

484 The returned value is tensors in the output_specs format as returned from 

485 the RPC call 

486 

487 

488 This call will block for RPC result. 

489 """ 

490 

491 self._check_status() 

492 if self._output_specs is None or isinstance(self._output_specs, 

493 structure.NoneTensorSpec): 

494 flat_output_dtypes = [] 

495 return_none = True 

496 else: 

497 return_none = False 

498 flat_output_dtypes = [s.dtype for s in nest.flatten(self._output_specs)] 

499 

500 result = gen_rpc_ops.rpc_get_value(self._status_or, Tout=flat_output_dtypes) 

501 if return_none: 

502 return None 

503 else: 

504 return nest.pack_sequence_as(self._output_specs, result)