Coverage for /pythoncovmergedfiles/medio/medio/usr/local/lib/python3.8/site-packages/tensorflow/python/distribute/experimental/rpc/rpc_ops.py: 32%
143 statements
« prev ^ index » next coverage.py v7.4.0, created at 2024-01-03 07:57 +0000
« prev ^ index » next coverage.py v7.4.0, created at 2024-01-03 07:57 +0000
1# Copyright 2021 The TensorFlow Authors. All Rights Reserved.
2#
3# Licensed under the Apache License, Version 2.0 (the "License");
4# you may not use this file except in compliance with the License.
5# You may obtain a copy of the License at
6#
7# http://www.apache.org/licenses/LICENSE-2.0
8#
9# Unless required by applicable law or agreed to in writing, software
10# distributed under the License is distributed on an "AS IS" BASIS,
11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12# See the License for the specific language governing permissions and
13# limitations under the License.
14# ==============================================================================
15"""Module to expose RPC APIs in tensorflow."""
17from typing import Optional, Sequence, Union
19import tensorflow.distribute.experimental.rpc.kernels.gen_rpc_ops as gen_rpc_ops
20from tensorflow.distribute.experimental.rpc.proto import tf_rpc_service_pb2 as rpc_pb2
21from tensorflow.python.data.util import structure
22from tensorflow.python.eager import context
23from tensorflow.python.eager import def_function
24from tensorflow.python.eager import function as tf_function
25from tensorflow.python.framework import constant_op
26from tensorflow.python.framework import dtypes
27from tensorflow.python.framework import errors
28from tensorflow.python.framework import type_spec
29from tensorflow.python.ops import math_ops
30from tensorflow.python.ops import resource_variable_ops
31from tensorflow.python.saved_model import nested_structure_coder
32from tensorflow.python.types import core as core_tf_types
33from tensorflow.python.util import nest
34from tensorflow.python.util.tf_export import tf_export
37def get_output_specs_from_function(func: tf_function.ConcreteFunction):
38 output_specs = nest.map_structure(type_spec.type_spec_from_value,
39 func.structured_outputs)
40 output_specs_proto = nested_structure_coder.encode_structure(output_specs)
41 return output_specs_proto.SerializeToString()
44def get_input_specs_from_function(func: tf_function.ConcreteFunction):
45 arg_specs, _ = func.structured_input_signature
46 arg_specs_proto = nested_structure_coder.encode_structure(arg_specs)
47 return arg_specs_proto.SerializeToString()
50@tf_export("distribute.experimental.rpc.Server", v1=[])
51class Server(object):
52 """A Server base class for accepting RPCs for registered tf.functions.
54 Functions can be registered on the server and are exposed via RPCs.
55 """
57 @staticmethod
58 def create(rpc_layer, address):
59 """Create TF RPC server at given address.
61 Args:
62 rpc_layer: Communication layer between client and server. Only "grpc" rpc
63 layer is supported at the moment.
64 address: Address where RPC server is hosted.
66 Returns:
67 An instance of `tf.distribute.experimental.rpc.Server` class.
69 Raises:
70 A ValueError if rpc_layer other than "grpc" is used. Only GRPC
71 is supported at the moment.
73 Example usage:
75 >>> import portpicker
76 >>> @tf.function(input_signature=[
77 ... tf.TensorSpec([], tf.int32),
78 ... tf.TensorSpec([], tf.int32)])
79 ... def remote_fn(a, b):
80 ... return tf.add(a, b)
82 >>> port = portpicker.pick_unused_port()
83 >>> address = "localhost:{}".format(port)
84 >>> server = tf.distribute.experimental.rpc.Server.create("grpc", address)
85 >>> server.register("addition", remote_fn)
86 >>> server.start()
88 """
89 if rpc_layer != "grpc":
90 raise ValueError("Only GRPC backend is supported at the moment.")
91 return GrpcServer(address=address)
93 def register(self, method_name: str,
94 func: Union[def_function.Function,
95 tf_function.ConcreteFunction]):
96 """Method for registering tf.function on server.
98 Registered methods can be invoked remotely from clients.
100 Args:
101 method_name: Name of the tf.function. Clients use this method_name to make
102 RPCs.
103 func: A `tf.function` or ConcreteFunction to register.
104 """
105 raise NotImplementedError("Please use create_server method to create a"
106 "concrete subclass of Server.")
108 def start(self):
109 """Starts the RPC server on provided address.
111 Server listens for new requests from client, once it is started.
112 """
113 raise NotImplementedError("Please use create_server method to create a"
114 "concrete subclass of Server.")
117@tf_export("distribute.experimental.rpc.Client", v1=[])
118class Client(object):
119 """Client class for invoking RPCs to the server."""
121 @staticmethod
122 def create(rpc_layer, address, name="", timeout_in_ms=0):
123 """Create TF RPC client to connect to the given address.
125 Args:
126 rpc_layer: Communication layer between client and server. Only "grpc" rpc
127 layer is supported at the moment.
128 address: Address of the server to connect the RPC client to.
129 name: Name of the RPC Client. You can create multiple clients connecting
130 to same server and distinguish them using different names.
131 timeout_in_ms: The default timeout to use for outgoing RPCs from client. 0
132 indicates no timeout. Exceeding timeout during RPC will raise
133 DeadlineExceeded error.
135 Returns:
136 An instance of `tf.distribute.experimental.rpc.Client` with the following
137 dynamically added methods for eagerly created clients:
138 * `Registered methods` e.g. multiply(**args):
139 If Client is created when executing eagerly, client will request the
140 list of registered methods from server during client creation.
141 The convenience methods for RPCs will be dynamically added to the
142 created Client instance.
144 For example, when a server has method "multiply" registered, the
145 client object created in eager mode will have 'multiply' method
146 available. Users can use client.multiply(..) to make RPC, instead of
147 client.call("multiply", ...)
149 Both "call" and "multiply" methods are non-blocking i.e. they return
150 a StatusOrResult object which should be used to wait for getting
151 value or error.
153 Along with the above, blocking versions of the registered
154 methods are also dynamically added to client instance.
155 e.g. multiply_blocking(**args). These methods block till the RPC is
156 finished and return response for successful RPC. Otherwise raise
157 exception.
159 These methods are not available when Client is created inside a
160 tf.function.
162 Raises:
163 A ValueError if rpc_layer other than "grpc" is used. Only GRPC
164 is supported at the moment.
165 A DeadlineExceeded exception in eager mode if timeout exceeds while
166 creating and listing client methods.
168 Example usage:
169 >>> # Have server already started.
170 >>> import portpicker
171 >>> @tf.function(input_signature=[
172 ... tf.TensorSpec([], tf.int32),
173 ... tf.TensorSpec([], tf.int32)])
174 ... def remote_fn(a, b):
175 ... return tf.add(a, b)
177 >>> port = portpicker.pick_unused_port()
178 >>> address = "localhost:{}".format(port)
179 >>> server = tf.distribute.experimental.rpc.Server.create("grpc", address)
180 >>> server.register("addition", remote_fn)
181 >>> server.start()
183 >>> # Start client
184 >>> client = tf.distribute.experimental.rpc.Client.create("grpc",
185 ... address=address, name="test_client")
187 >>> a = tf.constant(2, dtype=tf.int32)
188 >>> b = tf.constant(3, dtype=tf.int32)
190 >>> result = client.call(
191 ... args=[a, b],
192 ... method_name="addition",
193 ... output_specs=tf.TensorSpec((), tf.int32))
195 >>> if result.is_ok():
196 ... result.get_value()
198 >>> result = client.addition(a, b)
200 >>> if result.is_ok():
201 ... result.get_value()
203 >>> value = client.addition_blocking(a, b)
204 """
205 if rpc_layer != "grpc":
206 raise ValueError("Only GRPC backend is supported at the moment.")
207 if context.executing_eagerly():
208 list_registered_methods = True
209 else:
210 list_registered_methods = False
211 return GrpcClient(
212 address=address,
213 name=name,
214 list_registered_methods=list_registered_methods,
215 timeout_in_ms=timeout_in_ms)
217 def call(self,
218 method_name: str,
219 args: Optional[Sequence[core_tf_types.Tensor]] = None,
220 output_specs=None,
221 timeout_in_ms=0):
222 """Method for making RPC calls to remote server.
224 This invokes RPC to the server, executing the registered method_name
225 remotely.
226 Args:
227 method_name: Remote registered method to invoke
228 args: List of arguments for the registered method.
229 output_specs: Output specs for the output from method.
230 For example, if tf.function is: @tf.function(input_signature=[
231 tf.TensorSpec([], tf.int32), tf.TensorSpec([], tf.int32) ])
232 def multiply_fn(a, b): return tf.math.multiply(a, b)
233 output_spec is: tf.TensorSpec((), tf.int32) If you have access to TF
234 Function, the output specs can be generated
235 from tf.function by calling: output_specs =
236 tf.nest.map_structure(tf.type_spec_from_value,
237 tf_function.get_concrete_function().structured_outputs If output_specs
238 are not provided, flattened list of tensors will be returned in
239 response.
240 timeout_in_ms: Timeout for this call. If 0, default client timeout will be
241 used.
243 Returns:
244 An instance of `StatusOrResult` class with the following available
245 methods.
246 * `is_ok()`:
247 Returns True of RPC was successful.
248 * `get_error()`:
249 Returns TF error_code and error message for the RPC.
250 * `get_value()`:
251 Returns the returned value from remote TF function execution
252 when RPC is successful.
254 Calling any of the above methods will block till RPC is completed and
255 result is available.
256 """
257 raise NotImplementedError("Must be implemented in inherited classes.")
260class GrpcServer(Server):
261 """GrpcServer object encapsulates a resource with GRPC server.
263 Functions can be registered locally and are exposed via RPCs.
264 Example:
265 ```
266 server = rpc_ops.GrpcServer("host:port")
267 @tf.function
268 def add(a, b):
269 return a + b
271 server.register("add", add)
272 server.start()
273 ```
274 """
276 def __init__(self, address: str):
277 self._server_handle = gen_rpc_ops.rpc_server(address)
278 if context.executing_eagerly():
279 self._handle_deleter = resource_variable_ops.EagerResourceDeleter(
280 handle=self._server_handle, handle_device=self._server_handle.device)
281 else:
282 raise NotImplementedError("Please create the server outside tf.function.")
284 def register(self, method_name: str,
285 func: Union[def_function.Function,
286 tf_function.ConcreteFunction]):
287 """Method for registering functions."""
289 if isinstance(func, def_function.Function):
290 if func._function_spec.arg_names: # pylint: disable=protected-access
291 if func.input_signature is None:
292 raise ValueError("Input signature not specified for the function.")
293 concrete_fn = func.get_concrete_function()
294 gen_rpc_ops.rpc_server_register(
295 self._server_handle,
296 method_name=method_name,
297 captured_inputs=concrete_fn.captured_inputs,
298 input_specs=get_input_specs_from_function(concrete_fn),
299 output_specs=get_output_specs_from_function(concrete_fn),
300 f=concrete_fn)
301 elif isinstance(func, tf_function.ConcreteFunction):
302 gen_rpc_ops.rpc_server_register(
303 self._server_handle,
304 method_name=method_name,
305 captured_inputs=func.captured_inputs,
306 input_specs=get_input_specs_from_function(func),
307 output_specs=get_output_specs_from_function(func),
308 f=func)
309 else:
310 # Python functions
311 # TODO(b/186762191): Add an implementation to support python functions.
312 raise ValueError("Only TF functions are supported with Register method")
314 def start(self):
315 """Starts GRPC server."""
316 gen_rpc_ops.rpc_server_start(self._server_handle)
319class GrpcClient(Client):
320 """Client wrapper to connect to remote RPC server using GRPC.
322 If Client is created with (list_registered_methods=True):
323 1. Input and output specs for the methods till this point will be fetched from
324 Server.
325 2. convenience methods are added to invoke registered methods directly from
326 client.
327 For example:
328 For call a server method `add`
329 client.add(a, b) or client.add_async(a, b) can be used instead of
330 client.call(args=[a,b], output_specs=[..])
332 Prerequiste for using list_registered_methods=True:
333 1. Server should be already started with the registered methods.
334 2. Client must be created in Eager mode.
335 """
337 def __init__(self,
338 address: str,
339 name: str = "",
340 list_registered_methods=False,
341 timeout_in_ms=0):
342 self._client_handle, methods = gen_rpc_ops.rpc_client(
343 shared_name=name,
344 server_address=address,
345 list_registered_methods=list_registered_methods,
346 timeout_in_ms=timeout_in_ms)
347 if context.executing_eagerly():
348 self._handle_deleter = resource_variable_ops.EagerResourceDeleter(
349 handle=self._client_handle, handle_device=self._client_handle.device)
350 else:
351 raise NotImplementedError(
352 "Client creation is supported only in eager mode.")
353 self._server_address = address
354 self._method_registry = {}
355 for method in methods.numpy():
356 m = rpc_pb2.RegisteredMethod()
357 m.ParseFromString(method)
358 output_specs = nested_structure_coder.decode_proto(m.output_specs)
359 input_specs = nested_structure_coder.decode_proto(m.input_specs)
360 self._method_registry[m.method] = output_specs
361 # TODO(ishark): Perhaps doc string can also be taken as input during
362 # function registration.
363 doc_string = "RPC Call for " + m.method + " method to server " + address
364 self._add_method(m.method, output_specs, input_specs, self._client_handle,
365 doc_string)
367 def _add_method(self, method_name, output_specs, input_specs, client_handle,
368 doc_string):
369 """Method to add RPC methods to the client object."""
371 def validate_and_get_flat_inputs(*args):
372 if args is None:
373 args = []
374 if input_specs:
375 nest.assert_same_structure(args, input_specs)
376 flat_inputs = nest.flatten(args)
377 return flat_inputs
379 def call_wrapper(*args, timeout_in_ms=0):
380 status_or, deleter = gen_rpc_ops.rpc_call(
381 client_handle,
382 args=validate_and_get_flat_inputs(*args),
383 method_name=method_name,
384 timeout_in_ms=timeout_in_ms)
385 return StatusOrResult(status_or, deleter, output_specs)
387 def call_blocking_wrapper(*args, timeout_in_ms=0):
388 status_or, deleter = gen_rpc_ops.rpc_call(
389 client_handle,
390 args=validate_and_get_flat_inputs(*args),
391 method_name=method_name,
392 timeout_in_ms=timeout_in_ms)
393 status_or = StatusOrResult(status_or, deleter, output_specs)
394 if status_or.is_ok():
395 return status_or.get_value()
396 else:
397 error_code, error_msg = status_or.get_error()
398 raise errors.exception_type_from_error_code(error_code.numpy())(
399 None, None, error_msg.numpy())
401 setattr(self, method_name, call_wrapper)
402 call_wrapper.__doc__ = doc_string
404 blocking_method_name = method_name + "_blocking"
405 setattr(self, blocking_method_name, call_blocking_wrapper)
406 call_blocking_wrapper.__doc__ = doc_string
408 def call(self,
409 method_name: str,
410 args: Optional[Sequence[core_tf_types.Tensor]] = None,
411 output_specs=None,
412 timeout_in_ms=0):
413 """Method to invoke remote registered functions on the connected server.
415 Server should be started before making an RPC Call.
417 Args:
418 method_name: Registered method to invoke on Server.
419 args: Input arguments for the method.
420 output_specs: Output specs for the output from method.
421 timeout_in_ms: Timeout for this call. If 0, default client timeout will be
422 used.
424 Returns:
425 StatusOrResult object. This function issues the RPC call to server, it
426 does not block for the duration of RPC. Please call is_ok, get_error or
427 get_value methods on the returned object to blocked till RPC finishes.
428 """
429 if args is None:
430 args = []
431 status_or, deleter = gen_rpc_ops.rpc_call(
432 self._client_handle,
433 args=nest.flatten(args),
434 method_name=method_name,
435 timeout_in_ms=timeout_in_ms)
436 return StatusOrResult(status_or, deleter, output_specs)
439class StatusOrResult(object):
440 """Class representing result and status from RPC Call."""
442 def __init__(self, status_or, deleter, output_specs=None):
443 self._status_or = status_or
444 self._output_specs = output_specs
445 self._deleter = deleter
446 self._error_code, self._error_message = None, None
448 def _check_status(self):
449 if self._error_code is None:
450 self._error_code, self._error_message = gen_rpc_ops.rpc_check_status(
451 self._status_or)
453 def __del__(self):
454 # Make sure the resource is deleted in the same mode as it was created in.
455 if context.executing_eagerly():
456 with context.eager_mode():
457 gen_rpc_ops.delete_rpc_future_resource(
458 handle=self._status_or, deleter=self._deleter)
459 else:
460 with context.graph_mode():
461 gen_rpc_ops.delete_rpc_future_resource(
462 handle=self._status_or, deleter=self._deleter)
464 def is_ok(self):
465 """Returns True if RPC is successful, otherwise returns False.
467 This call will block for RPC result.
468 """
469 self._check_status()
470 return math_ops.equal(self._error_code,
471 constant_op.constant(0, dtype=dtypes.int64))
473 def get_error(self):
474 """Returns (TF Error Code, Error Message) from RPC Response.
476 This call will block for RPC result.
477 """
478 self._check_status()
479 return self._error_code, self._error_message
481 def get_value(self):
482 """Returns the returned response value from RPC Call when RPC is successful.
484 The returned value is tensors in the output_specs format as returned from
485 the RPC call
488 This call will block for RPC result.
489 """
491 self._check_status()
492 if self._output_specs is None or isinstance(self._output_specs,
493 structure.NoneTensorSpec):
494 flat_output_dtypes = []
495 return_none = True
496 else:
497 return_none = False
498 flat_output_dtypes = [s.dtype for s in nest.flatten(self._output_specs)]
500 result = gen_rpc_ops.rpc_get_value(self._status_or, Tout=flat_output_dtypes)
501 if return_none:
502 return None
503 else:
504 return nest.pack_sequence_as(self._output_specs, result)