from flax import linen from flaxformer.architectures.t5 import t5_architecture from flaxformer.components.attention import dense_attention from flaxformer.components import convolution from flaxformer.components import dense from flaxformer.components import embedding from flaxformer.components import layer_norm from flaxformer.components import relative_position_biases import seqio from t5x import adafactor from t5x import decoding as decoding2 from t5x import gin_utils from t5x import models from t5x import partitioning from t5x import utils # Macros: # ============================================================================== ACTIVATION_DTYPE = 'bfloat16' ACTIVATION_PARTITIONING_DIMS = 1 ARCHITECTURE = @t5_architecture.DecoderOnly() BATCH_SIZE = 8 BIAS_INIT = @bias_init/linen.initializers.normal() BIDIRECTIONAL_ATTENTION = True CHECKPOINT_PATH = \ 'gs://madlad-400-checkpoints/checkpoints/8b-lm' DROPOUT_FACTORY = @dropout_factory/linen.Dropout DROPOUT_RATE = 0.0 EMBED_DIM = 4096 EOS_ID = 3 EOS_STR = '\n' EXTRA_IDS = 512 HEAD_DIM = 256 LOSS_NORMALIZING_FACTOR = True MLP_DIM = 16384 MODE = 'specific' MODEL = @models.DecoderOnlyModel() NUM_DECODER_LAYERS = 32 NUM_DECODES = 4 NUM_EMBEDDINGS = %VOCAB_SIZE NUM_ENCODER_LAYERS = None NUM_HEADS = 16 NUM_PARTITIONS = 4 OPTIMIZER = @adafactor.Adafactor() PARALLEL_LAYERS = True PARAMETER_PARTITIONING_DIMS = 1 SCALE = 4.0 TASK_FEATURE_LENGTHS = None VOCAB_SIZE = 256512 VOCABULARY = @seqio.SentencePieceVocabulary() # Parameters for adafactor.Adafactor: # ============================================================================== adafactor.Adafactor.beta1 = 0.9 adafactor.Adafactor.factored = False adafactor.Adafactor.multiply_by_parameter_scale = @adafactor.HParamMap() adafactor.Adafactor.step_offset = 0 # Parameters for utils.DatasetConfig: # ============================================================================== utils.DatasetConfig.batch_size = %BATCH_SIZE utils.DatasetConfig.mixture_or_task_name = %MIXTURE_OR_TASK_NAME utils.DatasetConfig.module = %MIXTURE_OR_TASK_MODULE utils.DatasetConfig.pack = False utils.DatasetConfig.seed = 42 utils.DatasetConfig.shuffle = False utils.DatasetConfig.split = 'test' utils.DatasetConfig.task_feature_lengths = %TASK_FEATURE_LENGTHS utils.DatasetConfig.trim_output_features = False utils.DatasetConfig.use_cached = False utils.DatasetConfig.use_custom_packing_ops = False # Parameters for t5_architecture.Decoder: # ============================================================================== t5_architecture.Decoder.dropout_factory = %DROPOUT_FACTORY t5_architecture.Decoder.dtype = %ACTIVATION_DTYPE t5_architecture.Decoder.layer_factory = @t5_architecture.DecoderLayer t5_architecture.Decoder.layer_norm_factory = \ @final_layer_norm/layer_norm.T5LayerNorm t5_architecture.Decoder.layer_remat = 'full' t5_architecture.Decoder.num_layers = %NUM_DECODER_LAYERS t5_architecture.Decoder.output_logits_factory = None t5_architecture.Decoder.position_embedder_factory = None t5_architecture.Decoder.shared_relative_position_bias_factory = None t5_architecture.Decoder.token_embedder_factory = @embedding.Embed # Parameters for seqio.DecoderFeatureConverter: # ============================================================================== seqio.DecoderFeatureConverter.apply_length_check = False # Parameters for t5_architecture.DecoderLayer: # ============================================================================== t5_architecture.DecoderLayer.activation_partitioning_dims = \ %ACTIVATION_PARTITIONING_DIMS t5_architecture.DecoderLayer.dropout_factory = %DROPOUT_FACTORY t5_architecture.DecoderLayer.encoder_decoder_attention = \ @dense_attention.MultiHeadDotProductAttention() t5_architecture.DecoderLayer.layer_norm_factory = @layer_norm.T5LayerNorm t5_architecture.DecoderLayer.mlp = @dense.MlpBlock() t5_architecture.DecoderLayer.parallel = %PARALLEL_LAYERS t5_architecture.DecoderLayer.relative_position_bias_factory = None t5_architecture.DecoderLayer.self_attention = \ @dense_attention.MultiQueryDotProductAttention() # Parameters for t5_architecture.DecoderOnly: # ============================================================================== t5_architecture.DecoderOnly.decoder_factory = @t5_architecture.Decoder # Parameters for models.DecoderOnlyModel: # ============================================================================== models.DecoderOnlyModel.decode_fn = @decoding.constrained_beam_search models.DecoderOnlyModel.feature_converter_cls = @seqio.DecoderFeatureConverter models.DecoderOnlyModel.inputs_bidirectional_attention = %BIDIRECTIONAL_ATTENTION models.DecoderOnlyModel.loss_normalizing_factor = %LOSS_NORMALIZING_FACTOR models.DecoderOnlyModel.module = %ARCHITECTURE models.DecoderOnlyModel.optimizer_def = %OPTIMIZER models.DecoderOnlyModel.vocabulary = %VOCABULARY models.DecoderOnlyModel.z_loss = 0.0001 # Parameters for models.DecoderOnlyModel.predict_batch_with_aux: # ============================================================================== models.DecoderOnlyModel.predict_batch_with_aux.decoder_params = {'eos_id': %EOS_ID} models.DecoderOnlyModel.predict_batch_with_aux.num_decodes = %NUM_DECODES # Parameters for dropout_factory/linen.Dropout: # ============================================================================== dropout_factory/linen.Dropout.broadcast_dims = (-2,) dropout_factory/linen.Dropout.rate = %DROPOUT_RATE # Parameters for embedding.Embed: # ============================================================================== embedding.Embed.attend_dtype = 'float32' embedding.Embed.cast_input_dtype = 'int32' embedding.Embed.dtype = %ACTIVATION_DTYPE embedding.Embed.embedding_init = @token_embedder_init/linen.initializers.normal() embedding.Embed.features = %EMBED_DIM embedding.Embed.name = 'token_embedder' embedding.Embed.num_embeddings = %NUM_EMBEDDINGS embedding.Embed.one_hot = True # Parameters for t5_architecture.Encoder: # ============================================================================== t5_architecture.Encoder.dtype = %ACTIVATION_DTYPE t5_architecture.Encoder.input_dropout_factory = %DROPOUT_FACTORY t5_architecture.Encoder.layer_factory = @t5_architecture.EncoderLayer t5_architecture.Encoder.layer_norm_factory = @layer_norm.T5LayerNorm t5_architecture.Encoder.num_layers = %NUM_ENCODER_LAYERS t5_architecture.Encoder.output_dropout_factory = %DROPOUT_FACTORY t5_architecture.Encoder.position_embedder_factory = None t5_architecture.Encoder.shared_relative_position_bias_factory = \ @relative_position_biases.RelativePositionBiases # Parameters for t5_architecture.EncoderDecoder: # ============================================================================== t5_architecture.EncoderDecoder.decoder_factory = @t5_architecture.Decoder t5_architecture.EncoderDecoder.dtype = %ACTIVATION_DTYPE t5_architecture.EncoderDecoder.encoder_factory = @t5_architecture.Encoder t5_architecture.EncoderDecoder.shared_token_embedder_factory = @embedding.Embed # Parameters for t5_architecture.EncoderLayer: # ============================================================================== t5_architecture.EncoderLayer.activation_partitioning_dims = \ %ACTIVATION_PARTITIONING_DIMS t5_architecture.EncoderLayer.attention = \ @dense_attention.MultiHeadDotProductAttention() t5_architecture.EncoderLayer.dropout_factory = %DROPOUT_FACTORY t5_architecture.EncoderLayer.layer_norm_factory = @layer_norm.T5LayerNorm t5_architecture.EncoderLayer.mlp = @dense.MlpBlock() # Parameters for eval_script.evaluate: # ============================================================================== eval_script.evaluate.dataset_cfg = @utils.DatasetConfig() eval_script.evaluate.inference_evaluator_cls = @seqio.Evaluator eval_script.evaluate.model = %MODEL eval_script.evaluate.output_dir = %EVAL_OUTPUT_DIR eval_script.evaluate.partitioner = @partitioning.PjitPartitioner() eval_script.evaluate.restore_checkpoint_cfg = @utils.RestoreCheckpointConfig() # Parameters for seqio.Evaluator: # ============================================================================== seqio.Evaluator.logger_cls = \ [@seqio.PyLoggingLogger, @seqio.TensorBoardLogger, @seqio.JSONLogger] seqio.Evaluator.num_examples = None seqio.Evaluator.use_memory_cache = True # Parameters for adafactor.HParamMap: # ============================================================================== adafactor.HParamMap.rules = [('relpos_bias', False), ('.*', True)] # Parameters for dense.MlpBlock: # ============================================================================== dense.MlpBlock.activations = ('swish', 'linear') dense.MlpBlock.bias_init = %BIAS_INIT dense.MlpBlock.dtype = %ACTIVATION_DTYPE dense.MlpBlock.final_dropout_rate = 0 dense.MlpBlock.intermediate_dim = %MLP_DIM dense.MlpBlock.intermediate_dropout_rate = %DROPOUT_RATE dense.MlpBlock.kernel_init = @mlp_kernel_init/linen.initializers.variance_scaling() dense.MlpBlock.use_bias = False # Parameters for dense_attention.MultiHeadDotProductAttention: # ============================================================================== dense_attention.MultiHeadDotProductAttention.bias_init = %BIAS_INIT dense_attention.MultiHeadDotProductAttention.broadcast_dropout = True dense_attention.MultiHeadDotProductAttention.dropout_rate = %DROPOUT_RATE dense_attention.MultiHeadDotProductAttention.dtype = %ACTIVATION_DTYPE dense_attention.MultiHeadDotProductAttention.float32_logits = True dense_attention.MultiHeadDotProductAttention.head_dim = %HEAD_DIM dense_attention.MultiHeadDotProductAttention.kernel_init = \ @attention_kernel_init/linen.initializers.variance_scaling() dense_attention.MultiHeadDotProductAttention.num_heads = %NUM_HEADS dense_attention.MultiHeadDotProductAttention.use_bias = False dense_attention.MultiHeadDotProductAttention.use_rotary_embedding = True # Parameters for dense_attention.MultiQueryDotProductAttention: # ============================================================================== dense_attention.MultiQueryDotProductAttention.bias_init = %BIAS_INIT dense_attention.MultiQueryDotProductAttention.broadcast_dropout = True dense_attention.MultiQueryDotProductAttention.dropout_rate = %DROPOUT_RATE dense_attention.MultiQueryDotProductAttention.dtype = %ACTIVATION_DTYPE dense_attention.MultiQueryDotProductAttention.float32_logits = True dense_attention.MultiQueryDotProductAttention.head_dim = %HEAD_DIM dense_attention.MultiQueryDotProductAttention.kernel_init = \ @attention_kernel_init/linen.initializers.variance_scaling() dense_attention.MultiQueryDotProductAttention.num_heads = %NUM_HEADS dense_attention.MultiQueryDotProductAttention.use_bias = False dense_attention.MultiQueryDotProductAttention.use_rotary_embedding = True # Parameters for bias_init/linen.initializers.normal: # ============================================================================== bias_init/linen.initializers.normal.stddev = 1e-06 # Parameters for token_embedder_init/linen.initializers.normal: # ============================================================================== token_embedder_init/linen.initializers.normal.stddev = 1.0 # Parameters for partitioning.PjitPartitioner: # ============================================================================== partitioning.PjitPartitioner.logical_axis_rules = \ @partitioning.standard_logical_axis_rules() partitioning.PjitPartitioner.num_partitions = %NUM_PARTITIONS # Parameters for relative_position_biases.RelativePositionBiases: # ============================================================================== relative_position_biases.RelativePositionBiases.dtype = %ACTIVATION_DTYPE relative_position_biases.RelativePositionBiases.embedding_init = \ @relative_position_bias_init/linen.initializers.variance_scaling() relative_position_biases.RelativePositionBiases.max_distance = 128 relative_position_biases.RelativePositionBiases.num_buckets = 32 relative_position_biases.RelativePositionBiases.num_heads = %NUM_HEADS # Parameters for utils.RestoreCheckpointConfig: # ============================================================================== utils.RestoreCheckpointConfig.mode = %MODE utils.RestoreCheckpointConfig.path = %CHECKPOINT_PATH # Parameters for seqio.SentencePieceVocabulary: # ============================================================================== seqio.SentencePieceVocabulary.extra_ids = %EXTRA_IDS seqio.SentencePieceVocabulary.sentencepiece_model_file = \ 'gs://madlad-400-checkpoints/vocabulary/spm.model' # Parameters for layer_norm.T5LayerNorm: # ============================================================================== layer_norm.T5LayerNorm.center_scale_at_zero = True layer_norm.T5LayerNorm.dtype = %ACTIVATION_DTYPE layer_norm.T5LayerNorm.use_scale = True # Parameters for final_layer_norm/layer_norm.T5LayerNorm: # ============================================================================== final_layer_norm/layer_norm.T5LayerNorm.use_scale = False # Parameters for decoding2.temperature_sample: # ============================================================================== decoding2.temperature_sample.max_decode_steps = 256 decoding2.temperature_sample.temperature = 1.0 decoding2.temperature_sample.topk = 1 # Parameters for attention_kernel_init/linen.initializers.variance_scaling: # ============================================================================== attention_kernel_init/linen.initializers.variance_scaling.distribution = 'normal' attention_kernel_init/linen.initializers.variance_scaling.mode = 'fan_in' attention_kernel_init/linen.initializers.variance_scaling.scale = 1.0 # Parameters for mlp_kernel_init/linen.initializers.variance_scaling: # ============================================================================== mlp_kernel_init/linen.initializers.variance_scaling.distribution = \ 'truncated_normal' mlp_kernel_init/linen.initializers.variance_scaling.mode = 'fan_in' mlp_kernel_init/linen.initializers.variance_scaling.scale = 1.0 # Parameters for relative_position_bias_init/linen.initializers.variance_scaling: # ============================================================================== relative_position_bias_init/linen.initializers.variance_scaling.distribution = \ 'uniform' relative_position_bias_init/linen.initializers.variance_scaling.mode = 'fan_avg' relative_position_bias_init/linen.initializers.variance_scaling.scale = 1.0