Created
May 21, 2024 11:10
-
-
Save farnoy/5ce05c4cb12e2f3189214442e6324713 to your computer and use it in GitHub Desktop.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
import java.lang.foreign.* | |
import scala.util.{Using, Try} | |
def handleResult[T](res: Try[T]) = | |
res match | |
case scala.util.Success(r) => println(s"Result: $r") | |
case scala.util.Failure(e) => e.printStackTrace() | |
sealed case class MyData(x: Double, y: Byte, z: Short, last: Double) | |
given NativeType[MyData] with | |
def memoryLayout = MemoryLayout.structLayout( | |
implicitly[NativeType[Double]].memoryLayout.withName("x"), | |
implicitly[NativeType[Byte]].memoryLayout.withName("y"), | |
MemoryLayout.paddingLayout(1), | |
implicitly[NativeType[Short]].memoryLayout.withName("z"), | |
MemoryLayout.paddingLayout(4), | |
implicitly[NativeType[Double]].memoryLayout.withName("last") | |
) | |
val libraryPath = | |
"C:\\Projects\\scala3-learning/foreign-test/target/release/foreign_test.dll" | |
val linker = Linker.nativeLinker() | |
val lookup = SymbolLookup.libraryLookup(libraryPath, Arena.global()) | |
val add = | |
genNativeFunction[(Long, Long), Long]("add", Arena.global(), linker, lookup) | |
val process = | |
genNativeFunction[Tuple1[Pointer[MyData]], Double]( | |
"process", | |
Arena.global(), | |
linker, | |
lookup | |
) | |
@main def Foreign(args: String*) = | |
println(s"args: $args") | |
val first = args.lift(0).map(_.toLong).getOrElse(1L) | |
val second = args.lift(1).map(_.toLong).getOrElse(2L) | |
{ | |
val res = Try(add((first, second))) | |
handleResult(res) | |
} | |
{ | |
// #[repr(C)] | |
// pub struct MyData { | |
// x: f64, | |
// y: u8, | |
// z: u16, | |
// last: f64, | |
// } | |
// #[no_mangle] | |
// pub unsafe extern "C" fn process(data: *const MyData) -> f64 { | |
// let d = data.read(); | |
// d.x + d.y as f64 + d.z as f64 + d.last as f64 | |
// } | |
val res = Using(Arena.ofConfined()): arena => | |
val layout = implicitly[NativeType[MyData]].memoryLayout | |
val allocated = arena.allocate(layout) | |
val x = layout.varHandle(MemoryLayout.PathElement.groupElement("x")) | |
x.set(allocated, 0, 3.14) | |
layout | |
.varHandle(MemoryLayout.PathElement.groupElement("z")) | |
.set(allocated, 0, 6: Short) | |
layout | |
.varHandle(MemoryLayout.PathElement.groupElement("last")) | |
.set(allocated, 0, 8.9) | |
process(Tuple1(allocated)) | |
// process(Tuple1(MyData(3.14, 1, 384, 8.9))) // Error | |
handleResult(res) | |
} |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
import java.lang.foreign.* | |
import scala.compiletime.* | |
import scala.quoted.* | |
import java.lang.invoke.MethodHandle | |
trait NativeType[T]: | |
def memoryLayout: MemoryLayout | |
sealed trait NativeFunction[Args <: Tuple, R]: | |
def handle: MethodHandle | |
def apply(args: Args): R | |
given NativeType[Int] with | |
def memoryLayout = ValueLayout.JAVA_INT | |
given NativeType[Long] with | |
def memoryLayout = ValueLayout.JAVA_LONG | |
given NativeType[Float] with | |
def memoryLayout = ValueLayout.JAVA_FLOAT | |
given NativeType[Double] with | |
def memoryLayout = ValueLayout.JAVA_DOUBLE | |
given NativeType[Short] with | |
def memoryLayout = ValueLayout.JAVA_SHORT | |
given NativeType[Byte] with | |
def memoryLayout = ValueLayout.JAVA_BYTE | |
sealed case class Pointer[T](t: T) extends AnyVal | |
given [T: NativeType]: NativeType[Pointer[T]] with | |
def memoryLayout = summon[NativeType[T]].memoryLayout | |
type NativeArgs[T <: Tuple] <: Tuple = T match | |
case EmptyTuple => EmptyTuple | |
case Pointer[t] *: tail => MemorySegment *: NativeArgs[tail] | |
case head *: tail => head *: NativeArgs[tail] | |
inline def genNativeFunction[Args <: Tuple, R]( | |
inline name: String, | |
arena: Arena, | |
linker: Linker, | |
lookup: SymbolLookup | |
): NativeFunction[NativeArgs[Args], R] = ${ | |
makeNativeFunction('name, 'arena, 'linker, 'lookup) | |
} | |
def makeNativeFunction[Args <: Tuple: Type, R: Type]( | |
nameExpr: Expr[String], | |
arena: Expr[Arena], | |
linker: Expr[Linker], | |
lookup: Expr[SymbolLookup] | |
)(using Quotes): Expr[NativeFunction[NativeArgs[Args], R]] = | |
import quotes.reflect._ | |
val name = nameExpr.valueOrAbort | |
def decomposeTuple(tpe: TypeRepr): List[TypeRepr] = tpe match | |
case AppliedType(_, args) => args | |
case _ => List(tpe) | |
def mapTypeToLayout(tpe: TypeRepr): Expr[MemoryLayout] = tpe.asType match | |
case '[t] => | |
Expr.summon[NativeType[t]] match | |
case Some(instance) => '{ $instance.memoryLayout } | |
case None => | |
val s = s"Unexpected type when generating ValueLayout: ${tpe.show}" | |
'{ error(${ Expr(s) }) } | |
val inputTypes = decomposeTuple(TypeRepr.of[Args]) | |
val elementLayouts = inputTypes.map(mapTypeToLayout) | |
val returnLayout = mapTypeToLayout(TypeRepr.of[R]) | |
'{ | |
new NativeFunction: | |
val fun = ${ lookup }.find(${ Expr(name) }).orElseThrow | |
val desc = | |
FunctionDescriptor.of(${ returnLayout }, ${ Expr.ofSeq(elementLayouts) }*) | |
val _handle = | |
${ linker }.downcallHandle(fun, desc, Linker.Option.critical(false)) | |
override def handle = _handle | |
override def apply(args: NativeArgs[Args]): R = | |
handle.invokeWithArguments(args.toArray*).asInstanceOf[R] | |
} |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment