Last active
July 9, 2020 11:02
-
-
Save dacr/4642782 to your computer and use it in GitHub Desktop.
Custom scala collection examples
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
import scala.collection._ | |
import scala.collection.mutable.{ArrayBuffer,ListBuffer, Builder} | |
import scala.collection.generic._ | |
import scala.collection.immutable.VectorBuilder | |
// ================================ CustomTraversable ================================== | |
object CustomTraversable extends TraversableFactory[CustomTraversable] { | |
implicit def canBuildFrom[A]: CanBuildFrom[Coll, A, CustomTraversable[A]] = | |
new GenericCanBuildFrom[A] | |
def newBuilder[A] = new ListBuffer[A] mapResult (x => new CustomTraversable(x:_*)) | |
} | |
class CustomTraversable[A](seq : A*) | |
extends Traversable[A] | |
with GenericTraversableTemplate[A, CustomTraversable] | |
with TraversableLike[A, CustomTraversable[A]] { | |
override def companion = CustomTraversable | |
override def foreach[U](f: A => U) = seq.foreach(f) | |
} | |
// ================================ CustomSeq ================================== | |
object CustomSeq extends SeqFactory[CustomSeq] { | |
implicit def canBuildFrom[A]: CanBuildFrom[Coll, A, CustomSeq[A]] = | |
new GenericCanBuildFrom[A] | |
def newBuilder[A] = new ListBuffer[A] mapResult (x => new CustomSeq(x:_*)) | |
} | |
class CustomSeq[A](seq : A*) | |
extends Seq[A] | |
with GenericTraversableTemplate[A, CustomSeq] | |
with SeqLike[A, CustomSeq[A]] { | |
override def companion = CustomSeq | |
def iterator: Iterator[A] = seq.iterator | |
def apply(idx: Int): A = { | |
if (idx < 0 || idx>=length) throw new IndexOutOfBoundsException | |
seq(idx) | |
} | |
def length: Int = seq.size | |
} | |
// ================================ MySeq ================================== | |
object MySeq { | |
def apply[Base](bases: Base*) = fromSeq(bases) | |
def fromSeq[Base](buf: Seq[Base]): MySeq[Base] = { | |
var array = new ArrayBuffer[Base](buf.size) | |
for (i <- 0 until buf.size) array += buf(i) | |
new MySeq[Base](array) | |
} | |
def newBuilder[Base]: Builder[Base, MySeq[Base]] = | |
new ArrayBuffer mapResult fromSeq | |
implicit def canBuildFrom[Base,From]: CanBuildFrom[MySeq[_], Base, MySeq[Base]] = | |
new CanBuildFrom[MySeq[_], Base, MySeq[Base]] { | |
def apply(): Builder[Base, MySeq[Base]] = newBuilder | |
def apply(from: MySeq[_]): Builder[Base, MySeq[Base]] = newBuilder | |
} | |
} | |
class MySeq[Base] protected (buffer: ArrayBuffer[Base]) | |
extends IndexedSeq[Base] | |
with IndexedSeqLike[Base, MySeq[Base]] { | |
override protected[this] def newBuilder: Builder[Base, MySeq[Base]] = | |
MySeq.newBuilder | |
def apply(idx: Int): Base = { | |
if (idx < 0 || length <= idx) throw new IndexOutOfBoundsException | |
buffer(idx) | |
} | |
def length = buffer.length | |
} | |
// ================================ NamedSeq ================================== | |
object NamedSeq { | |
def apply[Base](name: String, bases: Base*) = fromSeq(name, bases) | |
def fromSeq[Base](name: String, buf: Seq[Base]): NamedSeq[Base] = { | |
var array = new ArrayBuffer[Base](buf.size) | |
for (i <- 0 until buf.size) array += buf(i) | |
new NamedSeq[Base](name, array) | |
} | |
def newBuilder[Base](name: String): Builder[Base, NamedSeq[Base]] = | |
new ArrayBuffer mapResult { x: ArrayBuffer[Base] => fromSeq(name, x) } | |
implicit def canBuildFrom[Base]: CanBuildFrom[NamedSeq[_], Base, NamedSeq[Base]] = | |
new CanBuildFrom[NamedSeq[_], Base, NamedSeq[Base]] { | |
def apply(): Builder[Base, NamedSeq[Base]] = newBuilder("default") | |
def apply(from: NamedSeq[_]): Builder[Base, NamedSeq[Base]] = | |
newBuilder(from.name) | |
} | |
} | |
class NamedSeq[Base] protected ( | |
val name: String, | |
buffer: ArrayBuffer[Base]) | |
extends IndexedSeq[Base] with IndexedSeqLike[Base, NamedSeq[Base]] { | |
override protected[this] def newBuilder: Builder[Base, NamedSeq[Base]] = | |
NamedSeq.newBuilder(name) | |
def apply(idx: Int): Base = { | |
if (idx < 0 || length <= idx) throw new IndexOutOfBoundsException | |
buffer(idx) | |
} | |
def length = buffer.length | |
override def toString() = "NamedSeq("+name+" : "+mkString(", ")+")" | |
} | |
// ============================= CustomVector =================================== | |
object CustomVector { | |
def apply[Base](bases: Base*) = fromSeq(bases.toVector) | |
def fromSeq[Base](buf: Vector[Base]): CustomVector[Base] = | |
new CustomVector[Base](buf) | |
def newBuilder[Base]: Builder[Base, CustomVector[Base]] = | |
new VectorBuilder mapResult fromSeq | |
implicit def canBuildFrom[Base,From]: | |
CanBuildFrom[CustomVector[_], Base, CustomVector[Base]] = | |
new CanBuildFrom[CustomVector[_], Base, CustomVector[Base]] { | |
def apply(): Builder[Base, CustomVector[Base]] = newBuilder | |
def apply(from: CustomVector[_]): Builder[Base, CustomVector[Base]] = | |
newBuilder | |
} | |
} | |
class CustomVector[Base] protected (buffer: Vector[Base]) | |
extends IndexedSeq[Base] | |
with IndexedSeqLike[Base, CustomVector[Base]] { | |
override protected[this] def newBuilder: Builder[Base, CustomVector[Base]] = CustomVector.newBuilder | |
def apply(idx: Int): Base = { | |
if (idx < 0 || length <= idx) throw new IndexOutOfBoundsException | |
buffer(idx) | |
} | |
def length = buffer.length | |
} | |
// NOW THE TEST CASES : | |
test("CustomTraversable test") { | |
val l = CustomTraversable(1, 2, 3, 4) | |
val c = List(5,6,7) | |
l should not be equals(List(1,2,3,4)) | |
(l ++ c) should be equals (CustomTraversable(1,2,3,4,5,6,7)) | |
(l.map(_.toString)) should be equals(CustomTraversable("1","2","3","4")) | |
(l.map(_.toString)).getClass.getName should include ("CustomTraversable") | |
(l.filter(_ > 2)) should be equals(CustomTraversable(3,4)) | |
(l.filter(_ > 2)).getClass.getName should include("CustomTraversable") | |
l.reduce(_ + _) should equal(10) | |
} | |
test("CustomSeq test") { | |
val l = CustomSeq(1, 2, 3, 4) | |
val c = List(5,6,7) | |
l should not be equals(List(1,2,3,4)) | |
(l :+ 8) should be equals (CustomSeq(1,2,3,4,8)) | |
(l ++ c) should be equals (CustomSeq(1,2,3,4,5,6,7)) | |
(l.map(_.toString)) should be equals (CustomSeq("1","2","3","4")) | |
(l.map(_.toString)) should not be equals (IndexedSeq("1","2","3","4")) | |
(l.map(_.toString)).getClass.getName should include("CustomSeq") | |
(l.filter(_ > 2)) should be equals (CustomSeq(3,4)) | |
(l.filter(_ > 2)).getClass.getName should include("CustomSeq") | |
l.reduce(_ + _) should equal(10) | |
} | |
test("CustomVector test") { | |
val l = CustomVector(1, 2, 3, 4) | |
val c = List(5,6,7) | |
l should not be equals(List(1,2,3,4)) | |
(l :+ 8) should be equals (CustomVector(1,2,3,4,8)) | |
(l ++ c) should be equals (CustomVector(1,2,3,4,5,6,7)) | |
(l.map(_.toString)) should be equals (CustomVector("1","2","3","4")) | |
(l.map(_.toString)) should not be equals (IndexedSeq("1","2","3","4")) | |
(l.map(_.toString)).getClass.getName should include("CustomVector") | |
(l.filter(_ > 2)) should be equals (CustomVector(3,4)) | |
(l.filter(_ > 2)).getClass.getName should include("CustomVector") | |
l.reduce(_ + _) should equal(10) | |
} | |
test("MySeq test") { | |
val cs = MySeq("1", "2", "3") | |
info(cs.toString) | |
cs should be equals (MySeq("1","2","3")) | |
val scs = cs.map(_.toInt) | |
info(scs.toString) | |
scs should be equals (MySeq(1,2,3)) | |
scs.getClass.getName should include("MySeq") | |
} | |
test("NamedSeq test") { | |
val cs = NamedSeq("toto", "1", "2", "3") | |
info(cs.toString) | |
cs should be equals (NamedSeq("toto", "1","2","3")) | |
val scs = cs.map(_.toInt) | |
info(scs.toString) | |
scs should be equals (NamedSeq("toto", 1,2,3)) | |
scs should not be equals (NamedSeq("tata", 1,2,3)) | |
scs.getClass.getName should include("NamedSeq") | |
} | |
test("NamedSeq && MySeq combined test") { | |
val cs = MySeq(5,6,7,8) | |
val scs = cs.filter(_ > 6) | |
val ncs = NamedSeq("myseq", 1,2,3,4) | |
val nscs = ncs.filter(_ > 2) | |
(nscs :+ 10) should be equals(NamedSeq("myseq", 3,4,10)) | |
(nscs ++ scs) should be equals(NamedSeq("myseq", 3,4,7,8)) | |
(nscs ++ scs) should not be equals(NamedSeq("trucmuche", 3,4,7,8)) | |
(nscs ++ scs).getClass.getName should include("NamedSeq") | |
(scs ++ nscs) should be equals(MySeq(7,8,3,4)) | |
(scs ++ nscs).getClass.getName should include("MySeq") | |
(nscs.map(_ + 1)) should be equals(NamedSeq("myseq",4,5)) | |
(nscs.map(_.toString)) should be equals(NamedSeq("myseq","3","4")) | |
(scs.map(_.toString)) should be equals(MySeq("7","8")) | |
} |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment