Coverage for /pythoncovmergedfiles/medio/medio/usr/local/lib/python3.8/site-packages/tensorflow_addons/text/crf_wrapper.py: 22%
64 statements
« prev ^ index » next coverage.py v7.4.0, created at 2024-01-03 07:57 +0000
« prev ^ index » next coverage.py v7.4.0, created at 2024-01-03 07:57 +0000
1import tensorflow as tf
3from tensorflow_addons.text import crf_log_likelihood
4from tensorflow_addons.utils import types
7@tf.keras.utils.register_keras_serializable(package="Addons")
8class CRFModelWrapper(tf.keras.Model):
9 def __init__(
10 self,
11 base_model: tf.keras.Model,
12 units: int,
13 chain_initializer: types.Initializer = "orthogonal",
14 use_boundary: bool = True,
15 boundary_initializer: types.Initializer = "zeros",
16 use_kernel: bool = True,
17 **kwargs,
18 ):
19 super().__init__()
21 # lazy import to solve circle import issue:
22 # tfa.layers.CRF -> tfa.text.__init__ -> tfa.text.crf_wrapper -> tfa.layers.CRF
23 from tensorflow_addons.layers.crf import CRF # noqa
25 self.crf_layer = CRF(
26 units=units,
27 chain_initializer=chain_initializer,
28 use_boundary=use_boundary,
29 boundary_initializer=boundary_initializer,
30 use_kernel=use_kernel,
31 **kwargs,
32 )
34 self.base_model = base_model
36 def unpack_training_data(self, data):
37 # override me, if this is not suit for your task
38 if len(data) == 3:
39 x, y, sample_weight = data
40 else:
41 x, y = data
42 sample_weight = None
43 return x, y, sample_weight
45 def call(self, inputs, training=None, mask=None, return_crf_internal=False):
46 base_model_outputs = self.base_model(inputs, training, mask)
48 # change next line, if your model has more outputs
49 crf_input = base_model_outputs
51 decode_sequence, potentials, sequence_length, kernel = self.crf_layer(crf_input)
53 # change next line, if your base model has more outputs
54 # Aways keep `(potentials, sequence_length, kernel), decode_sequence, `
55 # as first two outputs of model.
56 # current `self.train_step()` expected such settings
57 outputs = (potentials, sequence_length, kernel), decode_sequence
59 if return_crf_internal:
60 return outputs
61 else:
62 # outputs[0] is the crf internal, skip it
63 output_without_crf_internal = outputs[1:]
65 # it is nicer to return a tensor instead of an one tensor list
66 if len(output_without_crf_internal) == 1:
67 return output_without_crf_internal[0]
68 else:
69 return output_without_crf_internal
71 def compute_crf_loss(
72 self, potentials, sequence_length, kernel, y, sample_weight=None
73 ):
74 crf_likelihood, _ = crf_log_likelihood(potentials, y, sequence_length, kernel)
75 # convert likelihood to loss
76 flat_crf_loss = -1 * crf_likelihood
77 if sample_weight is not None:
78 flat_crf_loss = flat_crf_loss * sample_weight
79 crf_loss = tf.reduce_mean(flat_crf_loss)
81 return crf_loss
83 def train_step(self, data):
84 x, y, sample_weight = self.unpack_training_data(data)
85 with tf.GradientTape() as tape:
86 (potentials, sequence_length, kernel), decoded_sequence, *_ = self(
87 x, training=True, return_crf_internal=True
88 )
89 crf_loss = self.compute_crf_loss(
90 potentials, sequence_length, kernel, y, sample_weight
91 )
92 loss = crf_loss + tf.reduce_sum(self.losses)
93 gradients = tape.gradient(loss, self.trainable_variables)
94 self.optimizer.apply_gradients(zip(gradients, self.trainable_variables))
96 # Update metrics (includes the metric that tracks the loss)
97 self.compiled_metrics.update_state(y, decoded_sequence)
98 # Return a dict mapping metric names to current value
99 orig_results = {m.name: m.result() for m in self.metrics}
100 crf_results = {"loss": loss, "crf_loss": crf_loss}
101 return {**orig_results, **crf_results}
103 def test_step(self, data):
104 x, y, sample_weight = self.unpack_training_data(data)
105 (potentials, sequence_length, kernel), decode_sequence, *_ = self(
106 x, training=False, return_crf_internal=True
107 )
108 crf_loss = self.compute_crf_loss(
109 potentials, sequence_length, kernel, y, sample_weight
110 )
111 loss = crf_loss + tf.reduce_sum(self.losses)
112 # Update metrics (includes the metric that tracks the loss)
113 self.compiled_metrics.update_state(y, decode_sequence)
114 # Return a dict mapping metric names to current value
115 results = {m.name: m.result() for m in self.metrics}
116 results.update({"loss": loss, "crf_loss": crf_loss}) # append loss
117 return results
119 def get_config(self):
120 base_model_config = self.base_model.get_config()
121 crf_config = self.crf_layer.get_config()
123 return {**{"base_model": base_model_config}, **crf_config}
125 @classmethod
126 def from_config(cls, config):
127 base_model_config = config.pop("base_model")
128 base_model = tf.keras.Model.from_config(base_model_config)
130 return cls(base_model=base_model, **config)