package org.libj.util;

import java.util.Collection;
import java.util.HashMap;
import java.util.Map;
import java.util.NavigableMap;
import java.util.NavigableSet;
import java.util.Objects;
import java.util.Set;
import java.util.SortedMap;
import java.util.TreeMap;
import java.util.function.BiFunction;
import java.util.function.Function;

public class TreeHashMap<K,V> extends TreeMap<K,V> {
  private final HashMap<K,V> hashMap = new HashMap<>();

  @Override
  public V put(final K key, final V value) {
    hashMap.put(key, value);
    return super.put(key, value);
  }

  @Override
  public V getOrDefault(final Object key, final V defaultValue) {
    return hashMap.getOrDefault(key, defaultValue);
  }

  @Override
  public V putIfAbsent(final K key, final V value) {
    if (hashMap.containsKey(key))
      return null;

    hashMap.put(key, value);
    return super.put(key, value);
  }

  @Override
  public V computeIfAbsent(final K key, final Function<? super K,? extends V> mappingFunction) {
    if (hashMap.containsKey(key))
      return null;

    hashMap.computeIfAbsent(key, mappingFunction);
    return super.computeIfAbsent(key, mappingFunction);
  }

  @Override
  public V computeIfPresent(final K key, final BiFunction<? super K,? super V,? extends V> remappingFunction) {
    if (!hashMap.containsKey(key))
      return null;

    hashMap.computeIfPresent(key, remappingFunction);
    return super.computeIfPresent(key, remappingFunction);
  }

  @Override
  public V compute(final K key, final BiFunction<? super K,? super V,? extends V> remappingFunction) {
    hashMap.compute(key, remappingFunction);
    return super.compute(key, remappingFunction);
  }

  @Override
  public V merge(final K key, final V value, final BiFunction<? super V,? super V,? extends V> remappingFunction) {
    hashMap.merge(key, value, remappingFunction);
    return super.merge(key, value, remappingFunction);
  }

  @Override
  public boolean containsKey(final Object key) {
    return super.containsKey(key);
  }

  @Override
  public V get(final Object key) {
    return super.get(key);
  }

  @Override
  public void putAll(final Map<? extends K,? extends V> map) {
    hashMap.putAll(map);
    super.putAll(map);
  }

  @Override
  public V remove(final Object key) {
    hashMap.remove(key);
    return super.remove(key);
  }

  @Override
  public void clear() {
    hashMap.clear();
    super.clear();
  }

  @Override
  @SuppressWarnings("unchecked")
  public TreeHashMap<K,V> clone() {
    final TreeHashMap<K,V> clone = (TreeHashMap<K,V>)super.clone();
    clone.hashMap.putAll(this);
    return clone;
  }

  @Override
  public Set<K> keySet() {
    return new ObservableSet<K>(super.keySet()) {
      @Override
      protected boolean beforeRemove(final Object element) {
        hashMap.remove(element);
        return true;
      }
    };
  }

  @Override
  public NavigableSet<K> navigableKeySet() {
    return new ObservableNavigableSet<K>(super.navigableKeySet()) {
      @Override
      protected boolean beforeRemove(final Object element) {
        hashMap.remove(element);
        return true;
      }
    };
  }

  @Override
  public NavigableSet<K> descendingKeySet() {
    return new ObservableNavigableSet<K>(super.descendingKeySet()) {
      @Override
      protected boolean beforeRemove(final Object element) {
        hashMap.remove(element);
        return true;
      }
    };
  }

  @Override
  public Collection<V> values() {
    return new ObservableCollection<V>(super.values()) {
      @Override
      protected boolean beforeRemove(final Object element) {
        throw new UnsupportedOperationException();
      }
    };
  }

  @Override
  public Set<Map.Entry<K,V>> entrySet() {
    return new ObservableSet<Map.Entry<K,V>>(super.entrySet()) {
      @Override
      @SuppressWarnings("unchecked")
      protected boolean beforeRemove(final Object element) {
        final Map.Entry<K,V> entry = (Map.Entry<K,V>)element;
        hashMap.remove(entry.getKey());
        return true;
      }
    };
  }

  @Override
  public NavigableMap<K,V> descendingMap() {
    return new ObservableNavigableMap<K,V>(super.descendingMap()) {
      @Override
      protected Object beforePut(final K key, final V oldValue, final V newValue, final Object preventDefault) {
        hashMap.put(key, newValue);
        return newValue;
      }

      @Override
      protected boolean beforeRemove(final Object key, final V value) {
        hashMap.remove(key, value);
        return true;
      }
    };
  }

  private final class SortedSubMap extends ObservableSortedMap<K,V> {
    private SortedSubMap(final SortedMap<K,V> map) {
      super(map);
    }

    @Override
    protected Object beforePut(final K key, final V oldValue, final V newValue, final Object preventDefault) {
      hashMap.put(key, newValue);
      return newValue;
    }

    @Override
    protected boolean beforeRemove(final Object key, final V value) {
      hashMap.remove(key, value);
      return true;
    }
  }

  private final class NavigableSubMap extends ObservableNavigableMap<K,V> {
    private NavigableSubMap(final NavigableMap<K,V> map) {
      super(map);
    }

    @Override
    protected Object beforePut(final K key, final V oldValue, final V newValue, final Object preventDefault) {
      hashMap.put(key, newValue);
      return newValue;
    }

    @Override
    protected boolean beforeRemove(final Object key, final V value) {
      hashMap.remove(key, value);
      return true;
    }
  }

  @Override
  public NavigableMap<K,V> subMap(final K fromKey, final boolean fromInclusive, final K toKey, final boolean toInclusive) {
    return new NavigableSubMap(super.subMap(fromKey, fromInclusive, toKey, toInclusive));
  }

  @Override
  public NavigableMap<K,V> headMap(final K toKey, final boolean inclusive) {
    return new NavigableSubMap(super.headMap(toKey, inclusive));
  }

  @Override
  public NavigableMap<K,V> tailMap(final K fromKey, final boolean inclusive) {
    return new NavigableSubMap(super.tailMap(fromKey, inclusive));
  }

  @Override
  public SortedMap<K,V> subMap(final K fromKey, final K toKey) {
    return new SortedSubMap(super.subMap(fromKey, toKey));
  }

  @Override
  public SortedMap<K,V> headMap(final K toKey) {
    return new SortedSubMap(super.headMap(toKey));
  }

  @Override
  public SortedMap<K,V> tailMap(final K fromKey) {
    return new SortedSubMap(super.tailMap(fromKey));
  }

  @Override
  public boolean replace(final K key, final V oldValue, final V newValue) {
    if (!Objects.equals(hashMap.get(key), newValue))
      return false;

    hashMap.put(key, newValue);
    super.put(key, newValue);
    return true;
  }

  @Override
  public V replace(final K key, final V value) {
    if (!hashMap.containsKey(key))
      return null;

    hashMap.put(key, value);
    return super.put(key, value);
  }

  @Override
  public void replaceAll(final BiFunction<? super K,? super V,? extends V> function) {
    hashMap.replaceAll(function);
    super.replaceAll(function);
  }
}