Skip to content
Snippets Groups Projects
Commit dacc15f4 authored by Christopher Schankula's avatar Christopher Schankula :earth_africa:
Browse files

clean up KDT

parent 4c9829f9
No related branches found
No related tags found
No related merge requests found
...@@ -8,7 +8,7 @@ import java.io.ObjectOutputStream; ...@@ -8,7 +8,7 @@ import java.io.ObjectOutputStream;
import java.io.Serializable; import java.io.Serializable;
import java.util.ArrayList; import java.util.ArrayList;
import sandbox.Point; import org.hamcrest.core.IsInstanceOf;
/** /**
* A class for constructing KD-trees with range-searching abilities. Written using * A class for constructing KD-trees with range-searching abilities. Written using
...@@ -31,55 +31,8 @@ public class KDT<KeyVal extends Comparable<KeyVal>> implements Serializable { ...@@ -31,55 +31,8 @@ public class KDT<KeyVal extends Comparable<KeyVal>> implements Serializable {
/** /**
* *
*/ */
KDNode root; private KDNode root;
ArrayList<GeneralCompare<KeyVal>> axes; private ArrayList<GeneralCompare<KeyVal>> axes;
public static void main(String[] args) {
GeneralCompare<Point> compX = (p1, p2) -> (int) ((Point) p1).getX() - (int) ((Point) p2).getX();
GeneralCompare<Point> compY = (p1, p2) -> (int) ((Point) p1).getY() - (int) ((Point) p2).getY();
//(2,3), (4,7), (5,4), (7,2), (8,1), (9,6)
//(8,1), (7,2), (2,3), (5,4), (9,6), (4,7)
Point p1 = new Point(2,3);
Point p2 = new Point(5,4);
Point p3 = new Point(9,6);
Point p4 = new Point(4,7);
Point p5 = new Point(8,1);
Point p6 = new Point(7,2);
ArrayList<GeneralCompare<Point>> axes = new ArrayList<GeneralCompare<Point>>();
axes.add(compX);
axes.add(compY);
Point[] pts = {p1, p2, p3, p4, p5, p6};
//KDT<Point> kdt = new KDT<Point>(axes, pts);
KDT<Point> kdt = new KDT<Point>("kdtree.ser");
System.out.println(kdt.size());
System.out.println(kdt.height());
//GeneralRange<Point> xRange = p -> 4 <= p.getX() && p.getX() <= 6;
//GeneralRange<Point> yRange = p -> 3 <= p.getY() && p.getY() <= 5;
GeneralRange<Point> xRange = p -> p.getX() < 2 ? -1 : (p.getX() > 7 ? 1 : 0);
GeneralRange<Point> yRange = p -> 0;//p.getY() < 3 ? -1 : (p.getY() > 6 ? 1 : 0);
ArrayList<GeneralRange<Point>> ranges = new ArrayList<GeneralRange<Point>>();
ranges.add(xRange);
ranges.add(yRange);
Iterable<Point> results = kdt.rangeSearch(ranges);
System.out.println("Results");
for (Point p : results) {
System.out.println(p);
}
System.out.println(kdt.toString());
kdt.writeToFile("kdtree.ser");
}
/** /**
* A node in a serializable KD-tree. * A node in a serializable KD-tree.
...@@ -144,6 +97,16 @@ public class KDT<KeyVal extends Comparable<KeyVal>> implements Serializable { ...@@ -144,6 +97,16 @@ public class KDT<KeyVal extends Comparable<KeyVal>> implements Serializable {
root = buildTree(keyvals, 0, keyvals.length - 1, 0); root = buildTree(keyvals, 0, keyvals.length - 1, 0);
} }
/**
* Builds a balanced tree recursively. Builds the tree by using the (i % k)th comparison
* function on each level i of the tree.
*
* @param keyvals The array of key-value objects
* @param lo The lower bound of the part to build
* @param hi The upper bound of the part to build
* @param depth The depth of the new node to be created
* @return The new node
*/
private KDNode buildTree(Comparable<KeyVal>[] keyvals, int lo, int hi, int depth) { private KDNode buildTree(Comparable<KeyVal>[] keyvals, int lo, int hi, int depth) {
if (lo > hi) return null; if (lo > hi) return null;
int axis = depth % getK(); int axis = depth % getK();
...@@ -166,7 +129,10 @@ public class KDT<KeyVal extends Comparable<KeyVal>> implements Serializable { ...@@ -166,7 +129,10 @@ public class KDT<KeyVal extends Comparable<KeyVal>> implements Serializable {
* @param range An ArrayList of GeneralRange instances. This must take a value of type * @param range An ArrayList of GeneralRange instances. This must take a value of type
* KeyVal and return whether that value is in the range for the given axis. Must provide * KeyVal and return whether that value is in the range for the given axis. Must provide
* an ArrayList with a size equal to the current tree's number of axes, k. * an ArrayList with a size equal to the current tree's number of axes, k.
* @return An Iterable of nodes found to be matchign the range search. *
* The search works by traversing the tree as in a BST but only comparing the (i % k)th
* axis on each level i of the tree.
* @return An Iterable of nodes found to be matching the range search.
*/ */
public Iterable<KeyVal> rangeSearch(ArrayList<GeneralRange<KeyVal>> range){ public Iterable<KeyVal> rangeSearch(ArrayList<GeneralRange<KeyVal>> range){
ArrayList<KeyVal> result = new ArrayList<KeyVal>(); ArrayList<KeyVal> result = new ArrayList<KeyVal>();
...@@ -174,6 +140,7 @@ public class KDT<KeyVal extends Comparable<KeyVal>> implements Serializable { ...@@ -174,6 +140,7 @@ public class KDT<KeyVal extends Comparable<KeyVal>> implements Serializable {
return result; return result;
} }
//recursive private range search function
private void rangeSearch(KDNode x, ArrayList<GeneralRange<KeyVal>> range, ArrayList<KeyVal> result, int depth) { private void rangeSearch(KDNode x, ArrayList<GeneralRange<KeyVal>> range, ArrayList<KeyVal> result, int depth) {
if (x == null) return; if (x == null) return;
int axis = depth % getK(); int axis = depth % getK();
...@@ -182,21 +149,29 @@ public class KDT<KeyVal extends Comparable<KeyVal>> implements Serializable { ...@@ -182,21 +149,29 @@ public class KDT<KeyVal extends Comparable<KeyVal>> implements Serializable {
//System.out.println("Try: " + x.keyval); //System.out.println("Try: " + x.keyval);
int bounds = rg.isInBounds((KeyVal) x.keyval); int bounds = rg.isInBounds((KeyVal) x.keyval);
//if it's in the bounds, must search both subtrees. Also a candidate to be included in the results
if (bounds == 0) { if (bounds == 0) {
//System.out.println(pointInside(x.keyval, range)); //if the point is inside the axis range, check if it's in the other ranges too
if (pointInside(x.keyval, range)) { if (pointInside(x.keyval, range)) {
result.add(x.keyval); result.add(x.keyval);
} }
//range search both subtrees
rangeSearch(x.left, range, result, depth + 1); rangeSearch(x.left, range, result, depth + 1);
rangeSearch(x.right, range, result, depth + 1); rangeSearch(x.right, range, result, depth + 1);
} else if (bounds > 0) { } else if (bounds > 0) { //if it's bigger than the current axis, search the left subtree
rangeSearch(x.left, range, result, depth + 1); rangeSearch(x.left, range, result, depth + 1);
} else if (bounds < 0) } else if (bounds < 0) //if it's smaller than the current axis, search the right subtree
rangeSearch(x.right, range, result, depth + 1); rangeSearch(x.right, range, result, depth + 1);
return; return;
} }
/**
* See if a point is inside the range given for all axes
* @param pt The point to test
* @param range The range search for all axes
* @return true if inside all ranges, false otherwise
*/
private boolean pointInside(KeyVal pt, ArrayList<GeneralRange<KeyVal>> range) { private boolean pointInside(KeyVal pt, ArrayList<GeneralRange<KeyVal>> range) {
for (int i = 0; i < axes.size(); i++) for (int i = 0; i < axes.size(); i++)
if (range.get(i).isInBounds(pt) != 0) return false; if (range.get(i).isInBounds(pt) != 0) return false;
...@@ -236,36 +211,4 @@ public class KDT<KeyVal extends Comparable<KeyVal>> implements Serializable { ...@@ -236,36 +211,4 @@ public class KDT<KeyVal extends Comparable<KeyVal>> implements Serializable {
public int getK() { public int getK() {
return axes.size(); return axes.size();
} }
/**
* Generates and returns a string representation of the nodes in the kd-tree.
* Note: for large trees, this may be slow and unstable.
* @return a string representation of the nodes in the kd-tree.
*/
public String toString() {
return toString(root, "");
}
private String toString(KDNode x, String depth) {
if (x == null) return depth + "null\n";
String result = "";
result += depth + x.keyval.toString() + "\n";
result += toString(x.left, depth + " ");
result += toString(x.right, depth + " ");
return result;
}
public void writeToFile(String fn) {
try {
FileOutputStream fileOut =
new FileOutputStream(fn);
ObjectOutputStream out = new ObjectOutputStream(fileOut);
out.writeObject(this);
out.close();
fileOut.close();
System.out.printf("Serialized data is saved in /tmp/kdtree.ser");
} catch (IOException i) {
i.printStackTrace();
}
}
} }
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment