Skip to content

Instantly share code, notes, and snippets.

@m-manu
Last active January 25, 2016 06:55
Show Gist options
  • Save m-manu/9073351 to your computer and use it in GitHub Desktop.
Save m-manu/9073351 to your computer and use it in GitHub Desktop.
Implementation of a "sorted set", a type of set in which elements are sorted according to their 'scores'. Worst-case time complexity: O(1) for retrieval of score, O(log(n)) for addition/removal of elements, O(log(n)) for changes to scores, O(k) for getting top k or bottom k elements.
package manu.sandbox.utils;
import java.util.Set;
public interface ScoredSet<T extends Comparable<T>> extends Set<T>, Cloneable {
interface ElementWithScore<T> {
T getElement();
Double getScore();
}
boolean add(T element, Double score);
T[] top(int k);
ElementWithScore<T>[] topWithScores(int k);
T[] bottom(int k);
ElementWithScore<T>[] bottomWithScores(int k);
Double getScore(T element);
void setScore(T element, Double score);
boolean incrBy(T element, Double deltaScore);
boolean incr(T element);
boolean decrBy(T element, Double deltaScore);
boolean decr(T element);
}
package manu.sandbox.utils;
import org.junit.Test;
import java.lang.reflect.Field;
import java.util.*;
import static org.junit.Assert.*;
public class TestZSet {
public static final double DELTA = 0.0001;
@Test
public void test() throws Exception {
testGiven(String.class, new String[]{"e1", "e2", "e3", "e4", "e5"}, new Double[]{1.0, 2.0, 3.0, 4.0, 5.0});
testGiven(String.class, new String[]{"e1", "e2", "e3", "e4", "e5", "e6"}, new Double[]{0.0, 0.0, 0.0, 0.0, 0.0, 0.0});
testGiven(Character.class, new Character[]{'A', 'B', 'C', 'D', 'E'}, new Double[]{1.0, 2.0, 3.0, 4.0, 5.0});
testGiven(Integer.class, new Integer[]{5, 4, 3, 2, 1, 0}, new Double[]{0.0, 0.0, 0.0, 0.0, 0.0, 0.0});
testGiven(Double.class, new Double[]{5.3, 4.2, 3.1, 2.9, 1.000000001, 0.00000001}, new Double[]{-1.0, -2.0, -3.0, -4.0, -10.0, -110.0});
testGiven(String.class, new String[]{"e1", "e2", "e3", "e4", "e5", "e6"}, new Double[]{1.0, 1.0, 0.9, -5.0, 4.0, -5.0});
}
public <T extends Comparable<T>> void testGiven(Class<T> clazz, T[] elements, Double[] scores) throws Exception {
assertTrue("Bug in tests case!", elements.length == scores.length && elements.length >= 5);
T e1 = elements[0], e2 = elements[1], e3 = elements[2], e4 = elements[3], e5 = elements[4];
Double s1 = scores[0], s2 = scores[1], s3 = scores[2], s4 = scores[3], s5 = scores[4];
ScoredSet<T> z = ZSet.create(clazz);
assertTrue("Set should be empty", z.isEmpty());
for (int i = 0; i < scores.length; i++) {
z.add(elements[i], scores[i]);
}
for (int i = 0; i < scores.length; i++) {
Double score = z.getScore(elements[i]);
assertEquals("Score mismatch for " + elements[i], scores[i], score);
}
T[] arrInUse = Arrays.copyOf(elements, elements.length);
// System.out.println(Arrays.asList(arrInUse));
Object[] arrExtracted1 = new Object[elements.length];
z.toArray(arrExtracted1);
Arrays.sort(arrExtracted1);
// System.out.println(Arrays.asList(arrExtracted1));
Object[] arrExtracted2 = z.toArray();
Arrays.sort(arrExtracted2);
// System.out.println(Arrays.asList(arrExtracted2));
assertTrue("Extracted array for toArray(a) is not same as input array",
ArrayAlgos.areArraysSame(arrExtracted1, arrInUse));
assertTrue("Extracted array for toArray() is not same as input array",
ArrayAlgos.areArraysSame(arrExtracted2, arrInUse));
testIntegrity(z, "Integrity tests failed for simple add");
z.clear();
testSize(z, 0);
testIntegrity(z, "Integrity tests failed for clear()");
assertTrue("Set should be empty here", z.isEmpty());
// add elements:
z.add(e1, s1);
testSize(z, 1);
z.add(e2, s2);
testSize(z, 2);
assertTrue("addAll() added two elements - should've returned true", z.addAll(Arrays.asList(e3, e4)));
testSize(z, 4);
assertTrue("addAll() added an element - should've returned true", z.addAll(Arrays.asList(e5)));
testSize(z, 5);
assertFalse("addAll() didn't add any elements - should've returned false", z.addAll(Arrays.asList(e3, e4)));
testIntegrity(z, "Integrity tests failed after add() calls");
assertTrue("e1 should exist", z.contains(e1));
// Remove:
assertTrue("remove() should return true when it removes an element", z.remove(e1));
testIntegrity(z, "Integrity tests failed after remove()");
// Remove again:
assertFalse("remove() should return false when it doesn't remove an element", z.remove(e1));
testSize(z, 4);
testIntegrity(z, "Integrity tests failed after remove()");
assertTrue("e1 should not exist", !z.contains(e1));
assertTrue("Missing elements", z.containsAll(Arrays.asList(e2, e3, e4, e5)));
assertFalse("containsAll() should've returned false here", z.containsAll(Arrays.asList(e2, e3, e4, e1)));
assertTrue("remove() should return true when it removes an element", z.remove(e2));
testSize(z, 3);
testIntegrity(z, "Integrity tests failed after second remove()");
assertTrue("removeAll() removed an element - but returned false", z.removeAll(Arrays.asList(e3)));
testIntegrity(z, "Integrity tests failed after removeAll()");
assertFalse("removeAll() didn't remove any elements - but returned true", z.removeAll(Arrays.asList(e2, e3)));
testIntegrity(z, "Integrity tests failed after second removeAll()");
testSize(z, 2);
z.addAll(Arrays.asList(e1, e2, e3));
assertTrue("retainAll() removed elements - should've returned true", z.retainAll(Arrays.asList(e4)));
assertFalse("retainAll() didn't remove any elements - should've returned false", z.retainAll(Arrays.asList(e4)));
testIntegrity(z, "Integrity tests failed after retainAll()");
testSize(z, 1);
z.retainAll(Arrays.asList(e5));
testIntegrity(z, "Integrity tests failed after calling retainAll() for non-existent element");
testSize(z, 0);
z.add(e3, s3);
z.add(e4, s4);
z.add(e5, s5);
}
@Test
public void testScores() throws Exception {
ScoredSet<String> z = ZSet.create(String.class);
z.add("e1", 1.0);
z.add("e2", 2.0);
z.add("e3", 3.0);
z.add("e4", 4.0);
z.add("e5", 0.0);
testIntegrity(z, "Failed after initialization itself!");
assertEquals("e5 was expected to be the top element", "e5", z.bottom(1)[0]);
assertEquals("Scores mismatch", z.getScore("e1"), 1.0, DELTA);
z.setScore("e5", 5.0);
assertEquals("e5 was expected to be the top element", "e5", z.top(1)[0]);
assertEquals("Scores mismatch", z.getScore("e5"), 5.0, DELTA);
z.incr("e4");
z.incr("e4");
assertEquals("Scores mismatch", z.getScore("e4"), 6.0, DELTA);
assertEquals("e4 was expected to be the top element", "e4", z.top(1)[0]);
z.decr("e2");
z.decr("e2");
System.out.println(z);
System.out.println(Arrays.toString(z.topWithScores(5)));
assertArrayEquals("e2 was expected to be the bottom element", new String[]{"e2", "e1", "e3"}, z.bottom(3));
testIntegrity(z, "Failed after manipulating a few scores");
}
@Test
public void perfTest() {
ScoredSet<String> z = ZSet.create(String.class);
long start = System.currentTimeMillis();
for (int i = 0; i < 100; i++) {
z.add("e" + i, 100 * Math.random());
}
long end = System.currentTimeMillis();
System.out.println(z);
}
public void testSize(ScoredSet<?> z, int size) {
assertEquals("Set was of unexpected size", size, z.size());
}
private <T> void testIntegrity(ScoredSet<?> z, String annotation) throws Exception {
assertEquals(z.getClass(), ZSet.class);
Field vh = z.getClass().getDeclaredField("hMap");
vh.setAccessible(true);
@SuppressWarnings("unchecked")
Map<T, Double> hMap = (HashMap<T, Double>) vh.get(z);
Field vt = z.getClass().getDeclaredField("tSet");
vt.setAccessible(true);
@SuppressWarnings("unchecked")
NavigableSet<T> tSet = (TreeSet<T>) vt.get(z);
assertEquals(String.format("hMap is of size %d but tMap is of size %d\n(%s)", hMap.size(), tSet.size(), annotation), hMap.size(), tSet.size());
assertEquals("Sets are unequal" + "\n(" + annotation + ")", hMap.keySet(), tSet);
for (T key : tSet) {
assertTrue("hMap does not contain key " + key + "\n(" + annotation + ")", hMap.containsKey(key));
}
for (T e : hMap.keySet()) {
assertTrue("tSet does not contain element " + e + "\n(" + annotation + ")", tSet.contains(e));
}
}
}
package manu.sandbox.utils;
import java.lang.reflect.Array;
import java.util.*;
/**
* Implementation of a 'scored set', a type of set in which elements are sorted by their 'scores'.
* Comes handy when you wish to compute top 'K' or bottom 'K' of a given data set.
* <p/>Worst-case time complexities: <ul>
* <li><code>O(1)</code> for retrieval of score</li>
* <li><code>O(log(n))</code> for addition/removal of elements</li>
* <li><code>O(log(n))</code> for changes to scores</li>
* <li><code>O(k)</code> for getting top k or bottom k elements</li>
* </ul><p/>
*
* @param <T> Should be a comparable data type
* @author Manu Manjunath
*/
public class ZSet<T extends Comparable<T>> implements ScoredSet<T> {
private static final Double
DEFAULT_SCORE = 0.0,
DEFAULT_SCORE_INCREMENT = 1.0;
private static class ScoreComparator<T extends Comparable<T>> implements Comparator<T> {
private final Map<T, Double> parentMap;
private ScoreComparator(Map<T, Double> _hMap) {
this.parentMap = _hMap;
}
@Override
public int compare(T a, T b) {
Double sa = this.parentMap.get(a), sb = this.parentMap.get(b);
if (sa == null || sb == null || sa.equals(sb)) {
// if scores are same, compare elements
// (this totally depends on your null handling policy)
return a.compareTo(b);
}
return sa.compareTo(sb);
}
}
public static class ElementWithScoreImpl<T> implements ElementWithScore<T> {
private final T element;
private final Double score;
private ElementWithScoreImpl(T element, Double score) {
this.element = element;
this.score = score;
}
@Override
public T getElement() {
return element;
}
@Override
public Double getScore() {
return score;
}
@Override
public String toString() {
return String.format("%s(%s)", getElement(), getScore());
}
}
private final Map<T, Double> hMap;
private final NavigableSet<T> tSet;
private final Class<T> componentType;
private ZSet(Class<T> componentType) {
this.componentType = componentType;
this.hMap = new HashMap<>();
this.tSet = new TreeSet<>(new ScoreComparator<>(hMap));
}
/**
* Construct an object of {@link ZSet} class of specified container type
*
* @param clazz container type (e.g. <code>String.class</code>)
*/
public static <T extends Comparable<T>> ScoredSet<T> create(Class<T> clazz) {
return new ZSet<>(clazz);
}
@Override
public boolean add(T element) {
return this.add(element, DEFAULT_SCORE);
}
/**
* Adds an element with specified score. If the element already exists, leaves
* the {@link ZSet} unchanged (and returns <code>false</code>)
*
* @param element Element to add
* @param score Score for the element
* @return <code>true</code> if addition was successful
*/
@Override
public synchronized boolean add(T element, Double score) {
if (element == null) {
throw new IllegalArgumentException("Cannot accept a null element");
}
if (score == null) {
throw new IllegalArgumentException("Cannot accept a null score");
}
if (this.hMap.containsKey(element)) { // if element was already present
return false;
} else {
this.hMap.put(element, score);
this.tSet.add(element);
return true;
}
}
@Override
public String toString() {
StringBuilder sb = new StringBuilder();
sb.append('{');
int ri = this.tSet.size();
for (T key : this.tSet.descendingSet()) {
sb.append(key);
sb.append('(');
sb.append(this.hMap.get(key));
sb.append(')');
if (--ri > 0) {
sb.append(", ");
}
}
sb.append('}');
return sb.toString();
}
/**
* Get top <code>k</code> elements (by magnitude of scores) from this sorted set
*
* @param k Number of elements
* @return Top <code>k</code> elements
*/
@Override
public T[] top(int k) {
if (k < 0) {
throw new IllegalArgumentException("K cannot be negative or zero");
}
T[] topArray = (T[]) Array.newInstance(componentType, k);
int i = 0;
for (T e : this.tSet.descendingSet()) {
if (i == k) {
break;
}
topArray[i] = e;
i++;
}
return topArray;
}
@Override
public ElementWithScore<T>[] topWithScores(int k) {
ElementWithScore<T>[] topArray = (ElementWithScore<T>[]) Array.newInstance(ElementWithScoreImpl.class, k);
int i = 0;
for (T e : this.tSet.descendingSet()) {
if (i == k) {
break;
}
topArray[i] = new ElementWithScoreImpl<>(e, this.getScore(e));
i++;
}
return topArray;
}
/**
* Get bottom <code>k</code> elements (by magnitude of scores) from this sorted set
*
* @param k Number of elements
* @return Bottom <code>k</code> elements
*/
@Override
public T[] bottom(int k) {
if (k < 0) {
throw new IllegalArgumentException("k cannot be negative or zero");
}
T[] bottomArray = (T[]) Array.newInstance(componentType, k);
int i = 0;
for (T e : this.tSet) {
if (i == k) {
break;
}
bottomArray[i] = e;
i++;
}
return bottomArray;
}
@Override
public ElementWithScore<T>[] bottomWithScores(int k) {
ElementWithScore<T>[] topArray = (ElementWithScore<T>[]) Array.newInstance(ElementWithScoreImpl.class, k);
int i = 0;
for (T e : this.tSet) {
if (i == k) {
break;
}
topArray[i] = new ElementWithScoreImpl<>(e, this.getScore(e));
i++;
}
return topArray;
}
@Override
public int size() {
return this.hMap.size();
}
@Override
public boolean isEmpty() {
return this.hMap.isEmpty();
}
@Override
public boolean contains(Object o) {
return this.hMap.containsKey(o);
}
@Override
public Iterator<T> iterator() {
return this.tSet.iterator();
}
@Override
public Object[] toArray() {
return this.hMap.keySet().toArray();
}
@Override
public <E> E[] toArray(E[] a) {
return this.hMap.keySet().toArray(a);
}
@Override
public synchronized boolean remove(Object o) {
if (!this.hMap.containsKey(o))
return false;
this.tSet.remove(o);
// remove() on above TreeSet uses values stored in hMap below
this.hMap.remove(o);
return true;
}
@Override
public boolean containsAll(Collection<?> c) {
boolean flag = true;
for (Object e : c) {
flag &= this.contains(e);
}
return flag;
}
@Override
public synchronized boolean addAll(Collection<? extends T> c) {
boolean flag = false;
for (T o : c) {
flag |= this.add(o);
}
return flag;
}
@Override
public synchronized boolean retainAll(Collection<?> c) {
boolean flag = false;
Set<T> keys = new HashSet<>(this.hMap.keySet());
for (T key : keys) {
if (!c.contains(key)) {
flag |= this.remove(key);
}
}
return flag;
}
@Override
public synchronized boolean removeAll(Collection<?> c) {
boolean flag = false;
for (Object o : c) {
flag |= this.remove(o);
}
return flag;
}
@Override
public synchronized void clear() {
this.tSet.clear();
this.hMap.clear();
}
/**
* Gets score of given element
*
* @param element Element whose score needs to be fetched
* @return score
*/
@Override
public Double getScore(T element) {
return this.hMap.get(element);
}
/**
* Sets score of given element to specified value (Re-orders set automatically)
*
* @param element Element whose score needs to be changed
* @param score Score to be set
*/
@Override
public synchronized void setScore(T element, Double score) {
this.tSet.remove(element);
this.hMap.put(element, score);
this.tSet.add(element);
}
/**
* Increment score of the specified element by a given delta
*
* @param element Element to increment
*/
@Override
public synchronized boolean incrBy(T element, Double deltaScore) {
Double preScore = this.hMap.get(element);
if (preScore == null) {
return false;
} else {
setScore(element, preScore + deltaScore);
return true;
}
}
/**
* Decrement score of the specified element by a given delta
*
* @param element Element to increment
*/
@Override
public synchronized boolean decrBy(T element, Double deltaScore) {
return incrBy(element, -deltaScore);
}
/**
* Increment score of the specified element
*
* @param element Element to increment
*/
@Override
public synchronized boolean decr(T element) {
return this.decrBy(element, DEFAULT_SCORE_INCREMENT);
}
/**
* Decrement score of the specified element
*
* @param element Element to increment
*/
@Override
public synchronized boolean incr(T element) {
return this.incrBy(element, DEFAULT_SCORE_INCREMENT);
}
@Override
public ZSet<T> clone() {
ZSet<T> cloneObj = new ZSet<>(componentType);
for (Map.Entry<T, Double> e : this.hMap.entrySet()) {
cloneObj.add(e.getKey(), e.getValue());
}
return cloneObj;
}
}
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment