Last active
September 7, 2017 23:47
-
-
Save ebernhardson/4e9dcfc9a8e41d05a89d5bd26d1fdc64 to your computer and use it in GitHub Desktop.
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
import org.apache.spark.ml.linalg.{Vector, Vectors} | |
import org.apache.spark.ml.feature.LabeledPoint | |
import org.apache.spark.rdd.RDD | |
import scala.collection.mutable.ArrayBuffer | |
import scala.util.Random | |
def randomVec(r: Random, size: Int): Vector = { | |
val feats = for (i <- 0 to size) yield r.nextDouble | |
Vectors.dense(feats.toArray) | |
} | |
def randomPoint(r: Random, size: Int): LabeledPoint = LabeledPoint(r.nextInt(4), randomVec(r, size)) | |
def randomGroups(r: Random, points: Int): Seq[Int] = { | |
var sum = 0 | |
val group = ArrayBuffer[Int]() | |
while (sum < points) { | |
val next = 1 + r.nextInt(9) | |
if (sum + next > points) { | |
group += points - sum + 1 | |
sum += points - sum + 1 | |
} else { | |
group += next | |
sum += next | |
} | |
} | |
group.toSeq | |
} | |
import scala.io.Source | |
def estimateOffHeap(pid: Int): Long = { | |
val rss = for (line <- Source.fromFile("/proc/" + pid + "/status").getLines) yield { | |
line.split('\t') match { | |
case Array("VmRSS:", rss: String) => Some(rss.trim.split(' ')(0).toLong) | |
case _ => None | |
} | |
} | |
rss.toArray.flatten.apply(0) - (Runtime.getRuntime().totalMemory / 1024) | |
} | |
val nPoints = 1000000 | |
val r = new scala.util.Random(0) | |
val points = sc.parallelize(for (i <- 0 to nPoints) yield randomPoint(r, 3), 1) | |
points.saveAsObjectFile("file:///home/ebernhardson/xgboost-mem-leak") | |
val points: RDD[LabeledPoint] = sc.objectFile("file:///home/ebernhardson/xgboost-mem-leak", 1) | |
val params = Map( | |
"num_rounds" -> 1, | |
"objective" -> "rank:ndcg", | |
"eval_metric" -> "ndcg@10", | |
"groupData" -> Seq(randomGroups(r, nPoints))) | |
val pid = java.lang.management.ManagementFactory.getRuntimeMXBean().getName().split('@')(0).toInt | |
import scala.collection.parallel._ | |
val jobs = (1 to 10).toList.par | |
jobs.tasksupport = new ForkJoinTaskSupport(new scala.concurrent.forkjoin.ForkJoinPool(10)) | |
val diffs = (0 to 20).map { i => | |
val begin = estimateOffHeap(pid) | |
jobs map { i => ml.dmlc.xgboost4j.scala.spark.XGBoost.trainWithRDD(points, params, 1, 1) } | |
Thread.sleep(1000) | |
estimateOffHeap(pid) - begin | |
} | |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment