package com.atlassian.audit.ao.service;

import com.atlassian.audit.api.AuditEntityCursor;
import com.atlassian.audit.api.AuditQuery;
import com.atlassian.audit.api.AuditSearchService;
import com.atlassian.audit.api.util.pagination.Page;
import com.atlassian.audit.api.util.pagination.PageRequest;
import com.atlassian.audit.entity.AuditEntity;

import javax.annotation.Nonnull;
import javax.annotation.Nullable;
import java.util.concurrent.Semaphore;
import java.util.concurrent.TimeUnit;
import java.util.concurrent.TimeoutException;
import java.util.function.Consumer;

import static com.google.common.base.Preconditions.checkArgument;
import static java.util.Objects.requireNonNull;

public class RateLimitedSearchService implements AuditSearchService {

    private final Semaphore textSearchLimiter;
    private final Semaphore nonTextSearchLimiter;
    private final int queryTimeoutSeconds;
    private final AuditSearchService delegate;

    public RateLimitedSearchService(int maxConcurrentTextSearchRequests, int maxConcurrentNonTextSearchRequests,
                                    int queryTimeoutSeconds, AuditSearchService delegate) {
        checkArgument(maxConcurrentTextSearchRequests > 0, "Max concurrent text search requests should be at least 1");
        checkArgument(maxConcurrentNonTextSearchRequests > 0, "Max concurrent non-text search requests should be at least 1");
        this.textSearchLimiter = new Semaphore(maxConcurrentTextSearchRequests);
        this.nonTextSearchLimiter = new Semaphore(maxConcurrentNonTextSearchRequests);
        this.queryTimeoutSeconds = queryTimeoutSeconds;
        this.delegate = delegate;
    }

    @Nonnull
    @Override
    public Page<AuditEntity, AuditEntityCursor> findBy(@Nonnull AuditQuery query,
                                                       @Nonnull PageRequest<AuditEntityCursor> pageRequest)
            throws TimeoutException {
        requireNonNull(query, "query");
        requireNonNull(pageRequest, "pageRequest");
        try {
            if (query.getSearchText().isPresent()) {
                return tryFindAuditEntities(query, pageRequest, textSearchLimiter);
            }
            return tryFindAuditEntities(query, pageRequest, nonTextSearchLimiter);
        } catch (InterruptedException e) {
            Thread.currentThread().interrupt();
            throw new RuntimeException(e);
        }
    }

    @Nonnull
    @Override
    public Page<AuditEntity, AuditEntityCursor> findBy(@Nonnull AuditQuery query,
                                                       @Nonnull PageRequest<AuditEntityCursor> pageRequest,
                                                       int scanLimit) throws TimeoutException {
        requireNonNull(query, "query");
        requireNonNull(pageRequest, "pageRequest");
        try {
            if (query.getSearchText().isPresent()) {
                return tryFindAuditEntities(query, pageRequest, textSearchLimiter, scanLimit);
            }
            return tryFindAuditEntities(query, pageRequest, nonTextSearchLimiter, scanLimit);
        } catch (InterruptedException e) {
            Thread.currentThread().interrupt();
            throw new RuntimeException(e);
        }
    }

    @Override
    public void stream(@Nonnull AuditQuery query, int offset, int limit, @Nonnull Consumer<AuditEntity> consumer) throws TimeoutException{
        requireNonNull(query, "query");
        requireNonNull(consumer, "consumer");
        try {
            if (query.getSearchText().isPresent()) {
                tryStream(query, consumer, textSearchLimiter, offset, limit);
            } else {
                tryStream(query, consumer, nonTextSearchLimiter, offset, limit);
            }
        } catch (InterruptedException e) {
            Thread.currentThread().interrupt();
            throw new RuntimeException(e);
        }
    }

    private void tryStream(AuditQuery query, Consumer<AuditEntity> consumer,
                           Semaphore semaphore, int offset, int limit) throws InterruptedException, TimeoutException {
        if (semaphore.tryAcquire(queryTimeoutSeconds, TimeUnit.SECONDS)) {
            try {
                delegate.stream(query, offset, limit, consumer);
            } finally {
                semaphore.release();
            }
        } else {
            throw new TimeoutException("Can't perform streamed search as there are many other search requests in progress");
        }
    }

    private Page<AuditEntity, AuditEntityCursor> tryFindAuditEntities(AuditQuery query,
                                                                      PageRequest<AuditEntityCursor> pageRequest,
                                                                      Semaphore semaphore,
                                                                      int scanLimit) throws InterruptedException, TimeoutException {
        if (semaphore.tryAcquire(queryTimeoutSeconds, TimeUnit.SECONDS)) {
            try {
                return delegate.findBy(query, pageRequest, scanLimit);
            } finally {
                semaphore.release();
            }
        } else {
            throw new TimeoutException("Can't perform search as there are many other search requests in progress");
        }
    }

    private Page<AuditEntity, AuditEntityCursor> tryFindAuditEntities(AuditQuery query,
                                                                      PageRequest<AuditEntityCursor> pageRequest,
                                                                      Semaphore semaphore) throws InterruptedException, TimeoutException {
        if (semaphore.tryAcquire(queryTimeoutSeconds, TimeUnit.SECONDS)) {
            try {
                return delegate.findBy(query, pageRequest);
            } finally {
                semaphore.release();
            }
        } else {
            throw new TimeoutException("Can't perform search as there are many other search requests in progress");
        }
    }

    @Override
    public long count(@Nullable AuditQuery query) throws TimeoutException {
        try {
            if (query != null && query.getSearchText().isPresent()) {
                return tryCountAuditEntities(query, textSearchLimiter);
            }
            return tryCountAuditEntities(query, nonTextSearchLimiter);
        } catch (InterruptedException e) {
            Thread.currentThread().interrupt();
            throw new RuntimeException(e);
        }
    }

    private long tryCountAuditEntities(AuditQuery query, Semaphore semaphore) throws InterruptedException, TimeoutException {
        if (semaphore.tryAcquire(queryTimeoutSeconds, TimeUnit.SECONDS)) {
            try {
                return delegate.count(query);
            } finally {
                semaphore.release();
            }
        } else {
            throw new TimeoutException("Can't perform count as there are many other count requests in progress");
        }
    }
}
