CALL com.maxdemarzi.similarity(0.90, 100)
Created
September 12, 2017 17:26
-
-
Save maxdemarzi/ee3e3be8fa10f4e25a8ba9df31a629ac to your computer and use it in GitHub Desktop.
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
package com.maxdemarzi; | |
import com.maxdemarzi.results.StringResult; | |
import org.neo4j.kernel.internal.GraphDatabaseAPI; | |
import org.neo4j.logging.Log; | |
import org.neo4j.procedure.*; | |
import java.util.stream.Stream; | |
public class Similarity { | |
@Context | |
public GraphDatabaseAPI db; | |
@Context | |
public Log log; | |
@Description("com.maxdemarzi.similarity() ") | |
@Procedure(name = "com.maxdemarzi.similarity", mode = Mode.WRITE) | |
public Stream<StringResult> Similarity(@Name("minimum") Double min, @Name("limit") Number limit) throws InterruptedException { | |
Thread t1 = new Thread(new SimilarityRunnable(min, limit.intValue(), db, log)); | |
t1.start(); | |
t1.join(); | |
return Stream.of(new StringResult("Similarities were calculated.")); | |
} | |
} |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
package com.maxdemarzi; | |
import com.maxdemarzi.schema.Labels; | |
import com.maxdemarzi.schema.RelationshipTypes; | |
import org.neo4j.graphdb.*; | |
import org.neo4j.helpers.collection.Pair; | |
import org.neo4j.kernel.internal.GraphDatabaseAPI; | |
import org.neo4j.logging.Log; | |
import java.util.*; | |
import java.util.concurrent.TimeUnit; | |
import static java.util.Collections.reverseOrder; | |
public class SimilarityRunnable implements Runnable { | |
private static final int TRANSACTION_LIMIT = 10; | |
private static GraphDatabaseAPI db; | |
private Double min; | |
private Integer limit; | |
private static Log log; | |
public SimilarityRunnable (Double min, Integer limit, GraphDatabaseAPI db, Log log) { | |
this.min = min; | |
this.limit = limit; | |
this.db = db; | |
this.log = log; | |
} | |
@Override | |
public void run() { | |
long start = System.nanoTime(); | |
// Get all the Customer Accounts that have been divested | |
ArrayList<Node> divestedAccounts = new ArrayList<>(); | |
try (Transaction tx = db.beginTx()) { | |
ResourceIterator<Node> iterator = db.findNodes(Labels.divested); | |
while (iterator.hasNext()) { | |
divestedAccounts.add(iterator.next()); | |
} | |
tx.success(); | |
} | |
// For each divested account find similar accounts | |
Transaction tx = db.beginTx(); | |
int count = 0; | |
try { | |
for (Node account : divestedAccounts) { | |
count++; | |
Map<Node, List<Double>> mine = new HashMap<>(); | |
Map<Node, List<Double>> theirs = new HashMap<>(); | |
for (Relationship r : account.getRelationships(Direction.OUTGOING, RelationshipTypes.TAGGED)) { | |
Double weight = (Double)r.getProperty("weight"); | |
Node vector = r.getEndNode(); | |
for (Relationship r2 : vector.getRelationships(Direction.INCOMING, RelationshipTypes.TAGGED)) { | |
Node account2 = r2.getStartNode(); | |
if (!account.equals(account2)) { | |
addVectorNodes(mine, account2, weight); | |
addVectorNodes(theirs, account2, (Double)r2.getProperty("weight")); | |
} | |
} | |
} | |
ArrayList<Pair<Node, Double>> top = new ArrayList<>(); | |
for (Map.Entry<Node, List<Double>> entry : mine.entrySet()) { | |
double score = calculateSimilarity(entry.getValue(), theirs.get(entry.getKey())); | |
if (score >= min) { | |
top.add(Pair.of(entry.getKey(), score)); | |
} | |
} | |
top.sort(Comparator.comparing(m -> (Double) m.other(), reverseOrder())); | |
for (Pair<Node, Double> calculation : top.subList(0, Math.min(top.size(), limit))){ | |
Relationship similar = account.createRelationshipTo(calculation.first(), RelationshipTypes.SIMILAR); | |
similar.setProperty("similarity", calculation.other()); | |
} | |
if (count % TRANSACTION_LIMIT == 0) { | |
tx.success(); | |
tx.close(); | |
tx = db.beginTx(); | |
log.info("Committing similarity work after " + count + " in " + TimeUnit.NANOSECONDS.toSeconds(System.nanoTime() - start) + " seconds since starting."); | |
} | |
} | |
tx.success(); | |
} finally { | |
tx.close(); | |
} | |
long timeTaken = TimeUnit.NANOSECONDS.toSeconds(System.nanoTime() - start); | |
log.info("Similarity calculated in " + timeTaken + " Seconds"); | |
} | |
private void addVectorNodes(Map<Node, List<Double>> multimap, Node key, Double value) { | |
List<Double> list = multimap.computeIfAbsent(key, k -> new ArrayList<>()); | |
list.add(value); | |
} | |
private double calculateSimilarity(List<Double> vector1, List<Double> vector2) { | |
double dotProduct = 0d; | |
double xLength = 0d; | |
double yLength = 0d; | |
for (int i = 0; i < vector1.size(); i++) { | |
dotProduct += vector1.get(i) * vector2.get(i); | |
xLength += vector1.get(i) * vector1.get(i); | |
yLength += vector2.get(i) * vector2.get(i); | |
} | |
xLength = Math.sqrt(xLength); | |
yLength = Math.sqrt(yLength); | |
return dotProduct / (xLength * yLength); | |
} | |
} |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
package com.maxdemarzi.results; | |
/** | |
* @author mh | |
* @since 26.02.16 | |
*/ | |
public class StringResult { | |
public final static StringResult EMPTY = new StringResult(null); | |
public final String value; | |
public StringResult(String value) { | |
this.value = value; | |
} | |
} |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment