diff --git a/src/sort/KDT.java b/src/sort/KDT.java index 1ee1043de2be18d487bedc24ac721a9104069143..242a5af51cc932a58f6ea4fea150b5904bc37654 100644 --- a/src/sort/KDT.java +++ b/src/sort/KDT.java @@ -8,7 +8,7 @@ import java.io.ObjectOutputStream; import java.io.Serializable; import java.util.ArrayList; -import sandbox.Point; +import org.hamcrest.core.IsInstanceOf; /** * A class for constructing KD-trees with range-searching abilities. Written using @@ -31,55 +31,8 @@ public class KDT<KeyVal extends Comparable<KeyVal>> implements Serializable { /** * */ - KDNode root; - ArrayList<GeneralCompare<KeyVal>> axes; - - public static void main(String[] args) { - GeneralCompare<Point> compX = (p1, p2) -> (int) ((Point) p1).getX() - (int) ((Point) p2).getX(); - GeneralCompare<Point> compY = (p1, p2) -> (int) ((Point) p1).getY() - (int) ((Point) p2).getY(); - - //(2,3), (4,7), (5,4), (7,2), (8,1), (9,6) - //(8,1), (7,2), (2,3), (5,4), (9,6), (4,7) - Point p1 = new Point(2,3); - Point p2 = new Point(5,4); - Point p3 = new Point(9,6); - Point p4 = new Point(4,7); - Point p5 = new Point(8,1); - Point p6 = new Point(7,2); - - ArrayList<GeneralCompare<Point>> axes = new ArrayList<GeneralCompare<Point>>(); - axes.add(compX); - axes.add(compY); - - Point[] pts = {p1, p2, p3, p4, p5, p6}; - - //KDT<Point> kdt = new KDT<Point>(axes, pts); - KDT<Point> kdt = new KDT<Point>("kdtree.ser"); - System.out.println(kdt.size()); - System.out.println(kdt.height()); - - //GeneralRange<Point> xRange = p -> 4 <= p.getX() && p.getX() <= 6; - //GeneralRange<Point> yRange = p -> 3 <= p.getY() && p.getY() <= 5; - - GeneralRange<Point> xRange = p -> p.getX() < 2 ? -1 : (p.getX() > 7 ? 1 : 0); - GeneralRange<Point> yRange = p -> 0;//p.getY() < 3 ? -1 : (p.getY() > 6 ? 1 : 0); - - - ArrayList<GeneralRange<Point>> ranges = new ArrayList<GeneralRange<Point>>(); - ranges.add(xRange); - ranges.add(yRange); - - Iterable<Point> results = kdt.rangeSearch(ranges); - - System.out.println("Results"); - for (Point p : results) { - System.out.println(p); - } - - System.out.println(kdt.toString()); - - kdt.writeToFile("kdtree.ser"); - } + private KDNode root; + private ArrayList<GeneralCompare<KeyVal>> axes; /** * A node in a serializable KD-tree. @@ -144,6 +97,16 @@ public class KDT<KeyVal extends Comparable<KeyVal>> implements Serializable { root = buildTree(keyvals, 0, keyvals.length - 1, 0); } + /** + * Builds a balanced tree recursively. Builds the tree by using the (i % k)th comparison + * function on each level i of the tree. + * + * @param keyvals The array of key-value objects + * @param lo The lower bound of the part to build + * @param hi The upper bound of the part to build + * @param depth The depth of the new node to be created + * @return The new node + */ private KDNode buildTree(Comparable<KeyVal>[] keyvals, int lo, int hi, int depth) { if (lo > hi) return null; int axis = depth % getK(); @@ -166,7 +129,10 @@ public class KDT<KeyVal extends Comparable<KeyVal>> implements Serializable { * @param range An ArrayList of GeneralRange instances. This must take a value of type * KeyVal and return whether that value is in the range for the given axis. Must provide * an ArrayList with a size equal to the current tree's number of axes, k. - * @return An Iterable of nodes found to be matchign the range search. + * + * The search works by traversing the tree as in a BST but only comparing the (i % k)th + * axis on each level i of the tree. + * @return An Iterable of nodes found to be matching the range search. */ public Iterable<KeyVal> rangeSearch(ArrayList<GeneralRange<KeyVal>> range){ ArrayList<KeyVal> result = new ArrayList<KeyVal>(); @@ -174,6 +140,7 @@ public class KDT<KeyVal extends Comparable<KeyVal>> implements Serializable { return result; } + //recursive private range search function private void rangeSearch(KDNode x, ArrayList<GeneralRange<KeyVal>> range, ArrayList<KeyVal> result, int depth) { if (x == null) return; int axis = depth % getK(); @@ -182,21 +149,29 @@ public class KDT<KeyVal extends Comparable<KeyVal>> implements Serializable { //System.out.println("Try: " + x.keyval); int bounds = rg.isInBounds((KeyVal) x.keyval); + //if it's in the bounds, must search both subtrees. Also a candidate to be included in the results if (bounds == 0) { - //System.out.println(pointInside(x.keyval, range)); + //if the point is inside the axis range, check if it's in the other ranges too if (pointInside(x.keyval, range)) { result.add(x.keyval); } + //range search both subtrees rangeSearch(x.left, range, result, depth + 1); rangeSearch(x.right, range, result, depth + 1); - } else if (bounds > 0) { + } else if (bounds > 0) { //if it's bigger than the current axis, search the left subtree rangeSearch(x.left, range, result, depth + 1); - } else if (bounds < 0) + } else if (bounds < 0) //if it's smaller than the current axis, search the right subtree rangeSearch(x.right, range, result, depth + 1); return; } + /** + * See if a point is inside the range given for all axes + * @param pt The point to test + * @param range The range search for all axes + * @return true if inside all ranges, false otherwise + */ private boolean pointInside(KeyVal pt, ArrayList<GeneralRange<KeyVal>> range) { for (int i = 0; i < axes.size(); i++) if (range.get(i).isInBounds(pt) != 0) return false; @@ -236,36 +211,4 @@ public class KDT<KeyVal extends Comparable<KeyVal>> implements Serializable { public int getK() { return axes.size(); } - - /** - * Generates and returns a string representation of the nodes in the kd-tree. - * Note: for large trees, this may be slow and unstable. - * @return a string representation of the nodes in the kd-tree. - */ - public String toString() { - return toString(root, ""); - } - - private String toString(KDNode x, String depth) { - if (x == null) return depth + "null\n"; - String result = ""; - result += depth + x.keyval.toString() + "\n"; - result += toString(x.left, depth + " "); - result += toString(x.right, depth + " "); - return result; - } - - public void writeToFile(String fn) { - try { - FileOutputStream fileOut = - new FileOutputStream(fn); - ObjectOutputStream out = new ObjectOutputStream(fileOut); - out.writeObject(this); - out.close(); - fileOut.close(); - System.out.printf("Serialized data is saved in /tmp/kdtree.ser"); - } catch (IOException i) { - i.printStackTrace(); - } - } }