Last active
January 25, 2016 06:55
-
-
Save m-manu/9073351 to your computer and use it in GitHub Desktop.
Implementation of a "sorted set", a type of set in which elements are sorted according to their 'scores'. Worst-case time complexity: O(1) for retrieval of score, O(log(n)) for addition/removal of elements, O(log(n)) for changes to scores, O(k) for getting top k or bottom k elements.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
package manu.sandbox.utils; | |
import java.util.Set; | |
public interface ScoredSet<T extends Comparable<T>> extends Set<T>, Cloneable { | |
interface ElementWithScore<T> { | |
T getElement(); | |
Double getScore(); | |
} | |
boolean add(T element, Double score); | |
T[] top(int k); | |
ElementWithScore<T>[] topWithScores(int k); | |
T[] bottom(int k); | |
ElementWithScore<T>[] bottomWithScores(int k); | |
Double getScore(T element); | |
void setScore(T element, Double score); | |
boolean incrBy(T element, Double deltaScore); | |
boolean incr(T element); | |
boolean decrBy(T element, Double deltaScore); | |
boolean decr(T element); | |
} |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
package manu.sandbox.utils; | |
import org.junit.Test; | |
import java.lang.reflect.Field; | |
import java.util.*; | |
import static org.junit.Assert.*; | |
public class TestZSet { | |
public static final double DELTA = 0.0001; | |
@Test | |
public void test() throws Exception { | |
testGiven(String.class, new String[]{"e1", "e2", "e3", "e4", "e5"}, new Double[]{1.0, 2.0, 3.0, 4.0, 5.0}); | |
testGiven(String.class, new String[]{"e1", "e2", "e3", "e4", "e5", "e6"}, new Double[]{0.0, 0.0, 0.0, 0.0, 0.0, 0.0}); | |
testGiven(Character.class, new Character[]{'A', 'B', 'C', 'D', 'E'}, new Double[]{1.0, 2.0, 3.0, 4.0, 5.0}); | |
testGiven(Integer.class, new Integer[]{5, 4, 3, 2, 1, 0}, new Double[]{0.0, 0.0, 0.0, 0.0, 0.0, 0.0}); | |
testGiven(Double.class, new Double[]{5.3, 4.2, 3.1, 2.9, 1.000000001, 0.00000001}, new Double[]{-1.0, -2.0, -3.0, -4.0, -10.0, -110.0}); | |
testGiven(String.class, new String[]{"e1", "e2", "e3", "e4", "e5", "e6"}, new Double[]{1.0, 1.0, 0.9, -5.0, 4.0, -5.0}); | |
} | |
public <T extends Comparable<T>> void testGiven(Class<T> clazz, T[] elements, Double[] scores) throws Exception { | |
assertTrue("Bug in tests case!", elements.length == scores.length && elements.length >= 5); | |
T e1 = elements[0], e2 = elements[1], e3 = elements[2], e4 = elements[3], e5 = elements[4]; | |
Double s1 = scores[0], s2 = scores[1], s3 = scores[2], s4 = scores[3], s5 = scores[4]; | |
ScoredSet<T> z = ZSet.create(clazz); | |
assertTrue("Set should be empty", z.isEmpty()); | |
for (int i = 0; i < scores.length; i++) { | |
z.add(elements[i], scores[i]); | |
} | |
for (int i = 0; i < scores.length; i++) { | |
Double score = z.getScore(elements[i]); | |
assertEquals("Score mismatch for " + elements[i], scores[i], score); | |
} | |
T[] arrInUse = Arrays.copyOf(elements, elements.length); | |
// System.out.println(Arrays.asList(arrInUse)); | |
Object[] arrExtracted1 = new Object[elements.length]; | |
z.toArray(arrExtracted1); | |
Arrays.sort(arrExtracted1); | |
// System.out.println(Arrays.asList(arrExtracted1)); | |
Object[] arrExtracted2 = z.toArray(); | |
Arrays.sort(arrExtracted2); | |
// System.out.println(Arrays.asList(arrExtracted2)); | |
assertTrue("Extracted array for toArray(a) is not same as input array", | |
ArrayAlgos.areArraysSame(arrExtracted1, arrInUse)); | |
assertTrue("Extracted array for toArray() is not same as input array", | |
ArrayAlgos.areArraysSame(arrExtracted2, arrInUse)); | |
testIntegrity(z, "Integrity tests failed for simple add"); | |
z.clear(); | |
testSize(z, 0); | |
testIntegrity(z, "Integrity tests failed for clear()"); | |
assertTrue("Set should be empty here", z.isEmpty()); | |
// add elements: | |
z.add(e1, s1); | |
testSize(z, 1); | |
z.add(e2, s2); | |
testSize(z, 2); | |
assertTrue("addAll() added two elements - should've returned true", z.addAll(Arrays.asList(e3, e4))); | |
testSize(z, 4); | |
assertTrue("addAll() added an element - should've returned true", z.addAll(Arrays.asList(e5))); | |
testSize(z, 5); | |
assertFalse("addAll() didn't add any elements - should've returned false", z.addAll(Arrays.asList(e3, e4))); | |
testIntegrity(z, "Integrity tests failed after add() calls"); | |
assertTrue("e1 should exist", z.contains(e1)); | |
// Remove: | |
assertTrue("remove() should return true when it removes an element", z.remove(e1)); | |
testIntegrity(z, "Integrity tests failed after remove()"); | |
// Remove again: | |
assertFalse("remove() should return false when it doesn't remove an element", z.remove(e1)); | |
testSize(z, 4); | |
testIntegrity(z, "Integrity tests failed after remove()"); | |
assertTrue("e1 should not exist", !z.contains(e1)); | |
assertTrue("Missing elements", z.containsAll(Arrays.asList(e2, e3, e4, e5))); | |
assertFalse("containsAll() should've returned false here", z.containsAll(Arrays.asList(e2, e3, e4, e1))); | |
assertTrue("remove() should return true when it removes an element", z.remove(e2)); | |
testSize(z, 3); | |
testIntegrity(z, "Integrity tests failed after second remove()"); | |
assertTrue("removeAll() removed an element - but returned false", z.removeAll(Arrays.asList(e3))); | |
testIntegrity(z, "Integrity tests failed after removeAll()"); | |
assertFalse("removeAll() didn't remove any elements - but returned true", z.removeAll(Arrays.asList(e2, e3))); | |
testIntegrity(z, "Integrity tests failed after second removeAll()"); | |
testSize(z, 2); | |
z.addAll(Arrays.asList(e1, e2, e3)); | |
assertTrue("retainAll() removed elements - should've returned true", z.retainAll(Arrays.asList(e4))); | |
assertFalse("retainAll() didn't remove any elements - should've returned false", z.retainAll(Arrays.asList(e4))); | |
testIntegrity(z, "Integrity tests failed after retainAll()"); | |
testSize(z, 1); | |
z.retainAll(Arrays.asList(e5)); | |
testIntegrity(z, "Integrity tests failed after calling retainAll() for non-existent element"); | |
testSize(z, 0); | |
z.add(e3, s3); | |
z.add(e4, s4); | |
z.add(e5, s5); | |
} | |
@Test | |
public void testScores() throws Exception { | |
ScoredSet<String> z = ZSet.create(String.class); | |
z.add("e1", 1.0); | |
z.add("e2", 2.0); | |
z.add("e3", 3.0); | |
z.add("e4", 4.0); | |
z.add("e5", 0.0); | |
testIntegrity(z, "Failed after initialization itself!"); | |
assertEquals("e5 was expected to be the top element", "e5", z.bottom(1)[0]); | |
assertEquals("Scores mismatch", z.getScore("e1"), 1.0, DELTA); | |
z.setScore("e5", 5.0); | |
assertEquals("e5 was expected to be the top element", "e5", z.top(1)[0]); | |
assertEquals("Scores mismatch", z.getScore("e5"), 5.0, DELTA); | |
z.incr("e4"); | |
z.incr("e4"); | |
assertEquals("Scores mismatch", z.getScore("e4"), 6.0, DELTA); | |
assertEquals("e4 was expected to be the top element", "e4", z.top(1)[0]); | |
z.decr("e2"); | |
z.decr("e2"); | |
System.out.println(z); | |
System.out.println(Arrays.toString(z.topWithScores(5))); | |
assertArrayEquals("e2 was expected to be the bottom element", new String[]{"e2", "e1", "e3"}, z.bottom(3)); | |
testIntegrity(z, "Failed after manipulating a few scores"); | |
} | |
@Test | |
public void perfTest() { | |
ScoredSet<String> z = ZSet.create(String.class); | |
long start = System.currentTimeMillis(); | |
for (int i = 0; i < 100; i++) { | |
z.add("e" + i, 100 * Math.random()); | |
} | |
long end = System.currentTimeMillis(); | |
System.out.println(z); | |
} | |
public void testSize(ScoredSet<?> z, int size) { | |
assertEquals("Set was of unexpected size", size, z.size()); | |
} | |
private <T> void testIntegrity(ScoredSet<?> z, String annotation) throws Exception { | |
assertEquals(z.getClass(), ZSet.class); | |
Field vh = z.getClass().getDeclaredField("hMap"); | |
vh.setAccessible(true); | |
@SuppressWarnings("unchecked") | |
Map<T, Double> hMap = (HashMap<T, Double>) vh.get(z); | |
Field vt = z.getClass().getDeclaredField("tSet"); | |
vt.setAccessible(true); | |
@SuppressWarnings("unchecked") | |
NavigableSet<T> tSet = (TreeSet<T>) vt.get(z); | |
assertEquals(String.format("hMap is of size %d but tMap is of size %d\n(%s)", hMap.size(), tSet.size(), annotation), hMap.size(), tSet.size()); | |
assertEquals("Sets are unequal" + "\n(" + annotation + ")", hMap.keySet(), tSet); | |
for (T key : tSet) { | |
assertTrue("hMap does not contain key " + key + "\n(" + annotation + ")", hMap.containsKey(key)); | |
} | |
for (T e : hMap.keySet()) { | |
assertTrue("tSet does not contain element " + e + "\n(" + annotation + ")", tSet.contains(e)); | |
} | |
} | |
} |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
package manu.sandbox.utils; | |
import java.lang.reflect.Array; | |
import java.util.*; | |
/** | |
* Implementation of a 'scored set', a type of set in which elements are sorted by their 'scores'. | |
* Comes handy when you wish to compute top 'K' or bottom 'K' of a given data set. | |
* <p/>Worst-case time complexities: <ul> | |
* <li><code>O(1)</code> for retrieval of score</li> | |
* <li><code>O(log(n))</code> for addition/removal of elements</li> | |
* <li><code>O(log(n))</code> for changes to scores</li> | |
* <li><code>O(k)</code> for getting top k or bottom k elements</li> | |
* </ul><p/> | |
* | |
* @param <T> Should be a comparable data type | |
* @author Manu Manjunath | |
*/ | |
public class ZSet<T extends Comparable<T>> implements ScoredSet<T> { | |
private static final Double | |
DEFAULT_SCORE = 0.0, | |
DEFAULT_SCORE_INCREMENT = 1.0; | |
private static class ScoreComparator<T extends Comparable<T>> implements Comparator<T> { | |
private final Map<T, Double> parentMap; | |
private ScoreComparator(Map<T, Double> _hMap) { | |
this.parentMap = _hMap; | |
} | |
@Override | |
public int compare(T a, T b) { | |
Double sa = this.parentMap.get(a), sb = this.parentMap.get(b); | |
if (sa == null || sb == null || sa.equals(sb)) { | |
// if scores are same, compare elements | |
// (this totally depends on your null handling policy) | |
return a.compareTo(b); | |
} | |
return sa.compareTo(sb); | |
} | |
} | |
public static class ElementWithScoreImpl<T> implements ElementWithScore<T> { | |
private final T element; | |
private final Double score; | |
private ElementWithScoreImpl(T element, Double score) { | |
this.element = element; | |
this.score = score; | |
} | |
@Override | |
public T getElement() { | |
return element; | |
} | |
@Override | |
public Double getScore() { | |
return score; | |
} | |
@Override | |
public String toString() { | |
return String.format("%s(%s)", getElement(), getScore()); | |
} | |
} | |
private final Map<T, Double> hMap; | |
private final NavigableSet<T> tSet; | |
private final Class<T> componentType; | |
private ZSet(Class<T> componentType) { | |
this.componentType = componentType; | |
this.hMap = new HashMap<>(); | |
this.tSet = new TreeSet<>(new ScoreComparator<>(hMap)); | |
} | |
/** | |
* Construct an object of {@link ZSet} class of specified container type | |
* | |
* @param clazz container type (e.g. <code>String.class</code>) | |
*/ | |
public static <T extends Comparable<T>> ScoredSet<T> create(Class<T> clazz) { | |
return new ZSet<>(clazz); | |
} | |
@Override | |
public boolean add(T element) { | |
return this.add(element, DEFAULT_SCORE); | |
} | |
/** | |
* Adds an element with specified score. If the element already exists, leaves | |
* the {@link ZSet} unchanged (and returns <code>false</code>) | |
* | |
* @param element Element to add | |
* @param score Score for the element | |
* @return <code>true</code> if addition was successful | |
*/ | |
@Override | |
public synchronized boolean add(T element, Double score) { | |
if (element == null) { | |
throw new IllegalArgumentException("Cannot accept a null element"); | |
} | |
if (score == null) { | |
throw new IllegalArgumentException("Cannot accept a null score"); | |
} | |
if (this.hMap.containsKey(element)) { // if element was already present | |
return false; | |
} else { | |
this.hMap.put(element, score); | |
this.tSet.add(element); | |
return true; | |
} | |
} | |
@Override | |
public String toString() { | |
StringBuilder sb = new StringBuilder(); | |
sb.append('{'); | |
int ri = this.tSet.size(); | |
for (T key : this.tSet.descendingSet()) { | |
sb.append(key); | |
sb.append('('); | |
sb.append(this.hMap.get(key)); | |
sb.append(')'); | |
if (--ri > 0) { | |
sb.append(", "); | |
} | |
} | |
sb.append('}'); | |
return sb.toString(); | |
} | |
/** | |
* Get top <code>k</code> elements (by magnitude of scores) from this sorted set | |
* | |
* @param k Number of elements | |
* @return Top <code>k</code> elements | |
*/ | |
@Override | |
public T[] top(int k) { | |
if (k < 0) { | |
throw new IllegalArgumentException("K cannot be negative or zero"); | |
} | |
T[] topArray = (T[]) Array.newInstance(componentType, k); | |
int i = 0; | |
for (T e : this.tSet.descendingSet()) { | |
if (i == k) { | |
break; | |
} | |
topArray[i] = e; | |
i++; | |
} | |
return topArray; | |
} | |
@Override | |
public ElementWithScore<T>[] topWithScores(int k) { | |
ElementWithScore<T>[] topArray = (ElementWithScore<T>[]) Array.newInstance(ElementWithScoreImpl.class, k); | |
int i = 0; | |
for (T e : this.tSet.descendingSet()) { | |
if (i == k) { | |
break; | |
} | |
topArray[i] = new ElementWithScoreImpl<>(e, this.getScore(e)); | |
i++; | |
} | |
return topArray; | |
} | |
/** | |
* Get bottom <code>k</code> elements (by magnitude of scores) from this sorted set | |
* | |
* @param k Number of elements | |
* @return Bottom <code>k</code> elements | |
*/ | |
@Override | |
public T[] bottom(int k) { | |
if (k < 0) { | |
throw new IllegalArgumentException("k cannot be negative or zero"); | |
} | |
T[] bottomArray = (T[]) Array.newInstance(componentType, k); | |
int i = 0; | |
for (T e : this.tSet) { | |
if (i == k) { | |
break; | |
} | |
bottomArray[i] = e; | |
i++; | |
} | |
return bottomArray; | |
} | |
@Override | |
public ElementWithScore<T>[] bottomWithScores(int k) { | |
ElementWithScore<T>[] topArray = (ElementWithScore<T>[]) Array.newInstance(ElementWithScoreImpl.class, k); | |
int i = 0; | |
for (T e : this.tSet) { | |
if (i == k) { | |
break; | |
} | |
topArray[i] = new ElementWithScoreImpl<>(e, this.getScore(e)); | |
i++; | |
} | |
return topArray; | |
} | |
@Override | |
public int size() { | |
return this.hMap.size(); | |
} | |
@Override | |
public boolean isEmpty() { | |
return this.hMap.isEmpty(); | |
} | |
@Override | |
public boolean contains(Object o) { | |
return this.hMap.containsKey(o); | |
} | |
@Override | |
public Iterator<T> iterator() { | |
return this.tSet.iterator(); | |
} | |
@Override | |
public Object[] toArray() { | |
return this.hMap.keySet().toArray(); | |
} | |
@Override | |
public <E> E[] toArray(E[] a) { | |
return this.hMap.keySet().toArray(a); | |
} | |
@Override | |
public synchronized boolean remove(Object o) { | |
if (!this.hMap.containsKey(o)) | |
return false; | |
this.tSet.remove(o); | |
// remove() on above TreeSet uses values stored in hMap below | |
this.hMap.remove(o); | |
return true; | |
} | |
@Override | |
public boolean containsAll(Collection<?> c) { | |
boolean flag = true; | |
for (Object e : c) { | |
flag &= this.contains(e); | |
} | |
return flag; | |
} | |
@Override | |
public synchronized boolean addAll(Collection<? extends T> c) { | |
boolean flag = false; | |
for (T o : c) { | |
flag |= this.add(o); | |
} | |
return flag; | |
} | |
@Override | |
public synchronized boolean retainAll(Collection<?> c) { | |
boolean flag = false; | |
Set<T> keys = new HashSet<>(this.hMap.keySet()); | |
for (T key : keys) { | |
if (!c.contains(key)) { | |
flag |= this.remove(key); | |
} | |
} | |
return flag; | |
} | |
@Override | |
public synchronized boolean removeAll(Collection<?> c) { | |
boolean flag = false; | |
for (Object o : c) { | |
flag |= this.remove(o); | |
} | |
return flag; | |
} | |
@Override | |
public synchronized void clear() { | |
this.tSet.clear(); | |
this.hMap.clear(); | |
} | |
/** | |
* Gets score of given element | |
* | |
* @param element Element whose score needs to be fetched | |
* @return score | |
*/ | |
@Override | |
public Double getScore(T element) { | |
return this.hMap.get(element); | |
} | |
/** | |
* Sets score of given element to specified value (Re-orders set automatically) | |
* | |
* @param element Element whose score needs to be changed | |
* @param score Score to be set | |
*/ | |
@Override | |
public synchronized void setScore(T element, Double score) { | |
this.tSet.remove(element); | |
this.hMap.put(element, score); | |
this.tSet.add(element); | |
} | |
/** | |
* Increment score of the specified element by a given delta | |
* | |
* @param element Element to increment | |
*/ | |
@Override | |
public synchronized boolean incrBy(T element, Double deltaScore) { | |
Double preScore = this.hMap.get(element); | |
if (preScore == null) { | |
return false; | |
} else { | |
setScore(element, preScore + deltaScore); | |
return true; | |
} | |
} | |
/** | |
* Decrement score of the specified element by a given delta | |
* | |
* @param element Element to increment | |
*/ | |
@Override | |
public synchronized boolean decrBy(T element, Double deltaScore) { | |
return incrBy(element, -deltaScore); | |
} | |
/** | |
* Increment score of the specified element | |
* | |
* @param element Element to increment | |
*/ | |
@Override | |
public synchronized boolean decr(T element) { | |
return this.decrBy(element, DEFAULT_SCORE_INCREMENT); | |
} | |
/** | |
* Decrement score of the specified element | |
* | |
* @param element Element to increment | |
*/ | |
@Override | |
public synchronized boolean incr(T element) { | |
return this.incrBy(element, DEFAULT_SCORE_INCREMENT); | |
} | |
@Override | |
public ZSet<T> clone() { | |
ZSet<T> cloneObj = new ZSet<>(componentType); | |
for (Map.Entry<T, Double> e : this.hMap.entrySet()) { | |
cloneObj.add(e.getKey(), e.getValue()); | |
} | |
return cloneObj; | |
} | |
} |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment