package org.terracotta.upgradability.serialization;

import java.io.ByteArrayInputStream;
import java.io.ByteArrayOutputStream;
import java.io.File;
import java.io.FileOutputStream;
import java.io.IOException;
import java.io.InputStream;
import java.io.ObjectInput;
import java.io.ObjectInputStream;
import java.io.ObjectOutput;
import java.io.ObjectOutputStream;
import java.io.Serializable;
import java.net.URL;
import java.util.ArrayList;
import java.util.Arrays;
import java.util.Comparator;
import java.util.List;

import org.hamcrest.Description;
import org.hamcrest.Matcher;
import org.hamcrest.TypeSafeMatcher;
import org.hamcrest.core.IsCollectionContaining;

import static org.hamcrest.core.Is.is;
import static org.junit.Assert.assertThat;

/**
 *
 * @author cdennis
 */
public class SerializationUpgradabilityTesting {
  
  private static final Comparator<Object> EQUALS_COMPARATOR = new Comparator<Object>() {

    public int compare(Object o1, Object o2) {
      if (o1.equals(o2)) {
        return 0;
      } else {
        return -1;
      }
    }
  };
          
  public static void validateSerializedForm(Serializable o, String ... serializedFormResources) throws IOException, ClassNotFoundException {
    validateSerializedForm(o, EQUALS_COMPARATOR, serializedFormResources);
  }
  
  public static void validateSerializedForm(Serializable o, URL ... serializedFormUrls) throws IOException, ClassNotFoundException {
    validateSerializedForm(o, EQUALS_COMPARATOR, serializedFormUrls);
  }
  
  public static <T extends Serializable> void validateSerializedForm(T o, Comparator<? super T> comparator, String ... serializedFormResources) throws IOException, ClassNotFoundException {
    validateSerializedForm(o, comparator, getResources(o.getClass().getClassLoader(), serializedFormResources));
  }
  
  public static <T extends Serializable> void validateSerializedForm(T o, Comparator<? super T> comparator, URL ... serializedFormUrls) throws IOException, ClassNotFoundException {
    byte[] inputForm = serialize(o);
    
    List<byte[]> validForms = extractAll(Arrays.asList(serializedFormUrls));
    try {
      assertThat(validForms, IsCollectionContaining.<byte[]>hasItem(arrayEquals(inputForm)));
    } catch (AssertionError e) {
      File f = dumpSerializedForm(o, inputForm);
      AssertionError e2 = new AssertionError("Unrecognized serialized form saved to " + f.getAbsolutePath());
      e2.initCause(e);
      throw e2;
    }

    for (byte[] form : validForms) {
      assertThat(comparator.compare(o, (T) deserialize(form)), is(0));
    }
  }

  private static byte[] serialize(Serializable o) throws IOException {
    ByteArrayOutputStream bout = new ByteArrayOutputStream();
    try {
      ObjectOutput oout = new ObjectOutputStream(bout);
      try {
        oout.writeObject(o);
      } finally {
        oout.close();
      }
    } finally {
      bout.close();
    }
    return bout.toByteArray();
  }

  private static Object deserialize(byte[] serialized) throws IOException, ClassNotFoundException {
    InputStream in = new ByteArrayInputStream(serialized);
    try {
      ObjectInput oin = new ObjectInputStream(in);
      try {
        return oin.readObject();
      } finally {
        oin.close();
      }
    } finally {
      in.close();
    }
  }
  
  private static List<byte[]> extractAll(List<URL> urls) throws IOException {
    List<byte[]> forms = new ArrayList<byte[]>(urls.size());

    for (URL url : urls) {
      forms.add(extract(url));
    }
    return forms;
  }

  private static byte[] extract(URL url) throws IOException {
    InputStream in = url.openStream();
    try {
      ByteArrayOutputStream bout = new ByteArrayOutputStream();
      try {
        byte[] buffer = new byte[4096];
        for (int copied = 0; (copied = in.read(buffer)) != -1; bout.write(buffer, 0, copied));
      } finally {
        bout.close();
      }
      return bout.toByteArray();
    } finally {
      in.close();
    }
  }

  private static Matcher<byte[]> arrayEquals(final byte[] array) {
    return new TypeSafeMatcher<byte[]>() {

      @Override
      protected boolean matchesSafely(byte[] item) {
        return Arrays.equals(array, item);
      }

      public void describeTo(Description description) {
        description.appendText("array equal to").appendValue(array);
      }
    };
  }

  private static File dumpSerializedForm(Serializable o, byte[] inputForm) throws IOException {
    String nameStem = o.getClass().getSimpleName();
    File f;
    for (int i = 0; (f = new File(nameStem + "-" + i + ".ser")).exists(); i++);
    
    FileOutputStream fout = new FileOutputStream(f);
    try {
      fout.write(inputForm);
    } finally {
      fout.close();
    }
    return f;
  }

  private static URL[] getResources(ClassLoader loader, String ... serializedFormResources) {
    URL[] urls = new URL[serializedFormResources.length];
    for (int i = 0; i < serializedFormResources.length; i++) {
      URL url = loader.getResource(serializedFormResources[i]);
      if (url == null) {
        throw new IllegalArgumentException("Could not find serialized form " + serializedFormResources[i]);
      } else {
        urls[i] = url;
      }
    }
    return urls;
  }
  
  public static boolean nullSafeEquals(Object o1, Object o2) {
    if (o1 == null) {
      return o2 == null;
    } else if (o2 == null) {
      return false;
    } else {
      return o1.equals(o2);
    }
  }
}
