/*
 * All content copyright (c) Terracotta, Inc., except as may otherwise be noted in a separate copyright notice. All
 * rights reserved.
 */
package org.terracotta.cache.serialization;

import org.terracotta.collections.ConcurrentDistributedMap;
import org.terracotta.collections.LockType;
import org.terracotta.collections.NullLockStrategy;

import com.tc.object.bytecode.Clearable;
import com.tc.object.bytecode.Manager;
import com.tc.object.bytecode.ManagerUtil;
import com.tc.object.bytecode.NotClearable;

import java.io.IOException;
import java.io.ObjectOutputStream;
import java.io.ObjectStreamClass;
import java.io.OutputStream;
import java.lang.reflect.Method;
import java.util.Map;
import java.util.concurrent.ConcurrentHashMap;

class ObjectStreamClassSerializer implements NotClearable {

  private static final Lookup                                DEFAULT_LOOKUP = obtainLookup();

  private final Object                                       writeLock      = new Object();
  private final ConcurrentDistributedMap<String, Object>     mappings;
  private int                                                nextMapping    = 0;

  private volatile transient Map<Integer, ObjectStreamClass> localCache     = new ConcurrentHashMap<Integer, ObjectStreamClass>();
  private volatile transient Lookup                          lookup;

  private final boolean                                      locked;

  ObjectStreamClassSerializer() {
    this(true);
  }

  ObjectStreamClassSerializer(boolean locked) {
    this.locked = locked;

    if (locked) {
      mappings = new ConcurrentDistributedMap<String, Object>();
    } else {
      mappings = new ConcurrentDistributedMap<String, Object>(LockType.WRITE, new NullLockStrategy<String>());
    }

    initialize();
  }

  public int getMappingFor(String className) {
    Integer value = (Integer) mappings.unsafeGet(className);
    if (value != null) { return value.intValue(); }

    if (locked) ManagerUtil.monitorEnter(writeLock, Manager.LOCK_TYPE_WRITE);
    try {
      value = (Integer) mappings.get(className);
      if (value != null) { return value.intValue(); }

      value = Integer.valueOf(nextMapping++);

      // insert two-way mapping exploiting the fact that className can never start with a digit
      put(className, value);
      put(value.toString(), className);
    } finally {
      if (locked) ManagerUtil.monitorExit(writeLock, Manager.LOCK_TYPE_WRITE);
    }

    return value.intValue();
  }

  public ObjectStreamClass getObjectStreamClassFor(int mapping, ClassLoader loader) throws ClassNotFoundException {
    ObjectStreamClass osc = localCache.get(mapping);
    if (osc == null) {
      String className = (String) mappings.get(Integer.toString(mapping));
      if (className == null) { throw new AssertionError("missing reverse mapping for " + mapping); }

      Class c = loader == null ? Class.forName(className) : Class.forName(className, false, loader);
      osc = lookup.lookup(c);

      // can't assert on this put since we're intentionally allowing races
      localCache.put(mapping, osc);
    }
    return osc;
  }

  // terracotta <on-load> method
  public void initialize() {
    lookup = DEFAULT_LOOKUP;
    localCache = new ConcurrentHashMap<Integer, ObjectStreamClass>();

    // There isn't anything clearable in this map so don't waste time visiting it
    for (Map map : mappings.getConstituentMaps()) {
      if (map instanceof Clearable) {
        ((Clearable) map).setEvictionEnabled(false);
      }
    }
  }

  private void put(String key, Object value) {
    Object prev = mappings.put(key, value);
    if (prev != null) {
      // this shouldn't ever happen
      throw new AssertionError("replaced mapping for key (" + key + "), old value = " + prev + ", new value = " + value);
    }

  }

  void forceSlowLookup() {
    lookup = new SlowLookup();
    localCache.clear();
  }

  private static Lookup obtainLookup() {
    try {
      // use public lookupAny method if available (1.6 only)
      Method lookupMethod = ObjectStreamClass.class.getMethod("lookupAny", Class.class);
      return new LookupAny(lookupMethod);
    } catch (Exception e) {
      // ignore
    }

    try {
      // try to use package private lookup() method
      Method method = ObjectStreamClass.class.getDeclaredMethod("lookup", Class.class, Boolean.TYPE);
      method.setAccessible(true);
      return new PrivateLookup(method);
    } catch (Exception e) {
      // ignore
    }

    return new SlowLookup();
  }

  private interface Lookup {
    ObjectStreamClass lookup(Class<?> cl);
  }

  private static class LookupAny implements Lookup {

    private final Method lookupMethod;

    LookupAny(Method lookupMethod) {
      this.lookupMethod = lookupMethod;
    }

    public ObjectStreamClass lookup(Class<?> cl) {
      try {
        return (ObjectStreamClass) lookupMethod.invoke(null, cl);
      } catch (Exception e) {
        throw new RuntimeException(e);
      }
    }

  }

  private static class PrivateLookup implements Lookup {

    private final Method method;

    PrivateLookup(Method method) {
      this.method = method;
    }

    public ObjectStreamClass lookup(Class<?> cl) {
      try {
        return (ObjectStreamClass) method.invoke(null, cl, Boolean.TRUE);
      } catch (Exception e) {
        throw new RuntimeException(e);
      }
    }
  }

  private static class SlowLookup implements Lookup {

    public ObjectStreamClass lookup(Class<?> cl) {
      try {
        OOS oos = new OOS(cl);
        oos.writeObject(cl);
        oos.close();
        return oos.getDescriptor();
      } catch (IOException e) {
        throw new RuntimeException(e);
      }
    }

    private static class OOS extends ObjectOutputStream {

      private final Class<?> cl;

      public OOS(Class<?> cl) throws IOException {
        super(new NullOutputStream());
        this.cl = cl;
      }

      ObjectStreamClass getDescriptor() {
        if (this.desc == null) { throw new IllegalStateException("Descriptor never set for " + cl); }
        return this.desc;
      }

      private ObjectStreamClass desc;

      @Override
      protected void writeClassDescriptor(ObjectStreamClass osc) throws IOException {
        // ObjectOutputStream will call lookup(Class, true) and pass the result to us
        if (osc.forClass().equals(cl)) {
          this.desc = osc;
        }
        super.writeClassDescriptor(osc);
      }
    }

    private static class NullOutputStream extends OutputStream {

      NullOutputStream() {
        //
      }

      @Override
      public void write(int b) {
        //
      }
    }

  }

}
