Created
January 6, 2021 00:12
-
-
Save benvanik/f5a8ff4928f58572882669dd92caced4 to your computer and use it in GitHub Desktop.
WIP api_interfaces_cc.h example for #iree/4369
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
// Copyright 2020 Google LLC | |
// | |
// Licensed under the Apache License, Version 2.0 (the "License"); | |
// you may not use this file except in compliance with the License. | |
// You may obtain a copy of the License at | |
// | |
// https://www.apache.org/licenses/LICENSE-2.0 | |
// | |
// Unless required by applicable law or agreed to in writing, software | |
// distributed under the License is distributed on an "AS IS" BASIS, | |
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |
// See the License for the specific language governing permissions and | |
// limitations under the License. | |
#ifndef IREE_HAL_API_INTERFACES_CC_H_ | |
#define IREE_HAL_API_INTERFACES_CC_H_ | |
#include "iree/base/status.h" | |
#include "iree/base/time.h" | |
#include "iree/hal/api.h" | |
#include "iree/hal/api_interfaces.h" | |
#include "iree/hal/executable_format.h" | |
#ifndef __cplusplus | |
#error "This header is meant for use with C++ HAL implementations." | |
#endif // __cplusplus | |
namespace iree { | |
namespace hal { | |
//===----------------------------------------------------------------------===// | |
// iree_hal_resource_t | |
//===----------------------------------------------------------------------===// | |
template <typename T> | |
class ResourceBase { | |
public: | |
ResourceBase(const ResourceBase&) = delete; | |
ResourceBase& operator=(const ResourceBase&) = delete; | |
// Adds a reference; used by ref_ptr. | |
friend void ref_ptr_add_ref(T* p) { | |
volatile iree_atomic_ref_count_t* counter = p->base()->resource.ref_count; | |
iree_atomic_ref_count_inc(counter); | |
} | |
// Releases a reference, potentially deleting the object; used by ref_ptr. | |
friend void ref_ptr_release_ref(T* p) { | |
volatile iree_atomic_ref_count_t* counter = p->base()->resource.ref_count; | |
if (iree_atomic_ref_count_dec(counter) == 1) { | |
delete p; | |
} | |
} | |
}; | |
//===----------------------------------------------------------------------===// | |
// iree_hal_allocator_t | |
//===----------------------------------------------------------------------===// | |
class AllocatorBase : public ResourceBase<AllocatorBase> { | |
public: | |
virtual ~AllocatorBase() = default; | |
iree_hal_allocator_t* base() const noexcept { return &base_; } | |
virtual bool CheckBufferCompatibility( | |
iree_hal_memory_type_t memory_type, iree_hal_buffer_usage_t allowed_usage, | |
iree_hal_buffer_usage_t intended_usage) = 0; | |
virtual StatusOr<ref_ptr<BufferBase>> AllocateBuffer( | |
iree_hal_memory_type_t memory_type, iree_hal_buffer_usage_t buffer_usage, | |
iree_host_size_t allocation_size) = 0; | |
virtual StatusOr<ref_ptr<BufferBase>> WrapBuffer( | |
iree_hal_memory_type_t memory_type, | |
iree_hal_memory_access_t allowed_access, | |
iree_hal_buffer_usage_t buffer_usage, absl::Span<uint8_t> data) = 0; | |
protected: | |
AllocatorBase() { | |
static const iree_hal_allocator_vtable_t vtable = { | |
DestroyThunk, | |
CheckBufferCompatibilityThunk, | |
AllocateBufferThunk, | |
WrapBufferThunk, | |
}; | |
memset(&base_, 0, sizeof(base_)); | |
base_.vtable = &vtable; | |
iree_atomic_ref_count_init(&base_.resource.ref_count); | |
} | |
private: | |
static void DestroyThunk(iree_hal_allocator_t* allocator) { | |
delete reinterpret_cast<AllocatorBase*>(allocator); | |
} | |
static bool CheckBufferCompatibilityThunk( | |
iree_hal_allocator_t* source_allocator, | |
iree_hal_memory_type_t memory_type, iree_hal_buffer_usage_t allowed_usage, | |
iree_hal_buffer_usage_t intended_usage) { | |
return reinterpret_cast<AllocatorBase*>(allocator) | |
->CheckBufferCompatibility(memory_type, allowed_usage, intended_usage); | |
} | |
static iree_status_t AllocateBufferThunk(iree_hal_allocator_t* allocator, | |
iree_hal_memory_type_t memory_type, | |
iree_hal_buffer_usage_t buffer_usage, | |
iree_host_size_t allocation_size, | |
iree_hal_buffer_t** out_buffer) { | |
IREE_ASSIGN_OR_RETURN( | |
auto buffer, | |
reinterpret_cast<AllocatorBase*>(allocator)->AllocateBuffer( | |
memory_type, buffer_usage, allocation_size)); | |
*out_buffer = reinterpret_cast<iree_hal_buffer_t*>(buffer.release()); | |
return iree_ok_status(); | |
} | |
static iree_status_t WrapBufferThunk(iree_hal_allocator_t* allocator, | |
iree_hal_memory_type_t memory_type, | |
iree_hal_memory_access_t allowed_access, | |
iree_hal_buffer_usage_t buffer_usage, | |
iree_byte_span_t data, | |
iree_hal_buffer_t** out_buffer) { | |
IREE_ASSIGN_OR_RETURN( | |
auto buffer, reinterpret_cast<AllocatorBase*>(allocator)->WrapBuffer( | |
memory_type, allowed_access, buffer_usage, | |
absl::MakeSpan(data.data(), data.size()))); | |
*out_buffer = reinterpret_cast<iree_hal_buffer_t*>(buffer.release()); | |
return iree_ok_status(); | |
} | |
iree_hal_allocator_t base_; | |
}; | |
//===----------------------------------------------------------------------===// | |
// iree_hal_buffer_t | |
//===----------------------------------------------------------------------===// | |
class BufferBase : public ResourceBase<BufferBase> { | |
public: | |
virtual ~BufferBase() = default; | |
iree_hal_buffer_t* base() const noexcept { return &base_; } | |
constexpr iree_hal_allocator_t* allocator() const { | |
return base()->allocator; | |
} | |
iree_hal_memory_type_t memory_type() const { | |
return static_cast<iree_hal_memory_type_t>(base()->memory_type); | |
} | |
iree_hal_memory_access_t allowed_access() const { | |
return static_cast<iree_hal_memory_access_t>(base()->allowed_access); | |
} | |
iree_hal_buffer_usage_t usage() const { | |
return static_cast<iree_hal_buffer_usage_t>(base()->usage); | |
} | |
iree_hal_buffer_t* allocated_buffer() const noexcept { | |
return base()->allocated_buffer; | |
} | |
constexpr iree_device_size_t allocation_size() const { | |
return base()->allocation_size; | |
} | |
constexpr iree_device_size_t byte_offset() const noexcept { | |
return base()->byte_offset; | |
} | |
constexpr iree_device_size_t byte_length() const noexcept { | |
return base()->byte_length; | |
} | |
protected: | |
BufferBase() { | |
static const iree_hal_buffer_vtable_t vtable = { | |
DestroyThunk, | |
FillThunk, | |
ReadDataThunk, | |
WriteDataThunk, | |
CopyDataThunk, | |
MapThunk, | |
UnmapThunk, | |
InvalidateMappedMemoryThunk, | |
FlushMappedMemoryThunk, | |
}; | |
memset(&base_, 0, sizeof(base_)); | |
base_.vtable = &vtable; | |
iree_atomic_ref_count_init(&base_.resource.ref_count); | |
} | |
virtual Status FillImpl(iree_device_size_t byte_offset, | |
iree_device_size_t byte_length, const void* pattern, | |
iree_device_size_t pattern_length) = 0; | |
virtual Status ReadDataImpl(iree_device_size_t source_offset, | |
void* target_data, | |
iree_device_size_t data_length) = 0; | |
virtual Status WriteDataImpl(iree_device_size_t target_offset, | |
const void* source_data, | |
iree_device_size_t data_length) = 0; | |
virtual Status CopyDataImpl(iree_device_size_t target_offset, | |
iree_hal_buffer_t* source_buffer, | |
iree_device_size_t source_offset, | |
iree_device_size_t data_length) = 0; | |
virtual Status MapMemoryImpl(MappingMode mapping_mode, | |
iree_hal_memory_access_t memory_access, | |
iree_device_size_t local_byte_offset, | |
iree_device_size_t local_byte_length, | |
void** out_data) = 0; | |
virtual Status UnmapMemoryImpl(iree_device_size_t local_byte_offset, | |
iree_device_size_t local_byte_length, | |
void* data) = 0; | |
virtual Status InvalidateMappedMemoryImpl( | |
iree_device_size_t local_byte_offset, | |
iree_device_size_t local_byte_length) = 0; | |
virtual Status FlushMappedMemoryImpl( | |
iree_device_size_t local_byte_offset, | |
iree_device_size_t local_byte_length) = 0; | |
private: | |
static void DestroyThunk(iree_hal_buffer_t* buffer) { | |
delete reinterpret_cast<BufferBase*>(buffer); | |
} | |
static iree_status_t FillThunk(iree_hal_buffer_t* buffer, | |
iree_device_size_t byte_offset, | |
iree_device_size_t byte_length, | |
const void* pattern, | |
iree_host_size_t pattern_length) { | |
return reinterpret_cast<BufferBase*>(buffer)->FillImpl( | |
byte_offset, byte_length, pattern, pattern_length); | |
} | |
static iree_status_t ReadDataThunk(iree_hal_buffer_t* buffer, | |
iree_device_size_t source_offset, | |
void* target_buffer, | |
iree_device_size_t data_length) { | |
return reinterpret_cast<BufferBase*>(buffer)->ReadDataImpl( | |
source_offset, target_buffer, data_length); | |
} | |
static iree_status_t WriteDataThunk(iree_hal_buffer_t* buffer, | |
iree_device_size_t target_offset, | |
const void* source_buffer, | |
iree_device_size_t data_length) { | |
return reinterpret_cast<BufferBase*>(buffer)->WriteDataImpl( | |
target_offset, source_buffer, data_length); | |
} | |
static iree_status_t CopyDataThunk(iree_hal_buffer_t* source_buffer, | |
iree_device_size_t source_offset, | |
iree_hal_buffer_t* target_buffer, | |
iree_device_size_t target_offset, | |
iree_device_size_t data_length) { | |
return reinterpret_cast<BufferBase*>(buffer)->CopyDataImpl( | |
target_offset, source_buffer, source_offset, data_length); | |
} | |
static iree_status_t MapThunk(iree_hal_buffer_t* buffer, | |
iree_hal_memory_access_t memory_access, | |
iree_device_size_t byte_offset, | |
iree_device_size_t byte_length, | |
iree_hal_buffer_mapping_t* out_mapped_memory) { | |
// DO NOT SUBMIT | |
} | |
static iree_status_t UnmapThunk(iree_hal_buffer_t* buffer, | |
iree_hal_buffer_mapping_t* mapped_memory) { | |
// DO NOT SUBMIT | |
} | |
static iree_status_t InvalidateMappedMemoryThunk( | |
iree_hal_buffer_mapping_t* mapped_memory, | |
iree_device_size_t local_byte_offset, | |
iree_device_size_t local_byte_length) { | |
// DO NOT SUBMIT | |
} | |
static iree_status_t FlushMappedMemoryThunk( | |
iree_hal_buffer_mapping_t* mapped_memory, | |
iree_device_size_t local_byte_offset, | |
iree_device_size_t local_byte_length) { | |
// DO NOT SUBMIT | |
} | |
iree_hal_buffer_t base_; | |
}; | |
//===----------------------------------------------------------------------===// | |
// iree_hal_command_buffer_t | |
//===----------------------------------------------------------------------===// | |
class CommandBufferBase : public ResourceBase<CommandBufferBase> { | |
public: | |
virtual ~CommandBufferBase() = default; | |
iree_hal_command_buffer_t* base() const noexcept { return &base_; } | |
virtual Status Begin() = 0; | |
virtual Status End() = 0; | |
virtual Status ExecutionBarrier( | |
iree_hal_execution_stage_t source_stage_mask, | |
iree_hal_execution_stage_t target_stage_mask, | |
absl::Span<const iree_hal_memory_barrier_t> memory_barriers, | |
absl::Span<const iree_hal_buffer_barrier_t> buffer_barriers) = 0; | |
virtual Status SignalEvent(iree_hal_event_t* event, | |
iree_hal_execution_stage_t source_stage_mask) = 0; | |
virtual Status ResetEvent(iree_hal_event_t* event, | |
iree_hal_execution_stage_t source_stage_mask) = 0; | |
virtual Status WaitEvents( | |
absl::Span<iree_hal_event_t*> events, | |
iree_hal_execution_stage_t source_stage_mask, | |
iree_hal_execution_stage_t target_stage_mask, | |
absl::Span<const iree_hal_memory_barrier_t> memory_barriers, | |
absl::Span<const iree_hal_buffer_barrier_t> buffer_barriers) = 0; | |
virtual Status FillBuffer(iree_hal_buffer_t* target_buffer, | |
iree_device_size_t target_offset, | |
iree_device_size_t length, const void* pattern, | |
size_t pattern_length) = 0; | |
virtual Status DiscardBuffer(iree_hal_buffer_t* buffer) = 0; | |
virtual Status UpdateBuffer(const void* source_buffer, | |
iree_device_size_t source_offset, | |
iree_hal_buffer_t* target_buffer, | |
iree_device_size_t target_offset, | |
iree_device_size_t length) = 0; | |
virtual Status CopyBuffer(iree_hal_buffer_t* source_buffer, | |
iree_device_size_t source_offset, | |
iree_hal_buffer_t* target_buffer, | |
iree_device_size_t target_offset, | |
iree_device_size_t length) = 0; | |
virtual Status PushConstants(iree_hal_executable_layout_t* executable_layout, | |
size_t offset, | |
absl::Span<const uint32_t> values) = 0; | |
virtual Status PushDescriptorSet( | |
iree_hal_executable_layout_t* executable_layout, int32_t set, | |
absl::Span<const iree_hal_descriptor_set_binding_t> bindings) = 0; | |
virtual Status BindDescriptorSet( | |
iree_hal_executable_layout_t* executable_layout, int32_t set, | |
iree_hal_descriptor_set_t* descriptor_set, | |
absl::Span<const iree_device_size_t> dynamic_offsets) = 0; | |
virtual Status Dispatch(iree_hal_executable_t* executable, | |
int32_t entry_point, | |
std::array<uint32_t, 3> workgroups) = 0; | |
virtual Status DispatchIndirect(iree_hal_executable_t* executable, | |
int32_t entry_point, | |
iree_hal_buffer_t* workgroups_buffer, | |
iree_device_size_t workgroups_offset) = 0; | |
protected: | |
CommandBufferBase() { | |
static const iree_hal_command_buffer_vtable_t vtable = { | |
DestroyThunk, BeginThunk, EndThunk, | |
ExecutionBarrierThunk, SignalEventThunk, ResetEventThunk, | |
WaitEventsThunk, FillBufferThunk, UpdateBufferThunk, | |
CopyBufferThunk, PushConstantsThunk, PushDescriptorSetThunk, | |
BindDescriptorSetThunk, DispatchThunk, DispatchIndirectThunk, | |
}; | |
memset(&base_, 0, sizeof(base_)); | |
base_.vtable = &vtable; | |
iree_atomic_ref_count_init(&base_.resource.ref_count); | |
} | |
private: | |
static void DestroyThunk(iree_hal_command_buffer_t* command_buffer) { | |
delete reinterpret_cast<CommandBufferBase*>(command_buffer); | |
} | |
static iree_status_t BeginThunk(iree_hal_command_buffer_t* command_buffer) { | |
// DO NOT SUBMIT | |
} | |
static iree_status_t EndThunk(iree_hal_command_buffer_t* command_buffer) { | |
// DO NOT SUBMIT | |
} | |
static iree_status_t ExecutionBarrierThunk( | |
iree_hal_command_buffer_t* command_buffer, | |
iree_hal_execution_stage_t source_stage_mask, | |
iree_hal_execution_stage_t target_stage_mask, | |
iree_host_size_t memory_barrier_count, | |
const iree_hal_memory_barrier_t* memory_barriers, | |
iree_host_size_t buffer_barrier_count, | |
const iree_hal_buffer_barrier_t* buffer_barriers) { | |
// DO NOT SUBMIT | |
} | |
static iree_status_t SignalEventThunk( | |
iree_hal_command_buffer_t* command_buffer, iree_hal_event_t* event, | |
iree_hal_execution_stage_t source_stage_mask) { | |
// DO NOT SUBMIT | |
} | |
static iree_status_t ResetEventThunk( | |
iree_hal_command_buffer_t* command_buffer, iree_hal_event_t* event, | |
iree_hal_execution_stage_t source_stage_mask) { | |
// DO NOT SUBMIT | |
} | |
static iree_status_t WaitEventsThunk( | |
iree_host_size_t event_count, const iree_hal_event_t* events, | |
iree_hal_execution_stage_t source_stage_mask, | |
iree_hal_execution_stage_t target_stage_mask, | |
iree_host_size_t memory_barrier_count, | |
const iree_hal_memory_barrier_t* memory_barriers, | |
iree_host_size_t buffer_barrier_count, | |
const iree_hal_buffer_barrier_t* buffer_barriers) { | |
// DO NOT SUBMIT | |
} | |
static iree_status_t FillBufferThunk( | |
iree_hal_command_buffer_t* command_buffer, | |
iree_hal_buffer_t* target_buffer, iree_device_size_t target_offset, | |
iree_device_size_t length, const void* pattern, | |
iree_host_size_t pattern_length) { | |
// DO NOT SUBMIT | |
} | |
static iree_status_t UpdateBufferThunk( | |
iree_hal_command_buffer_t* command_buffer, const void* source_buffer, | |
iree_host_size_t source_offset, iree_hal_buffer_t* target_buffer, | |
iree_device_size_t target_offset, iree_device_size_t length) { | |
// DO NOT SUBMIT | |
} | |
static iree_status_t CopyBufferThunk( | |
iree_hal_command_buffer_t* command_buffer, | |
iree_hal_buffer_t* source_buffer, iree_device_size_t source_offset, | |
iree_hal_buffer_t* target_buffer, iree_device_size_t target_offset, | |
iree_device_size_t length) { | |
// DO NOT SUBMIT | |
} | |
static iree_status_t PushConstantsThunk( | |
iree_hal_command_buffer_t* command_buffer, | |
iree_hal_executable_layout_t* executable_layout, iree_host_size_t offset, | |
const void* values, iree_host_size_t values_length) { | |
// DO NOT SUBMIT | |
} | |
static iree_status_t PushDescriptorSetThunk( | |
iree_hal_command_buffer_t* command_buffer, | |
iree_hal_executable_layout_t* executable_layout, int32_t set, | |
iree_host_size_t binding_count, | |
const iree_hal_descriptor_set_binding_t* bindings) { | |
// DO NOT SUBMIT | |
} | |
static iree_status_t BindDescriptorSetThunk( | |
iree_hal_command_buffer_t* command_buffer, | |
iree_hal_executable_layout_t* executable_layout, int32_t set, | |
iree_hal_descriptor_set_t* descriptor_set, | |
iree_host_size_t dynamic_offset_count, | |
const iree_device_size_t* dynamic_offsets) { | |
// DO NOT SUBMIT | |
} | |
static iree_status_t DispatchThunk(iree_hal_command_buffer_t* command_buffer, | |
iree_hal_executable_t* executable, | |
int32_t entry_point, uint32_t workgroup_x, | |
uint32_t workgroup_y, | |
uint32_t workgroup_z) { | |
// DO NOT SUBMIT | |
} | |
static iree_status_t DispatchIndirectThunk( | |
iree_hal_command_buffer_t* command_buffer, | |
iree_hal_executable_t* executable, int32_t entry_point, | |
iree_hal_buffer_t* workgroups_buffer, | |
iree_device_size_t workgroups_offset) { | |
// DO NOT SUBMIT | |
} | |
iree_hal_command_buffer_t base_; | |
}; | |
//===----------------------------------------------------------------------===// | |
// iree_hal_descriptor_set_t | |
//===----------------------------------------------------------------------===// | |
class DescriptorSetBase : public ResourceBase<DescriptorSetBase> { | |
public: | |
virtual ~DescriptorSetBase() = default; | |
iree_hal_descriptor_set_t* base() const noexcept { return &base_; } | |
protected: | |
DescriptorSetBase() { | |
static const iree_hal_descriptor_set_vtable_t vtable = { | |
DestroyThunk, | |
}; | |
memset(&base_, 0, sizeof(base_)); | |
base_.vtable = &vtable; | |
iree_atomic_ref_count_init(&base_.resource.ref_count); | |
} | |
private: | |
static void DestroyThunk(iree_hal_descriptor_set_t* descriptor_set) { | |
delete reinterpret_cast<DescriptorSetBase*>(descriptor_set); | |
} | |
iree_hal_descriptor_set_t base_; | |
}; | |
//===----------------------------------------------------------------------===// | |
// iree_hal_descriptor_set_layout_t | |
//===----------------------------------------------------------------------===// | |
class DescriptorSetLayoutBase : public ResourceBase<DescriptorSetLayoutBase> { | |
public: | |
virtual ~DescriptorSetLayoutBase() = default; | |
iree_hal_descriptor_set_layout_t* base() const noexcept { return &base_; } | |
protected: | |
DescriptorSetBase() { | |
static const iree_hal_descriptor_set_layout_vtable_t vtable = { | |
DestroyThunk, | |
}; | |
memset(&base_, 0, sizeof(base_)); | |
base_.vtable = &vtable; | |
iree_atomic_ref_count_init(&base_.resource.ref_count); | |
} | |
private: | |
static void DestroyThunk( | |
iree_hal_descriptor_set_layout_t* descriptor_set_layout) { | |
delete reinterpret_cast<DescriptorSetLayoutBase*>(descriptor_set_layout); | |
} | |
iree_hal_descriptor_set_layout_t base_; | |
}; | |
//===----------------------------------------------------------------------===// | |
// iree_hal_device_t | |
//===----------------------------------------------------------------------===// | |
class DeviceBase : public ResourceBase<DeviceBase> { | |
public: | |
virtual ~DeviceBase() = default; | |
iree_hal_device_t* base() const noexcept { return &base_; } | |
protected: | |
DeviceBase(absl::string_view id) { | |
static const iree_hal_device_vtable_t vtable = { | |
DestroyThunk, | |
CreateBufferThunk, | |
CreateCommandBufferThunk, | |
CreateDescriptorSetThunk, | |
CreateDescriptorSetLayoutThunk, | |
CreateEventThunk, | |
CreateExecutableCacheThunk, | |
CreateExecutableLayoutThunk, | |
CreateSemaphoreThunk, | |
QueueSubmitThunk, | |
WaitSemaphoresWithDeadlineThunk, | |
WaitSemaphoresWithTimeoutThunk, | |
WaitIdleWithDeadlineThunk, | |
WaitIdleWithTimeoutThunk, | |
}; | |
memset(&base_, 0, sizeof(base_)); | |
base_.vtable = &vtable; | |
iree_atomic_ref_count_init(&base_.resource.ref_count); | |
base_.id = iree_make_string_view(id.data(), id.size()); | |
} | |
private: | |
static void DestroyThunk(iree_hal_device_t* device) { | |
delete reinterpret_cast<DeviceBase*>(device); | |
} | |
static iree_status_t CreateBufferThunk( | |
iree_hal_device_t* device, iree_hal_command_buffer_mode_t mode, | |
iree_hal_command_category_t command_categories, | |
iree_allocator_t allocator, | |
iree_hal_command_buffer_t** out_command_buffer); | |
static iree_status_t CreateCommandBufferThunk( | |
iree_hal_device_t* device, iree_hal_command_buffer_mode_t mode, | |
iree_hal_command_category_t command_categories, | |
iree_allocator_t allocator, | |
iree_hal_command_buffer_t** out_command_buffer); | |
static iree_status_t CreateDescriptorSetThunk( | |
iree_hal_device_t* device, iree_hal_descriptor_set_layout_t* set_layout, | |
iree_host_size_t binding_count, | |
const iree_hal_descriptor_set_binding_t* bindings, | |
iree_allocator_t allocator, | |
iree_hal_descriptor_set_t** out_descriptor_set); | |
static iree_status_t CreateDescriptorSetLayoutThunk( | |
iree_hal_device_t* device, | |
iree_hal_descriptor_set_layout_usage_type_t usage_type, | |
iree_host_size_t binding_count, | |
const iree_hal_descriptor_set_layout_binding_t* bindings, | |
iree_allocator_t allocator, | |
iree_hal_descriptor_set_layout_t** out_descriptor_set_layout); | |
static iree_status_t CreateEventThunk(iree_hal_device_t* device, | |
iree_allocator_t allocator, | |
iree_hal_event_t** out_event); | |
static iree_status_t CreateExecutableCacheThunk( | |
iree_hal_device_t* device, iree_string_view_t identifier, | |
iree_allocator_t allocator, | |
iree_hal_executable_cache_t** out_executable_cache); | |
static iree_status_t CreateExecutableLayoutThunk( | |
iree_hal_device_t* device, iree_host_size_t set_layout_count, | |
iree_hal_descriptor_set_layout_t** set_layouts, | |
iree_host_size_t push_constants, iree_allocator_t allocator, | |
iree_hal_executable_layout_t** out_executable_layout); | |
static iree_status_t CreateSemaphoreThunk( | |
iree_hal_device_t* device, uint64_t initial_value, | |
iree_allocator_t allocator, iree_hal_semaphore_t** out_semaphore); | |
static iree_status_t QueueSubmitThunk( | |
iree_hal_device_t* device, iree_hal_command_category_t command_categories, | |
uint64_t queue_affinity, iree_host_size_t batch_count, | |
const iree_hal_submission_batch_t* batches); | |
static iree_status_t WaitSemaphoresWithDeadlineThunk( | |
iree_hal_device_t* device, iree_hal_wait_mode_t wait_mode, | |
const iree_hal_semaphore_list_t* semaphore_list, iree_time_t deadline_ns); | |
static iree_status_t WaitSemaphoresWithTimeoutThunk( | |
iree_hal_device_t* device, iree_hal_wait_mode_t wait_mode, | |
const iree_hal_semaphore_list_t* semaphore_list, | |
iree_duration_t timeout_ns); | |
static iree_status_t WaitIdleWithDeadlineThunk(iree_hal_device_t* device, | |
iree_time_t deadline_ns); | |
static iree_status_t WaitIdleWithTimeoutThunk(iree_hal_device_t* device, | |
iree_duration_t timeout_ns); | |
iree_hal_device_t base_; | |
}; | |
//===----------------------------------------------------------------------===// | |
// iree_hal_driver_t | |
//===----------------------------------------------------------------------===// | |
class DriverBase : public ResourceBase<DriverBase> { | |
public: | |
virtual ~DriverBase() = default; | |
iree_hal_driver_t* base() const noexcept { return &base_; } | |
protected: | |
DriverBase() { | |
static const iree_hal_driver_vtable_t vtable = { | |
DestroyThunk, | |
QueryAvailableDevicesThunk, | |
CreateDeviceThunk, | |
CreateDefaultDeviceThunk, | |
}; | |
memset(&base_, 0, sizeof(base_)); | |
base_.vtable = &vtable; | |
iree_atomic_ref_count_init(&base_.resource.ref_count); | |
} | |
private: | |
static void DestroyThunk(iree_hal_driver_t* driver) { | |
delete reinterpret_cast<DriverBase*>(driver); | |
} | |
static iree_status_t QueryAvailableDevicesThunk( | |
iree_hal_driver_t* driver, iree_allocator_t allocator, | |
iree_hal_device_info_t** out_device_infos, | |
iree_host_size_t* out_device_info_count) { | |
// | |
} | |
static iree_status_t CreateDeviceThunk(iree_hal_driver_t* driver, | |
iree_hal_device_id_t device_id, | |
iree_allocator_t allocator, | |
iree_hal_device_t** out_device) { | |
// | |
} | |
static iree_status_t CreateDefaultDeviceThunk( | |
iree_hal_driver_t* driver, iree_allocator_t allocator, | |
iree_hal_device_t** out_device) { | |
// | |
} | |
iree_hal_driver_t base_; | |
}; | |
//===----------------------------------------------------------------------===// | |
// iree_hal_event_t | |
//===----------------------------------------------------------------------===// | |
class EventBase : public ResourceBase<EventBase> { | |
public: | |
virtual ~EventBase() = default; | |
iree_hal_event_t* base() const noexcept { return &base_; } | |
protected: | |
EventBase() { | |
static const iree_hal_event_vtable_t vtable = { | |
DestroyThunk, | |
}; | |
memset(&base_, 0, sizeof(base_)); | |
base_.vtable = &vtable; | |
iree_atomic_ref_count_init(&base_.resource.ref_count); | |
} | |
private: | |
static void DestroyThunk(iree_hal_event_t* event) { | |
delete reinterpret_cast<EventBase*>(event); | |
} | |
iree_hal_event_t base_; | |
}; | |
//===----------------------------------------------------------------------===// | |
// iree_hal_executable_t | |
//===----------------------------------------------------------------------===// | |
class ExecutableBase : public ResourceBase<ExecutableBase> { | |
public: | |
virtual ~ExecutableBase() = default; | |
iree_hal_executable_t* base() const noexcept { return &base_; } | |
protected: | |
ExecutableBase() { | |
static const iree_hal_executable_vtable_t vtable = { | |
DestroyThunk, | |
}; | |
memset(&base_, 0, sizeof(base_)); | |
base_.vtable = &vtable; | |
iree_atomic_ref_count_init(&base_.resource.ref_count); | |
} | |
private: | |
static void DestroyThunk(iree_hal_executable_t* executable) { | |
delete reinterpret_cast<ExecutableBase*>(executable); | |
} | |
iree_hal_executable_t base_; | |
}; | |
//===----------------------------------------------------------------------===// | |
// iree_hal_executable_cache_t | |
//===----------------------------------------------------------------------===// | |
class ExecutableCacheBase : public ResourceBase<ExecutableCacheBase> { | |
public: | |
virtual ~ExecutableCacheBase() = default; | |
iree_hal_executable_cache_t* base() const noexcept { return &base_; } | |
virtual bool CanPrepareFormat(ExecutableFormat format) const = 0; | |
virtual StatusOr<ref_ptr<ExecutableBase>> PrepareExecutable( | |
iree_hal_executable_layout_t* executable_layout, | |
iree_hal_executable_caching_mode_t caching_mode, | |
absl::Span<const uint8_t> executable_data, | |
iree_allocator_t allocator) = 0; | |
protected: | |
ExecutableCacheBase() { | |
static const iree_hal_executable_cache_vtable_t vtable = { | |
DestroyThunk, | |
CanPrepareFormatThunk, | |
PrepareExecutableThunk, | |
}; | |
memset(&base_, 0, sizeof(base_)); | |
base_.vtable = &vtable; | |
iree_atomic_ref_count_init(&base_.resource.ref_count); | |
} | |
private: | |
static void DestroyThunk(iree_hal_executable_cache_t* executable_cache) { | |
delete reinterpret_cast<ExecutableCacheBase*>(executable_cache); | |
} | |
static bool CanPrepareFormatThunk( | |
iree_hal_executable_cache_t* executable_cache, | |
iree_hal_executable_format_t format) { | |
return reinterpret_cast<ExecutableCacheBase*>(executable_cache) | |
->CanPrepareFormat(static_cast<ExecutableFormat>(format)); | |
} | |
static iree_status_t PrepareExecutableThunk( | |
iree_hal_executable_cache_t* executable_cache, | |
iree_hal_executable_layout_t* executable_layout, | |
iree_hal_executable_caching_mode_t caching_mode, | |
iree_const_byte_span_t executable_data, iree_allocator_t allocator, | |
iree_hal_executable_t** out_executable) { | |
IREE_ASSIGN_OR_RETURN( | |
auto executable, | |
reinterpret_cast<ExecutableCacheBase*>(executable_cache) | |
->PrepareExecutable( | |
reinterpret_cast<ExecutableLayoutBase*>(executable_layout), | |
static_cast<iree_hal_executable_caching_mode_t>(caching_mode), | |
absl::MakeConstSpan(executable_data.data, executable_data.size), | |
allocator)); | |
*out_executable = | |
reinterpret_cast<iree_hal_executable_t*>(executable.release()); | |
return iree_ok_status(); | |
} | |
iree_hal_executable_cache_t base_; | |
}; | |
//===----------------------------------------------------------------------===// | |
// iree_hal_executable_layout_t | |
//===----------------------------------------------------------------------===// | |
class ExecutableLayoutBase : public ResourceBase<ExecutableLayoutBase> { | |
public: | |
virtual ~ExecutableLayoutBase() = default; | |
iree_hal_executable_layout_t* base() const noexcept { return &base_; } | |
protected: | |
ExecutableLayoutBase() { | |
static const iree_hal_executable_layout_vtable_t vtable = { | |
DestroyThunk, | |
}; | |
memset(&base_, 0, sizeof(base_)); | |
base_.vtable = &vtable; | |
iree_atomic_ref_count_init(&base_.resource.ref_count); | |
} | |
private: | |
static void DestroyThunk(iree_hal_executable_layout_t* executable_layout) { | |
delete reinterpret_cast<ExecutableLayoutBase*>(executable_layout); | |
} | |
iree_hal_executable_layout_t base_; | |
}; | |
//===----------------------------------------------------------------------===// | |
// iree_hal_semaphore_t | |
//===----------------------------------------------------------------------===// | |
class SemaphoreBase : public ResourceBase<SemaphoreBase> { | |
public: | |
virtual ~SemaphoreBase() = default; | |
iree_hal_semaphore_t* base() const noexcept { return &base_; } | |
virtual StatusOr<uint64_t> Query() = 0; | |
virtual Status Signal(uint64_t value) = 0; | |
virtual void Fail(Status status) = 0; | |
virtual Status Wait(uint64_t value, Time deadline_ns) = 0; | |
virtual Status Wait(uint64_t value, Duration timeout_ns) = 0; | |
protected: | |
SemaphoreBase() { | |
static const iree_hal_semaphore_vtable_t vtable = { | |
DestroyThunk, QueryThunk, SignalThunk, FailThunk, | |
WaitWithDeadlineThunk, WaitWithTimeoutThunk, | |
}; | |
memset(&base_, 0, sizeof(base_)); | |
base_.vtable = &vtable; | |
iree_atomic_ref_count_init(&base_.resource.ref_count); | |
} | |
private: | |
static void DestroyThunk(iree_hal_semaphore_t* semaphore) { | |
delete reinterpret_cast<SemaphoreBase*>(semaphore); | |
} | |
static iree_status_t QueryThunk(iree_hal_semaphore_t* semaphore, | |
uint64_t* out_value) { | |
IREE_ASSIGN_OR_RETURN(uint64_t value, | |
reinterpret_cast<SemaphoreBase*>(semaphore)->Query()); | |
*out_value = value; | |
return iree_ok_status(); | |
} | |
static iree_status_t SignalThunk(iree_hal_semaphore_t* semaphore, | |
uint64_t new_value) { | |
return reinterpret_cast<SemaphoreBase*>(semaphore)->Signal(new_value); | |
} | |
static void FailThunk(iree_hal_semaphore_t* semaphore, iree_status_t status) { | |
reinterpret_cast<SemaphoreBase*>(semaphore)->Fail(Status(status)); | |
} | |
static iree_status_t WaitWithDeadlineThunk(iree_hal_semaphore_t* semaphore, | |
uint64_t value, | |
iree_time_t deadline_ns) { | |
return reinterpret_cast<SemaphoreBase*>(semaphore)->WaitWithDeadline( | |
value, deadline_ns); | |
} | |
static iree_status_t WaitWithTimeoutThunk(iree_hal_semaphore_t* semaphore, | |
uint64_t value, | |
iree_duration_t timeout_ns) { | |
return reinterpret_cast<SemaphoreBase*>(semaphore)->WaitWithTimeout( | |
value, timeout_ns); | |
} | |
iree_hal_semaphore_t base_; | |
}; | |
} // namespace hal | |
} // namespace iree | |
#endif // __cplusplus | |
#endif // IREE_HAL_API_INTERFACES_CC_H_ |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment