Skip to content
Snippets Groups Projects
Commit dacc15f4 authored by Christopher Schankula's avatar Christopher Schankula :earth_africa:
Browse files

clean up KDT

parent 4c9829f9
No related branches found
No related tags found
No related merge requests found
......@@ -8,7 +8,7 @@ import java.io.ObjectOutputStream;
import java.io.Serializable;
import java.util.ArrayList;
import sandbox.Point;
import org.hamcrest.core.IsInstanceOf;
/**
* A class for constructing KD-trees with range-searching abilities. Written using
......@@ -31,55 +31,8 @@ public class KDT<KeyVal extends Comparable<KeyVal>> implements Serializable {
/**
*
*/
KDNode root;
ArrayList<GeneralCompare<KeyVal>> axes;
public static void main(String[] args) {
GeneralCompare<Point> compX = (p1, p2) -> (int) ((Point) p1).getX() - (int) ((Point) p2).getX();
GeneralCompare<Point> compY = (p1, p2) -> (int) ((Point) p1).getY() - (int) ((Point) p2).getY();
//(2,3), (4,7), (5,4), (7,2), (8,1), (9,6)
//(8,1), (7,2), (2,3), (5,4), (9,6), (4,7)
Point p1 = new Point(2,3);
Point p2 = new Point(5,4);
Point p3 = new Point(9,6);
Point p4 = new Point(4,7);
Point p5 = new Point(8,1);
Point p6 = new Point(7,2);
ArrayList<GeneralCompare<Point>> axes = new ArrayList<GeneralCompare<Point>>();
axes.add(compX);
axes.add(compY);
Point[] pts = {p1, p2, p3, p4, p5, p6};
//KDT<Point> kdt = new KDT<Point>(axes, pts);
KDT<Point> kdt = new KDT<Point>("kdtree.ser");
System.out.println(kdt.size());
System.out.println(kdt.height());
//GeneralRange<Point> xRange = p -> 4 <= p.getX() && p.getX() <= 6;
//GeneralRange<Point> yRange = p -> 3 <= p.getY() && p.getY() <= 5;
GeneralRange<Point> xRange = p -> p.getX() < 2 ? -1 : (p.getX() > 7 ? 1 : 0);
GeneralRange<Point> yRange = p -> 0;//p.getY() < 3 ? -1 : (p.getY() > 6 ? 1 : 0);
ArrayList<GeneralRange<Point>> ranges = new ArrayList<GeneralRange<Point>>();
ranges.add(xRange);
ranges.add(yRange);
Iterable<Point> results = kdt.rangeSearch(ranges);
System.out.println("Results");
for (Point p : results) {
System.out.println(p);
}
System.out.println(kdt.toString());
kdt.writeToFile("kdtree.ser");
}
private KDNode root;
private ArrayList<GeneralCompare<KeyVal>> axes;
/**
* A node in a serializable KD-tree.
......@@ -144,6 +97,16 @@ public class KDT<KeyVal extends Comparable<KeyVal>> implements Serializable {
root = buildTree(keyvals, 0, keyvals.length - 1, 0);
}
/**
* Builds a balanced tree recursively. Builds the tree by using the (i % k)th comparison
* function on each level i of the tree.
*
* @param keyvals The array of key-value objects
* @param lo The lower bound of the part to build
* @param hi The upper bound of the part to build
* @param depth The depth of the new node to be created
* @return The new node
*/
private KDNode buildTree(Comparable<KeyVal>[] keyvals, int lo, int hi, int depth) {
if (lo > hi) return null;
int axis = depth % getK();
......@@ -166,7 +129,10 @@ public class KDT<KeyVal extends Comparable<KeyVal>> implements Serializable {
* @param range An ArrayList of GeneralRange instances. This must take a value of type
* KeyVal and return whether that value is in the range for the given axis. Must provide
* an ArrayList with a size equal to the current tree's number of axes, k.
* @return An Iterable of nodes found to be matchign the range search.
*
* The search works by traversing the tree as in a BST but only comparing the (i % k)th
* axis on each level i of the tree.
* @return An Iterable of nodes found to be matching the range search.
*/
public Iterable<KeyVal> rangeSearch(ArrayList<GeneralRange<KeyVal>> range){
ArrayList<KeyVal> result = new ArrayList<KeyVal>();
......@@ -174,6 +140,7 @@ public class KDT<KeyVal extends Comparable<KeyVal>> implements Serializable {
return result;
}
//recursive private range search function
private void rangeSearch(KDNode x, ArrayList<GeneralRange<KeyVal>> range, ArrayList<KeyVal> result, int depth) {
if (x == null) return;
int axis = depth % getK();
......@@ -182,21 +149,29 @@ public class KDT<KeyVal extends Comparable<KeyVal>> implements Serializable {
//System.out.println("Try: " + x.keyval);
int bounds = rg.isInBounds((KeyVal) x.keyval);
//if it's in the bounds, must search both subtrees. Also a candidate to be included in the results
if (bounds == 0) {
//System.out.println(pointInside(x.keyval, range));
//if the point is inside the axis range, check if it's in the other ranges too
if (pointInside(x.keyval, range)) {
result.add(x.keyval);
}
//range search both subtrees
rangeSearch(x.left, range, result, depth + 1);
rangeSearch(x.right, range, result, depth + 1);
} else if (bounds > 0) {
} else if (bounds > 0) { //if it's bigger than the current axis, search the left subtree
rangeSearch(x.left, range, result, depth + 1);
} else if (bounds < 0)
} else if (bounds < 0) //if it's smaller than the current axis, search the right subtree
rangeSearch(x.right, range, result, depth + 1);
return;
}
/**
* See if a point is inside the range given for all axes
* @param pt The point to test
* @param range The range search for all axes
* @return true if inside all ranges, false otherwise
*/
private boolean pointInside(KeyVal pt, ArrayList<GeneralRange<KeyVal>> range) {
for (int i = 0; i < axes.size(); i++)
if (range.get(i).isInBounds(pt) != 0) return false;
......@@ -236,36 +211,4 @@ public class KDT<KeyVal extends Comparable<KeyVal>> implements Serializable {
public int getK() {
return axes.size();
}
/**
* Generates and returns a string representation of the nodes in the kd-tree.
* Note: for large trees, this may be slow and unstable.
* @return a string representation of the nodes in the kd-tree.
*/
public String toString() {
return toString(root, "");
}
private String toString(KDNode x, String depth) {
if (x == null) return depth + "null\n";
String result = "";
result += depth + x.keyval.toString() + "\n";
result += toString(x.left, depth + " ");
result += toString(x.right, depth + " ");
return result;
}
public void writeToFile(String fn) {
try {
FileOutputStream fileOut =
new FileOutputStream(fn);
ObjectOutputStream out = new ObjectOutputStream(fileOut);
out.writeObject(this);
out.close();
fileOut.close();
System.out.printf("Serialized data is saved in /tmp/kdtree.ser");
} catch (IOException i) {
i.printStackTrace();
}
}
}
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment