Skip to content

Instantly share code, notes, and snippets.

@tanishiking
Created May 2, 2024 23:14
Show Gist options
  • Save tanishiking/ca04029b4bae113f31c09e7c6b5deb3d to your computer and use it in GitHub Desktop.
Save tanishiking/ca04029b4bae113f31c09e7c6b5deb3d to your computer and use it in GitHub Desktop.
diff --git a/sample/src/main/scala/Sample.scala b/sample/src/main/scala/Sample.scala
index 2bd374c..c861535 100644
--- a/sample/src/main/scala/Sample.scala
+++ b/sample/src/main/scala/Sample.scala
@@ -17,6 +17,7 @@ object Main {
def main(args: Array[String]): Unit = {
println("hello world")
+ val foo = new Foo
}
// Tested in SampleTest.scala
@@ -25,3 +26,11 @@ object Main {
private def println(x: Any): Unit =
js.Dynamic.global.console.log("" + x)
}
+
+class Bar extends D
+class Foo extends Bar
+
+trait A
+trait B extends A
+trait C extends B
+trait D extends C
\ No newline at end of file
diff --git a/wasm/src/main/scala/WebAssemblyLinkerBackend.scala b/wasm/src/main/scala/WebAssemblyLinkerBackend.scala
index d06000f..e20e7ec 100644
--- a/wasm/src/main/scala/WebAssemblyLinkerBackend.scala
+++ b/wasm/src/main/scala/WebAssemblyLinkerBackend.scala
@@ -20,6 +20,7 @@ import org.scalajs.linker.backend.webassembly.SourceMapWriterAccess
import wasm.ir2wasm._
import wasm.ir2wasm.SpecialNames._
import wasm.wasm4s._
+import org.scalajs.ir.ClassKind
final class WebAssemblyLinkerBackend(
linkerConfig: StandardConfig,
@@ -87,7 +88,15 @@ final class WebAssemblyLinkerBackend(
else a.className.compareTo(b.className) < 0
}
- // sortedClasses.foreach(cls => println(utils.LinkedClassPrinters.showLinkedClass(cls)))
+ // sortedClasses.foreach(cls =>
+ // if (cls.pos.source.getPath().contains("Sample.scala")) {
+ // // if (cls.kind == ClassKind.Interface) {
+ // println(utils.LinkedClassPrinters.showLinkedClass(cls))
+ // println(cls.ancestors)
+ // println("===")
+ // // }
+ // }
+ // )
Preprocessor.preprocess(sortedClasses, onlyModule.topLevelExports)(context)
HelperFunctions.genGlobalHelpers()
diff --git a/wasm/src/main/scala/ir2wasm/LoaderContent.scala b/wasm/src/main/scala/ir2wasm/LoaderContent.scala
index f8cb088..5dc5f27 100644
--- a/wasm/src/main/scala/ir2wasm/LoaderContent.scala
+++ b/wasm/src/main/scala/ir2wasm/LoaderContent.scala
@@ -118,6 +118,7 @@ const scalaJSHelpers = {
closureRestNoData: (f, n) => ((...args) => f(...args.slice(0, n), args.slice(n))),
// Strings
+ print: (x) => console.log(x),
emptyString: () => "",
stringLength: (s) => s.length,
stringCharAt: (s, i) => s.charCodeAt(i),
diff --git a/wasm/src/main/scala/ir2wasm/Preprocessor.scala b/wasm/src/main/scala/ir2wasm/Preprocessor.scala
index 65befe2..5e7c4a4 100644
--- a/wasm/src/main/scala/ir2wasm/Preprocessor.scala
+++ b/wasm/src/main/scala/ir2wasm/Preprocessor.scala
@@ -28,11 +28,15 @@ object Preprocessor {
for (clazz <- classes) {
ctx.getClassInfo(clazz.className).buildMethodTable()
+ }
+ ctx.assignBuckets(classes)
+ for (clazz <- classes) {
if (clazz.kind == ClassKind.Interface && clazz.hasInstanceTests)
HelperFunctions.genInstanceTest(clazz)
HelperFunctions.genCloneFunction(clazz)
}
+
}
private def preprocess(clazz: LinkedClass)(implicit ctx: WasmContext): Unit = {
diff --git a/wasm/src/main/scala/ir2wasm/WasmExpressionBuilder.scala b/wasm/src/main/scala/ir2wasm/WasmExpressionBuilder.scala
index aa79f3d..c636635 100644
--- a/wasm/src/main/scala/ir2wasm/WasmExpressionBuilder.scala
+++ b/wasm/src/main/scala/ir2wasm/WasmExpressionBuilder.scala
@@ -658,6 +658,22 @@ private class WasmExpressionBuilder private (
val itableIdx = ctx.getItableIdx(receiverClassInfo.name)
val methodIdx = receiverClassInfo.tableMethodInfos(methodName).tableIndex
+ instrs += LOCAL_GET(receiverLocalForDispatch)
+ instrs += STRUCT_GET(
+ // receiver type should be upcasted into `Object` if it's interface
+ // by TypeTransformer#transformType
+ WasmStructTypeName.forClass(IRNames.ObjectClass),
+ WasmFieldIdx.itables
+ )
+ instrs += I32_CONST(itableIdx)
+ instrs += ARRAY_GET(WasmArrayTypeName.itables)
+ instrs += REF_TEST(Types.WasmRefType(WasmStructTypeName.forITable(receiverClassInfo.name)))
+ fctx.ifThenElse() {
+ } {
+ instrs ++= ctx.getConstantStringInstr(s"fail to cast to ${receiverClassInfo.name} from ${itableIdx} in ${ctx.buckets(itableIdx).elements}")
+ instrs += CALL(WasmFunctionName.print)
+ }
+
instrs += LOCAL_GET(receiverLocalForDispatch)
instrs += STRUCT_GET(
// receiver type should be upcasted into `Object` if it's interface
diff --git a/wasm/src/main/scala/wasm4s/Names.scala b/wasm/src/main/scala/wasm4s/Names.scala
index 79a0e80..8c9c8fe 100644
--- a/wasm/src/main/scala/wasm4s/Names.scala
+++ b/wasm/src/main/scala/wasm4s/Names.scala
@@ -179,6 +179,7 @@ object Names {
val closureThisRest = helper("closureThisRest")
val closureRestNoData = helper("closureRestNoData")
+ val print = helper("print")
val emptyString = helper("emptyString")
val stringLength = helper("stringLength")
val stringCharAt = helper("stringCharAt")
diff --git a/wasm/src/main/scala/wasm4s/WasmContext.scala b/wasm/src/main/scala/wasm4s/WasmContext.scala
index 24ad3ad..4080e90 100644
--- a/wasm/src/main/scala/wasm4s/WasmContext.scala
+++ b/wasm/src/main/scala/wasm4s/WasmContext.scala
@@ -20,6 +20,7 @@ import wasm.ir2wasm.WasmExpressionBuilder
import org.scalajs.linker.interface.ModuleInitializer
import org.scalajs.linker.interface.unstable.ModuleInitializerImpl
import org.scalajs.linker.standard.LinkedTopLevelExport
+import org.scalajs.linker.standard.LinkedClass
abstract class ReadOnlyWasmContext {
import WasmContext._
@@ -27,20 +28,21 @@ abstract class ReadOnlyWasmContext {
protected val itableIdx = mutable.Map[IRNames.ClassName, Int]()
protected val classInfo = mutable.Map[IRNames.ClassName, WasmClassInfo]()
protected var nextItableIdx: Int
+ def buckets: List[Bucket]
val cloneFunctionTypeName: WasmTypeName
val isJSClassInstanceFuncTypeName: WasmTypeName
- def itablesLength = nextItableIdx
+ def itablesLength = buckets.length
/** Get an index of the itable for the given interface. The itable instance must be placed at the
* index in the array of itables (whose size is `itablesLength`).
*/
- def getItableIdx(iface: IRNames.ClassName): Int =
- itableIdx.getOrElse(
- iface,
- throw new IllegalArgumentException(s"Interface $iface is not registed.")
- )
+ def getItableIdx(iface: IRNames.ClassName): Int = {
+ val idx = buckets.indexWhere(b => b.elements.contains(iface))
+ if (idx < 0) throw new IllegalArgumentException(s"Interface $iface is not registed.")
+ idx
+ }
def getClassInfoOption(name: IRNames.ClassName): Option[WasmClassInfo] =
classInfo.get(name)
@@ -212,15 +214,18 @@ class WasmContext(val module: WasmModule) extends TypeDefinableWasmContext {
private val _importedModules: mutable.LinkedHashSet[String] =
new mutable.LinkedHashSet()
+ def buckets: List[Bucket] = _buckets
private var _buckets: List[Bucket] = Nil
def assignBuckets(classes: List[LinkedClass]): Unit = {
- val ifaces =
- classes.filter(clazz => (clazz.kind == ClassKind.Interface && clazz.hasInstanceTests))
- _buckets = assignBuckets0(ifaces)
+ _buckets = assignBuckets0(classes)
+ _buckets.zipWithIndex.foreach { case (b, i) =>
+ println(s"bucket$i: ${b.elements}")
+ }
+ println(nextItableIdx)
}
// def buckets = _buckets
- override protected def bucketLength: Int = _buckets.length
+ override protected var nextItableIdx: Int = 0
private val _jsPrivateFieldNames: mutable.ListBuffer[IRNames.FieldName] =
new mutable.ListBuffer()
@@ -392,6 +397,7 @@ class WasmContext(val module: WasmModule) extends TypeDefinableWasmContext {
List(WasmRefType.any)
)
+ addHelperImport(WasmFunctionName.print, List(WasmRefType.any), List())
addHelperImport(WasmFunctionName.emptyString, List(), List(WasmRefType.any))
addHelperImport(WasmFunctionName.stringLength, List(WasmRefType.any), List(WasmInt32))
addHelperImport(WasmFunctionName.stringCharAt, List(WasmRefType.any, WasmInt32), List(WasmInt32))
@@ -580,6 +586,25 @@ class WasmContext(val module: WasmModule) extends TypeDefinableWasmContext {
val classInfo = getClassInfo(name)
val interfaces = classInfo.ancestors.map(getClassInfo(_)).filter(_.isInterface)
val resolvedMethodInfos = classInfo.resolvedMethodInfos
+
+ for {
+ bucketWithIdx <- buckets.zipWithIndex
+ b = bucketWithIdx._1
+ idx = bucketWithIdx._2
+ } {
+ val ifaces = interfaces.filter(iface => b.elements.contains(iface.name))
+ assert(ifaces.length <= 1, s" $ifaces")
+ val iface = ifaces.headOption.foreach { iface =>
+ instrs += WasmInstr.GLOBAL_GET(globalName)
+ instrs += WasmInstr.I32_CONST(idx)
+ for (method <- iface.tableEntries)
+ instrs += refFuncWithDeclaration(resolvedMethodInfos(method).tableEntryName)
+ instrs += WasmInstr.STRUCT_NEW(WasmTypeName.WasmStructTypeName.forITable(iface.name))
+ instrs += WasmInstr.ARRAY_SET(WasmTypeName.WasmArrayTypeName.itables)
+ }
+ }
+
+ /*
interfaces.foreach { iface =>
val idx = getItableIdx(iface.name)
instrs += WasmInstr.GLOBAL_GET(globalName)
@@ -590,6 +615,7 @@ class WasmContext(val module: WasmModule) extends TypeDefinableWasmContext {
instrs += WasmInstr.STRUCT_NEW(WasmTypeName.WasmStructTypeName.forITable(iface.name))
instrs += WasmInstr.ARRAY_SET(WasmTypeName.WasmArrayTypeName.itables)
}
+ */
}
locally {
@@ -712,6 +738,90 @@ class WasmContext(val module: WasmModule) extends TypeDefinableWasmContext {
module.addElement(WasmElement(WasmRefType.funcref, exprs, WasmElement.Mode.Declarative))
}
}
+
+ private def assignBuckets0(classes: List[LinkedClass]): List[Bucket] = {
+ val buckets = new mutable.ListBuffer[Bucket]()
+ val joinsOf =
+ new mutable.LinkedHashMap[IRNames.ClassName, mutable.LinkedHashSet[IRNames.ClassName]]()
+ val usedOf = new mutable.LinkedHashMap[IRNames.ClassName, mutable.LinkedHashSet[Bucket]]()
+ val spines = new mutable.LinkedHashSet[IRNames.ClassName]()
+ for (clazz <- classes.reverseIterator) {
+ val className = clazz.name.name
+ val info = getClassInfo(className)
+
+ val ifaces = (if (clazz.kind == ClassKind.Interface) List(className) else Nil) ++
+ info.ancestors.collect {
+ case a if getClassInfo(a).isInterface => a
+ }
+
+ if (ifaces.nonEmpty) {
+ val joins = joinsOf.getOrElse(className, new mutable.LinkedHashSet())
+
+ if (joins.nonEmpty) { // spine
+ var found = false
+ val bs = buckets.iterator
+ while (!found && bs.hasNext) {
+ val b = bs.next()
+ if (b.size < Bucket.MAX_SIZE && b.joins.intersect(joins).isEmpty) {
+ found = true
+ b.add(className)
+ b.joins ++= joins
+ }
+ }
+ if (!found) { // create new bucket and assign
+ val b = new Bucket()
+ b.add(className)
+ buckets.append(b)
+ b.joins ++= joins
+ }
+ for (iface <- ifaces) {
+ joinsOf.getOrElseUpdate(iface, new mutable.LinkedHashSet()) ++= joins
+ }
+ spines.add(clazz.name.name)
+ } else if (ifaces.length > 1) { // join
+ ifaces.foreach { iface =>
+ joinsOf.getOrElseUpdate(iface, new mutable.LinkedHashSet()) += clazz.name.name
+ }
+ }
+ // else: plain, do nothing
+ }
+
+ }
+
+ for (clazz <- classes) {
+ val className = clazz.name.name
+ val info = getClassInfo(className)
+ val ifaces = (if (clazz.kind == ClassKind.Interface) List(className) else Nil) ++
+ info.ancestors.collect {
+ case a if getClassInfo(a).isInterface => a
+ }
+ if (ifaces.nonEmpty && !spines.contains(className)) {
+ val used = usedOf.getOrElse(clazz.name.name, new mutable.LinkedHashSet())
+ for {
+ iface <- ifaces
+ parentUsed <- usedOf.get(iface)
+ } { used ++= parentUsed }
+
+ var found = false
+ val bs = buckets.iterator
+ while (!found && bs.hasNext) {
+ val b = bs.next()
+ if (b.size < Bucket.MAX_SIZE && !used.contains(b)) {
+ found = true
+ b.add(className)
+ used.add(b)
+ }
+ }
+ if (!found) {
+ val b = new Bucket()
+ buckets.append(b)
+ b.add(clazz.name.name)
+ used.add(b)
+ }
+ }
+ }
+ buckets.toList
+ }
}
object WasmContext {
@@ -732,6 +842,11 @@ object WasmContext {
val hasRuntimeTypeInfo: Boolean,
val jsNativeLoadSpec: Option[IRTrees.JSNativeLoadSpec],
val jsNativeMembers: Map[IRNames.MethodName, IRTrees.JSNativeLoadSpec]
+
+ // interface dispatch
+ // private var bucketId: Int,
+ // private var tid: Int,
+ // private var itable: Array[Int]
) {
private val fieldIdxByName: Map[IRNames.FieldName, Int] =
allFieldDefs.map(_.name.name).zipWithIndex.map(p => p._1 -> (p._2 + classFieldOffset)).toMap
@@ -898,4 +1013,18 @@ object WasmContext {
}
final class TableMethodInfo(val methodName: IRNames.MethodName, val tableIndex: Int)
+
+ class Bucket {
+ private val _elements = new mutable.ListBuffer[IRNames.ClassName]()
+ def add(clazz: IRNames.ClassName) = {
+ _elements.append(clazz)
+ }
+ def size: Int = _elements.size
+ def elements = _elements.toList
+
+ val joins = new mutable.LinkedHashSet[IRNames.ClassName]()
+ }
+ object Bucket {
+ val MAX_SIZE = 255
+ }
}
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment