Skip to content

Instantly share code, notes, and snippets.

@mcimadamore
Last active March 31, 2022 18:31
Show Gist options
  • Save mcimadamore/128ee904157bb6c729a10596e69edffd to your computer and use it in GitHub Desktop.
Save mcimadamore/128ee904157bb6c729a10596e69edffd to your computer and use it in GitHub Desktop.
An example of how to implement a mapped memory segment on top of the Panama ABI support
import jdk.incubator.foreign.Addressable;
import jdk.incubator.foreign.CLinker;
import jdk.incubator.foreign.FunctionDescriptor;
import jdk.incubator.foreign.MemoryAddress;
import jdk.incubator.foreign.MemoryLayout;
import jdk.incubator.foreign.MemorySegment;
import jdk.incubator.foreign.ResourceScope;
import jdk.incubator.foreign.SegmentAllocator;
import jdk.incubator.foreign.SequenceLayout;
import jdk.incubator.foreign.ValueLayout;
import java.io.File;
import java.lang.invoke.MethodHandle;
import java.lang.invoke.VarHandle;
import java.nio.ByteBuffer;
import java.nio.ByteOrder;
import java.nio.channels.FileChannel;
import java.nio.file.Files;
import java.nio.file.Path;
import java.nio.file.StandardOpenOption;
public class CustomMappedSegment {
static final CLinker ABI = CLinker.systemCLinker();
/* int open(const char *pathname, int flags); */
static final MethodHandle OPEN = ABI.downcallHandle(
ABI.lookup("open").get(),
FunctionDescriptor.of(ValueLayout.JAVA_INT, ValueLayout.ADDRESS, ValueLayout.JAVA_INT));
/* int close(int fd); */
static final MethodHandle CLOSE = ABI.downcallHandle(
ABI.lookup("close").get(),
FunctionDescriptor.of(ValueLayout.JAVA_INT, ValueLayout.JAVA_INT));
/* void *mmap(void *addr, size_t length, int prot, int flags, int fd, off_t offset); */
static final MethodHandle MMAP = ABI.downcallHandle(
ABI.lookup("mmap").get(),
FunctionDescriptor.of(ValueLayout.ADDRESS, ValueLayout.ADDRESS, ValueLayout.JAVA_LONG, ValueLayout.JAVA_INT,
ValueLayout.JAVA_INT, ValueLayout.JAVA_INT, ValueLayout.JAVA_LONG));
/* int munmap(void *addr, size_t length); */
static final MethodHandle MUNMAP = ABI.downcallHandle(
ABI.lookup("munmap").get(),
FunctionDescriptor.of(ValueLayout.JAVA_INT, ValueLayout.ADDRESS, ValueLayout.JAVA_LONG));
static final int O_RDONLY = 0;
static final int O_RDWR = 2;
static final int PROT_READ = 1;
static final int PROT_WRITE = 2;
static final int MAP_PRIVATE = 2;
static final int MAP_SHARED = 1;
public static MemorySegment map(Path path, long bytesSize, FileChannel.MapMode mapMode, ResourceScope scope) throws Throwable {
if (bytesSize <= 0) throw new IllegalArgumentException("Requested bytes size must be > 0.");
try (ResourceScope innerScope = ResourceScope.newConfinedScope()) {
int fd = (int) OPEN.invokeExact(
(Addressable)SegmentAllocator.nativeAllocator(innerScope).allocateUtf8String(path.toAbsolutePath().toString()),
openOptions(mapMode));
assertNotEquals(fd, -1);
MemoryAddress mappedAddress = (MemoryAddress) MMAP.invokeExact(
(Addressable)MemoryAddress.NULL,
bytesSize,
mapProt(mapMode),
mapFlags(mapMode),
fd,
0L
);
assertEquals((int)CLOSE.invokeExact(fd), 0);
scope.addCloseAction(() -> {
try {
assertEquals((int)MUNMAP.invokeExact((Addressable)mappedAddress, bytesSize), 0);
} catch (Throwable ex) {
throw new RuntimeException(ex);
}
});
return MemorySegment.ofAddress(mappedAddress, bytesSize, scope);
}
}
private static int openOptions(FileChannel.MapMode mapMode) {
if (mapMode == FileChannel.MapMode.READ_ONLY) {
return O_RDONLY;
} else if (mapMode == FileChannel.MapMode.READ_WRITE || mapMode == FileChannel.MapMode.PRIVATE) {
return O_RDWR;
} else {
throw new UnsupportedOperationException("Unsupported map mode: " + mapMode);
}
}
private static int mapProt(FileChannel.MapMode mapMode) {
if (mapMode == FileChannel.MapMode.READ_ONLY) {
return PROT_READ;
} else if (mapMode == FileChannel.MapMode.READ_WRITE || mapMode == FileChannel.MapMode.PRIVATE) {
return PROT_READ | PROT_WRITE;
} else {
throw new UnsupportedOperationException("Unsupported map mode: " + mapMode);
}
}
private static int mapFlags(FileChannel.MapMode mapMode) {
return (mapMode == FileChannel.MapMode.PRIVATE) ?
MAP_PRIVATE : MAP_SHARED;
}
// quick sanity test
static final SequenceLayout tuples = MemoryLayout.sequenceLayout(500,
MemoryLayout.structLayout(
ValueLayout.JAVA_INT.withOrder(ByteOrder.BIG_ENDIAN).withName("index"),
ValueLayout.JAVA_FLOAT.withOrder(ByteOrder.BIG_ENDIAN).withName("value")
));
static final VarHandle indexHandle = tuples.varHandle(
MemoryLayout.PathElement.sequenceElement(), MemoryLayout.PathElement.groupElement("index"));
static final VarHandle valueHandle = tuples.varHandle(
MemoryLayout.PathElement.sequenceElement(), MemoryLayout.PathElement.groupElement("value"));
public static void main(String[] args) throws Throwable {
File f = new File("test2.out");
f.createNewFile();
f.deleteOnExit();
Files.write(f.toPath(), new byte[500], StandardOpenOption.WRITE);
//write to channel
try (ResourceScope scope = ResourceScope.newConfinedScope()) {
MemorySegment segment = map(f.toPath(), tuples.byteSize(), FileChannel.MapMode.READ_WRITE, scope);
initTuples(segment);
}
//read from channel
try (ResourceScope scope = ResourceScope.newConfinedScope()) {
MemorySegment segment = map(f.toPath(), tuples.byteSize(), FileChannel.MapMode.READ_ONLY, scope);
checkTuples(segment, segment.asByteBuffer());
}
}
static void initTuples(MemorySegment segment) {
for (long i = 0; i < tuples.elementCount().getAsLong() ; i++) {
indexHandle.set(segment, i, (int)i);
valueHandle.set(segment, i, (float)(i / 500f));
}
}
static void checkTuples(MemorySegment segment, ByteBuffer bb) {
for (long i = 0; i < tuples.elementCount().getAsLong() ; i++) {
int index = (int)indexHandle.get(segment, i);
float value = (float)valueHandle.get(segment, i);
System.out.println("Tuple { index = " + index + " ; value = " + value + " }");
assertEquals(bb.getInt(), (int)indexHandle.get(segment, i));
assertEquals(bb.getFloat(), (float)valueHandle.get(segment, i));
}
}
static void assertEquals(int i, int y) {
if (i != y) throw new AssertionError();
}
static void assertEquals(float i, float y) {
if (i != y) throw new AssertionError();
}
static void assertNotEquals(int i, int j) {
if (i == j) {
throw new AssertionError();
}
}
}
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment