package com.atlassian.audit.broker;

import com.atlassian.audit.api.AuditConsumer;
import com.atlassian.audit.entity.AuditEntity;
import com.atlassian.audit.event.AuditConsumerAddedEvent;
import com.atlassian.audit.event.AuditConsumerRemovedEvent;
import com.atlassian.event.api.EventListener;
import com.atlassian.event.api.EventPublisher;
import org.springframework.beans.factory.DisposableBean;
import org.springframework.beans.factory.InitializingBean;

import javax.annotation.Nonnull;
import javax.annotation.concurrent.ThreadSafe;
import java.util.ArrayList;
import java.util.List;
import java.util.concurrent.ArrayBlockingQueue;
import java.util.concurrent.BlockingQueue;
import java.util.concurrent.ConcurrentHashMap;
import java.util.concurrent.atomic.AtomicBoolean;
import java.util.function.BiConsumer;
import java.util.function.Consumer;

import static java.util.Objects.requireNonNull;

@ThreadSafe
public class ScatterAuditBroker implements InternalAuditBroker, InitializingBean, DisposableBean {

    private final EventPublisher eventPublisher;

    private final AuditPolicy auditPolicy;

    private final AuditEntityRejectionHandler rejectAuditEntityHandler;

    private final AuditConsumerExceptionHandler exceptionHandler;

    private final ConcurrentHashMap<AuditConsumer, ConsumerRegistration> consumerRegistry;

    private final int defaultConsumerBufferSize;

    private final int defaultConsumerBatchSize;

    public ScatterAuditBroker(EventPublisher eventPublisher,
                              AuditPolicy auditPolicy,
                              AuditEntityRejectionHandler rejectionHandler,
                              AuditConsumerExceptionHandler exceptionHandler,
                              int defaultConsumerBufferSize,
                              int defaultConsumerBatchSize) {
        this.eventPublisher = eventPublisher;
        this.defaultConsumerBatchSize = defaultConsumerBatchSize;
        this.defaultConsumerBufferSize = defaultConsumerBufferSize;
        this.auditPolicy = requireNonNull(auditPolicy);
        this.rejectAuditEntityHandler = requireNonNull(rejectionHandler);
        this.exceptionHandler = requireNonNull(exceptionHandler);
        this.consumerRegistry = new ConcurrentHashMap<>();
    }

    @Override
    public void afterPropertiesSet() {
        eventPublisher.register(this);
    }

    @Override
    public void destroy() {
        shutdown();
    }

    /**
     * Process in-flight entities in consumer queues before shutting down their threads.
     * This method also wait for all consumer thread to terminate.
     */
    public synchronized void shutdown() {
        consumerRegistry.values().forEach(x -> x.getThread().shutdown());
        waitForTermination();
    }

    /**
     * Clear consumer queues and shutdown their threads. In-flight entities are lost.
     * This method also wait for all consumer thread to terminate.
     */
    public synchronized void shutdownNow() {
        consumerRegistry.values().forEach(x -> x.getThread().shutdownNow());
        waitForTermination();
    }

    @EventListener
    public void onAuditConsumerAdded(AuditConsumerAddedEvent event) {
        addConsumer(event.getConsumerService(),
                defaultConsumerBufferSize,
                defaultConsumerBatchSize);
    }

    @EventListener
    public void onAuditConsumerRemoved(AuditConsumerRemovedEvent event) {
        removeConsumer(event.getConsumerService(), false);
    }

    public void addConsumer(AuditConsumer consumer, int bufferSize, int batchSize) {
        ConsumerQueue queue = new ConsumerQueue(new ArrayBlockingQueue<>(bufferSize),
                batchSize,
                entity -> rejectAuditEntityHandler.reject(this, consumer, entity)
        );
        ConsumerThread thread = new ConsumerThread(queue,
                consumer,
                (exception, batch) -> exceptionHandler.handle(consumer, exception, batch));
        consumerRegistry.put(consumer, new ConsumerRegistration(queue, thread));
        thread.start();
    }

    /**
     * Removes a consumer from the broker if it exists. If {@code force} is {@code true} then
     * in-flight entities are lost.
     */
    public void removeConsumer(AuditConsumer consumer, boolean force) {
        ConsumerRegistration registration = consumerRegistry.remove(consumer);
        if (registration != null) {
            if (force) {
                registration.getThread().shutdownNow();
            } else {
                registration.getThread().shutdown();
            }
        }
    }

    @Override
    public void audit(@Nonnull AuditEntity entity) {
        requireNonNull(entity, "entity");
        if (auditPolicy.pass(entity)) {
            consumerRegistry.entrySet()
                    .stream()
                    .filter(e -> e.getKey().isEnabled())
                    .forEach(e -> e.getValue().queue.offer(entity));
        }
    }

    private void waitForTermination() {
        consumerRegistry.values().forEach(x -> {
            try {
                x.thread.join();
            } catch (InterruptedException ignored) {
            }
        });
    }

    private static class ConsumerRegistration {
        private final ConsumerQueue queue;
        private final ConsumerThread thread;

        private ConsumerRegistration(ConsumerQueue queue, ConsumerThread thread) {
            this.queue = requireNonNull(queue);
            this.thread = requireNonNull(thread);
        }

        ConsumerThread getThread() {
            return thread;
        }
    }

    private final class ConsumerThread extends Thread {
        private final AtomicBoolean running;
        private final AuditConsumer consumer;
        private final ConsumerQueue queue;
        private final BiConsumer<RuntimeException, List<AuditEntity>> exceptionHandler;

        ConsumerThread(ConsumerQueue queue,
                       AuditConsumer consumer,
                       BiConsumer<RuntimeException, List<AuditEntity>> exceptionHandler) {
            super("audit-broker-consumer-thread-" + ScatterAuditBroker.this.hashCode());
            this.running = new AtomicBoolean(false);
            this.queue = requireNonNull(queue);
            this.consumer = requireNonNull(consumer);
            this.exceptionHandler = requireNonNull(exceptionHandler);
        }

        @Override
        public void run() {
            while (!isInterrupted()) {
                try {
                    List<AuditEntity> batch = queue.take();
                    processBatch(batch);
                } catch (InterruptedException e) {
                    break;
                }
            }
        }

        @Override
        public void start() {
            if (running.compareAndSet(false, true)) {
                super.start();
            }
        }

        public void shutdown() {
            if (running.compareAndSet(true, false)) {
                interrupt();
                drainQueue();
            }
        }

        public void shutdownNow() {
            if (running.compareAndSet(true, false)) {
                interrupt();
            }
            queue.clear();
            running.set(false);
        }

        private void drainQueue() {
            while (true) {
                List<AuditEntity> batch = queue.poll();
                if (batch.isEmpty()) {
                    break;
                } else {
                    processBatch(batch);
                }
            }
        }

        private void processBatch(List<AuditEntity> batch) {
            try {
                consumer.accept(batch);
            } catch (RuntimeException e) {
                exceptionHandler.accept(e, batch);
            }
        }
    }

    @ThreadSafe
    private static final class ConsumerQueue {
        private final BlockingQueue<AuditEntity> queue;
        private final int batchSize;
        private final Consumer<List<AuditEntity>> rejectionHandler;

        ConsumerQueue(BlockingQueue<AuditEntity> queue,
                      int batchSize,
                      Consumer<List<AuditEntity>> rejectionHandler) {
            this.queue = requireNonNull(queue);
            this.batchSize = batchSize;
            this.rejectionHandler = requireNonNull(rejectionHandler);
        }

        void offer(AuditEntity entity) {
            while (!queue.offer(entity)) {
                // queue is full, discard some entities from it and retry
                discardOldestEntities();
            }
        }

        void clear() {
            queue.clear();
        }

        List<AuditEntity> take() throws InterruptedException {
            List<AuditEntity> batch = new ArrayList<>(batchSize);
            // it will be blocked if queue is empty
            AuditEntity entity = queue.take();
            batch.add(entity);
            while (batch.size() < batchSize && (entity = queue.poll()) != null) {
                batch.add(entity);
            }
            return batch;
        }

        List<AuditEntity> poll() {
            List<AuditEntity> batch = new ArrayList<>(batchSize);
            AuditEntity entity;
            while (batch.size() < batchSize && (entity = queue.poll()) != null) {
                batch.add(entity);
            }
            return batch;
        }

        private void discardOldestEntities() {
            List<AuditEntity> batch = new ArrayList<>(batchSize);
            for (int i = 0; i < batchSize; i++) {
                AuditEntity entity = queue.poll();
                if (entity == null) {
                    break;
                } else {
                    batch.add(entity);
                }
            }
            rejectionHandler.accept(batch);
        }
    }
}
