package com.atlassian.vcache.internal.core.metrics;

import com.atlassian.vcache.LocalCacheOperations;
import com.atlassian.vcache.internal.MetricLabel;

import java.util.Map;
import java.util.Optional;
import java.util.Set;
import java.util.function.Function;
import java.util.function.Supplier;

import static com.atlassian.vcache.internal.MetricLabel.NUMBER_OF_FACTORY_KEYS;
import static com.atlassian.vcache.internal.MetricLabel.NUMBER_OF_HITS;
import static com.atlassian.vcache.internal.MetricLabel.NUMBER_OF_MISSES;
import static com.atlassian.vcache.internal.MetricLabel.TIMED_FACTORY_CALL;
import static com.atlassian.vcache.internal.MetricLabel.TIMED_GET_CALL;
import static com.atlassian.vcache.internal.MetricLabel.TIMED_PUT_CALL;
import static com.atlassian.vcache.internal.MetricLabel.TIMED_REMOVE_ALL_CALL;
import static com.atlassian.vcache.internal.MetricLabel.TIMED_REMOVE_CALL;
import static com.atlassian.vcache.internal.MetricLabel.TIMED_SUPPLIER_CALL;
import static java.util.Objects.requireNonNull;

/**
 * Wrapper for a {@link LocalCacheOperations} that records metrics.
 *
 * @param <K> the key type
 * @param <V> the value type
 * @since 1.0.0
 */
abstract class TimedLocalCacheOperations<K, V>
        implements LocalCacheOperations<K, V> {
    protected final String cacheName;
    protected final CacheType cacheType;
    protected final MetricsRecorder metricsRecorder;

    TimedLocalCacheOperations(String cacheName, CacheType cacheType, MetricsRecorder metricsRecorder) {
        this.cacheName = requireNonNull(cacheName);
        this.cacheType = requireNonNull(cacheType);
        this.metricsRecorder = requireNonNull(metricsRecorder);
    }

    protected abstract LocalCacheOperations<K, V> getDelegate();

    @Override
    public Optional<V> get(K key) {
        try (ElapsedTimer ignored = new ElapsedTimer(
                t -> metricsRecorder.record(cacheName, cacheType, TIMED_GET_CALL, t))) {
            final Optional<V> result = getDelegate().get(key);
            metricsRecorder.record(
                    cacheName,
                    cacheType,
                    result.isPresent() ? MetricLabel.NUMBER_OF_HITS : MetricLabel.NUMBER_OF_MISSES,
                    1);

            return result;
        }
    }

    @Override
    public V get(K key, Supplier<? extends V> supplier) {
        try (ElapsedTimer ignored = new ElapsedTimer(
                t -> metricsRecorder.record(cacheName, cacheType, TIMED_GET_CALL, t));
             TimedSupplier<? extends V> timedSupplier = new TimedSupplier<>(supplier, this::handleTimedSupplier)) {
            return getDelegate().get(key, timedSupplier);
        }
    }

    @Override
    public Map<K, V> getBulk(Function<Set<K>, Map<K, V>> factory, Iterable<K> keys) {
        try (ElapsedTimer ignored = new ElapsedTimer(
                t -> metricsRecorder.record(cacheName, cacheType, TIMED_GET_CALL, t));
             TimedFactory<K, V> timedFactory = new TimedFactory<>(factory, this::handleTimedFactory)) {
            return getDelegate().getBulk(timedFactory, keys);
        }
    }

    @Override
    public void put(K key, V value) {
        try (ElapsedTimer ignored = new ElapsedTimer(
                t -> metricsRecorder.record(cacheName, cacheType, TIMED_PUT_CALL, t))) {
            getDelegate().put(key, value);
        }
    }

    @Override
    public Optional<V> putIfAbsent(K key, V value) {
        try (ElapsedTimer ignored = new ElapsedTimer(
                t -> metricsRecorder.record(cacheName, cacheType, TIMED_PUT_CALL, t))) {
            return getDelegate().putIfAbsent(key, value);
        }
    }

    @Override
    public boolean replaceIf(K key, V currentValue, V newValue) {
        try (ElapsedTimer ignored = new ElapsedTimer(
                t -> metricsRecorder.record(cacheName, cacheType, TIMED_PUT_CALL, t))) {
            return getDelegate().replaceIf(key, currentValue, newValue);
        }
    }

    @Override
    public boolean removeIf(K key, V value) {
        try (ElapsedTimer ignored = new ElapsedTimer(
                t -> metricsRecorder.record(cacheName, cacheType, TIMED_REMOVE_CALL, t))) {
            return getDelegate().removeIf(key, value);
        }
    }

    @Override
    public void remove(K key) {
        try (ElapsedTimer ignored = new ElapsedTimer(
                t -> metricsRecorder.record(cacheName, cacheType, TIMED_REMOVE_CALL, t))) {
            getDelegate().remove(key);
        }
    }

    @Override
    public void removeAll() {
        try (ElapsedTimer ignored = new ElapsedTimer(
                t -> metricsRecorder.record(cacheName, cacheType, TIMED_REMOVE_ALL_CALL, t))) {
            getDelegate().removeAll();
        }
    }

    private void handleTimedSupplier(Optional<Long> time) {
        if (time.isPresent()) {
            metricsRecorder.record(cacheName, cacheType, TIMED_SUPPLIER_CALL, time.get());
        }
        metricsRecorder.record(
                cacheName,
                cacheType,
                time.isPresent() ? NUMBER_OF_MISSES : NUMBER_OF_HITS,
                1);
    }

    private void handleTimedFactory(Optional<Long> time, Long numberOfKeys) {
        time.ifPresent(t -> {
            metricsRecorder.record(cacheName, cacheType, TIMED_FACTORY_CALL, t);
            metricsRecorder.record(cacheName, cacheType, NUMBER_OF_FACTORY_KEYS, numberOfKeys);
        });
    }
}
