Skip to content

Instantly share code, notes, and snippets.

@Aldlevine
Created November 6, 2024 21:20
Show Gist options
  • Save Aldlevine/2b32ac0855e3d72a587017a8e03b66f6 to your computer and use it in GitHub Desktop.
Save Aldlevine/2b32ac0855e3d72a587017a8e03b66f6 to your computer and use it in GitHub Desktop.
Zig virtual interface (dynamic dispatch)

virtual.zig

Provides a mixin and type which enables simple definitions of virtual types with a dynamic dispatch interface.

How to use

To create a virtual type, you must create a struct which:

  • declares:
    pub usingnamespace virtual(access, @This());
  • has a single field:
    vtable: Vtable(access, @This()),
  • and implements the interface for a generic type as a pub fn named iface
    pub fn iface(T: type) type {
        return struct {
            pub inline fn doSomething(self: *T) void {
                return self.vtable.draw(self.vtable.__ptr__, data);
            }
        }
    }

To use a virtual type, you must:

  • Create a type which implements the interface
    const MyImpl = struct {
        pub fn doSomething(self: @This()) void {
        }
    }
  • Init the virtual type with a pointer to your implementation type
    var my_impl: MyImpl = .{};
    var my_virt: MyVirt.init(&my_impl);
    my_virt.doSomething();

Example

/// A type which has function `doSomething(i32) void` and `doSomethingAndReturn(f32) i32`
fn SomethingDoer(access: Access) type {
    return struct {
        pub usingnamespace virtual(access, @This());
        vtable: Vtable(access, @This()),

        pub fn iface(T: type) type {
            return struct {
                pub inline fn doSomethingVar(self: *const T, value: i32) void {
                    return self.vtable.get(self.vtable.__ptr__, value);
                }

                pub inline fn doSomethingConst(self: *T, value: f32) i32 {
                    return self.vtable.set(self.vtable.__ptr__, value);
                }
            };
        }
    };
}

const ExampleDoer = struct {
    value: i32,
    
    pub doSomethingVar(self: *@This(), value: i32) void {
        self.value = value;
    }
    
    pub doSomethingConst(self: * const @This(), value: f32) i32 {
        return @intFromFloat(value) + self.value;
    }
}

fn main() void {
    // declare the virtuals
    var something_doer: SomethingDoer(.variable) = undefined;
    var something_doer_const: SomethingDoer(.constant) = undefined;

    // declare the implementations
    var example_doer = ExampleDoer{.value = 0};
    const example_doer_const = ExampleDoer{.value = 0};    
    
    // use var impl with var virtual
    something_doer = SomethingDoer(.variable).init(example_doer); // works
    something_doer.doSomethingVar(12); // works
    _ = something_doer.doSomethingConst(12.0); // works
    
    // use const impl with const virtual
    something_doer_const = SomethingDoer(.constant).init(example_doer_const); // works
    something_doer_const.doSomethingVar(12); // does not work! attempt to modify const value
    something_doer_const.doSomethingConst(12); // works!
    
    // use const impl with var virtual
    something_doer = SomethingDoer(.variable).init(example_doer_const); // does not work! attempt to set variable virtual with const pointer
    
    // use var impl with const virtual is same as const with const
}
const std = @import("std");
/// Represents constness of a virtual type
pub const Access = enum(u1) { variable, constant };
/// A mixin which makes a virtual type
pub fn virtual(comptime access: Access, T: type) type {
return struct {
pub usingnamespace T.iface(T);
pub fn init(ptr: anytype) T {
comptime {
const t: T = undefined;
if (@typeInfo(@TypeOf(ptr)).Pointer.is_const and !@typeInfo(@TypeOf(t.vtable.__ptr__)).Pointer.is_const) {
@compileError("Cannot assign const ptr to mutable virtual");
}
}
return .{
.vtable = makeVtable(access, T, ptr),
};
}
};
}
/// A vtable for a virtual type
pub fn Vtable(comptime access: Access, T: type) type {
const Iface = T.iface(anyopaque);
const iface_struct_info = @typeInfo(Iface).Struct;
var fields = @typeInfo(struct {}).Struct.fields;
// add __ptr__ field
fields = fields ++ [_]std.builtin.Type.StructField{.{
.alignment = 0,
.default_value = null,
.is_comptime = false,
.name = "__ptr__",
.type = if (access == .variable) (*anyopaque) else (*const anyopaque),
}};
// add vtable fields
for (iface_struct_info.decls) |decl| {
if (std.mem.eql(u8, decl.name, "init")) {
continue;
}
const decl_val = @field(Iface, decl.name);
const decl_val_info = @typeInfo(@TypeOf(decl_val));
switch (decl_val_info) {
.Fn => {
// set calling convention to unspecified so we can mark iface methods as inline.
var fn_type_info = @typeInfo(@TypeOf(decl_val));
fn_type_info.Fn.calling_convention = .Unspecified;
fields = fields ++ [_]std.builtin.Type.StructField{.{
.alignment = 0,
.default_value = null,
.is_comptime = false,
.name = decl.name,
.type = *const @Type(fn_type_info),
}};
},
else => {},
}
}
return @Type(std.builtin.Type{
.Struct = .{
.layout = .auto,
.is_tuple = false,
.decls = &.{},
.fields = fields,
},
});
}
fn makeVirtualFn(
IfaceFn: type,
impl_fn: anytype,
) IfaceFn {
const iface_fn_info = @typeInfo(@typeInfo(IfaceFn).Pointer.child).Fn;
return switch (iface_fn_info.params.len) {
0 => impl_fn,
1 => struct {
fn impl(
p0: iface_fn_info.params[0].type.?,
) iface_fn_info.return_type.? {
return impl_fn(@alignCast(@ptrCast(p0)));
}
}.impl,
2 => struct {
fn impl(
p0: iface_fn_info.params[0].type.?,
p1: iface_fn_info.params[1].type.?,
) iface_fn_info.return_type.? {
return impl_fn(@alignCast(@ptrCast(p0)), p1);
}
}.impl,
3 => struct {
fn impl(
p0: iface_fn_info.params[0].type.?,
p1: iface_fn_info.params[1].type.?,
p2: iface_fn_info.params[2].type.?,
) iface_fn_info.return_type.? {
return impl_fn(@alignCast(@ptrCast(p0)), p1, p2);
}
}.impl,
4 => struct {
fn impl(
p0: iface_fn_info.params[0].type.?,
p1: iface_fn_info.params[1].type.?,
p2: iface_fn_info.params[2].type.?,
p3: iface_fn_info.params[3].type.?,
) iface_fn_info.return_type.? {
return impl_fn(@alignCast(@ptrCast(p0)), p1, p2, p3);
}
}.impl,
5 => struct {
fn impl(
p0: iface_fn_info.params[0].type.?,
p1: iface_fn_info.params[1].type.?,
p2: iface_fn_info.params[2].type.?,
p3: iface_fn_info.params[3].type.?,
p4: iface_fn_info.params[4].type.?,
) iface_fn_info.return_type.? {
return impl_fn(@alignCast(@ptrCast(p0)), p1, p2, p3, p4);
}
}.impl,
6 => struct {
fn impl(
p0: iface_fn_info.params[0].type.?,
p1: iface_fn_info.params[1].type.?,
p2: iface_fn_info.params[2].type.?,
p3: iface_fn_info.params[3].type.?,
p4: iface_fn_info.params[4].type.?,
p5: iface_fn_info.params[5].type.?,
) iface_fn_info.return_type.? {
return impl_fn(@alignCast(@ptrCast(p0)), p1, p2, p3, p4, p5);
}
}.impl,
else => @compileError("Virtual methods only support up to 6 parameters"),
};
}
fn makeVtable(comptime access: Access, T: type, ptr: anytype) Vtable(access, T) {
const Impl = @TypeOf(ptr.*);
const iface_info = @typeInfo(T.iface(anyopaque)).Struct;
var vtable: Vtable(access, T) = undefined;
inline for (iface_info.decls) |decl| {
@field(vtable, decl.name) = makeVirtualFn(
@TypeOf(@field(vtable, decl.name)),
@field(Impl, decl.name),
);
}
vtable.__ptr__ = ptr;
return vtable;
}
// -------------------------------------------------------------------------------------------------
// Testing
// -------------------------------------------------------------------------------------------------
fn SetGet(access: Access) type {
return struct {
pub usingnamespace virtual(access, @This());
vtable: Vtable(access, @This()),
pub fn iface(T: type) type {
return struct {
pub inline fn get(self: *const T) i32 {
return self.vtable.get(self.vtable.__ptr__);
}
pub inline fn set(self: *T, value: i32) void {
self.vtable.set(self.vtable.__ptr__, value);
}
};
}
pub inline fn double(self: *const @This()) i32 {
return self.get() * 2;
}
};
}
test "Virtual" {
const Impl = struct {
value: i32,
pub fn get(self: *const @This()) i32 {
return self.value;
}
pub fn set(self: *@This(), value: i32) void {
self.value = value;
}
};
const Impl2 = struct {
value: i32,
pub fn get(self: *const @This()) i32 {
return self.value;
}
pub fn set(self: *@This(), value: i32) void {
self.value = value * 2;
}
};
const SG = SetGet(.variable);
var sg: SG = undefined;
var impl: Impl = .{ .value = 0 };
sg = SG.init(&impl);
sg.set(10);
try std.testing.expectEqual(sg.get(), 10);
try std.testing.expectEqual(sg.double(), 20);
var impl2: Impl2 = .{ .value = 0 };
sg = SG.init(&impl2);
sg.set(10);
try std.testing.expectEqual(sg.get(), 20);
try std.testing.expectEqual(sg.double(), 40);
}
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment