DefaultCallbackExecutor.java

/*-
 * ========================LICENSE_START=================================
 * flyway-core
 * ========================================================================
 * Copyright (C) 2010 - 2025 Red Gate Software Ltd
 * ========================================================================
 * Licensed under the Apache License, Version 2.0 (the "License");
 * you may not use this file except in compliance with the License.
 * You may obtain a copy of the License at
 * 
 *      http://www.apache.org/licenses/LICENSE-2.0
 * 
 * Unless required by applicable law or agreed to in writing, software
 * distributed under the License is distributed on an "AS IS" BASIS,
 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 * See the License for the specific language governing permissions and
 * limitations under the License.
 * =========================LICENSE_END==================================
 */
package org.flywaydb.core.internal.callback;

import java.util.ArrayList;
import java.util.Collection;
import java.util.Comparator;
import java.util.List;
import java.util.Optional;
import java.util.concurrent.Callable;
import org.flywaydb.core.FlywayTelemetryManager;
import org.flywaydb.core.api.FlywayException;
import org.flywaydb.core.api.MigrationInfo;
import org.flywaydb.core.api.callback.CallbackEvent;
import org.flywaydb.core.api.callback.Context;
import org.flywaydb.core.api.callback.Error;
import org.flywaydb.core.api.callback.GenericCallback;
import org.flywaydb.core.api.callback.Warning;
import org.flywaydb.core.api.configuration.Configuration;
import org.flywaydb.core.api.exception.FlywayBlockStatementExecutionException;
import org.flywaydb.core.api.output.OperationResult;
import org.flywaydb.core.extensibility.EventTelemetryModel;
import org.flywaydb.core.internal.database.base.Connection;
import org.flywaydb.core.internal.database.base.Database;
import org.flywaydb.core.internal.database.base.Schema;
import org.flywaydb.core.internal.jdbc.ExecutionTemplateFactory;

/**
 * Executes the callbacks for a specific event.
 */
public class DefaultCallbackExecutor<E extends CallbackEvent<E>> implements CallbackExecutor<E> {

    private final Configuration configuration;
    private final Database database;
    private final Schema schema;
    private final FlywayTelemetryManager flywayTelemetryManager;
    private final List<GenericCallback<E>> callbacks;
    private MigrationInfo migrationInfo;

    /**
     * Creates a new callback executor.
     *
     * @param configuration The configuration.
     * @param database      The database.
     * @param schema        The current schema to use for the connection.
     * @param callbacks     The callbacks to execute.
     */
    public DefaultCallbackExecutor(final Configuration configuration,
        final Database database,
        final Schema schema,
        final FlywayTelemetryManager flywayTelemetryManager,
        final Collection<GenericCallback<E>> callbacks) {
        this.configuration = configuration;
        this.database = database;
        this.schema = schema;
        this.flywayTelemetryManager = flywayTelemetryManager;

        this.callbacks = new ArrayList<>(callbacks);
        this.callbacks.sort(Comparator.comparing(GenericCallback::getCallbackName));
    }

    @Override
    public Collection<String> onEvent(final E event) {
        return execute(event, database.getMainConnection());
    }

    @Override
    public void onMigrateOrUndoEvent(final E event) {

        if (callbacks.stream().anyMatch(callback -> callback.supports(event, null))) {
            execute(event, database.getEventConnection());
            database.disposeEventConnection();
        }
    }

    @Override
    public void setMigrationInfo(final MigrationInfo migrationInfo) {
        this.migrationInfo = migrationInfo;
    }

    @Override
    public void onEachMigrateOrUndoEvent(final E event) {
        final Context context = new SimpleContext(configuration,
            database.getMigrationConnection(),
            migrationInfo,
            null);
        for (final GenericCallback<E> callback : callbacks) {
            if (callback.supports(event, context)) {
                handleEvent(callback, event, context);
            }
        }
    }

    @Override
    public void onEachMigrateOrUndoStatementEvent(final E event,
        final String sql,
        final List<Warning> warnings,
        final List<Error> errors) {
        final Context context = new SimpleContext(configuration,
            database.getMigrationConnection(),
            migrationInfo,
            sql,
            warnings,
            errors);
        for (final GenericCallback<E> callback : callbacks) {
            if (callback.supports(event, context)) {
                handleEvent(callback, event, context);
            }
        }
    }

    public void onOperationFinishEvent(final E event, final OperationResult operationResult) {
        final Context context = new SimpleContext(configuration,
            database.getMigrationConnection(),
            migrationInfo,
            operationResult);
        for (final GenericCallback<E> callback : callbacks) {
            if (callback.supports(event, context)) {
                handleEvent(callback, event, context);
            }
        }
    }

    private Collection<String> execute(final E event, final Connection connection) {
        final Context context = new SimpleContext(configuration, connection, null, null);

        final Collection<GenericCallback<E>> callbacksToExecute = callbacks.stream()
            .filter(x -> x.supports(event, context))
            .toList();
        callbacksToExecute.forEach(callback -> {
            if (callback.canHandleInTransaction(event, context)) {
                ExecutionTemplateFactory.createExecutionTemplate(connection.getJdbcConnection(), database)
                    .execute((Callable<Void>) () -> {
                        DefaultCallbackExecutor.this.execute(connection, callback, event, context);
                        return null;
                    });
            } else {
                execute(connection, callback, event, context);
            }
        });

        return callbacksToExecute.stream().map(GenericCallback::getCallbackName).toList();
    }

    private void execute(final Connection connection,
        final GenericCallback<? super E> callback,
        final E event,
        final Context context) {
        connection.restoreOriginalState();
        connection.changeCurrentSchemaTo(schema);
        handleEvent(callback, event, context);
    }

    private void handleEvent(final GenericCallback<? super E> callback, final E event, final Context context) {
        final String callbackType = Optional.ofNullable(callback.getClass().getCanonicalName())
            .map(x -> x.startsWith("org.flywaydb"))
            .orElse(false) ? callback.getClass().getSimpleName() : "(custom callback class)";
        try (final EventTelemetryModel ignored = new CallbackTelemetryModel(event.getId(),
            callbackType,
            flywayTelemetryManager)) {
            callback.handle(event, context);
        } catch (final FlywayBlockStatementExecutionException e) {
            throw e;
        } catch (final Exception e) {
            throw new FlywayException("Error while executing " + event.getId() + " callback: " + e.getMessage(), e);
        }
    }
}