Created
May 2, 2024 23:14
-
-
Save tanishiking/ca04029b4bae113f31c09e7c6b5deb3d to your computer and use it in GitHub Desktop.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
diff --git a/sample/src/main/scala/Sample.scala b/sample/src/main/scala/Sample.scala | |
index 2bd374c..c861535 100644 | |
--- a/sample/src/main/scala/Sample.scala | |
+++ b/sample/src/main/scala/Sample.scala | |
@@ -17,6 +17,7 @@ object Main { | |
def main(args: Array[String]): Unit = { | |
println("hello world") | |
+ val foo = new Foo | |
} | |
// Tested in SampleTest.scala | |
@@ -25,3 +26,11 @@ object Main { | |
private def println(x: Any): Unit = | |
js.Dynamic.global.console.log("" + x) | |
} | |
+ | |
+class Bar extends D | |
+class Foo extends Bar | |
+ | |
+trait A | |
+trait B extends A | |
+trait C extends B | |
+trait D extends C | |
\ No newline at end of file | |
diff --git a/wasm/src/main/scala/WebAssemblyLinkerBackend.scala b/wasm/src/main/scala/WebAssemblyLinkerBackend.scala | |
index d06000f..e20e7ec 100644 | |
--- a/wasm/src/main/scala/WebAssemblyLinkerBackend.scala | |
+++ b/wasm/src/main/scala/WebAssemblyLinkerBackend.scala | |
@@ -20,6 +20,7 @@ import org.scalajs.linker.backend.webassembly.SourceMapWriterAccess | |
import wasm.ir2wasm._ | |
import wasm.ir2wasm.SpecialNames._ | |
import wasm.wasm4s._ | |
+import org.scalajs.ir.ClassKind | |
final class WebAssemblyLinkerBackend( | |
linkerConfig: StandardConfig, | |
@@ -87,7 +88,15 @@ final class WebAssemblyLinkerBackend( | |
else a.className.compareTo(b.className) < 0 | |
} | |
- // sortedClasses.foreach(cls => println(utils.LinkedClassPrinters.showLinkedClass(cls))) | |
+ // sortedClasses.foreach(cls => | |
+ // if (cls.pos.source.getPath().contains("Sample.scala")) { | |
+ // // if (cls.kind == ClassKind.Interface) { | |
+ // println(utils.LinkedClassPrinters.showLinkedClass(cls)) | |
+ // println(cls.ancestors) | |
+ // println("===") | |
+ // // } | |
+ // } | |
+ // ) | |
Preprocessor.preprocess(sortedClasses, onlyModule.topLevelExports)(context) | |
HelperFunctions.genGlobalHelpers() | |
diff --git a/wasm/src/main/scala/ir2wasm/LoaderContent.scala b/wasm/src/main/scala/ir2wasm/LoaderContent.scala | |
index f8cb088..5dc5f27 100644 | |
--- a/wasm/src/main/scala/ir2wasm/LoaderContent.scala | |
+++ b/wasm/src/main/scala/ir2wasm/LoaderContent.scala | |
@@ -118,6 +118,7 @@ const scalaJSHelpers = { | |
closureRestNoData: (f, n) => ((...args) => f(...args.slice(0, n), args.slice(n))), | |
// Strings | |
+ print: (x) => console.log(x), | |
emptyString: () => "", | |
stringLength: (s) => s.length, | |
stringCharAt: (s, i) => s.charCodeAt(i), | |
diff --git a/wasm/src/main/scala/ir2wasm/Preprocessor.scala b/wasm/src/main/scala/ir2wasm/Preprocessor.scala | |
index 65befe2..5e7c4a4 100644 | |
--- a/wasm/src/main/scala/ir2wasm/Preprocessor.scala | |
+++ b/wasm/src/main/scala/ir2wasm/Preprocessor.scala | |
@@ -28,11 +28,15 @@ object Preprocessor { | |
for (clazz <- classes) { | |
ctx.getClassInfo(clazz.className).buildMethodTable() | |
+ } | |
+ ctx.assignBuckets(classes) | |
+ for (clazz <- classes) { | |
if (clazz.kind == ClassKind.Interface && clazz.hasInstanceTests) | |
HelperFunctions.genInstanceTest(clazz) | |
HelperFunctions.genCloneFunction(clazz) | |
} | |
+ | |
} | |
private def preprocess(clazz: LinkedClass)(implicit ctx: WasmContext): Unit = { | |
diff --git a/wasm/src/main/scala/ir2wasm/WasmExpressionBuilder.scala b/wasm/src/main/scala/ir2wasm/WasmExpressionBuilder.scala | |
index aa79f3d..c636635 100644 | |
--- a/wasm/src/main/scala/ir2wasm/WasmExpressionBuilder.scala | |
+++ b/wasm/src/main/scala/ir2wasm/WasmExpressionBuilder.scala | |
@@ -658,6 +658,22 @@ private class WasmExpressionBuilder private ( | |
val itableIdx = ctx.getItableIdx(receiverClassInfo.name) | |
val methodIdx = receiverClassInfo.tableMethodInfos(methodName).tableIndex | |
+ instrs += LOCAL_GET(receiverLocalForDispatch) | |
+ instrs += STRUCT_GET( | |
+ // receiver type should be upcasted into `Object` if it's interface | |
+ // by TypeTransformer#transformType | |
+ WasmStructTypeName.forClass(IRNames.ObjectClass), | |
+ WasmFieldIdx.itables | |
+ ) | |
+ instrs += I32_CONST(itableIdx) | |
+ instrs += ARRAY_GET(WasmArrayTypeName.itables) | |
+ instrs += REF_TEST(Types.WasmRefType(WasmStructTypeName.forITable(receiverClassInfo.name))) | |
+ fctx.ifThenElse() { | |
+ } { | |
+ instrs ++= ctx.getConstantStringInstr(s"fail to cast to ${receiverClassInfo.name} from ${itableIdx} in ${ctx.buckets(itableIdx).elements}") | |
+ instrs += CALL(WasmFunctionName.print) | |
+ } | |
+ | |
instrs += LOCAL_GET(receiverLocalForDispatch) | |
instrs += STRUCT_GET( | |
// receiver type should be upcasted into `Object` if it's interface | |
diff --git a/wasm/src/main/scala/wasm4s/Names.scala b/wasm/src/main/scala/wasm4s/Names.scala | |
index 79a0e80..8c9c8fe 100644 | |
--- a/wasm/src/main/scala/wasm4s/Names.scala | |
+++ b/wasm/src/main/scala/wasm4s/Names.scala | |
@@ -179,6 +179,7 @@ object Names { | |
val closureThisRest = helper("closureThisRest") | |
val closureRestNoData = helper("closureRestNoData") | |
+ val print = helper("print") | |
val emptyString = helper("emptyString") | |
val stringLength = helper("stringLength") | |
val stringCharAt = helper("stringCharAt") | |
diff --git a/wasm/src/main/scala/wasm4s/WasmContext.scala b/wasm/src/main/scala/wasm4s/WasmContext.scala | |
index 24ad3ad..4080e90 100644 | |
--- a/wasm/src/main/scala/wasm4s/WasmContext.scala | |
+++ b/wasm/src/main/scala/wasm4s/WasmContext.scala | |
@@ -20,6 +20,7 @@ import wasm.ir2wasm.WasmExpressionBuilder | |
import org.scalajs.linker.interface.ModuleInitializer | |
import org.scalajs.linker.interface.unstable.ModuleInitializerImpl | |
import org.scalajs.linker.standard.LinkedTopLevelExport | |
+import org.scalajs.linker.standard.LinkedClass | |
abstract class ReadOnlyWasmContext { | |
import WasmContext._ | |
@@ -27,20 +28,21 @@ abstract class ReadOnlyWasmContext { | |
protected val itableIdx = mutable.Map[IRNames.ClassName, Int]() | |
protected val classInfo = mutable.Map[IRNames.ClassName, WasmClassInfo]() | |
protected var nextItableIdx: Int | |
+ def buckets: List[Bucket] | |
val cloneFunctionTypeName: WasmTypeName | |
val isJSClassInstanceFuncTypeName: WasmTypeName | |
- def itablesLength = nextItableIdx | |
+ def itablesLength = buckets.length | |
/** Get an index of the itable for the given interface. The itable instance must be placed at the | |
* index in the array of itables (whose size is `itablesLength`). | |
*/ | |
- def getItableIdx(iface: IRNames.ClassName): Int = | |
- itableIdx.getOrElse( | |
- iface, | |
- throw new IllegalArgumentException(s"Interface $iface is not registed.") | |
- ) | |
+ def getItableIdx(iface: IRNames.ClassName): Int = { | |
+ val idx = buckets.indexWhere(b => b.elements.contains(iface)) | |
+ if (idx < 0) throw new IllegalArgumentException(s"Interface $iface is not registed.") | |
+ idx | |
+ } | |
def getClassInfoOption(name: IRNames.ClassName): Option[WasmClassInfo] = | |
classInfo.get(name) | |
@@ -212,15 +214,18 @@ class WasmContext(val module: WasmModule) extends TypeDefinableWasmContext { | |
private val _importedModules: mutable.LinkedHashSet[String] = | |
new mutable.LinkedHashSet() | |
+ def buckets: List[Bucket] = _buckets | |
private var _buckets: List[Bucket] = Nil | |
def assignBuckets(classes: List[LinkedClass]): Unit = { | |
- val ifaces = | |
- classes.filter(clazz => (clazz.kind == ClassKind.Interface && clazz.hasInstanceTests)) | |
- _buckets = assignBuckets0(ifaces) | |
+ _buckets = assignBuckets0(classes) | |
+ _buckets.zipWithIndex.foreach { case (b, i) => | |
+ println(s"bucket$i: ${b.elements}") | |
+ } | |
+ println(nextItableIdx) | |
} | |
// def buckets = _buckets | |
- override protected def bucketLength: Int = _buckets.length | |
+ override protected var nextItableIdx: Int = 0 | |
private val _jsPrivateFieldNames: mutable.ListBuffer[IRNames.FieldName] = | |
new mutable.ListBuffer() | |
@@ -392,6 +397,7 @@ class WasmContext(val module: WasmModule) extends TypeDefinableWasmContext { | |
List(WasmRefType.any) | |
) | |
+ addHelperImport(WasmFunctionName.print, List(WasmRefType.any), List()) | |
addHelperImport(WasmFunctionName.emptyString, List(), List(WasmRefType.any)) | |
addHelperImport(WasmFunctionName.stringLength, List(WasmRefType.any), List(WasmInt32)) | |
addHelperImport(WasmFunctionName.stringCharAt, List(WasmRefType.any, WasmInt32), List(WasmInt32)) | |
@@ -580,6 +586,25 @@ class WasmContext(val module: WasmModule) extends TypeDefinableWasmContext { | |
val classInfo = getClassInfo(name) | |
val interfaces = classInfo.ancestors.map(getClassInfo(_)).filter(_.isInterface) | |
val resolvedMethodInfos = classInfo.resolvedMethodInfos | |
+ | |
+ for { | |
+ bucketWithIdx <- buckets.zipWithIndex | |
+ b = bucketWithIdx._1 | |
+ idx = bucketWithIdx._2 | |
+ } { | |
+ val ifaces = interfaces.filter(iface => b.elements.contains(iface.name)) | |
+ assert(ifaces.length <= 1, s" $ifaces") | |
+ val iface = ifaces.headOption.foreach { iface => | |
+ instrs += WasmInstr.GLOBAL_GET(globalName) | |
+ instrs += WasmInstr.I32_CONST(idx) | |
+ for (method <- iface.tableEntries) | |
+ instrs += refFuncWithDeclaration(resolvedMethodInfos(method).tableEntryName) | |
+ instrs += WasmInstr.STRUCT_NEW(WasmTypeName.WasmStructTypeName.forITable(iface.name)) | |
+ instrs += WasmInstr.ARRAY_SET(WasmTypeName.WasmArrayTypeName.itables) | |
+ } | |
+ } | |
+ | |
+ /* | |
interfaces.foreach { iface => | |
val idx = getItableIdx(iface.name) | |
instrs += WasmInstr.GLOBAL_GET(globalName) | |
@@ -590,6 +615,7 @@ class WasmContext(val module: WasmModule) extends TypeDefinableWasmContext { | |
instrs += WasmInstr.STRUCT_NEW(WasmTypeName.WasmStructTypeName.forITable(iface.name)) | |
instrs += WasmInstr.ARRAY_SET(WasmTypeName.WasmArrayTypeName.itables) | |
} | |
+ */ | |
} | |
locally { | |
@@ -712,6 +738,90 @@ class WasmContext(val module: WasmModule) extends TypeDefinableWasmContext { | |
module.addElement(WasmElement(WasmRefType.funcref, exprs, WasmElement.Mode.Declarative)) | |
} | |
} | |
+ | |
+ private def assignBuckets0(classes: List[LinkedClass]): List[Bucket] = { | |
+ val buckets = new mutable.ListBuffer[Bucket]() | |
+ val joinsOf = | |
+ new mutable.LinkedHashMap[IRNames.ClassName, mutable.LinkedHashSet[IRNames.ClassName]]() | |
+ val usedOf = new mutable.LinkedHashMap[IRNames.ClassName, mutable.LinkedHashSet[Bucket]]() | |
+ val spines = new mutable.LinkedHashSet[IRNames.ClassName]() | |
+ for (clazz <- classes.reverseIterator) { | |
+ val className = clazz.name.name | |
+ val info = getClassInfo(className) | |
+ | |
+ val ifaces = (if (clazz.kind == ClassKind.Interface) List(className) else Nil) ++ | |
+ info.ancestors.collect { | |
+ case a if getClassInfo(a).isInterface => a | |
+ } | |
+ | |
+ if (ifaces.nonEmpty) { | |
+ val joins = joinsOf.getOrElse(className, new mutable.LinkedHashSet()) | |
+ | |
+ if (joins.nonEmpty) { // spine | |
+ var found = false | |
+ val bs = buckets.iterator | |
+ while (!found && bs.hasNext) { | |
+ val b = bs.next() | |
+ if (b.size < Bucket.MAX_SIZE && b.joins.intersect(joins).isEmpty) { | |
+ found = true | |
+ b.add(className) | |
+ b.joins ++= joins | |
+ } | |
+ } | |
+ if (!found) { // create new bucket and assign | |
+ val b = new Bucket() | |
+ b.add(className) | |
+ buckets.append(b) | |
+ b.joins ++= joins | |
+ } | |
+ for (iface <- ifaces) { | |
+ joinsOf.getOrElseUpdate(iface, new mutable.LinkedHashSet()) ++= joins | |
+ } | |
+ spines.add(clazz.name.name) | |
+ } else if (ifaces.length > 1) { // join | |
+ ifaces.foreach { iface => | |
+ joinsOf.getOrElseUpdate(iface, new mutable.LinkedHashSet()) += clazz.name.name | |
+ } | |
+ } | |
+ // else: plain, do nothing | |
+ } | |
+ | |
+ } | |
+ | |
+ for (clazz <- classes) { | |
+ val className = clazz.name.name | |
+ val info = getClassInfo(className) | |
+ val ifaces = (if (clazz.kind == ClassKind.Interface) List(className) else Nil) ++ | |
+ info.ancestors.collect { | |
+ case a if getClassInfo(a).isInterface => a | |
+ } | |
+ if (ifaces.nonEmpty && !spines.contains(className)) { | |
+ val used = usedOf.getOrElse(clazz.name.name, new mutable.LinkedHashSet()) | |
+ for { | |
+ iface <- ifaces | |
+ parentUsed <- usedOf.get(iface) | |
+ } { used ++= parentUsed } | |
+ | |
+ var found = false | |
+ val bs = buckets.iterator | |
+ while (!found && bs.hasNext) { | |
+ val b = bs.next() | |
+ if (b.size < Bucket.MAX_SIZE && !used.contains(b)) { | |
+ found = true | |
+ b.add(className) | |
+ used.add(b) | |
+ } | |
+ } | |
+ if (!found) { | |
+ val b = new Bucket() | |
+ buckets.append(b) | |
+ b.add(clazz.name.name) | |
+ used.add(b) | |
+ } | |
+ } | |
+ } | |
+ buckets.toList | |
+ } | |
} | |
object WasmContext { | |
@@ -732,6 +842,11 @@ object WasmContext { | |
val hasRuntimeTypeInfo: Boolean, | |
val jsNativeLoadSpec: Option[IRTrees.JSNativeLoadSpec], | |
val jsNativeMembers: Map[IRNames.MethodName, IRTrees.JSNativeLoadSpec] | |
+ | |
+ // interface dispatch | |
+ // private var bucketId: Int, | |
+ // private var tid: Int, | |
+ // private var itable: Array[Int] | |
) { | |
private val fieldIdxByName: Map[IRNames.FieldName, Int] = | |
allFieldDefs.map(_.name.name).zipWithIndex.map(p => p._1 -> (p._2 + classFieldOffset)).toMap | |
@@ -898,4 +1013,18 @@ object WasmContext { | |
} | |
final class TableMethodInfo(val methodName: IRNames.MethodName, val tableIndex: Int) | |
+ | |
+ class Bucket { | |
+ private val _elements = new mutable.ListBuffer[IRNames.ClassName]() | |
+ def add(clazz: IRNames.ClassName) = { | |
+ _elements.append(clazz) | |
+ } | |
+ def size: Int = _elements.size | |
+ def elements = _elements.toList | |
+ | |
+ val joins = new mutable.LinkedHashSet[IRNames.ClassName]() | |
+ } | |
+ object Bucket { | |
+ val MAX_SIZE = 255 | |
+ } | |
} |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment