/*
 * Decompiled with CFR 0.152.
 */
package org.hibernate.testing.orm.transaction;

import jakarta.persistence.EntityManager;
import jakarta.persistence.QueryTimeoutException;
import java.util.Map;
import java.util.concurrent.TimeUnit;
import java.util.function.Consumer;
import java.util.function.Function;
import org.hibernate.LockMode;
import org.hibernate.LockOptions;
import org.hibernate.SharedSessionContract;
import org.hibernate.Transaction;
import org.hibernate.dialect.Dialect;
import org.hibernate.dialect.SQLServerDialect;
import org.hibernate.engine.spi.SessionFactoryImplementor;
import org.hibernate.engine.spi.SessionImplementor;
import org.hibernate.engine.spi.StatelessSessionImplementor;
import org.hibernate.exception.ConstraintViolationException;
import org.hibernate.exception.LockTimeoutException;
import org.hibernate.testing.orm.AsyncExecutor;
import org.hibernate.testing.orm.junit.SessionFactoryScope;
import org.jboss.logging.Logger;
import org.junit.jupiter.api.Assertions;

public abstract class TransactionUtil {
    private static final Logger log = Logger.getLogger(TransactionUtil.class);

    public static void inTransaction(SessionFactoryImplementor sessionFactory, Consumer<SessionImplementor> action) {
        try (SessionImplementor session = sessionFactory.openSession();){
            TransactionUtil.inTransaction(session, action);
        }
    }

    public static void inTransaction(SessionImplementor session, Consumer<SessionImplementor> action) {
        TransactionUtil.wrapInTransaction((SharedSessionContract)session, session, action);
    }

    public static void inTransaction(EntityManager entityManager, Consumer<EntityManager> action) {
        TransactionUtil.wrapInTransaction((SharedSessionContract)entityManager, entityManager, action);
    }

    public static void inTransaction(StatelessSessionImplementor session, Consumer<StatelessSessionImplementor> action) {
        TransactionUtil.wrapInTransaction((SharedSessionContract)session, session, action);
    }

    public static <R> R fromTransaction(SessionFactoryImplementor sessionFactory, Function<SessionImplementor, R> action) {
        try (SessionImplementor session = sessionFactory.openSession();){
            R r = TransactionUtil.fromTransaction(session, action);
            return r;
        }
    }

    public static <R> R fromTransaction(SessionImplementor session, Function<SessionImplementor, R> action) {
        return TransactionUtil.wrapInTransaction((SharedSessionContract)session, session, action);
    }

    public static <R> R fromTransaction(EntityManager entityManager, Function<EntityManager, R> action) {
        return TransactionUtil.wrapInTransaction((SharedSessionContract)entityManager, entityManager, action);
    }

    private static <T> void wrapInTransaction(SharedSessionContract session, T actionInput, Consumer<T> action) {
        Transaction txn = session.beginTransaction();
        log.trace((Object)"Started transaction");
        try {
            log.trace((Object)"Calling action in txn");
            action.accept(actionInput);
            log.trace((Object)"Called action - in txn");
            if (!txn.getRollbackOnly()) {
                log.trace((Object)"Committing transaction");
                txn.commit();
                log.trace((Object)"Committed transaction");
            } else {
                try {
                    log.trace((Object)"Rollback transaction marked for rollback only");
                    txn.rollback();
                }
                catch (Exception e) {
                    log.error((Object)"Rollback failure", (Throwable)e);
                }
            }
        }
        catch (Exception e) {
            log.tracef("Error calling action: %s (%s) - rolling back", (Object)e.getClass().getName(), (Object)e.getMessage());
            try {
                txn.rollback();
            }
            catch (Exception ignore) {
                log.trace((Object)"Was unable to roll back transaction");
            }
            throw e;
        }
        catch (AssertionError t) {
            try {
                txn.rollback();
            }
            catch (Exception ignore) {
                log.trace((Object)"Was unable to roll back transaction");
            }
            throw t;
        }
    }

    private static <T, R> R wrapInTransaction(SharedSessionContract session, T actionInput, Function<T, R> action) {
        log.trace((Object)"Started transaction");
        Transaction txn = session.beginTransaction();
        try {
            log.trace((Object)"Calling action in txn");
            R result = action.apply(actionInput);
            log.trace((Object)"Called action - in txn");
            log.trace((Object)"Committing transaction");
            txn.commit();
            log.trace((Object)"Committed transaction");
            return result;
        }
        catch (Exception e) {
            log.tracef("Error calling action: %s (%s) - rolling back", (Object)e.getClass().getName(), (Object)e.getMessage());
            try {
                txn.rollback();
            }
            catch (Exception ignore) {
                log.trace((Object)"Was unable to roll back transaction");
            }
            throw e;
        }
        catch (AssertionError t) {
            try {
                txn.rollback();
            }
            catch (Exception ignore) {
                log.trace((Object)"Was unable to roll back transaction");
            }
            throw t;
        }
    }

    public static void deleteRow(SessionFactoryScope factoryScope, String tableName, boolean expectingToBlock) {
        try {
            AsyncExecutor.executeAsync(2, TimeUnit.SECONDS, () -> factoryScope.inTransaction(session -> {
                String sql = String.format("delete from %s", tableName);
                session.createNativeQuery(sql).executeUpdate();
                if (expectingToBlock) {
                    Assertions.fail((String)("Expecting `delete from " + tableName + "` to block due to locks"));
                }
            }));
        }
        catch (AsyncExecutor.TimeoutException expected) {
            if (!expectingToBlock) {
                Assertions.fail((String)("Expecting update to " + tableName + " to succeed, but failed due to async timeout (presumably due to locks)"), (Throwable)expected);
            }
        }
        catch (RuntimeException re) {
            if (re.getCause() instanceof jakarta.persistence.LockTimeoutException || re.getCause() instanceof LockTimeoutException || re.getCause() instanceof QueryTimeoutException) {
                if (!expectingToBlock) {
                    Assertions.fail((String)("Expecting update to " + tableName + " to succeed, but failed due to async timeout (presumably due to locks)"), (Throwable)re.getCause());
                }
            }
            Throwable throwable = re.getCause();
            if (throwable instanceof ConstraintViolationException) {
                ConstraintViolationException cve = (ConstraintViolationException)throwable;
                throw cve;
            }
            throw re;
        }
    }

    public static void assertRowLock(SessionFactoryScope factoryScope, String tableName, String columnName, String idColumn, Number id, boolean expectingToBlock) {
        Dialect dialect = factoryScope.getSessionFactory().getJdbcServices().getDialect();
        boolean skipLocked = dialect.getLockingSupport().getMetadata().supportsSkipLocked();
        if (skipLocked && !(dialect instanceof SQLServerDialect)) {
            factoryScope.inTransaction(session -> {
                String baseSql = String.format("select %s from %s t where %s=%s", columnName, tableName, idColumn, id);
                String sql = dialect.applyLocksToSql(baseSql, new LockOptions(LockMode.UPGRADE_SKIPLOCKED), Map.of("t", new String[0]));
                int resultSize = session.createNativeQuery(sql).getResultList().size();
                if (expectingToBlock && resultSize > 0) {
                    Assertions.fail((String)("Expecting update to " + tableName + " to block dues to locks"));
                } else if (!expectingToBlock && resultSize == 0) {
                    Assertions.fail((String)("Unexpected lock found on " + tableName));
                }
            });
        } else {
            try {
                AsyncExecutor.executeAsync(2, TimeUnit.SECONDS, () -> factoryScope.inTransaction(session -> {
                    String sql = String.format("update %s set %s = null", tableName, columnName);
                    session.createNativeQuery(sql).executeUpdate();
                    if (expectingToBlock) {
                        Assertions.fail((String)("Expecting update to " + tableName + " to block dues to locks"));
                    }
                }));
            }
            catch (AsyncExecutor.TimeoutException expected) {
                if (!expectingToBlock) {
                    Assertions.fail((String)("Expecting update to " + tableName + " to succeed, but failed due to async timeout (presumably due to locks)"), (Throwable)expected);
                }
            }
            catch (RuntimeException re) {
                if (re.getCause() instanceof jakarta.persistence.LockTimeoutException || re.getCause() instanceof LockTimeoutException || re.getCause() instanceof QueryTimeoutException) {
                    if (!expectingToBlock) {
                        Assertions.fail((String)("Expecting update to " + tableName + " to succeed, but failed due to async timeout (presumably due to locks)"), (Throwable)re.getCause());
                    }
                }
                Throwable throwable = re.getCause();
                if (throwable instanceof ConstraintViolationException) {
                    ConstraintViolationException cve = (ConstraintViolationException)throwable;
                    throw cve;
                }
                throw re;
            }
        }
    }
}

