Coverage for /pythoncovmergedfiles/medio/medio/usr/local/lib/python3.8/site-packages/tensorflow_addons/text/crf_wrapper.py: 22%

64 statements  

« prev     ^ index     » next       coverage.py v7.4.0, created at 2024-01-03 07:57 +0000

1import tensorflow as tf 

2 

3from tensorflow_addons.text import crf_log_likelihood 

4from tensorflow_addons.utils import types 

5 

6 

7@tf.keras.utils.register_keras_serializable(package="Addons") 

8class CRFModelWrapper(tf.keras.Model): 

9 def __init__( 

10 self, 

11 base_model: tf.keras.Model, 

12 units: int, 

13 chain_initializer: types.Initializer = "orthogonal", 

14 use_boundary: bool = True, 

15 boundary_initializer: types.Initializer = "zeros", 

16 use_kernel: bool = True, 

17 **kwargs, 

18 ): 

19 super().__init__() 

20 

21 # lazy import to solve circle import issue: 

22 # tfa.layers.CRF -> tfa.text.__init__ -> tfa.text.crf_wrapper -> tfa.layers.CRF 

23 from tensorflow_addons.layers.crf import CRF # noqa 

24 

25 self.crf_layer = CRF( 

26 units=units, 

27 chain_initializer=chain_initializer, 

28 use_boundary=use_boundary, 

29 boundary_initializer=boundary_initializer, 

30 use_kernel=use_kernel, 

31 **kwargs, 

32 ) 

33 

34 self.base_model = base_model 

35 

36 def unpack_training_data(self, data): 

37 # override me, if this is not suit for your task 

38 if len(data) == 3: 

39 x, y, sample_weight = data 

40 else: 

41 x, y = data 

42 sample_weight = None 

43 return x, y, sample_weight 

44 

45 def call(self, inputs, training=None, mask=None, return_crf_internal=False): 

46 base_model_outputs = self.base_model(inputs, training, mask) 

47 

48 # change next line, if your model has more outputs 

49 crf_input = base_model_outputs 

50 

51 decode_sequence, potentials, sequence_length, kernel = self.crf_layer(crf_input) 

52 

53 # change next line, if your base model has more outputs 

54 # Aways keep `(potentials, sequence_length, kernel), decode_sequence, ` 

55 # as first two outputs of model. 

56 # current `self.train_step()` expected such settings 

57 outputs = (potentials, sequence_length, kernel), decode_sequence 

58 

59 if return_crf_internal: 

60 return outputs 

61 else: 

62 # outputs[0] is the crf internal, skip it 

63 output_without_crf_internal = outputs[1:] 

64 

65 # it is nicer to return a tensor instead of an one tensor list 

66 if len(output_without_crf_internal) == 1: 

67 return output_without_crf_internal[0] 

68 else: 

69 return output_without_crf_internal 

70 

71 def compute_crf_loss( 

72 self, potentials, sequence_length, kernel, y, sample_weight=None 

73 ): 

74 crf_likelihood, _ = crf_log_likelihood(potentials, y, sequence_length, kernel) 

75 # convert likelihood to loss 

76 flat_crf_loss = -1 * crf_likelihood 

77 if sample_weight is not None: 

78 flat_crf_loss = flat_crf_loss * sample_weight 

79 crf_loss = tf.reduce_mean(flat_crf_loss) 

80 

81 return crf_loss 

82 

83 def train_step(self, data): 

84 x, y, sample_weight = self.unpack_training_data(data) 

85 with tf.GradientTape() as tape: 

86 (potentials, sequence_length, kernel), decoded_sequence, *_ = self( 

87 x, training=True, return_crf_internal=True 

88 ) 

89 crf_loss = self.compute_crf_loss( 

90 potentials, sequence_length, kernel, y, sample_weight 

91 ) 

92 loss = crf_loss + tf.reduce_sum(self.losses) 

93 gradients = tape.gradient(loss, self.trainable_variables) 

94 self.optimizer.apply_gradients(zip(gradients, self.trainable_variables)) 

95 

96 # Update metrics (includes the metric that tracks the loss) 

97 self.compiled_metrics.update_state(y, decoded_sequence) 

98 # Return a dict mapping metric names to current value 

99 orig_results = {m.name: m.result() for m in self.metrics} 

100 crf_results = {"loss": loss, "crf_loss": crf_loss} 

101 return {**orig_results, **crf_results} 

102 

103 def test_step(self, data): 

104 x, y, sample_weight = self.unpack_training_data(data) 

105 (potentials, sequence_length, kernel), decode_sequence, *_ = self( 

106 x, training=False, return_crf_internal=True 

107 ) 

108 crf_loss = self.compute_crf_loss( 

109 potentials, sequence_length, kernel, y, sample_weight 

110 ) 

111 loss = crf_loss + tf.reduce_sum(self.losses) 

112 # Update metrics (includes the metric that tracks the loss) 

113 self.compiled_metrics.update_state(y, decode_sequence) 

114 # Return a dict mapping metric names to current value 

115 results = {m.name: m.result() for m in self.metrics} 

116 results.update({"loss": loss, "crf_loss": crf_loss}) # append loss 

117 return results 

118 

119 def get_config(self): 

120 base_model_config = self.base_model.get_config() 

121 crf_config = self.crf_layer.get_config() 

122 

123 return {**{"base_model": base_model_config}, **crf_config} 

124 

125 @classmethod 

126 def from_config(cls, config): 

127 base_model_config = config.pop("base_model") 

128 base_model = tf.keras.Model.from_config(base_model_config) 

129 

130 return cls(base_model=base_model, **config)