package sort;

import java.io.FileInputStream;
import java.io.FileOutputStream;
import java.io.IOException;
import java.io.ObjectInputStream;
import java.io.ObjectOutputStream;
import java.io.Serializable;
import java.util.ArrayList;

import org.hamcrest.core.IsInstanceOf;

/**
 * A class for constructing KD-trees with range-searching abilities. Written using
 * pseudocode from https://en.wikipedia.org/wiki/K-d_tree for building the tree
 * and referencing several resources: 
 * - http://www.cs.utah.edu/~lifeifei/cs6931/kdtree.pdf
 * - https://www.youtube.com/watch?v=Z4dNLvno-EY
 * - http://www.cs.uu.nl/docs/vakken/ga/slides5a.pdf
 * - https://www.datasciencecentral.com/profiles/blogs/implementing-kd-tree-for-fast-range-search-nearest-neighbor
 * All code is original and was written by Christopher W. Schankula.
 * @author Christopher W. Schankula
 *
 * @param <KeyVal> The type of key-value objects to insert into the tree.
 */
public class KDT<KeyVal extends Comparable<KeyVal>> implements Serializable {
	/**
	 * 
	 */
	private static final long serialVersionUID = 8807259801436570835L;
	/**
	 * 
	 */
	private KDNode root;
	private ArrayList<GeneralCompare<KeyVal>> axes;
	
	/**
	 * A node in a serializable KD-tree.
	 * @author Christopher W. Schankula
	 *
	 */
	private class KDNode implements Serializable{
		/**
		 * 
		 */
		private static final long serialVersionUID = -664511393872278542L;
		/**
		 * 
		 */
		private KeyVal keyval;
		private KDNode left, right;
		private int n;
		
		/**
		 * Constructor for KDNode. Creates a new KD-tree node.
		 * @param keyval
		 * @param n
		 */
		public KDNode(KeyVal keyval, int n) {
			this.keyval = keyval;
			this.n = n;
		}
	}
	
	/**
	 * Load a kd-tree from a serialized file.
	 * @param fn
	 */
	public KDT(String fn) {
		KDT<KeyVal> kdt = null;
		try {
	         FileInputStream fileIn = new FileInputStream(fn);
	         ObjectInputStream in = new ObjectInputStream(fileIn);
	         kdt = (KDT<KeyVal>) in.readObject();
	         in.close();
	         fileIn.close();
	      } catch (IOException i) {
	         i.printStackTrace();
	      } catch (ClassNotFoundException c) {
	         System.out.println("Employee class not found");
	         c.printStackTrace();
	      }
		//https://stackoverflow.com/questions/26327956/set-this-in-a-class
		this.root = kdt.root;
		this.axes = kdt.axes;
	}
	
	/**
	 * Construct a new kd-tree from an array of nodes.
	 * @param axes A GeneralCompare instance for each dimension of the kd-tree. This
	 * GeneralCompare must input an object of type KeyVal and output an ordering correspoding
	 * to that axis.
	 * @param keyvals An array of Key-Value pairs to insert into the tree.
	 */
	public KDT(ArrayList<GeneralCompare<KeyVal>> axes, Comparable<KeyVal>[] keyvals) {
		this.axes = axes;
		root = buildTree(keyvals, 0, keyvals.length - 1, 0);
	}
	
	/**
	 * Builds a balanced tree recursively. Builds the tree by using the (i % k)th comparison 
	 * function on each level i of the tree.
	 * 
	 * @param keyvals The array of key-value objects
	 * @param lo The lower bound of the part to build
	 * @param hi The upper bound of the part to build
	 * @param depth The depth of the new node to be created
	 * @return The new node
	 */
	private KDNode buildTree(Comparable<KeyVal>[] keyvals, int lo, int hi, int depth) {
		if (lo > hi) return null;
		int axis = depth % getK();
		
		int mid = (lo + hi) / 2;
		QuickSelect.median(keyvals, lo, hi, axes.get(axis));
		KeyVal median = (KeyVal) keyvals[mid];
		
		//TODO: fix size
		KDNode newNode = new KDNode(median, 0);
		newNode.left = buildTree(keyvals, lo, mid - 1, depth + 1);
		newNode.right = buildTree(keyvals, mid + 1, hi, depth + 1);
		
		newNode.n = size(newNode.left) + size(newNode.right) + 1;
		return newNode;
	}
	
	/**
	 * Perform a range search for nodes in the tree.
	 * @param range	An ArrayList of GeneralRange instances. This must take a value of type
	 * KeyVal and return whether that value is in the range for the given axis. Must provide
	 * an ArrayList with a size equal to the current tree's number of axes, k.
	 * 
	 * The search works by traversing the tree as in a BST but only comparing the (i % k)th
	 * axis on each level i of the tree.
	 * @return An Iterable of nodes found to be matching the range search.
	 */
	public Iterable<KeyVal> rangeSearch(ArrayList<GeneralRange<KeyVal>> range){
		ArrayList<KeyVal> result = new ArrayList<KeyVal>();
		rangeSearch(root, range, result, 0);
		return result;
	}
	
	//recursive private range search function
	private void rangeSearch(KDNode x, ArrayList<GeneralRange<KeyVal>> range, ArrayList<KeyVal> result, int depth) {
		if (x == null) return;
		int axis = depth % getK();
		GeneralRange<KeyVal> rg = range.get(axis);
		
		//System.out.println("Try: " + x.keyval);
		
		int bounds = rg.isInBounds((KeyVal) x.keyval);
		//if it's in the bounds, must search both subtrees. Also a candidate to be included in the results
		if (bounds == 0) {
			//if the point is inside the axis range, check if it's in the other ranges too
			if (pointInside(x.keyval, range)) {
				result.add(x.keyval);
			}
			//range search both subtrees
			rangeSearch(x.left, range, result, depth + 1);
			rangeSearch(x.right, range, result, depth + 1);
		} else if (bounds > 0) { //if it's bigger than the current axis, search the left subtree
			rangeSearch(x.left, range, result, depth + 1);
		} else if (bounds < 0)	//if it's smaller than the current axis, search the right subtree
			rangeSearch(x.right, range, result, depth + 1);
		
		return;
	}
	
	/**
	 * See if a point is inside the range given for all axes
	 * @param pt The point to test
	 * @param range The range search for all axes
	 * @return true if inside all ranges, false otherwise
	 */
	private boolean pointInside(KeyVal pt, ArrayList<GeneralRange<KeyVal>> range) {
		for (int i = 0; i < axes.size(); i++)
			if (range.get(i).isInBounds(pt) != 0) return false;
		return true;
	}
	
	/**
	 * Returns the number of nodes in the tree.
	 * @return the number of nodes present in the kd-tree.
	 */
	public int size() {
		return size(root);
	}
	
	/**
	 * The maximum depth for any node in the kd-tree.
	 * @return the depth of the maximum-depth node in the kd-tree.
	 */
	public int height() {
		return height(root);
	}
	
	private int height(KDNode x) {
		if (x == null) return 0;
		return 1 + Math.max(height(x.left), height(x.right));
	}
	
	private int size(KDNode x) {
		if (x == null) return 0;
		else return x.n;
	}
	
	/**
	 * Return the number of axes in the current kd-tree.
	 * @return the nubmer of axes in the current kd-tree.
	 */
	public int getK() {
		return axes.size();
	}
	public String toString() {
		return toString(root, "");
	}
	
	public void writeToFile(String fn) {
		try {
	         FileOutputStream fileOut =
	        		 new FileOutputStream(fn);
	         ObjectOutputStream out = new ObjectOutputStream(fileOut);
	         out.writeObject(this);
	         out.close();
	         fileOut.close();
	         System.out.printf("Serialized data is saved in /tmp/kdtree.ser");
	      } catch (IOException i) {
	         i.printStackTrace();
	      }
	}
	
	private String toString(KDNode x, String depth) {
		if (x == null) return depth + "null\n";
		String result = "";
		result += depth + x.keyval.toString() + "\n";
		result += toString(x.left, depth + "    ");
		result += toString(x.right, depth + "    ");
		return result;
	}
}