Bug 1676916 - Implicit bind group layouts in WebGPU r=jgilbert,webidl,smaug

This change updates and enables Naga to get the
SPIRV shaders parsed, validated, and reflected back into
implicit bind group layouts.
WebGPU examples heavily rely on the implicit layouts now,
and the PR also updates the WebIDL to make that possible.
With the change, we are able to run most of the examples again!

Differential Revision: https://phabricator.services.mozilla.com/D96850
This commit is contained in:
Dzmitry Malyshau 2020-11-13 14:15:49 +00:00
parent c4b40283ab
commit 8f74799ba5
67 changed files with 3182 additions and 1186 deletions

View file

@ -50,7 +50,7 @@ rev = "0917fe780032a6bbb23d71be545f9c1834128d75"
[source."https://github.com/gfx-rs/naga"] [source."https://github.com/gfx-rs/naga"]
git = "https://github.com/gfx-rs/naga" git = "https://github.com/gfx-rs/naga"
replace-with = "vendored-sources" replace-with = "vendored-sources"
rev = "aa35110471ee7915e1f4e1de61ea41f2f32f92c4" rev = "4d4e1cd4cbfad2b81264a7239a336b6ec1346611"
[source."https://github.com/djg/cubeb-pulse-rs"] [source."https://github.com/djg/cubeb-pulse-rs"]
git = "https://github.com/djg/cubeb-pulse-rs" git = "https://github.com/djg/cubeb-pulse-rs"

2
Cargo.lock generated
View file

@ -3329,7 +3329,7 @@ checksum = "a2983372caf4480544083767bf2d27defafe32af49ab4df3a0b7fc90793a3664"
[[package]] [[package]]
name = "naga" name = "naga"
version = "0.2.0" version = "0.2.0"
source = "git+https://github.com/gfx-rs/naga?rev=aa35110471ee7915e1f4e1de61ea41f2f32f92c4#aa35110471ee7915e1f4e1de61ea41f2f32f92c4" source = "git+https://github.com/gfx-rs/naga?rev=4d4e1cd4cbfad2b81264a7239a336b6ec1346611#4d4e1cd4cbfad2b81264a7239a336b6ec1346611"
dependencies = [ dependencies = [
"bitflags", "bitflags",
"fxhash", "fxhash",

View file

@ -13,8 +13,11 @@ namespace webgpu {
GPU_IMPL_CYCLE_COLLECTION(ComputePipeline, mParent) GPU_IMPL_CYCLE_COLLECTION(ComputePipeline, mParent)
GPU_IMPL_JS_WRAP(ComputePipeline) GPU_IMPL_JS_WRAP(ComputePipeline)
ComputePipeline::ComputePipeline(Device* const aParent, RawId aId) ComputePipeline::ComputePipeline(Device* const aParent, RawId aId,
: ChildOf(aParent), mId(aId) {} nsTArray<RawId>&& aImplicitBindGroupLayoutIds)
: ChildOf(aParent),
mImplicitBindGroupLayoutIds(std::move(aImplicitBindGroupLayoutIds)),
mId(aId) {}
ComputePipeline::~ComputePipeline() { Cleanup(); } ComputePipeline::~ComputePipeline() { Cleanup(); }
@ -28,5 +31,12 @@ void ComputePipeline::Cleanup() {
} }
} }
already_AddRefed<BindGroupLayout> ComputePipeline::GetBindGroupLayout(
uint32_t index) const {
RefPtr<BindGroupLayout> object =
new BindGroupLayout(mParent, mImplicitBindGroupLayoutIds[index]);
return object.forget();
}
} // namespace webgpu } // namespace webgpu
} // namespace mozilla } // namespace mozilla

View file

@ -12,17 +12,22 @@
namespace mozilla { namespace mozilla {
namespace webgpu { namespace webgpu {
class BindGroupLayout;
class Device; class Device;
class ComputePipeline final : public ObjectBase, public ChildOf<Device> { class ComputePipeline final : public ObjectBase, public ChildOf<Device> {
const nsTArray<RawId> mImplicitBindGroupLayoutIds;
public: public:
GPU_DECL_CYCLE_COLLECTION(ComputePipeline) GPU_DECL_CYCLE_COLLECTION(ComputePipeline)
GPU_DECL_JS_WRAP(ComputePipeline) GPU_DECL_JS_WRAP(ComputePipeline)
ComputePipeline(Device* const aParent, RawId aId);
const RawId mId; const RawId mId;
ComputePipeline(Device* const aParent, RawId aId,
nsTArray<RawId>&& aImplicitBindGroupLayoutIds);
already_AddRefed<BindGroupLayout> GetBindGroupLayout(uint32_t index) const;
private: private:
~ComputePipeline(); ~ComputePipeline();
void Cleanup(); void Cleanup();

View file

@ -194,15 +194,21 @@ already_AddRefed<ShaderModule> Device::CreateShaderModule(
already_AddRefed<ComputePipeline> Device::CreateComputePipeline( already_AddRefed<ComputePipeline> Device::CreateComputePipeline(
const dom::GPUComputePipelineDescriptor& aDesc) { const dom::GPUComputePipelineDescriptor& aDesc) {
RawId id = mBridge->DeviceCreateComputePipeline(mId, aDesc); nsTArray<RawId> implicitBindGroupLayoutIds;
RefPtr<ComputePipeline> object = new ComputePipeline(this, id); RawId id = mBridge->DeviceCreateComputePipeline(mId, aDesc,
&implicitBindGroupLayoutIds);
RefPtr<ComputePipeline> object =
new ComputePipeline(this, id, std::move(implicitBindGroupLayoutIds));
return object.forget(); return object.forget();
} }
already_AddRefed<RenderPipeline> Device::CreateRenderPipeline( already_AddRefed<RenderPipeline> Device::CreateRenderPipeline(
const dom::GPURenderPipelineDescriptor& aDesc) { const dom::GPURenderPipelineDescriptor& aDesc) {
RawId id = mBridge->DeviceCreateRenderPipeline(mId, aDesc); nsTArray<RawId> implicitBindGroupLayoutIds;
RefPtr<RenderPipeline> object = new RenderPipeline(this, id); RawId id = mBridge->DeviceCreateRenderPipeline(mId, aDesc,
&implicitBindGroupLayoutIds);
RefPtr<RenderPipeline> object =
new RenderPipeline(this, id, std::move(implicitBindGroupLayoutIds));
return object.forget(); return object.forget();
} }

View file

@ -13,8 +13,11 @@ namespace webgpu {
GPU_IMPL_CYCLE_COLLECTION(RenderPipeline, mParent) GPU_IMPL_CYCLE_COLLECTION(RenderPipeline, mParent)
GPU_IMPL_JS_WRAP(RenderPipeline) GPU_IMPL_JS_WRAP(RenderPipeline)
RenderPipeline::RenderPipeline(Device* const aParent, RawId aId) RenderPipeline::RenderPipeline(Device* const aParent, RawId aId,
: ChildOf(aParent), mId(aId) {} nsTArray<RawId>&& aImplicitBindGroupLayoutIds)
: ChildOf(aParent),
mImplicitBindGroupLayoutIds(std::move(aImplicitBindGroupLayoutIds)),
mId(aId) {}
RenderPipeline::~RenderPipeline() { Cleanup(); } RenderPipeline::~RenderPipeline() { Cleanup(); }
@ -28,5 +31,12 @@ void RenderPipeline::Cleanup() {
} }
} }
already_AddRefed<BindGroupLayout> RenderPipeline::GetBindGroupLayout(
uint32_t index) const {
RefPtr<BindGroupLayout> object =
new BindGroupLayout(mParent, mImplicitBindGroupLayoutIds[index]);
return object.forget();
}
} // namespace webgpu } // namespace webgpu
} // namespace mozilla } // namespace mozilla

View file

@ -15,14 +15,18 @@ namespace webgpu {
class Device; class Device;
class RenderPipeline final : public ObjectBase, public ChildOf<Device> { class RenderPipeline final : public ObjectBase, public ChildOf<Device> {
const nsTArray<RawId> mImplicitBindGroupLayoutIds;
public: public:
GPU_DECL_CYCLE_COLLECTION(RenderPipeline) GPU_DECL_CYCLE_COLLECTION(RenderPipeline)
GPU_DECL_JS_WRAP(RenderPipeline) GPU_DECL_JS_WRAP(RenderPipeline)
RenderPipeline(Device* const aParent, RawId aId);
const RawId mId; const RawId mId;
RenderPipeline(Device* const aParent, RawId aId,
nsTArray<RawId>&& aImplicitBindGroupLayoutIds);
already_AddRefed<BindGroupLayout> GetBindGroupLayout(uint32_t index) const;
private: private:
virtual ~RenderPipeline(); virtual ~RenderPipeline();
void Cleanup(); void Cleanup();

View file

@ -36,6 +36,7 @@ parent:
async DeviceAction(RawId selfId, ByteBuf buf); async DeviceAction(RawId selfId, ByteBuf buf);
async TextureAction(RawId selfId, ByteBuf buf); async TextureAction(RawId selfId, ByteBuf buf);
async CommandEncoderAction(RawId selfId, ByteBuf buf); async CommandEncoderAction(RawId selfId, ByteBuf buf);
async BumpImplicitBindGroupLayout(RawId pipelineId, bool isCompute, uint32_t index);
async InstanceRequestAdapter(GPURequestAdapterOptions options, RawId[] ids) returns (RawId adapterId); async InstanceRequestAdapter(GPURequestAdapterOptions options, RawId[] ids) returns (RawId adapterId);
async AdapterRequestDevice(RawId selfId, GPUDeviceDescriptor desc, RawId newId); async AdapterRequestDevice(RawId selfId, GPUDeviceDescriptor desc, RawId newId);
@ -69,6 +70,7 @@ parent:
async Shutdown(); async Shutdown();
child: child:
async DropAction(ByteBuf buf);
async FreeAdapter(RawId id); async FreeAdapter(RawId id);
async FreeDevice(RawId id); async FreeDevice(RawId id);
async FreePipelineLayout(RawId id); async FreePipelineLayout(RawId id);

View file

@ -14,10 +14,6 @@ NS_IMPL_CYCLE_COLLECTION(WebGPUChild)
NS_IMPL_CYCLE_COLLECTION_ROOT_NATIVE(WebGPUChild, AddRef) NS_IMPL_CYCLE_COLLECTION_ROOT_NATIVE(WebGPUChild, AddRef)
NS_IMPL_CYCLE_COLLECTION_UNROOT_NATIVE(WebGPUChild, Release) NS_IMPL_CYCLE_COLLECTION_UNROOT_NATIVE(WebGPUChild, Release)
ffi::WGPUByteBuf* ToFFI(ipc::ByteBuf* x) {
return reinterpret_cast<ffi::WGPUByteBuf*>(x);
}
static ffi::WGPUClient* initialize() { static ffi::WGPUClient* initialize() {
ffi::WGPUInfrastructure infra = ffi::wgpu_client_new(); ffi::WGPUInfrastructure infra = ffi::wgpu_client_new();
return infra.client; return infra.client;
@ -376,21 +372,30 @@ RawId WebGPUChild::DeviceCreateShaderModule(
} }
RawId WebGPUChild::DeviceCreateComputePipeline( RawId WebGPUChild::DeviceCreateComputePipeline(
RawId aSelfId, const dom::GPUComputePipelineDescriptor& aDesc) { RawId aSelfId, const dom::GPUComputePipelineDescriptor& aDesc,
nsTArray<RawId>* const aImplicitBindGroupLayoutIds) {
ffi::WGPUComputePipelineDescriptor desc = {}; ffi::WGPUComputePipelineDescriptor desc = {};
nsCString label, entryPoint; nsCString label, entryPoint;
if (aDesc.mLabel.WasPassed()) { if (aDesc.mLabel.WasPassed()) {
LossyCopyUTF16toASCII(aDesc.mLabel.Value(), label); LossyCopyUTF16toASCII(aDesc.mLabel.Value(), label);
desc.label = label.get(); desc.label = label.get();
} }
desc.layout = aDesc.mLayout->mId; if (aDesc.mLayout.WasPassed()) {
desc.layout = aDesc.mLayout.Value().mId;
}
desc.compute_stage.module = aDesc.mComputeStage.mModule->mId; desc.compute_stage.module = aDesc.mComputeStage.mModule->mId;
LossyCopyUTF16toASCII(aDesc.mComputeStage.mEntryPoint, entryPoint); LossyCopyUTF16toASCII(aDesc.mComputeStage.mEntryPoint, entryPoint);
desc.compute_stage.entry_point = entryPoint.get(); desc.compute_stage.entry_point = entryPoint.get();
ByteBuf bb; ByteBuf bb;
RawId id = ffi::wgpu_client_create_compute_pipeline(mClient, aSelfId, &desc, RawId implicit_bgl_ids[WGPUMAX_BIND_GROUPS] = {};
ToFFI(&bb)); RawId id = ffi::wgpu_client_create_compute_pipeline(
mClient, aSelfId, &desc, ToFFI(&bb), implicit_bgl_ids);
for (const auto& cur : implicit_bgl_ids) {
if (!cur) break;
aImplicitBindGroupLayoutIds->AppendElement(cur);
}
if (!SendDeviceAction(aSelfId, std::move(bb))) { if (!SendDeviceAction(aSelfId, std::move(bb))) {
MOZ_CRASH("IPC failure"); MOZ_CRASH("IPC failure");
} }
@ -457,7 +462,8 @@ static ffi::WGPUDepthStencilStateDescriptor ConvertDepthStencilDescriptor(
} }
RawId WebGPUChild::DeviceCreateRenderPipeline( RawId WebGPUChild::DeviceCreateRenderPipeline(
RawId aSelfId, const dom::GPURenderPipelineDescriptor& aDesc) { RawId aSelfId, const dom::GPURenderPipelineDescriptor& aDesc,
nsTArray<RawId>* const aImplicitBindGroupLayoutIds) {
ffi::WGPURenderPipelineDescriptor desc = {}; ffi::WGPURenderPipelineDescriptor desc = {};
nsCString label, vsEntry, fsEntry; nsCString label, vsEntry, fsEntry;
ffi::WGPUProgrammableStageDescriptor vertexStage = {}; ffi::WGPUProgrammableStageDescriptor vertexStage = {};
@ -467,7 +473,10 @@ RawId WebGPUChild::DeviceCreateRenderPipeline(
LossyCopyUTF16toASCII(aDesc.mLabel.Value(), label); LossyCopyUTF16toASCII(aDesc.mLabel.Value(), label);
desc.label = label.get(); desc.label = label.get();
} }
desc.layout = aDesc.mLayout->mId; if (aDesc.mLayout.WasPassed()) {
desc.layout = aDesc.mLayout.Value().mId;
}
vertexStage.module = aDesc.mVertexStage.mModule->mId; vertexStage.module = aDesc.mVertexStage.mModule->mId;
LossyCopyUTF16toASCII(aDesc.mVertexStage.mEntryPoint, vsEntry); LossyCopyUTF16toASCII(aDesc.mVertexStage.mEntryPoint, vsEntry);
vertexStage.entry_point = vsEntry.get(); vertexStage.entry_point = vsEntry.get();
@ -537,14 +546,26 @@ RawId WebGPUChild::DeviceCreateRenderPipeline(
desc.alpha_to_coverage_enabled = aDesc.mAlphaToCoverageEnabled; desc.alpha_to_coverage_enabled = aDesc.mAlphaToCoverageEnabled;
ByteBuf bb; ByteBuf bb;
RawId id = ffi::wgpu_client_create_render_pipeline(mClient, aSelfId, &desc, RawId implicit_bgl_ids[WGPUMAX_BIND_GROUPS] = {};
ToFFI(&bb)); RawId id = ffi::wgpu_client_create_render_pipeline(
mClient, aSelfId, &desc, ToFFI(&bb), implicit_bgl_ids);
for (const auto& cur : implicit_bgl_ids) {
if (!cur) break;
aImplicitBindGroupLayoutIds->AppendElement(cur);
}
if (!SendDeviceAction(aSelfId, std::move(bb))) { if (!SendDeviceAction(aSelfId, std::move(bb))) {
MOZ_CRASH("IPC failure"); MOZ_CRASH("IPC failure");
} }
return id; return id;
} }
ipc::IPCResult WebGPUChild::RecvDropAction(const ipc::ByteBuf& aByteBuf) {
const auto* byteBuf = ToFFI(&aByteBuf);
ffi::wgpu_client_drop_action(mClient, byteBuf);
return IPC_OK();
}
ipc::IPCResult WebGPUChild::RecvFreeAdapter(RawId id) { ipc::IPCResult WebGPUChild::RecvFreeAdapter(RawId id) {
ffi::wgpu_client_kill_adapter_id(mClient, id); ffi::wgpu_client_kill_adapter_id(mClient, id);
return IPC_OK(); return IPC_OK();

View file

@ -64,9 +64,11 @@ class WebGPUChild final : public PWebGPUChild, public SupportsWeakPtr {
RawId DeviceCreateShaderModule(RawId aSelfId, RawId DeviceCreateShaderModule(RawId aSelfId,
const dom::GPUShaderModuleDescriptor& aDesc); const dom::GPUShaderModuleDescriptor& aDesc);
RawId DeviceCreateComputePipeline( RawId DeviceCreateComputePipeline(
RawId aSelfId, const dom::GPUComputePipelineDescriptor& aDesc); RawId aSelfId, const dom::GPUComputePipelineDescriptor& aDesc,
nsTArray<RawId>* const aImplicitBindGroupLayoutIds);
RawId DeviceCreateRenderPipeline( RawId DeviceCreateRenderPipeline(
RawId aSelfId, const dom::GPURenderPipelineDescriptor& aDesc); RawId aSelfId, const dom::GPURenderPipelineDescriptor& aDesc,
nsTArray<RawId>* const aImplicitBindGroupLayoutIds);
void DeviceCreateSwapChain(RawId aSelfId, const RGBDescriptor& aRgbDesc, void DeviceCreateSwapChain(RawId aSelfId, const RGBDescriptor& aRgbDesc,
size_t maxBufferCount, size_t maxBufferCount,
@ -96,6 +98,7 @@ class WebGPUChild final : public PWebGPUChild, public SupportsWeakPtr {
bool mIPCOpen; bool mIPCOpen;
public: public:
ipc::IPCResult RecvDropAction(const ipc::ByteBuf& aByteBuf);
ipc::IPCResult RecvFreeAdapter(RawId id); ipc::IPCResult RecvFreeAdapter(RawId id);
ipc::IPCResult RecvFreeDevice(RawId id); ipc::IPCResult RecvFreeDevice(RawId id);
ipc::IPCResult RecvFreePipelineLayout(RawId id); ipc::IPCResult RecvFreePipelineLayout(RawId id);

View file

@ -173,6 +173,8 @@ ipc::IPCResult WebGPUParent::RecvInstanceRequestAdapter(
ipc::IPCResult WebGPUParent::RecvAdapterRequestDevice( ipc::IPCResult WebGPUParent::RecvAdapterRequestDevice(
RawId aSelfId, const dom::GPUDeviceDescriptor& aDesc, RawId aNewId) { RawId aSelfId, const dom::GPUDeviceDescriptor& aDesc, RawId aNewId) {
ffi::WGPUDeviceDescriptor desc = {}; ffi::WGPUDeviceDescriptor desc = {};
desc.shader_validation = true; // required for implicit pipeline layouts
if (aDesc.mLimits.WasPassed()) { if (aDesc.mLimits.WasPassed()) {
const auto& lim = aDesc.mLimits.Value(); const auto& lim = aDesc.mLimits.Value();
desc.limits.max_bind_groups = lim.mMaxBindGroups; desc.limits.max_bind_groups = lim.mMaxBindGroups;
@ -194,7 +196,7 @@ ipc::IPCResult WebGPUParent::RecvAdapterRequestDevice(
} else { } else {
ffi::wgpu_server_fill_default_limits(&desc.limits); ffi::wgpu_server_fill_default_limits(&desc.limits);
} }
// TODO: fill up the descriptor
ffi::wgpu_server_adapter_request_device(mContext, aSelfId, &desc, aNewId); ffi::wgpu_server_adapter_request_device(mContext, aSelfId, &desc, aNewId);
return IPC_OK(); return IPC_OK();
} }
@ -591,22 +593,40 @@ ipc::IPCResult WebGPUParent::RecvShutdown() {
ipc::IPCResult WebGPUParent::RecvDeviceAction(RawId aSelf, ipc::IPCResult WebGPUParent::RecvDeviceAction(RawId aSelf,
const ipc::ByteBuf& aByteBuf) { const ipc::ByteBuf& aByteBuf) {
ffi::wgpu_server_device_action( ipc::ByteBuf byteBuf;
mContext, aSelf, reinterpret_cast<const ffi::WGPUByteBuf*>(&aByteBuf)); ffi::wgpu_server_device_action(mContext, aSelf, ToFFI(&aByteBuf),
ToFFI(&byteBuf));
if (byteBuf.mData) {
if (!SendDropAction(std::move(byteBuf))) {
NS_WARNING("Unable to set a drop action!");
}
}
return IPC_OK(); return IPC_OK();
} }
ipc::IPCResult WebGPUParent::RecvTextureAction(RawId aSelf, ipc::IPCResult WebGPUParent::RecvTextureAction(RawId aSelf,
const ipc::ByteBuf& aByteBuf) { const ipc::ByteBuf& aByteBuf) {
ffi::wgpu_server_texture_action( ffi::wgpu_server_texture_action(mContext, aSelf, ToFFI(&aByteBuf));
mContext, aSelf, reinterpret_cast<const ffi::WGPUByteBuf*>(&aByteBuf));
return IPC_OK(); return IPC_OK();
} }
ipc::IPCResult WebGPUParent::RecvCommandEncoderAction( ipc::IPCResult WebGPUParent::RecvCommandEncoderAction(
RawId aSelf, const ipc::ByteBuf& aByteBuf) { RawId aSelf, const ipc::ByteBuf& aByteBuf) {
ffi::wgpu_server_command_encoder_action( ffi::wgpu_server_command_encoder_action(mContext, aSelf, ToFFI(&aByteBuf));
mContext, aSelf, reinterpret_cast<const ffi::WGPUByteBuf*>(&aByteBuf)); return IPC_OK();
}
ipc::IPCResult WebGPUParent::RecvBumpImplicitBindGroupLayout(RawId pipelineId,
bool isCompute,
uint32_t index) {
if (isCompute) {
ffi::wgpu_server_compute_pipeline_get_bind_group_layout(mContext,
pipelineId, index);
} else {
ffi::wgpu_server_render_pipeline_get_bind_group_layout(mContext, pipelineId,
index);
}
return IPC_OK(); return IPC_OK();
} }

View file

@ -70,6 +70,9 @@ class WebGPUParent final : public PWebGPUParent {
ipc::IPCResult RecvTextureAction(RawId aSelf, const ipc::ByteBuf& aByteBuf); ipc::IPCResult RecvTextureAction(RawId aSelf, const ipc::ByteBuf& aByteBuf);
ipc::IPCResult RecvCommandEncoderAction(RawId aSelf, ipc::IPCResult RecvCommandEncoderAction(RawId aSelf,
const ipc::ByteBuf& aByteBuf); const ipc::ByteBuf& aByteBuf);
ipc::IPCResult RecvBumpImplicitBindGroupLayout(RawId pipelineId,
bool isCompute,
uint32_t index);
ipc::IPCResult RecvShutdown(); ipc::IPCResult RecvShutdown();

View file

@ -31,8 +31,7 @@ DEFINE_IPC_SERIALIZER_WITHOUT_FIELDS(mozilla::dom::GPUCommandBufferDescriptor);
DEFINE_IPC_SERIALIZER_WITH_FIELDS(mozilla::dom::GPURequestAdapterOptions, DEFINE_IPC_SERIALIZER_WITH_FIELDS(mozilla::dom::GPURequestAdapterOptions,
mPowerPreference); mPowerPreference);
DEFINE_IPC_SERIALIZER_WITH_FIELDS(mozilla::dom::GPUExtensions, DEFINE_IPC_SERIALIZER_WITHOUT_FIELDS(mozilla::dom::GPUExtensions);
mAnisotropicFiltering);
DEFINE_IPC_SERIALIZER_WITH_FIELDS(mozilla::dom::GPULimits, mMaxBindGroups); DEFINE_IPC_SERIALIZER_WITH_FIELDS(mozilla::dom::GPULimits, mMaxBindGroups);
DEFINE_IPC_SERIALIZER_WITH_FIELDS(mozilla::dom::GPUDeviceDescriptor, DEFINE_IPC_SERIALIZER_WITH_FIELDS(mozilla::dom::GPUDeviceDescriptor,
mExtensions, mLimits); mExtensions, mLimits);

View file

@ -100,7 +100,6 @@ interface GPUAdapter {
GPUAdapter includes GPUObjectBase; GPUAdapter includes GPUObjectBase;
dictionary GPUExtensions { dictionary GPUExtensions {
boolean anisotropicFiltering = false;
}; };
dictionary GPULimits { dictionary GPULimits {
@ -412,7 +411,8 @@ GPUSampler includes GPUObjectBase;
enum GPUTextureComponentType { enum GPUTextureComponentType {
"float", "float",
"sint", "sint",
"uint" "uint",
"depth-comparison"
}; };
// **************************************************************************** // ****************************************************************************
@ -659,7 +659,11 @@ GPUShaderModule includes GPUObjectBase;
// Common stuff for ComputePipeline and RenderPipeline // Common stuff for ComputePipeline and RenderPipeline
dictionary GPUPipelineDescriptorBase : GPUObjectDescriptorBase { dictionary GPUPipelineDescriptorBase : GPUObjectDescriptorBase {
required GPUPipelineLayout layout; GPUPipelineLayout layout;
};
interface mixin GPUPipelineBase {
GPUBindGroupLayout getBindGroupLayout(unsigned long index);
}; };
dictionary GPUProgrammableStageDescriptor { dictionary GPUProgrammableStageDescriptor {
@ -677,6 +681,7 @@ dictionary GPUComputePipelineDescriptor : GPUPipelineDescriptorBase {
interface GPUComputePipeline { interface GPUComputePipeline {
}; };
GPUComputePipeline includes GPUObjectBase; GPUComputePipeline includes GPUObjectBase;
GPUComputePipeline includes GPUPipelineBase;
// GPURenderPipeline // GPURenderPipeline
enum GPUPrimitiveTopology { enum GPUPrimitiveTopology {
@ -727,6 +732,7 @@ dictionary GPURenderPipelineDescriptor : GPUPipelineDescriptorBase {
interface GPURenderPipeline { interface GPURenderPipeline {
}; };
GPURenderPipeline includes GPUObjectBase; GPURenderPipeline includes GPUObjectBase;
GPURenderPipeline includes GPUPipelineBase;
// **************************************************************************** // ****************************************************************************
// COMMAND RECORDING (Command buffer and all relevant structures) // COMMAND RECORDING (Command buffer and all relevant structures)

34
gfx/wgpu/Cargo.lock generated
View file

@ -316,7 +316,7 @@ version = "2.1.1"
source = "registry+https://github.com/rust-lang/crates.io-index" source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "cb582b60359da160a9477ee80f15c8d784c477e69c217ef2cdd4169c24ea380f" checksum = "cb582b60359da160a9477ee80f15c8d784c477e69c217ef2cdd4169c24ea380f"
dependencies = [ dependencies = [
"proc-macro2 1.0.18", "proc-macro2 1.0.24",
"quote 1.0.7", "quote 1.0.7",
"syn", "syn",
] ]
@ -867,7 +867,7 @@ dependencies = [
[[package]] [[package]]
name = "naga" name = "naga"
version = "0.2.0" version = "0.2.0"
source = "git+https://github.com/gfx-rs/naga?rev=aa35110471ee7915e1f4e1de61ea41f2f32f92c4#aa35110471ee7915e1f4e1de61ea41f2f32f92c4" source = "git+https://github.com/gfx-rs/naga?rev=4d4e1cd4cbfad2b81264a7239a336b6ec1346611#4d4e1cd4cbfad2b81264a7239a336b6ec1346611"
dependencies = [ dependencies = [
"bitflags", "bitflags",
"fxhash", "fxhash",
@ -969,7 +969,7 @@ source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "ffa5a33ddddfee04c0283a7653987d634e880347e96b5b2ed64de07efb59db9d" checksum = "ffa5a33ddddfee04c0283a7653987d634e880347e96b5b2ed64de07efb59db9d"
dependencies = [ dependencies = [
"proc-macro-crate", "proc-macro-crate",
"proc-macro2 1.0.18", "proc-macro2 1.0.24",
"quote 1.0.7", "quote 1.0.7",
"syn", "syn",
] ]
@ -1117,9 +1117,9 @@ dependencies = [
[[package]] [[package]]
name = "proc-macro2" name = "proc-macro2"
version = "1.0.18" version = "1.0.24"
source = "registry+https://github.com/rust-lang/crates.io-index" source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "beae6331a816b1f65d04c45b078fd8e6c93e8071771f41b8163255bbd8d7c8fa" checksum = "1e0704ee1a7e00d7bb417d0770ea303c1bccbabf0ef1667dae92b5967f5f8a71"
dependencies = [ dependencies = [
"unicode-xid 0.2.0", "unicode-xid 0.2.0",
] ]
@ -1145,7 +1145,7 @@ version = "1.0.7"
source = "registry+https://github.com/rust-lang/crates.io-index" source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "aa563d17ecb180e500da1cfd2b028310ac758de548efdd203e18f283af693f37" checksum = "aa563d17ecb180e500da1cfd2b028310ac758de548efdd203e18f283af693f37"
dependencies = [ dependencies = [
"proc-macro2 1.0.18", "proc-macro2 1.0.24",
] ]
[[package]] [[package]]
@ -1315,7 +1315,7 @@ version = "1.0.111"
source = "registry+https://github.com/rust-lang/crates.io-index" source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "3f2c3ac8e6ca1e9c80b8be1023940162bf81ae3cffbb1809474152f2ce1eb250" checksum = "3f2c3ac8e6ca1e9c80b8be1023940162bf81ae3cffbb1809474152f2ce1eb250"
dependencies = [ dependencies = [
"proc-macro2 1.0.18", "proc-macro2 1.0.24",
"quote 1.0.7", "quote 1.0.7",
"syn", "syn",
] ]
@ -1409,11 +1409,11 @@ dependencies = [
[[package]] [[package]]
name = "syn" name = "syn"
version = "1.0.31" version = "1.0.48"
source = "registry+https://github.com/rust-lang/crates.io-index" source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "b5304cfdf27365b7585c25d4af91b35016ed21ef88f17ced89c7093b43dba8b6" checksum = "cc371affeffc477f42a221a1e4297aedcea33d47d19b61455588bd9d8f6b19ac"
dependencies = [ dependencies = [
"proc-macro2 1.0.18", "proc-macro2 1.0.24",
"quote 1.0.7", "quote 1.0.7",
"unicode-xid 0.2.0", "unicode-xid 0.2.0",
] ]
@ -1429,20 +1429,20 @@ dependencies = [
[[package]] [[package]]
name = "thiserror" name = "thiserror"
version = "1.0.20" version = "1.0.22"
source = "registry+https://github.com/rust-lang/crates.io-index" source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "7dfdd070ccd8ccb78f4ad66bf1982dc37f620ef696c6b5028fe2ed83dd3d0d08" checksum = "0e9ae34b84616eedaaf1e9dd6026dbe00dcafa92aa0c8077cb69df1fcfe5e53e"
dependencies = [ dependencies = [
"thiserror-impl", "thiserror-impl",
] ]
[[package]] [[package]]
name = "thiserror-impl" name = "thiserror-impl"
version = "1.0.20" version = "1.0.22"
source = "registry+https://github.com/rust-lang/crates.io-index" source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "bd80fc12f73063ac132ac92aceea36734f04a1d93c1240c6944e23a3b8841793" checksum = "9ba20f23e85b10754cd195504aebf6a27e2e6cbe28c17778a0c930724628dd56"
dependencies = [ dependencies = [
"proc-macro2 1.0.18", "proc-macro2 1.0.24",
"quote 1.0.7", "quote 1.0.7",
"syn", "syn",
] ]
@ -1600,7 +1600,7 @@ dependencies = [
"bumpalo", "bumpalo",
"lazy_static", "lazy_static",
"log", "log",
"proc-macro2 1.0.18", "proc-macro2 1.0.24",
"quote 1.0.7", "quote 1.0.7",
"syn", "syn",
"wasm-bindgen-shared", "wasm-bindgen-shared",
@ -1622,7 +1622,7 @@ version = "0.2.63"
source = "registry+https://github.com/rust-lang/crates.io-index" source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "3156052d8ec77142051a533cdd686cba889537b213f948cd1d20869926e68e92" checksum = "3156052d8ec77142051a533cdd686cba889537b213f948cd1d20869926e68e92"
dependencies = [ dependencies = [
"proc-macro2 1.0.18", "proc-macro2 1.0.24",
"quote 1.0.7", "quote 1.0.7",
"syn", "syn",
"wasm-bindgen-backend", "wasm-bindgen-backend",

View file

@ -40,7 +40,7 @@ gfx-memory = "0.2"
[dependencies.naga] [dependencies.naga]
version = "0.2" version = "0.2"
git = "https://github.com/gfx-rs/naga" git = "https://github.com/gfx-rs/naga"
rev = "aa35110471ee7915e1f4e1de61ea41f2f32f92c4" rev = "4d4e1cd4cbfad2b81264a7239a336b6ec1346611"
features = ["spv-in", "spv-out", "wgsl-in"] features = ["spv-in", "spv-out", "wgsl-in"]
[dependencies.wgt] [dependencies.wgt]

View file

@ -2512,7 +2512,7 @@ impl<G: GlobalIdentityHandlerFactory> Global<G> {
let module = if device.private_features.shader_validation { let module = if device.private_features.shader_validation {
// Parse the given shader code and store its representation. // Parse the given shader code and store its representation.
let spv_iter = spv.into_iter().cloned(); let spv_iter = spv.into_iter().cloned();
naga::front::spv::Parser::new(spv_iter) naga::front::spv::Parser::new(spv_iter, &Default::default())
.parse() .parse()
.map_err(|err| { .map_err(|err| {
// TODO: eventually, when Naga gets support for all features, // TODO: eventually, when Naga gets support for all features,

View file

@ -132,14 +132,18 @@ fn get_aligned_type_size(
Ti::Pointer { .. } => 4, Ti::Pointer { .. } => 4,
Ti::Array { Ti::Array {
base, base,
size: naga::ArraySize::Static(count), size: naga::ArraySize::Constant(const_handle),
stride, stride,
} => { } => {
let base_size = match stride { let base_size = match stride {
Some(stride) => stride.get() as wgt::BufferAddress, Some(stride) => stride.get() as wgt::BufferAddress,
None => get_aligned_type_size(module, base, false), None => get_aligned_type_size(module, base, false),
}; };
base_size * count as wgt::BufferAddress let count = match module.constants[const_handle].inner {
naga::ConstantInner::Uint(count) => count,
ref other => panic!("Unexpected array size {:?}", other),
};
base_size * count
} }
Ti::Array { Ti::Array {
base, base,
@ -786,7 +790,7 @@ fn derive_binding_type(
dynamic, dynamic,
min_binding_size: wgt::BufferSize::new(actual_size), min_binding_size: wgt::BufferSize::new(actual_size),
}, },
naga::StorageClass::StorageBuffer => BindingType::StorageBuffer { naga::StorageClass::Storage => BindingType::StorageBuffer {
dynamic, dynamic,
min_binding_size: wgt::BufferSize::new(actual_size), min_binding_size: wgt::BufferSize::new(actual_size),
readonly: !usage.contains(naga::GlobalUse::STORE), readonly: !usage.contains(naga::GlobalUse::STORE),

View file

@ -4,7 +4,6 @@
// The `broken_intra_doc_links` is a new name, and will fail if built on the old compiler. // The `broken_intra_doc_links` is a new name, and will fail if built on the old compiler.
#![allow(unknown_lints)] #![allow(unknown_lints)]
// The intra doc links to the wgpu crate in this crate actually succesfully link to the types in the wgpu crate, when built from the wgpu crate. // The intra doc links to the wgpu crate in this crate actually succesfully link to the types in the wgpu crate, when built from the wgpu crate.
// However when building from both the wgpu crate or this crate cargo doc will claim all the links cannot be resolved // However when building from both the wgpu crate or this crate cargo doc will claim all the links cannot be resolved
// despite the fact that it works fine when it needs to. // despite the fact that it works fine when it needs to.

View file

@ -16,6 +16,7 @@ typedef uint8_t WGPUOption_NonZeroU8;
typedef uint64_t WGPUOption_AdapterId; typedef uint64_t WGPUOption_AdapterId;
typedef uint64_t WGPUOption_BufferId; typedef uint64_t WGPUOption_BufferId;
typedef uint64_t WGPUOption_PipelineLayoutId; typedef uint64_t WGPUOption_PipelineLayoutId;
typedef uint64_t WGPUOption_BindGroupLayoutId;
typedef uint64_t WGPUOption_SamplerId; typedef uint64_t WGPUOption_SamplerId;
typedef uint64_t WGPUOption_SurfaceId; typedef uint64_t WGPUOption_SurfaceId;
typedef uint64_t WGPUOption_TextureViewId; typedef uint64_t WGPUOption_TextureViewId;
@ -30,7 +31,8 @@ style = "tag"
[export] [export]
prefix = "WGPU" prefix = "WGPU"
exclude = [ exclude = [
"Option_AdapterId", "Option_BufferId", "Option_PipelineLayoutId", "Option_SamplerId", "Option_SurfaceId", "Option_TextureViewId", "Option_AdapterId", "Option_BufferId", "Option_PipelineLayoutId", "Option_BindGroupLayoutId",
"Option_SamplerId", "Option_SurfaceId", "Option_TextureViewId",
"Option_BufferSize", "Option_NonZeroU32", "Option_NonZeroU8", "Option_BufferSize", "Option_NonZeroU32", "Option_NonZeroU8",
] ]

View file

@ -2,7 +2,10 @@
* License, v. 2.0. If a copy of the MPL was not distributed with this * License, v. 2.0. If a copy of the MPL was not distributed with this
* file, You can obtain one at http://mozilla.org/MPL/2.0/. */ * file, You can obtain one at http://mozilla.org/MPL/2.0/. */
use crate::{cow_label, ByteBuf, CommandEncoderAction, DeviceAction, RawString, TextureAction}; use crate::{
cow_label, ByteBuf, CommandEncoderAction, DeviceAction, DropAction, ImplicitLayout, RawString,
TextureAction,
};
use wgc::{hub::IdentityManager, id}; use wgc::{hub::IdentityManager, id};
use wgt::Backend; use wgt::Backend;
@ -13,20 +16,13 @@ use parking_lot::Mutex;
use std::{ use std::{
borrow::Cow, borrow::Cow,
mem,
num::{NonZeroU32, NonZeroU8}, num::{NonZeroU32, NonZeroU8},
ptr, slice, ptr, slice,
}; };
fn make_byte_buf<T: serde::Serialize>(data: &T) -> ByteBuf { fn make_byte_buf<T: serde::Serialize>(data: &T) -> ByteBuf {
let vec = bincode::serialize(data).unwrap(); let vec = bincode::serialize(data).unwrap();
let bb = ByteBuf { ByteBuf::from_vec(vec)
data: vec.as_ptr(),
len: vec.len(),
capacity: vec.capacity(),
};
mem::forget(vec);
bb
} }
#[repr(C)] #[repr(C)]
@ -191,6 +187,19 @@ struct IdentityHub {
samplers: IdentityManager, samplers: IdentityManager,
} }
impl ImplicitLayout<'_> {
fn new(identities: &mut IdentityHub, backend: Backend) -> Self {
ImplicitLayout {
pipeline: identities.pipeline_layouts.alloc(backend),
bind_groups: Cow::Owned(
(0..wgc::MAX_BIND_GROUPS)
.map(|_| identities.bind_group_layouts.alloc(backend))
.collect(),
),
}
}
}
#[derive(Debug, Default)] #[derive(Debug, Default)]
struct Identities { struct Identities {
surfaces: IdentityManager, surfaces: IdentityManager,
@ -219,6 +228,22 @@ pub struct Client {
identities: Mutex<Identities>, identities: Mutex<Identities>,
} }
#[no_mangle]
pub unsafe extern "C" fn wgpu_client_drop_action(client: &mut Client, byte_buf: &ByteBuf) {
let mut cursor = std::io::Cursor::new(byte_buf.as_slice());
let mut identities = client.identities.lock();
while let Ok(action) = bincode::deserialize_from(&mut cursor) {
match action {
DropAction::Buffer(id) => identities.select(id.backend()).buffers.free(id),
DropAction::Texture(id) => identities.select(id.backend()).textures.free(id),
DropAction::Sampler(id) => identities.select(id.backend()).samplers.free(id),
DropAction::BindGroupLayout(id) => {
identities.select(id.backend()).bind_group_layouts.free(id)
}
}
}
}
#[repr(C)] #[repr(C)]
#[derive(Debug)] #[derive(Debug)]
pub struct Infrastructure { pub struct Infrastructure {
@ -613,8 +638,13 @@ pub unsafe extern "C" fn wgpu_client_create_bind_group_layout(
RawBindingType::Sampler => wgt::BindingType::Sampler { comparison: false }, RawBindingType::Sampler => wgt::BindingType::Sampler { comparison: false },
RawBindingType::ComparisonSampler => wgt::BindingType::Sampler { comparison: true }, RawBindingType::ComparisonSampler => wgt::BindingType::Sampler { comparison: true },
RawBindingType::SampledTexture => wgt::BindingType::SampledTexture { RawBindingType::SampledTexture => wgt::BindingType::SampledTexture {
dimension: *entry.view_dimension.unwrap(), //TODO: the spec has a bug here
component_type: *entry.texture_component_type.unwrap(), dimension: *entry
.view_dimension
.unwrap_or(&wgt::TextureViewDimension::D2),
component_type: *entry
.texture_component_type
.unwrap_or(&wgt::TextureComponentType::Float),
multisampled: entry.multisampled, multisampled: entry.multisampled,
}, },
RawBindingType::ReadonlyStorageTexture => wgt::BindingType::StorageTexture { RawBindingType::ReadonlyStorageTexture => wgt::BindingType::StorageTexture {
@ -763,12 +793,15 @@ pub unsafe extern "C" fn wgpu_client_create_shader_module(
.alloc(backend); .alloc(backend);
assert!(!desc.spirv_words.is_null()); assert!(!desc.spirv_words.is_null());
let data = Cow::Borrowed(slice::from_raw_parts( let spv = Cow::Borrowed(if desc.spirv_words.is_null() {
desc.spirv_words, &[][..]
desc.spirv_words_length, } else {
)); slice::from_raw_parts(desc.spirv_words, desc.spirv_words_length)
});
let action = DeviceAction::CreateShaderModule(id, data); let wgsl = cow_label(&desc.wgsl_chars).unwrap_or_default();
let action = DeviceAction::CreateShaderModule(id, spv, wgsl);
*bb = make_byte_buf(&action); *bb = make_byte_buf(&action);
id id
} }
@ -789,14 +822,11 @@ pub unsafe extern "C" fn wgpu_client_create_compute_pipeline(
device_id: id::DeviceId, device_id: id::DeviceId,
desc: &ComputePipelineDescriptor, desc: &ComputePipelineDescriptor,
bb: &mut ByteBuf, bb: &mut ByteBuf,
implicit_bind_group_layout_ids: *mut Option<id::BindGroupLayoutId>,
) -> id::ComputePipelineId { ) -> id::ComputePipelineId {
let backend = device_id.backend(); let backend = device_id.backend();
let id = client let mut identities = client.identities.lock();
.identities let id = identities.select(backend).compute_pipelines.alloc(backend);
.lock()
.select(backend)
.compute_pipelines
.alloc(backend);
let wgpu_desc = wgc::pipeline::ComputePipelineDescriptor { let wgpu_desc = wgc::pipeline::ComputePipelineDescriptor {
label: cow_label(&desc.label), label: cow_label(&desc.label),
@ -804,7 +834,18 @@ pub unsafe extern "C" fn wgpu_client_create_compute_pipeline(
compute_stage: desc.compute_stage.to_wgpu(), compute_stage: desc.compute_stage.to_wgpu(),
}; };
let action = DeviceAction::CreateComputePipeline(id, wgpu_desc); let implicit = match desc.layout {
Some(_) => None,
None => {
let implicit = ImplicitLayout::new(identities.select(backend), backend);
for (i, bgl_id) in implicit.bind_groups.iter().enumerate() {
*implicit_bind_group_layout_ids.add(i) = Some(*bgl_id);
}
Some(implicit)
}
};
let action = DeviceAction::CreateComputePipeline(id, wgpu_desc, implicit);
*bb = make_byte_buf(&action); *bb = make_byte_buf(&action);
id id
} }
@ -825,14 +866,11 @@ pub unsafe extern "C" fn wgpu_client_create_render_pipeline(
device_id: id::DeviceId, device_id: id::DeviceId,
desc: &RenderPipelineDescriptor, desc: &RenderPipelineDescriptor,
bb: &mut ByteBuf, bb: &mut ByteBuf,
implicit_bind_group_layout_ids: *mut Option<id::BindGroupLayoutId>,
) -> id::RenderPipelineId { ) -> id::RenderPipelineId {
let backend = device_id.backend(); let backend = device_id.backend();
let id = client let mut identities = client.identities.lock();
.identities let id = identities.select(backend).render_pipelines.alloc(backend);
.lock()
.select(backend)
.render_pipelines
.alloc(backend);
let wgpu_desc = wgc::pipeline::RenderPipelineDescriptor { let wgpu_desc = wgc::pipeline::RenderPipelineDescriptor {
label: cow_label(&desc.label), label: cow_label(&desc.label),
@ -875,7 +913,18 @@ pub unsafe extern "C" fn wgpu_client_create_render_pipeline(
alpha_to_coverage_enabled: desc.alpha_to_coverage_enabled, alpha_to_coverage_enabled: desc.alpha_to_coverage_enabled,
}; };
let action = DeviceAction::CreateRenderPipeline(id, wgpu_desc); let implicit = match desc.layout {
Some(_) => None,
None => {
let implicit = ImplicitLayout::new(identities.select(backend), backend);
for (i, bgl_id) in implicit.bind_groups.iter().enumerate() {
*implicit_bind_group_layout_ids.add(i) = Some(*bgl_id);
}
Some(implicit)
}
};
let action = DeviceAction::CreateRenderPipeline(id, wgpu_desc, implicit);
*bb = make_byte_buf(&action); *bb = make_byte_buf(&action);
id id
} }

View file

@ -28,6 +28,7 @@ impl<I: id::TypedId + Clone + std::fmt::Debug> wgc::hub::IdentityHandler<I>
} }
} }
//TODO: remove this in favor of `DropAction` that could be sent over IPC.
#[repr(C)] #[repr(C)]
pub struct IdentityRecyclerFactory { pub struct IdentityRecyclerFactory {
param: FactoryParam, param: FactoryParam,

View file

@ -12,7 +12,7 @@ pub mod server;
pub use wgc::device::trace::Command as CommandEncoderAction; pub use wgc::device::trace::Command as CommandEncoderAction;
use std::{borrow::Cow, slice}; use std::{borrow::Cow, mem, slice};
type RawString = *const std::os::raw::c_char; type RawString = *const std::os::raw::c_char;
@ -35,11 +35,35 @@ pub struct ByteBuf {
} }
impl ByteBuf { impl ByteBuf {
fn from_vec(vec: Vec<u8>) -> Self {
if vec.is_empty() {
ByteBuf {
data: std::ptr::null(),
len: 0,
capacity: 0,
}
} else {
let bb = ByteBuf {
data: vec.as_ptr(),
len: vec.len(),
capacity: vec.capacity(),
};
mem::forget(vec);
bb
}
}
unsafe fn as_slice(&self) -> &[u8] { unsafe fn as_slice(&self) -> &[u8] {
slice::from_raw_parts(self.data, self.len) slice::from_raw_parts(self.data, self.len)
} }
} }
#[derive(serde::Serialize, serde::Deserialize)]
struct ImplicitLayout<'a> {
pipeline: id::PipelineLayoutId,
bind_groups: Cow<'a, [id::BindGroupLayoutId]>,
}
#[derive(serde::Serialize, serde::Deserialize)] #[derive(serde::Serialize, serde::Deserialize)]
enum DeviceAction<'a> { enum DeviceAction<'a> {
CreateBuffer(id::BufferId, wgc::resource::BufferDescriptor<'a>), CreateBuffer(id::BufferId, wgc::resource::BufferDescriptor<'a>),
@ -54,14 +78,16 @@ enum DeviceAction<'a> {
wgc::binding_model::PipelineLayoutDescriptor<'a>, wgc::binding_model::PipelineLayoutDescriptor<'a>,
), ),
CreateBindGroup(id::BindGroupId, wgc::binding_model::BindGroupDescriptor<'a>), CreateBindGroup(id::BindGroupId, wgc::binding_model::BindGroupDescriptor<'a>),
CreateShaderModule(id::ShaderModuleId, Cow<'a, [u32]>), CreateShaderModule(id::ShaderModuleId, Cow<'a, [u32]>, Cow<'a, str>),
CreateComputePipeline( CreateComputePipeline(
id::ComputePipelineId, id::ComputePipelineId,
wgc::pipeline::ComputePipelineDescriptor<'a>, wgc::pipeline::ComputePipelineDescriptor<'a>,
Option<ImplicitLayout<'a>>,
), ),
CreateRenderPipeline( CreateRenderPipeline(
id::RenderPipelineId, id::RenderPipelineId,
wgc::pipeline::RenderPipelineDescriptor<'a>, wgc::pipeline::RenderPipelineDescriptor<'a>,
Option<ImplicitLayout<'a>>,
), ),
CreateRenderBundle( CreateRenderBundle(
id::RenderBundleId, id::RenderBundleId,
@ -78,3 +104,11 @@ enum DeviceAction<'a> {
enum TextureAction<'a> { enum TextureAction<'a> {
CreateView(id::TextureViewId, wgc::resource::TextureViewDescriptor<'a>), CreateView(id::TextureViewId, wgc::resource::TextureViewDescriptor<'a>),
} }
#[derive(serde::Serialize, serde::Deserialize)]
enum DropAction {
Buffer(id::BufferId),
Texture(id::TextureId),
Sampler(id::SamplerId),
BindGroupLayout(id::BindGroupLayoutId),
}

View file

@ -4,7 +4,7 @@
use crate::{ use crate::{
cow_label, identity::IdentityRecyclerFactory, ByteBuf, CommandEncoderAction, DeviceAction, cow_label, identity::IdentityRecyclerFactory, ByteBuf, CommandEncoderAction, DeviceAction,
RawString, TextureAction, DropAction, RawString, TextureAction,
}; };
use wgc::{gfx_select, id}; use wgc::{gfx_select, id};
@ -164,7 +164,11 @@ pub extern "C" fn wgpu_server_buffer_drop(global: &Global, self_id: id::BufferId
} }
trait GlobalExt { trait GlobalExt {
fn device_action<B: wgc::hub::GfxBackend>(&self, self_id: id::DeviceId, action: DeviceAction); fn device_action<B: wgc::hub::GfxBackend>(
&self,
self_id: id::DeviceId,
action: DeviceAction,
) -> Vec<u8>;
fn texture_action<B: wgc::hub::GfxBackend>( fn texture_action<B: wgc::hub::GfxBackend>(
&self, &self,
self_id: id::TextureId, self_id: id::TextureId,
@ -178,8 +182,12 @@ trait GlobalExt {
} }
impl GlobalExt for Global { impl GlobalExt for Global {
fn device_action<B: wgc::hub::GfxBackend>(&self, self_id: id::DeviceId, action: DeviceAction) { fn device_action<B: wgc::hub::GfxBackend>(
let implicit_ids = None; //TODO &self,
self_id: id::DeviceId,
action: DeviceAction,
) -> Vec<u8> {
let mut drop_actions = Vec::new();
match action { match action {
DeviceAction::CreateBuffer(id, desc) => { DeviceAction::CreateBuffer(id, desc) => {
self.device_create_buffer::<B>(self_id, &desc, id).unwrap(); self.device_create_buffer::<B>(self_id, &desc, id).unwrap();
@ -202,21 +210,54 @@ impl GlobalExt for Global {
self.device_create_bind_group::<B>(self_id, &desc, id) self.device_create_bind_group::<B>(self_id, &desc, id)
.unwrap(); .unwrap();
} }
DeviceAction::CreateShaderModule(id, spirv) => { DeviceAction::CreateShaderModule(id, spirv, wgsl) => {
self.device_create_shader_module::<B>( let source = if spirv.is_empty() {
self_id, wgc::pipeline::ShaderModuleSource::Wgsl(wgsl)
wgc::pipeline::ShaderModuleSource::SpirV(spirv), } else {
id, wgc::pipeline::ShaderModuleSource::SpirV(spirv)
};
self.device_create_shader_module::<B>(self_id, source, id)
.unwrap();
}
DeviceAction::CreateComputePipeline(id, desc, implicit) => {
let implicit_ids = implicit
.as_ref()
.map(|imp| wgc::device::ImplicitPipelineIds {
root_id: imp.pipeline,
group_ids: &imp.bind_groups,
});
let (_, group_count) = self
.device_create_compute_pipeline::<B>(self_id, &desc, id, implicit_ids)
.unwrap();
if let Some(ref imp) = implicit {
for &bgl_id in imp.bind_groups[group_count as usize..].iter() {
bincode::serialize_into(
&mut drop_actions,
&DropAction::BindGroupLayout(bgl_id),
) )
.unwrap(); .unwrap();
} }
DeviceAction::CreateComputePipeline(id, desc) => { }
self.device_create_compute_pipeline::<B>(self_id, &desc, id, implicit_ids) }
DeviceAction::CreateRenderPipeline(id, desc, implicit) => {
let implicit_ids = implicit
.as_ref()
.map(|imp| wgc::device::ImplicitPipelineIds {
root_id: imp.pipeline,
group_ids: &imp.bind_groups,
});
let (_, group_count) = self
.device_create_render_pipeline::<B>(self_id, &desc, id, implicit_ids)
.unwrap();
if let Some(ref imp) = implicit {
for &bgl_id in imp.bind_groups[group_count as usize..].iter() {
bincode::serialize_into(
&mut drop_actions,
&DropAction::BindGroupLayout(bgl_id),
)
.unwrap(); .unwrap();
} }
DeviceAction::CreateRenderPipeline(id, desc) => { }
self.device_create_render_pipeline::<B>(self_id, &desc, id, implicit_ids)
.unwrap();
} }
DeviceAction::CreateRenderBundle(_id, desc, _base) => { DeviceAction::CreateRenderBundle(_id, desc, _base) => {
wgc::command::RenderBundleEncoder::new(&desc, self_id, None).unwrap(); wgc::command::RenderBundleEncoder::new(&desc, self_id, None).unwrap();
@ -226,6 +267,7 @@ impl GlobalExt for Global {
.unwrap(); .unwrap();
} }
} }
drop_actions
} }
fn texture_action<B: wgc::hub::GfxBackend>( fn texture_action<B: wgc::hub::GfxBackend>(
@ -292,9 +334,11 @@ pub unsafe extern "C" fn wgpu_server_device_action(
global: &Global, global: &Global,
self_id: id::DeviceId, self_id: id::DeviceId,
byte_buf: &ByteBuf, byte_buf: &ByteBuf,
drop_byte_buf: &mut ByteBuf,
) { ) {
let action = bincode::deserialize(byte_buf.as_slice()).unwrap(); let action = bincode::deserialize(byte_buf.as_slice()).unwrap();
gfx_select!(self_id => global.device_action(self_id, action)); let drop_actions = gfx_select!(self_id => global.device_action(self_id, action));
*drop_byte_buf = ByteBuf::from_vec(drop_actions);
} }
#[no_mangle] #[no_mangle]
@ -465,3 +509,21 @@ pub extern "C" fn wgpu_server_texture_view_drop(global: &Global, self_id: id::Te
pub extern "C" fn wgpu_server_sampler_drop(global: &Global, self_id: id::SamplerId) { pub extern "C" fn wgpu_server_sampler_drop(global: &Global, self_id: id::SamplerId) {
gfx_select!(self_id => global.sampler_drop(self_id)); gfx_select!(self_id => global.sampler_drop(self_id));
} }
#[no_mangle]
pub extern "C" fn wgpu_server_compute_pipeline_get_bind_group_layout(
global: &Global,
self_id: id::ComputePipelineId,
index: u32,
) -> id::BindGroupLayoutId {
gfx_select!(self_id => global.compute_pipeline_get_bind_group_layout(self_id, index)).unwrap()
}
#[no_mangle]
pub extern "C" fn wgpu_server_render_pipeline_get_bind_group_layout(
global: &Global,
self_id: id::RenderPipelineId,
index: u32,
) -> id::BindGroupLayoutId {
gfx_select!(self_id => global.render_pipeline_get_bind_group_layout(self_id, index)).unwrap()
}

View file

@ -9,6 +9,9 @@
// Prelude of types necessary before including wgpu_ffi_generated.h // Prelude of types necessary before including wgpu_ffi_generated.h
namespace mozilla { namespace mozilla {
namespace ipc {
class ByteBuf;
} // namespace ipc
namespace webgpu { namespace webgpu {
namespace ffi { namespace ffi {
@ -23,6 +26,14 @@ extern "C" {
#undef WGPU_FUNC #undef WGPU_FUNC
} // namespace ffi } // namespace ffi
inline ffi::WGPUByteBuf* ToFFI(ipc::ByteBuf* x) {
return reinterpret_cast<ffi::WGPUByteBuf*>(x);
}
inline const ffi::WGPUByteBuf* ToFFI(const ipc::ByteBuf* x) {
return reinterpret_cast<const ffi::WGPUByteBuf*>(x);
}
} // namespace webgpu } // namespace webgpu
} // namespace mozilla } // namespace mozilla

File diff suppressed because one or more lines are too long

View file

@ -16,7 +16,7 @@ log = "0.4"
num-traits = "0.2" num-traits = "0.2"
spirv = { package = "spirv_headers", version = "1.4.2", optional = true } spirv = { package = "spirv_headers", version = "1.4.2", optional = true }
pomelo = { version = "0.1.4", optional = true } pomelo = { version = "0.1.4", optional = true }
thiserror = "1.0" thiserror = "1.0.21"
serde = { version = "1.0", features = ["derive"], optional = true } serde = { version = "1.0", features = ["derive"], optional = true }
petgraph = { version ="0.5", optional = true } petgraph = { version ="0.5", optional = true }

View file

@ -22,7 +22,7 @@ SPIR-V (binary) | :construction: | |
WGSL | | | WGSL | | |
Metal | :construction: | | Metal | :construction: | |
HLSL | | | HLSL | | |
GLSL | | | GLSL | :construction: | |
AIR | | | AIR | | |
DXIR | | | DXIR | | |
DXIL | | | DXIL | | |

View file

@ -1,8 +1,16 @@
use serde::{Deserialize, Serialize}; use serde::{Deserialize, Serialize};
use std::{env, fs, path::Path}; use std::{env, fs, path::Path};
#[derive(Hash, PartialEq, Eq, Serialize, Deserialize)]
enum Stage {
Vertex,
Fragment,
Compute,
}
#[derive(Hash, PartialEq, Eq, Serialize, Deserialize)] #[derive(Hash, PartialEq, Eq, Serialize, Deserialize)]
struct BindSource { struct BindSource {
stage: Stage,
group: u32, group: u32,
binding: u32, binding: u32,
} }
@ -21,6 +29,8 @@ struct BindTarget {
#[derive(Default, Serialize, Deserialize)] #[derive(Default, Serialize, Deserialize)]
struct Parameters { struct Parameters {
#[serde(default)]
spv_flow_dump_prefix: String,
metal_bindings: naga::FastHashMap<BindSource, BindTarget>, metal_bindings: naga::FastHashMap<BindSource, BindTarget>,
} }
@ -33,6 +43,13 @@ fn main() {
println!("Call with <input> <output>"); println!("Call with <input> <output>");
return; return;
} }
let param_path = std::path::PathBuf::from(&args[1]).with_extension("param.ron");
let params = match fs::read_to_string(param_path) {
Ok(string) => ron::de::from_str(&string).unwrap(),
Err(_) => Parameters::default(),
};
let module = match Path::new(&args[1]) let module = match Path::new(&args[1])
.extension() .extension()
.expect("Input has no extension?") .expect("Input has no extension?")
@ -41,8 +58,15 @@ fn main() {
{ {
#[cfg(feature = "spv-in")] #[cfg(feature = "spv-in")]
"spv" => { "spv" => {
let options = naga::front::spv::Options {
flow_graph_dump_prefix: if params.spv_flow_dump_prefix.is_empty() {
None
} else {
Some(params.spv_flow_dump_prefix.into())
},
};
let input = fs::read(&args[1]).unwrap(); let input = fs::read(&args[1]).unwrap();
naga::front::spv::parse_u8_slice(&input).unwrap() naga::front::spv::parse_u8_slice(&input, &options).unwrap()
} }
#[cfg(feature = "wgsl-in")] #[cfg(feature = "wgsl-in")]
"wgsl" => { "wgsl" => {
@ -52,17 +76,35 @@ fn main() {
#[cfg(feature = "glsl-in")] #[cfg(feature = "glsl-in")]
"vert" => { "vert" => {
let input = fs::read_to_string(&args[1]).unwrap(); let input = fs::read_to_string(&args[1]).unwrap();
naga::front::glsl::parse_str(&input, "main", naga::ShaderStage::Vertex).unwrap() naga::front::glsl::parse_str(
&input,
"main",
naga::ShaderStage::Vertex,
Default::default(),
)
.unwrap()
} }
#[cfg(feature = "glsl-in")] #[cfg(feature = "glsl-in")]
"frag" => { "frag" => {
let input = fs::read_to_string(&args[1]).unwrap(); let input = fs::read_to_string(&args[1]).unwrap();
naga::front::glsl::parse_str(&input, "main", naga::ShaderStage::Fragment).unwrap() naga::front::glsl::parse_str(
&input,
"main",
naga::ShaderStage::Fragment,
Default::default(),
)
.unwrap()
} }
#[cfg(feature = "glsl-in")] #[cfg(feature = "glsl-in")]
"comp" => { "comp" => {
let input = fs::read_to_string(&args[1]).unwrap(); let input = fs::read_to_string(&args[1]).unwrap();
naga::front::glsl::parse_str(&input, "main", naga::ShaderStage::Compute).unwrap() naga::front::glsl::parse_str(
&input,
"main",
naga::ShaderStage::Compute,
Default::default(),
)
.unwrap()
} }
#[cfg(feature = "deserialize")] #[cfg(feature = "deserialize")]
"ron" => { "ron" => {
@ -83,12 +125,6 @@ fn main() {
return; return;
} }
let param_path = std::path::PathBuf::from(&args[1]).with_extension("param.ron");
let params = match fs::read_to_string(param_path) {
Ok(string) => ron::de::from_str(&string).unwrap(),
Err(_) => Parameters::default(),
};
match Path::new(&args[2]) match Path::new(&args[2])
.extension() .extension()
.expect("Output has no extension?") .expect("Output has no extension?")
@ -102,6 +138,11 @@ fn main() {
for (key, value) in params.metal_bindings { for (key, value) in params.metal_bindings {
binding_map.insert( binding_map.insert(
msl::BindSource { msl::BindSource {
stage: match key.stage {
Stage::Vertex => naga::ShaderStage::Vertex,
Stage::Fragment => naga::ShaderStage::Fragment,
Stage::Compute => naga::ShaderStage::Compute,
},
group: key.group, group: key.group,
binding: key.binding, binding: key.binding,
}, },
@ -114,9 +155,11 @@ fn main() {
); );
} }
let options = msl::Options { let options = msl::Options {
binding_map: &binding_map, lang_version: (1, 0),
spirv_cross_compatibility: false,
binding_map,
}; };
let msl = msl::write_string(&module, options).unwrap(); let msl = msl::write_string(&module, &options).unwrap();
fs::write(&args[2], msl).unwrap(); fs::write(&args[2], msl).unwrap();
} }
#[cfg(feature = "spv-out")] #[cfg(feature = "spv-out")]
@ -198,7 +241,10 @@ fn main() {
} }
other => { other => {
let _ = params; let _ = params;
panic!("Unknown output extension: {}", other); panic!(
"Unknown output extension: {}, forgot to enable a feature?",
other
);
} }
} }
} }

View file

@ -1,11 +1,8 @@
use std::{cmp::Ordering, fmt, hash, marker::PhantomData, num::NonZeroU32}; use std::{cmp::Ordering, fmt, hash, marker::PhantomData, num::NonZeroU32};
/// An unique index in the arena array that a handle points to. /// An unique index in the arena array that a handle points to.
/// /// The "non-zero" part ensures that an `Option<Handle<T>>` has
/// This type is independent of `spv::Word`. `spv::Word` is used in data /// the same size and representation as `Handle<T>`.
/// representation. It holds a SPIR-V and refers to that instruction. In
/// structured representation, we use Handle to refer to an SPIR-V instruction.
/// `Index` is an implementation detail to `Handle`.
type Index = NonZeroU32; type Index = NonZeroU32;
/// A strongly typed reference to a SPIR-V element. /// A strongly typed reference to a SPIR-V element.
@ -168,6 +165,10 @@ impl<T> Arena<T> {
self.fetch_if_or_append(value, T::eq) self.fetch_if_or_append(value, T::eq)
} }
pub fn try_get(&self, handle: Handle<T>) -> Option<&T> {
self.data.get(handle.index.get() as usize - 1)
}
/// Get a mutable reference to an element in the arena. /// Get a mutable reference to an element in the arena.
pub fn get_mut(&mut self, handle: Handle<T>) -> &mut T { pub fn get_mut(&mut self, handle: Handle<T>) -> &mut T {
self.data.get_mut(handle.index.get() as usize - 1).unwrap() self.data.get_mut(handle.index.get() as usize - 1).unwrap()

View file

@ -112,6 +112,7 @@ bitflags::bitflags! {
const IMAGE_LOAD_STORE = 1 << 8; const IMAGE_LOAD_STORE = 1 << 8;
const CONSERVATIVE_DEPTH = 1 << 9; const CONSERVATIVE_DEPTH = 1 << 9;
const TEXTURE_1D = 1 << 10; const TEXTURE_1D = 1 << 10;
const PUSH_CONSTANT = 1 << 11;
} }
} }
@ -364,7 +365,7 @@ pub fn write<'a>(
} }
let block = match global.class { let block = match global.class {
StorageClass::StorageBuffer | StorageClass::Uniform => true, StorageClass::Storage | StorageClass::Uniform => true,
_ => false, _ => false,
}; };
@ -409,14 +410,28 @@ pub fn write<'a>(
&mut buf, &mut buf,
"{} {}({});", "{} {}({});",
func.return_type func.return_type
.map(|ty| write_type(ty, &module.types, &structs, None, &mut manager)) .map(|ty| write_type(
ty,
&module.types,
&module.constants,
&structs,
None,
&mut manager
))
.transpose()? .transpose()?
.as_deref() .as_deref()
.unwrap_or("void"), .unwrap_or("void"),
name, name,
func.parameter_types func.arguments
.iter() .iter()
.map(|ty| write_type(*ty, &module.types, &structs, None, &mut manager)) .map(|arg| write_type(
arg.ty,
&module.types,
&module.constants,
&structs,
None,
&mut manager
))
.collect::<Result<Vec<_>, _>>()? .collect::<Result<Vec<_>, _>>()?
.join(","), .join(","),
)?; )?;
@ -557,14 +572,15 @@ pub fn write<'a>(
let name = if let Some(ref binding) = global.binding { let name = if let Some(ref binding) = global.binding {
let prefix = match global.class { let prefix = match global.class {
StorageClass::Constant => "const",
StorageClass::Function => "fn", StorageClass::Function => "fn",
StorageClass::Input => "in", StorageClass::Input => "in",
StorageClass::Output => "out", StorageClass::Output => "out",
StorageClass::Private => "priv", StorageClass::Private => "priv",
StorageClass::StorageBuffer => "buffer", StorageClass::Storage => "buffer",
StorageClass::Uniform => "uniform", StorageClass::Uniform => "uniform",
StorageClass::Handle => "handle",
StorageClass::WorkGroup => "wg", StorageClass::WorkGroup => "wg",
StorageClass::PushConstant => "pc",
}; };
match binding { match binding {
@ -606,7 +622,7 @@ pub fn write<'a>(
} }
let block = match global.class { let block = match global.class {
StorageClass::StorageBuffer | StorageClass::Uniform => { StorageClass::Storage | StorageClass::Uniform => {
Some(format!("global_block_{}", handle.index())) Some(format!("global_block_{}", handle.index()))
} }
_ => None, _ => None,
@ -616,7 +632,14 @@ pub fn write<'a>(
&mut buf, &mut buf,
"{}{} {};", "{}{} {};",
write_storage_class(global.class, &mut manager)?, write_storage_class(global.class, &mut manager)?,
write_type(global.ty, &module.types, &structs, block, &mut manager)?, write_type(
global.ty,
&module.types,
&module.constants,
&structs,
block,
&mut manager
)?,
name name
)?; )?;
@ -635,33 +658,53 @@ pub fn write<'a>(
global_vars: &module.global_variables, global_vars: &module.global_variables,
local_vars: &func.local_variables, local_vars: &func.local_variables,
functions: &module.functions, functions: &module.functions,
parameter_types: &func.parameter_types, arguments: &func.arguments,
}, },
)?; )?;
let args: FastHashMap<_, _> = func let args: FastHashMap<_, _> = func
.parameter_types .arguments
.iter() .iter()
.enumerate() .enumerate()
.map(|(pos, _)| (pos as u32, format!("arg_{}", pos))) .map(|(pos, arg)| {
let name = arg
.name
.clone()
.filter(|ident| is_valid_ident(ident))
.unwrap_or_else(|| format!("arg_{}", pos + 1));
(pos as u32, name)
})
.collect(); .collect();
writeln!( writeln!(
&mut buf, &mut buf,
"{} {}({}) {{", "{} {}({}) {{",
func.return_type func.return_type
.map(|ty| write_type(ty, &module.types, &structs, None, &mut manager)) .map(|ty| write_type(
ty,
&module.types,
&module.constants,
&structs,
None,
&mut manager
))
.transpose()? .transpose()?
.as_deref() .as_deref()
.unwrap_or("void"), .unwrap_or("void"),
name, name,
func.parameter_types func.arguments
.iter() .iter()
.zip(args.values()) .enumerate()
.map::<Result<_, Error>, _>(|(ty, name)| { .map::<Result<_, Error>, _>(|(pos, arg)| {
let ty = write_type(*ty, &module.types, &structs, None, &mut manager)?; let ty = write_type(
arg.ty,
Ok(format!("{} {}", ty, name)) &module.types,
&module.constants,
&structs,
None,
&mut manager,
)?;
Ok(format!("{} {}", ty, args[&(pos as u32)]))
}) })
.collect::<Result<Vec<_>, _>>()? .collect::<Result<Vec<_>, _>>()?
.join(","), .join(","),
@ -682,23 +725,6 @@ pub fn write<'a>(
}) })
.collect(); .collect();
for (handle, name) in locals.iter() {
writeln!(
&mut buf,
"\t{} {};",
write_type(
func.local_variables[*handle].ty,
&module.types,
&structs,
None,
&mut manager
)?,
name
)?;
}
writeln!(&mut buf)?;
let mut builder = StatementBuilder { let mut builder = StatementBuilder {
functions: &functions, functions: &functions,
globals: &globals_lookup, globals: &globals_lookup,
@ -707,14 +733,40 @@ pub fn write<'a>(
args: &args, args: &args,
expressions: &func.expressions, expressions: &func.expressions,
typifier: &typifier, typifier: &typifier,
manager: &mut manager,
}; };
for (handle, name) in locals.iter() {
let var = &func.local_variables[*handle];
write!(
&mut buf,
"\t{} {}",
write_type(
var.ty,
&module.types,
&module.constants,
&structs,
None,
&mut manager
)?,
name
)?;
if let Some(init) = var.init {
write!(
&mut buf,
" = {}",
write_constant(&module.constants[init], module, &mut builder, &mut manager)?
)?;
}
writeln!(&mut buf, ";")?;
}
writeln!(&mut buf)?;
for sta in func.body.iter() { for sta in func.body.iter() {
writeln!( writeln!(
&mut buf, &mut buf,
"{}", "{}",
write_statement(sta, module, &mut builder, 1)? write_statement(sta, module, &mut builder, &mut manager, 1)?
)?; )?;
} }
@ -746,19 +798,19 @@ struct StatementBuilder<'a> {
args: &'a FastHashMap<u32, String>, args: &'a FastHashMap<u32, String>,
expressions: &'a Arena<Expression>, expressions: &'a Arena<Expression>,
typifier: &'a Typifier, typifier: &'a Typifier,
pub manager: &'a mut FeaturesManager,
} }
fn write_statement<'a, 'b>( fn write_statement<'a, 'b>(
sta: &Statement, sta: &Statement,
module: &'a Module, module: &'a Module,
builder: &'b mut StatementBuilder<'a>, builder: &'b mut StatementBuilder<'a>,
manager: &mut FeaturesManager,
indent: usize, indent: usize,
) -> Result<String, Error> { ) -> Result<String, Error> {
Ok(match sta { Ok(match sta {
Statement::Block(block) => block Statement::Block(block) => block
.iter() .iter()
.map(|sta| write_statement(sta, module, builder, indent)) .map(|sta| write_statement(sta, module, builder, manager, indent))
.collect::<Result<Vec<_>, _>>()? .collect::<Result<Vec<_>, _>>()?
.join("\n"), .join("\n"),
Statement::If { Statement::If {
@ -772,14 +824,14 @@ fn write_statement<'a, 'b>(
&mut out, &mut out,
"{}if({}) {{", "{}if({}) {{",
"\t".repeat(indent), "\t".repeat(indent),
write_expression(&builder.expressions[*condition], module, builder)? write_expression(&builder.expressions[*condition], module, builder, manager)?
)?; )?;
for sta in accept { for sta in accept {
writeln!( writeln!(
&mut out, &mut out,
"{}", "{}",
write_statement(sta, module, builder, indent + 1)? write_statement(sta, module, builder, manager, indent + 1)?
)?; )?;
} }
@ -789,7 +841,7 @@ fn write_statement<'a, 'b>(
writeln!( writeln!(
&mut out, &mut out,
"{}", "{}",
write_statement(sta, module, builder, indent + 1)? write_statement(sta, module, builder, manager, indent + 1)?
)?; )?;
} }
} }
@ -809,7 +861,7 @@ fn write_statement<'a, 'b>(
&mut out, &mut out,
"{}switch({}) {{", "{}switch({}) {{",
"\t".repeat(indent), "\t".repeat(indent),
write_expression(&builder.expressions[*selector], module, builder)? write_expression(&builder.expressions[*selector], module, builder, manager)?
)?; )?;
for (label, (block, fallthrough)) in cases { for (label, (block, fallthrough)) in cases {
@ -819,7 +871,7 @@ fn write_statement<'a, 'b>(
writeln!( writeln!(
&mut out, &mut out,
"{}", "{}",
write_statement(sta, module, builder, indent + 2)? write_statement(sta, module, builder, manager, indent + 2)?
)?; )?;
} }
@ -835,7 +887,7 @@ fn write_statement<'a, 'b>(
writeln!( writeln!(
&mut out, &mut out,
"{}", "{}",
write_statement(sta, module, builder, indent + 2)? write_statement(sta, module, builder, manager, indent + 2)?
)?; )?;
} }
} }
@ -853,7 +905,7 @@ fn write_statement<'a, 'b>(
writeln!( writeln!(
&mut out, &mut out,
"{}", "{}",
write_statement(sta, module, builder, indent + 1)? write_statement(sta, module, builder, manager, indent + 1)?
)?; )?;
} }
@ -869,7 +921,7 @@ fn write_statement<'a, 'b>(
if let Some(expr) = value { if let Some(expr) = value {
format!( format!(
"return {};", "return {};",
write_expression(&builder.expressions[*expr], module, builder)? write_expression(&builder.expressions[*expr], module, builder, manager)?
) )
} else { } else {
String::from("return;") String::from("return;")
@ -879,8 +931,8 @@ fn write_statement<'a, 'b>(
Statement::Store { pointer, value } => format!( Statement::Store { pointer, value } => format!(
"{}{} = {};", "{}{} = {};",
"\t".repeat(indent), "\t".repeat(indent),
write_expression(&builder.expressions[*pointer], module, builder)?, write_expression(&builder.expressions[*pointer], module, builder, manager)?,
write_expression(&builder.expressions[*value], module, builder)? write_expression(&builder.expressions[*value], module, builder, manager)?
), ),
}) })
} }
@ -889,18 +941,19 @@ fn write_expression<'a, 'b>(
expr: &Expression, expr: &Expression,
module: &'a Module, module: &'a Module,
builder: &'b mut StatementBuilder<'a>, builder: &'b mut StatementBuilder<'a>,
manager: &mut FeaturesManager,
) -> Result<Cow<'a, str>, Error> { ) -> Result<Cow<'a, str>, Error> {
Ok(match *expr { Ok(match *expr {
Expression::Access { base, index } => { Expression::Access { base, index } => {
let base_expr = write_expression(&builder.expressions[base], module, builder)?; let base_expr = write_expression(&builder.expressions[base], module, builder, manager)?;
Cow::Owned(format!( Cow::Owned(format!(
"{}[{}]", "{}[{}]",
base_expr, base_expr,
write_expression(&builder.expressions[index], module, builder)? write_expression(&builder.expressions[index], module, builder, manager)?
)) ))
} }
Expression::AccessIndex { base, index } => { Expression::AccessIndex { base, index } => {
let base_expr = write_expression(&builder.expressions[base], module, builder)?; let base_expr = write_expression(&builder.expressions[base], module, builder, manager)?;
match *builder.typifier.get(base, &module.types) { match *builder.typifier.get(base, &module.types) {
TypeInner::Vector { .. } => Cow::Owned(format!("{}[{}]", base_expr, index)), TypeInner::Vector { .. } => Cow::Owned(format!("{}[{}]", base_expr, index)),
@ -929,12 +982,13 @@ fn write_expression<'a, 'b>(
&module.constants[constant], &module.constants[constant],
module, module,
builder, builder,
manager,
)?), )?),
Expression::Compose { ty, ref components } => { Expression::Compose { ty, ref components } => {
let constructor = match module.types[ty].inner { let constructor = match module.types[ty].inner {
TypeInner::Vector { size, kind, width } => format!( TypeInner::Vector { size, kind, width } => format!(
"{}vec{}", "{}vec{}",
map_scalar(kind, width, builder.manager)?.prefix, map_scalar(kind, width, manager)?.prefix,
size as u8, size as u8,
), ),
TypeInner::Matrix { TypeInner::Matrix {
@ -943,19 +997,31 @@ fn write_expression<'a, 'b>(
width, width,
} => format!( } => format!(
"{}mat{}x{}", "{}mat{}x{}",
map_scalar(crate::ScalarKind::Float, width, builder.manager)?.prefix, map_scalar(crate::ScalarKind::Float, width, manager)?.prefix,
columns as u8, columns as u8,
rows as u8, rows as u8,
), ),
TypeInner::Array { .. } => { TypeInner::Array { .. } => write_type(
write_type(ty, &module.types, builder.structs, None, builder.manager)? ty,
.into_owned() &module.types,
} &module.constants,
builder.structs,
None,
manager,
)?
.into_owned(),
TypeInner::Struct { .. } => builder.structs.get(&ty).unwrap().clone(), TypeInner::Struct { .. } => builder.structs.get(&ty).unwrap().clone(),
_ => { _ => {
return Err(Error::Custom(format!( return Err(Error::Custom(format!(
"Cannot compose type {}", "Cannot compose type {}",
write_type(ty, &module.types, builder.structs, None, builder.manager)? write_type(
ty,
&module.types,
&module.constants,
builder.structs,
None,
manager
)?
))) )))
} }
}; };
@ -968,19 +1034,20 @@ fn write_expression<'a, 'b>(
.map::<Result<_, Error>, _>(|arg| Ok(write_expression( .map::<Result<_, Error>, _>(|arg| Ok(write_expression(
&builder.expressions[*arg], &builder.expressions[*arg],
module, module,
builder builder,
manager,
)?)) )?))
.collect::<Result<Vec<_>, _>>()? .collect::<Result<Vec<_>, _>>()?
.join(","), .join(","),
)) ))
} }
Expression::FunctionParameter(pos) => Cow::Borrowed(builder.args.get(&pos).unwrap()), Expression::FunctionArgument(pos) => Cow::Borrowed(builder.args.get(&pos).unwrap()),
Expression::GlobalVariable(handle) => Cow::Borrowed(builder.globals.get(&handle).unwrap()), Expression::GlobalVariable(handle) => Cow::Borrowed(builder.globals.get(&handle).unwrap()),
Expression::LocalVariable(handle) => { Expression::LocalVariable(handle) => {
Cow::Borrowed(builder.locals_lookup.get(&handle).unwrap()) Cow::Borrowed(builder.locals_lookup.get(&handle).unwrap())
} }
Expression::Load { pointer } => { Expression::Load { pointer } => {
write_expression(&builder.expressions[pointer], module, builder)? write_expression(&builder.expressions[pointer], module, builder, manager)?
} }
Expression::ImageSample { Expression::ImageSample {
image, image,
@ -989,10 +1056,11 @@ fn write_expression<'a, 'b>(
level, level,
depth_ref, depth_ref,
} => { } => {
let image_expr = write_expression(&builder.expressions[image], module, builder)?; let image_expr =
write_expression(&builder.expressions[sampler], module, builder)?; write_expression(&builder.expressions[image], module, builder, manager)?;
write_expression(&builder.expressions[sampler], module, builder, manager)?;
let coordinate_expr = let coordinate_expr =
write_expression(&builder.expressions[coordinate], module, builder)?; write_expression(&builder.expressions[coordinate], module, builder, manager)?;
let size = match *builder.typifier.get(coordinate, &module.types) { let size = match *builder.typifier.get(coordinate, &module.types) {
TypeInner::Vector { size, .. } => size, TypeInner::Vector { size, .. } => size,
@ -1009,7 +1077,7 @@ fn write_expression<'a, 'b>(
"vec{}({},{})", "vec{}({},{})",
size as u8 + 1, size as u8 + 1,
coordinate_expr, coordinate_expr,
write_expression(&builder.expressions[depth_ref], module, builder)? write_expression(&builder.expressions[depth_ref], module, builder, manager)?
)) ))
} else { } else {
coordinate_expr coordinate_expr
@ -1022,14 +1090,16 @@ fn write_expression<'a, 'b>(
format!("textureLod({},{},0)", image_expr, coordinate_expr) format!("textureLod({},{},0)", image_expr, coordinate_expr)
} }
crate::SampleLevel::Exact(expr) => { crate::SampleLevel::Exact(expr) => {
let level_expr = write_expression(&builder.expressions[expr], module, builder)?; let level_expr =
write_expression(&builder.expressions[expr], module, builder, manager)?;
format!( format!(
"textureLod({}, {}, {})", "textureLod({}, {}, {})",
image_expr, coordinate_expr, level_expr image_expr, coordinate_expr, level_expr
) )
} }
crate::SampleLevel::Bias(bias) => { crate::SampleLevel::Bias(bias) => {
let bias_expr = write_expression(&builder.expressions[bias], module, builder)?; let bias_expr =
write_expression(&builder.expressions[bias], module, builder, manager)?;
format!("texture({},{},{})", image_expr, coordinate_expr, bias_expr) format!("texture({},{},{})", image_expr, coordinate_expr, bias_expr)
} }
}) })
@ -1039,9 +1109,10 @@ fn write_expression<'a, 'b>(
coordinate, coordinate,
index, index,
} => { } => {
let image_expr = write_expression(&builder.expressions[image], module, builder)?; let image_expr =
write_expression(&builder.expressions[image], module, builder, manager)?;
let coordinate_expr = let coordinate_expr =
write_expression(&builder.expressions[coordinate], module, builder)?; write_expression(&builder.expressions[coordinate], module, builder, manager)?;
let (dim, arrayed, class) = match *builder.typifier.get(image, &module.types) { let (dim, arrayed, class) = match *builder.typifier.get(image, &module.types) {
TypeInner::Image { TypeInner::Image {
@ -1057,15 +1128,19 @@ fn write_expression<'a, 'b>(
//TODO: fix this //TODO: fix this
let sampler_constructor = format!( let sampler_constructor = format!(
"{}sampler{}{}{}({})", "{}sampler{}{}{}({})",
map_scalar(kind, 4, builder.manager)?.prefix, map_scalar(kind, 4, manager)?.prefix,
ImageDimension(dim), ImageDimension(dim),
if multi { "MS" } else { "" }, if multi { "MS" } else { "" },
if arrayed { "Array" } else { "" }, if arrayed { "Array" } else { "" },
image_expr, image_expr,
); );
let index_expr = let index_expr = write_expression(
write_expression(&builder.expressions[index.unwrap()], module, builder)?; &builder.expressions[index.unwrap()],
module,
builder,
manager,
)?;
format!( format!(
"texelFetch({},{},{})", "texelFetch({},{},{})",
sampler_constructor, coordinate_expr, index_expr sampler_constructor, coordinate_expr, index_expr
@ -1076,7 +1151,7 @@ fn write_expression<'a, 'b>(
}) })
} }
Expression::Unary { op, expr } => { Expression::Unary { op, expr } => {
let base_expr = write_expression(&builder.expressions[expr], module, builder)?; let base_expr = write_expression(&builder.expressions[expr], module, builder, manager)?;
Cow::Owned(format!( Cow::Owned(format!(
"({} {})", "({} {})",
@ -1106,8 +1181,9 @@ fn write_expression<'a, 'b>(
)) ))
} }
Expression::Binary { op, left, right } => { Expression::Binary { op, left, right } => {
let left_expr = write_expression(&builder.expressions[left], module, builder)?; let left_expr = write_expression(&builder.expressions[left], module, builder, manager)?;
let right_expr = write_expression(&builder.expressions[right], module, builder)?; let right_expr =
write_expression(&builder.expressions[right], module, builder, manager)?;
let op_str = match op { let op_str = match op {
BinaryOperator::Add => "+", BinaryOperator::Add => "+",
@ -1126,15 +1202,30 @@ fn write_expression<'a, 'b>(
BinaryOperator::InclusiveOr => "|", BinaryOperator::InclusiveOr => "|",
BinaryOperator::LogicalAnd => "&&", BinaryOperator::LogicalAnd => "&&",
BinaryOperator::LogicalOr => "||", BinaryOperator::LogicalOr => "||",
BinaryOperator::ShiftLeftLogical => "<<", BinaryOperator::ShiftLeft => "<<",
BinaryOperator::ShiftRightLogical => todo!(), BinaryOperator::ShiftRight => ">>",
BinaryOperator::ShiftRightArithmetic => ">>",
}; };
Cow::Owned(format!("({} {} {})", left_expr, op_str, right_expr)) Cow::Owned(format!("({} {} {})", left_expr, op_str, right_expr))
} }
Expression::Select {
condition,
accept,
reject,
} => {
let cond_expr =
write_expression(&builder.expressions[condition], module, builder, manager)?;
let accept_expr =
write_expression(&builder.expressions[accept], module, builder, manager)?;
let reject_expr =
write_expression(&builder.expressions[reject], module, builder, manager)?;
Cow::Owned(format!(
"({} ? {} : {})",
cond_expr, accept_expr, reject_expr
))
}
Expression::Intrinsic { fun, argument } => { Expression::Intrinsic { fun, argument } => {
let expr = write_expression(&builder.expressions[argument], module, builder)?; let expr = write_expression(&builder.expressions[argument], module, builder, manager)?;
Cow::Owned(format!( Cow::Owned(format!(
"{:?}({})", "{:?}({})",
@ -1150,17 +1241,20 @@ fn write_expression<'a, 'b>(
)) ))
} }
Expression::Transpose(matrix) => { Expression::Transpose(matrix) => {
let matrix_expr = write_expression(&builder.expressions[matrix], module, builder)?; let matrix_expr =
write_expression(&builder.expressions[matrix], module, builder, manager)?;
Cow::Owned(format!("transpose({})", matrix_expr)) Cow::Owned(format!("transpose({})", matrix_expr))
} }
Expression::DotProduct(left, right) => { Expression::DotProduct(left, right) => {
let left_expr = write_expression(&builder.expressions[left], module, builder)?; let left_expr = write_expression(&builder.expressions[left], module, builder, manager)?;
let right_expr = write_expression(&builder.expressions[right], module, builder)?; let right_expr =
write_expression(&builder.expressions[right], module, builder, manager)?;
Cow::Owned(format!("dot({},{})", left_expr, right_expr)) Cow::Owned(format!("dot({},{})", left_expr, right_expr))
} }
Expression::CrossProduct(left, right) => { Expression::CrossProduct(left, right) => {
let left_expr = write_expression(&builder.expressions[left], module, builder)?; let left_expr = write_expression(&builder.expressions[left], module, builder, manager)?;
let right_expr = write_expression(&builder.expressions[right], module, builder)?; let right_expr =
write_expression(&builder.expressions[right], module, builder, manager)?;
Cow::Owned(format!("cross({},{})", left_expr, right_expr)) Cow::Owned(format!("cross({},{})", left_expr, right_expr))
} }
Expression::As { Expression::As {
@ -1168,7 +1262,8 @@ fn write_expression<'a, 'b>(
kind, kind,
convert, convert,
} => { } => {
let value_expr = write_expression(&builder.expressions[expr], module, builder)?; let value_expr =
write_expression(&builder.expressions[expr], module, builder, manager)?;
let (source_kind, ty_expr) = match *builder.typifier.get(expr, &module.types) { let (source_kind, ty_expr) = match *builder.typifier.get(expr, &module.types) {
TypeInner::Scalar { TypeInner::Scalar {
@ -1176,7 +1271,7 @@ fn write_expression<'a, 'b>(
kind: source_kind, kind: source_kind,
} => ( } => (
source_kind, source_kind,
Cow::Borrowed(map_scalar(kind, width, builder.manager)?.full), Cow::Borrowed(map_scalar(kind, width, manager)?.full),
), ),
TypeInner::Vector { TypeInner::Vector {
width, width,
@ -1186,7 +1281,7 @@ fn write_expression<'a, 'b>(
source_kind, source_kind,
Cow::Owned(format!( Cow::Owned(format!(
"{}vec{}", "{}vec{}",
map_scalar(kind, width, builder.manager)?.prefix, map_scalar(kind, width, manager)?.prefix,
size as u32, size as u32,
)), )),
), ),
@ -1213,7 +1308,7 @@ fn write_expression<'a, 'b>(
Cow::Owned(format!("{}({})", op, value_expr)) Cow::Owned(format!("{}({})", op, value_expr))
} }
Expression::Derivative { axis, expr } => { Expression::Derivative { axis, expr } => {
let expr = write_expression(&builder.expressions[expr], module, builder)?; let expr = write_expression(&builder.expressions[expr], module, builder, manager)?;
Cow::Owned(format!( Cow::Owned(format!(
"{}({})", "{}({})",
@ -1236,7 +1331,8 @@ fn write_expression<'a, 'b>(
.map::<Result<_, Error>, _>(|arg| write_expression( .map::<Result<_, Error>, _>(|arg| write_expression(
&builder.expressions[*arg], &builder.expressions[*arg],
module, module,
builder builder,
manager,
)) ))
.collect::<Result<Vec<_>, _>>()? .collect::<Result<Vec<_>, _>>()?
.join(","), .join(","),
@ -1245,41 +1341,42 @@ fn write_expression<'a, 'b>(
origin: crate::FunctionOrigin::External(ref name), origin: crate::FunctionOrigin::External(ref name),
ref arguments, ref arguments,
} => match name.as_str() { } => match name.as_str() {
"cos" | "normalize" | "sin" => { "cos" | "normalize" | "sin" | "length" | "abs" | "floor" | "inverse" => {
let expr = write_expression(&builder.expressions[arguments[0]], module, builder)?; let expr =
write_expression(&builder.expressions[arguments[0]], module, builder, manager)?;
Cow::Owned(format!("{}({})", name, expr)) Cow::Owned(format!("{}({})", name, expr))
} }
"fclamp" => { "fclamp" | "clamp" | "mix" | "smoothstep" => {
let val = write_expression(&builder.expressions[arguments[0]], module, builder)?; let x =
let min = write_expression(&builder.expressions[arguments[1]], module, builder)?; write_expression(&builder.expressions[arguments[0]], module, builder, manager)?;
let max = write_expression(&builder.expressions[arguments[2]], module, builder)?; let y =
write_expression(&builder.expressions[arguments[1]], module, builder, manager)?;
let a =
write_expression(&builder.expressions[arguments[2]], module, builder, manager)?;
Cow::Owned(format!("clamp({}, {}, {})", val, min, max)) let name = match name.as_str() {
"fclamp" => "clamp",
name => name,
};
Cow::Owned(format!("{}({}, {}, {})", name, x, y, a))
} }
"atan2" => { "atan2" => {
let x = write_expression(&builder.expressions[arguments[0]], module, builder)?; let x =
let y = write_expression(&builder.expressions[arguments[1]], module, builder)?; write_expression(&builder.expressions[arguments[0]], module, builder, manager)?;
let y =
write_expression(&builder.expressions[arguments[1]], module, builder, manager)?;
Cow::Owned(format!("atan({}, {})", y, x)) Cow::Owned(format!("atan({}, {})", y, x))
} }
"distance" => { "distance" | "dot" | "min" | "max" | "reflect" | "pow" | "step" | "cross" => {
let p0 = write_expression(&builder.expressions[arguments[0]], module, builder)?; let x =
let p1 = write_expression(&builder.expressions[arguments[1]], module, builder)?; write_expression(&builder.expressions[arguments[0]], module, builder, manager)?;
let y =
write_expression(&builder.expressions[arguments[1]], module, builder, manager)?;
Cow::Owned(format!("distance({}, {})", p0, p1)) Cow::Owned(format!("{}({}, {})", name, x, y))
}
"length" => {
let x = write_expression(&builder.expressions[arguments[0]], module, builder)?;
Cow::Owned(format!("length({})", x))
}
"mix" => {
let x = write_expression(&builder.expressions[arguments[0]], module, builder)?;
let y = write_expression(&builder.expressions[arguments[0]], module, builder)?;
let a = write_expression(&builder.expressions[arguments[0]], module, builder)?;
Cow::Owned(format!("mix({}, {}, {})", x, y, a))
} }
other => { other => {
return Err(Error::Custom(format!( return Err(Error::Custom(format!(
@ -1289,7 +1386,7 @@ fn write_expression<'a, 'b>(
} }
}, },
Expression::ArrayLength(expr) => { Expression::ArrayLength(expr) => {
let base = write_expression(&builder.expressions[expr], module, builder)?; let base = write_expression(&builder.expressions[expr], module, builder, manager)?;
Cow::Owned(format!("uint({}.length())", base)) Cow::Owned(format!("uint({}.length())", base))
} }
}) })
@ -1299,6 +1396,7 @@ fn write_constant(
constant: &Constant, constant: &Constant,
module: &Module, module: &Module,
builder: &mut StatementBuilder<'_>, builder: &mut StatementBuilder<'_>,
manager: &mut FeaturesManager,
) -> Result<String, Error> { ) -> Result<String, Error> {
Ok(match constant.inner { Ok(match constant.inner {
ConstantInner::Sint(int) => int.to_string(), ConstantInner::Sint(int) => int.to_string(),
@ -1316,9 +1414,10 @@ fn write_constant(
TypeInner::Array { .. } => write_type( TypeInner::Array { .. } => write_type(
constant.ty, constant.ty,
&module.types, &module.types,
&module.constants,
builder.structs, builder.structs,
None, None,
builder.manager manager
)?, )?,
_ => _ =>
return Err(Error::Custom(format!( return Err(Error::Custom(format!(
@ -1326,15 +1425,21 @@ fn write_constant(
write_type( write_type(
constant.ty, constant.ty,
&module.types, &module.types,
&module.constants,
builder.structs, builder.structs,
None, None,
builder.manager manager
)? )?
))), ))),
}, },
components components
.iter() .iter()
.map(|component| write_constant(&module.constants[*component], module, builder,)) .map(|component| write_constant(
&module.constants[*component],
module,
builder,
manager
))
.collect::<Result<Vec<_>, _>>()? .collect::<Result<Vec<_>, _>>()?
.join(","), .join(","),
), ),
@ -1390,6 +1495,7 @@ fn map_scalar(
fn write_type<'a>( fn write_type<'a>(
ty: Handle<Type>, ty: Handle<Type>,
types: &Arena<Type>, types: &Arena<Type>,
constants: &Arena<Constant>,
structs: &'a FastHashMap<Handle<Type>, String>, structs: &'a FastHashMap<Handle<Type>, String>,
block: Option<String>, block: Option<String>,
manager: &mut FeaturesManager, manager: &mut FeaturesManager,
@ -1417,7 +1523,9 @@ fn write_type<'a>(
rows as u8 rows as u8
)) ))
} }
TypeInner::Pointer { base, .. } => write_type(base, types, structs, None, manager)?, TypeInner::Pointer { base, .. } => {
write_type(base, types, constants, structs, None, manager)?
}
TypeInner::Array { base, size, .. } => { TypeInner::Array { base, size, .. } => {
if let TypeInner::Array { .. } = types[base].inner { if let TypeInner::Array { .. } = types[base].inner {
manager.request(Features::ARRAY_OF_ARRAYS) manager.request(Features::ARRAY_OF_ARRAYS)
@ -1425,8 +1533,8 @@ fn write_type<'a>(
Cow::Owned(format!( Cow::Owned(format!(
"{}[{}]", "{}[{}]",
write_type(base, types, structs, None, manager)?, write_type(base, types, constants, structs, None, manager)?,
write_array_size(size)? write_array_size(size, constants)?
)) ))
} }
TypeInner::Struct { ref members } => { TypeInner::Struct { ref members } => {
@ -1438,7 +1546,7 @@ fn write_type<'a>(
writeln!( writeln!(
&mut out, &mut out,
"\t{} {};", "\t{} {};",
write_type(member.ty, types, structs, None, manager)?, write_type(member.ty, types, constants, structs, None, manager)?,
member member
.name .name
.clone() .clone()
@ -1500,22 +1608,24 @@ fn write_storage_class(
manager: &mut FeaturesManager, manager: &mut FeaturesManager,
) -> Result<&'static str, Error> { ) -> Result<&'static str, Error> {
Ok(match class { Ok(match class {
StorageClass::Constant => "",
StorageClass::Function => "", StorageClass::Function => "",
StorageClass::Input => "in ", StorageClass::Input => "in ",
StorageClass::Output => "out ", StorageClass::Output => "out ",
StorageClass::Private => "", StorageClass::Private => "",
StorageClass::StorageBuffer => { StorageClass::Storage => {
manager.request(Features::BUFFER_STORAGE); manager.request(Features::BUFFER_STORAGE);
"buffer " "buffer "
} }
StorageClass::Uniform => "uniform ", StorageClass::Uniform => "uniform ",
StorageClass::Handle => "uniform ",
StorageClass::WorkGroup => { StorageClass::WorkGroup => {
manager.request(Features::COMPUTE_SHADER); manager.request(Features::COMPUTE_SHADER);
"shared " "shared "
} }
StorageClass::PushConstant => {
manager.request(Features::PUSH_CONSTANT);
""
}
}) })
} }
@ -1534,9 +1644,12 @@ fn write_interpolation(interpolation: Interpolation) -> Result<&'static str, Err
}) })
} }
fn write_array_size(size: ArraySize) -> Result<String, Error> { fn write_array_size(size: ArraySize, constants: &Arena<Constant>) -> Result<String, Error> {
Ok(match size { Ok(match size {
ArraySize::Static(size) => size.to_string(), ArraySize::Constant(const_handle) => match constants[const_handle].inner {
ConstantInner::Uint(size) => size.to_string(),
_ => unreachable!(),
},
ArraySize::Dynamic => String::from(""), ArraySize::Dynamic => String::from(""),
}) })
} }
@ -1598,7 +1711,14 @@ fn write_struct(
writeln!( writeln!(
&mut tmp, &mut tmp,
"\t{} {};", "\t{} {};",
write_type(member.ty, &module.types, &structs, None, manager)?, write_type(
member.ty,
&module.types,
&module.constants,
&structs,
None,
manager
)?,
member member
.name .name
.clone() .clone()
@ -1794,6 +1914,7 @@ fn collect_texture_mapping<'a>(
for func in functions { for func in functions {
let mut interface = Interface { let mut interface = Interface {
expressions: &func.expressions, expressions: &func.expressions,
local_variables: &func.local_variables,
visitor: TextureMappingVisitor { visitor: TextureMappingVisitor {
expressions: &func.expressions, expressions: &func.expressions,
map: &mut mappings, map: &mut mappings,

File diff suppressed because it is too large Load diff

View file

@ -235,6 +235,13 @@ pub(super) fn instruction_type_sampler(id: Word) -> Instruction {
instruction instruction
} }
pub(super) fn instruction_type_sampled_image(id: Word, image_type_id: Word) -> Instruction {
let mut instruction = Instruction::new(Op::TypeSampledImage);
instruction.set_result(id);
instruction.add_operand(image_type_id);
instruction
}
pub(super) fn instruction_type_array( pub(super) fn instruction_type_array(
id: Word, id: Word,
element_type_id: Word, element_type_id: Word,
@ -399,6 +406,24 @@ pub(super) fn instruction_store(
instruction instruction
} }
pub(super) fn instruction_access_chain(
result_type_id: Word,
id: Word,
base_id: Word,
index_ids: &[Word],
) -> Instruction {
let mut instruction = Instruction::new(Op::AccessChain);
instruction.set_type(result_type_id);
instruction.set_result(id);
instruction.add_operand(base_id);
for index_id in index_ids {
instruction.add_operand(*index_id);
}
instruction
}
// //
// Function Instructions // Function Instructions
// //
@ -449,6 +474,33 @@ pub(super) fn instruction_function_call(
// //
// Image Instructions // Image Instructions
// //
pub(super) fn instruction_sampled_image(
result_type_id: Word,
id: Word,
image: Word,
sampler: Word,
) -> Instruction {
let mut instruction = Instruction::new(Op::SampledImage);
instruction.set_type(result_type_id);
instruction.set_result(id);
instruction.add_operand(image);
instruction.add_operand(sampler);
instruction
}
pub(super) fn instruction_image_sample_implicit_lod(
result_type_id: Word,
id: Word,
sampled_image: Word,
coordinates: Word,
) -> Instruction {
let mut instruction = Instruction::new(Op::ImageSampleImplicitLod);
instruction.set_type(result_type_id);
instruction.set_result(id);
instruction.add_operand(sampled_image);
instruction.add_operand(coordinates);
instruction
}
// //
// Conversion Instructions // Conversion Instructions

View file

@ -24,6 +24,10 @@ impl PhysicalLayout {
sink.extend(iter::once(self.bound)); sink.extend(iter::once(self.bound));
sink.extend(iter::once(self.instruction_schema)); sink.extend(iter::once(self.instruction_schema));
} }
pub(super) fn supports_storage_buffers(&self) -> bool {
self.version >= 0x10300
}
} }
impl LogicalLayout { impl LogicalLayout {

View file

@ -1,10 +1,25 @@
/*! Standard Portable Intermediate Representation (SPIR-V) backend !*/ /*! Standard Portable Intermediate Representation (SPIR-V) backend !*/
use super::{Instruction, LogicalLayout, PhysicalLayout, WriterFlags}; use super::{Instruction, LogicalLayout, PhysicalLayout, WriterFlags};
use spirv::Word; use spirv::Word;
use std::collections::hash_map::Entry; use std::{collections::hash_map::Entry, ops};
use thiserror::Error;
const BITS_PER_BYTE: crate::Bytes = 8; const BITS_PER_BYTE: crate::Bytes = 8;
#[derive(Clone, Debug, Error)]
pub enum Error {
#[error("can't find local variable: {0:?}")]
UnknownLocalVariable(crate::LocalVariable),
#[error("bad image class for op: {0:?}")]
BadImageClass(crate::ImageClass),
#[error("not an image")]
NotImage,
#[error("empty value")]
EmptyValue,
#[error("feature is not yet implemented")]
FeatureNotImplemented(),
}
struct Block { struct Block {
label: Option<Instruction>, label: Option<Instruction>,
body: Vec<Instruction>, body: Vec<Instruction>,
@ -77,7 +92,10 @@ enum LocalType {
}, },
Pointer { Pointer {
base: crate::Handle<crate::Type>, base: crate::Handle<crate::Type>,
class: spirv::StorageClass, class: crate::StorageClass,
},
SampledImage {
image_type: crate::Handle<crate::Type>,
}, },
} }
@ -102,6 +120,21 @@ struct LookupFunctionType {
return_type_id: Word, return_type_id: Word,
} }
enum MaybeOwned<'a, T> {
Owned(T),
Borrowed(&'a T),
}
impl<'a, T> ops::Deref for MaybeOwned<'a, T> {
type Target = T;
fn deref(&self) -> &T {
match *self {
MaybeOwned::Owned(ref value) => value,
MaybeOwned::Borrowed(reference) => reference,
}
}
}
pub struct Writer { pub struct Writer {
physical_layout: PhysicalLayout, physical_layout: PhysicalLayout,
logical_layout: LogicalLayout, logical_layout: LogicalLayout,
@ -118,6 +151,9 @@ pub struct Writer {
lookup_global_variable: crate::FastHashMap<crate::Handle<crate::GlobalVariable>, Word>, lookup_global_variable: crate::FastHashMap<crate::Handle<crate::GlobalVariable>, Word>,
} }
// type alias, for success return of write_expression
type WriteExpressionOutput = (Word, Option<LookupType>);
impl Writer { impl Writer {
pub fn new(header: &crate::Header, writer_flags: WriterFlags) -> Self { pub fn new(header: &crate::Header, writer_flags: WriterFlags) -> Self {
Writer { Writer {
@ -176,15 +212,13 @@ impl Writer {
fn get_global_variable_id( fn get_global_variable_id(
&mut self, &mut self,
arena: &crate::Arena<crate::Type>, ir_module: &crate::Module,
global_arena: &crate::Arena<crate::GlobalVariable>,
handle: crate::Handle<crate::GlobalVariable>, handle: crate::Handle<crate::GlobalVariable>,
) -> Word { ) -> Word {
match self.lookup_global_variable.entry(handle) { match self.lookup_global_variable.entry(handle) {
Entry::Occupied(e) => *e.get(), Entry::Occupied(e) => *e.get(),
_ => { _ => {
let global_variable = &global_arena[handle]; let (instruction, id) = self.write_global_variable(ir_module, handle);
let (instruction, id) = self.write_global_variable(arena, global_variable, handle);
instruction.to_words(&mut self.logical_layout.declarations); instruction.to_words(&mut self.logical_layout.declarations);
id id
} }
@ -215,7 +249,7 @@ impl Writer {
&mut self, &mut self,
arena: &crate::Arena<crate::Type>, arena: &crate::Arena<crate::Type>,
handle: crate::Handle<crate::Type>, handle: crate::Handle<crate::Type>,
class: spirv::StorageClass, class: crate::StorageClass,
) -> Word { ) -> Word {
let ty = &arena[handle]; let ty = &arena[handle];
let ty_id = self.get_type_id(arena, LookupType::Handle(handle)); let ty_id = self.get_type_id(arena, LookupType::Handle(handle));
@ -230,24 +264,36 @@ impl Writer {
})) { })) {
Entry::Occupied(e) => *e.get(), Entry::Occupied(e) => *e.get(),
_ => { _ => {
let pointer_id = self.generate_id(); let id =
let instruction = self.create_pointer(ty_id, self.parse_to_spirv_storage_class(class));
super::instructions::instruction_type_pointer(pointer_id, class, ty_id);
instruction.to_words(&mut self.logical_layout.declarations);
self.lookup_type.insert( self.lookup_type.insert(
LookupType::Local(LocalType::Pointer { LookupType::Local(LocalType::Pointer {
base: handle, base: handle,
class, class,
}), }),
pointer_id, id,
); );
pointer_id id
} }
} }
} }
} }
} }
fn create_pointer(&mut self, ty_id: Word, class: spirv::StorageClass) -> Word {
let id = self.generate_id();
let instruction = super::instructions::instruction_type_pointer(id, class, ty_id);
instruction.to_words(&mut self.logical_layout.declarations);
id
}
fn create_constant(&mut self, type_id: Word, value: &[Word]) -> Word {
let id = self.generate_id();
let instruction = super::instructions::instruction_constant(type_id, id, value);
instruction.to_words(&mut self.logical_layout.declarations);
id
}
fn write_function( fn write_function(
&mut self, &mut self,
ir_function: &crate::Function, ir_function: &crate::Function,
@ -258,18 +304,12 @@ impl Writer {
for (_, variable) in ir_function.local_variables.iter() { for (_, variable) in ir_function.local_variables.iter() {
let id = self.generate_id(); let id = self.generate_id();
let init_word = match variable.init { let init_word = variable
Some(exp) => match &ir_function.expressions[exp] { .init
crate::Expression::Constant(handle) => { .map(|constant| self.get_constant_id(constant, ir_module));
Some(self.get_constant_id(*handle, ir_module))
}
_ => unreachable!(),
},
None => None,
};
let pointer_id = let pointer_id =
self.get_pointer_id(&ir_module.types, variable.ty, spirv::StorageClass::Function); self.get_pointer_id(&ir_module.types, variable.ty, crate::StorageClass::Function);
function.variables.push(LocalVariable { function.variables.push(LocalVariable {
id, id,
name: variable.name.clone(), name: variable.name.clone(),
@ -284,21 +324,18 @@ impl Writer {
let return_type_id = let return_type_id =
self.get_function_return_type(ir_function.return_type, &ir_module.types); self.get_function_return_type(ir_function.return_type, &ir_module.types);
let mut parameter_type_ids = Vec::with_capacity(ir_function.parameter_types.len()); let mut parameter_type_ids = Vec::with_capacity(ir_function.arguments.len());
let mut function_parameter_pointer_ids = vec![]; let mut function_parameter_pointer_ids = vec![];
for parameter_type in ir_function.parameter_types.iter() { for argument in ir_function.arguments.iter() {
let id = self.generate_id(); let id = self.generate_id();
let pointer_id = self.get_pointer_id( let pointer_id =
&ir_module.types, self.get_pointer_id(&ir_module.types, argument.ty, crate::StorageClass::Function);
*parameter_type,
spirv::StorageClass::Function,
);
function_parameter_pointer_ids.push(pointer_id); function_parameter_pointer_ids.push(pointer_id);
parameter_type_ids parameter_type_ids
.push(self.get_type_id(&ir_module.types, LookupType::Handle(*parameter_type))); .push(self.get_type_id(&ir_module.types, LookupType::Handle(argument.ty)));
function function
.parameters .parameters
.push(super::instructions::instruction_function_parameter( .push(super::instructions::instruction_function_parameter(
@ -350,14 +387,13 @@ impl Writer {
for ((handle, _), &usage) in ir_module for ((handle, _), &usage) in ir_module
.global_variables .global_variables
.iter() .iter()
.filter(|&(_, var)| {
var.class == crate::StorageClass::Input || var.class == crate::StorageClass::Output
})
.zip(&entry_point.function.global_usage) .zip(&entry_point.function.global_usage)
{ {
if usage.contains(crate::GlobalUse::STORE) || usage.contains(crate::GlobalUse::LOAD) { if usage.contains(crate::GlobalUse::STORE) || usage.contains(crate::GlobalUse::LOAD) {
let id = self.get_global_variable_id( let id = self.get_global_variable_id(ir_module, handle);
&ir_module.types,
&ir_module.global_variables,
handle,
);
interface_ids.push(id); interface_ids.push(id);
} }
} }
@ -407,14 +443,19 @@ impl Writer {
fn parse_to_spirv_storage_class(&self, class: crate::StorageClass) -> spirv::StorageClass { fn parse_to_spirv_storage_class(&self, class: crate::StorageClass) -> spirv::StorageClass {
match class { match class {
crate::StorageClass::Constant => spirv::StorageClass::UniformConstant, crate::StorageClass::Handle => spirv::StorageClass::UniformConstant,
crate::StorageClass::Function => spirv::StorageClass::Function, crate::StorageClass::Function => spirv::StorageClass::Function,
crate::StorageClass::Input => spirv::StorageClass::Input, crate::StorageClass::Input => spirv::StorageClass::Input,
crate::StorageClass::Output => spirv::StorageClass::Output, crate::StorageClass::Output => spirv::StorageClass::Output,
crate::StorageClass::Private => spirv::StorageClass::Private, crate::StorageClass::Private => spirv::StorageClass::Private,
crate::StorageClass::StorageBuffer => spirv::StorageClass::StorageBuffer, crate::StorageClass::Storage if self.physical_layout.supports_storage_buffers() => {
crate::StorageClass::Uniform => spirv::StorageClass::Uniform, spirv::StorageClass::StorageBuffer
}
crate::StorageClass::Storage | crate::StorageClass::Uniform => {
spirv::StorageClass::Uniform
}
crate::StorageClass::WorkGroup => spirv::StorageClass::Workgroup, crate::StorageClass::WorkGroup => spirv::StorageClass::Workgroup,
crate::StorageClass::PushConstant => spirv::StorageClass::PushConstant,
} }
} }
@ -432,6 +473,10 @@ impl Writer {
super::instructions::instruction_type_vector(id, scalar_id, size) super::instructions::instruction_type_vector(id, scalar_id, size)
} }
LocalType::Pointer { .. } => unimplemented!(), LocalType::Pointer { .. } => unimplemented!(),
LocalType::SampledImage { image_type } => {
let image_type_id = self.get_type_id(arena, LookupType::Handle(image_type));
super::instructions::instruction_type_sampled_image(id, image_type_id)
}
}; };
self.lookup_type.insert(LookupType::Local(local_ty), id); self.lookup_type.insert(LookupType::Local(local_ty), id);
@ -484,17 +529,14 @@ impl Writer {
} => { } => {
let width = 4; let width = 4;
let local_type = match class { let local_type = match class {
crate::ImageClass::Sampled { kind, multi: _ } => LocalType::Vector { crate::ImageClass::Sampled { kind, multi: _ } => {
size: crate::VectorSize::Quad, LocalType::Scalar { kind, width }
kind, }
width,
},
crate::ImageClass::Depth => LocalType::Scalar { crate::ImageClass::Depth => LocalType::Scalar {
kind: crate::ScalarKind::Float, kind: crate::ScalarKind::Float,
width, width,
}, },
crate::ImageClass::Storage(format) => LocalType::Vector { crate::ImageClass::Storage(format) => LocalType::Scalar {
size: crate::VectorSize::Quad,
kind: format.into(), kind: format.into(),
width, width,
}, },
@ -507,7 +549,7 @@ impl Writer {
crate::TypeInner::Sampler { comparison: _ } => { crate::TypeInner::Sampler { comparison: _ } => {
super::instructions::instruction_type_sampler(id) super::instructions::instruction_type_sampler(id)
} }
crate::TypeInner::Array { size, stride, .. } => { crate::TypeInner::Array { base, size, stride } => {
if let Some(array_stride) = stride { if let Some(array_stride) = stride {
self.annotations self.annotations
.push(super::instructions::instruction_decorate( .push(super::instructions::instruction_decorate(
@ -517,10 +559,11 @@ impl Writer {
)); ));
} }
let type_id = self.get_type_id(arena, LookupType::Handle(handle)); let type_id = self.get_type_id(arena, LookupType::Handle(base));
match size { match size {
crate::ArraySize::Static(length) => { crate::ArraySize::Constant(const_handle) => {
super::instructions::instruction_type_array(id, type_id, length) let length_id = self.lookup_constant[&const_handle];
super::instructions::instruction_type_array(id, type_id, length_id)
} }
crate::ArraySize::Dynamic => { crate::ArraySize::Dynamic => {
super::instructions::instruction_type_runtime_array(id, type_id) super::instructions::instruction_type_runtime_array(id, type_id)
@ -537,13 +580,8 @@ impl Writer {
} }
crate::TypeInner::Pointer { base, class } => { crate::TypeInner::Pointer { base, class } => {
let type_id = self.get_type_id(arena, LookupType::Handle(base)); let type_id = self.get_type_id(arena, LookupType::Handle(base));
self.lookup_type.insert( self.lookup_type
LookupType::Local(LocalType::Pointer { .insert(LookupType::Local(LocalType::Pointer { base, class }), id);
base,
class: self.parse_to_spirv_storage_class(class),
}),
id,
);
super::instructions::instruction_type_pointer( super::instructions::instruction_type_pointer(
id, id,
self.parse_to_spirv_storage_class(class), self.parse_to_spirv_storage_class(class),
@ -656,17 +694,22 @@ impl Writer {
fn write_global_variable( fn write_global_variable(
&mut self, &mut self,
arena: &crate::Arena<crate::Type>, ir_module: &crate::Module,
global_variable: &crate::GlobalVariable,
handle: crate::Handle<crate::GlobalVariable>, handle: crate::Handle<crate::GlobalVariable>,
) -> (Instruction, Word) { ) -> (Instruction, Word) {
let global_variable = &ir_module.global_variables[handle];
let id = self.generate_id(); let id = self.generate_id();
let class = self.parse_to_spirv_storage_class(global_variable.class); let class = self.parse_to_spirv_storage_class(global_variable.class);
self.try_add_capabilities(class.required_capabilities()); self.try_add_capabilities(class.required_capabilities());
let pointer_id = self.get_pointer_id(arena, global_variable.ty, class); let init_word = global_variable
let instruction = super::instructions::instruction_variable(pointer_id, id, class, None); .init
.map(|constant| self.get_constant_id(constant, ir_module));
let pointer_id =
self.get_pointer_id(&ir_module.types, global_variable.ty, global_variable.class);
let instruction =
super::instructions::instruction_variable(pointer_id, id, class, init_word);
if self.writer_flags.contains(WriterFlags::DEBUG) { if self.writer_flags.contains(WriterFlags::DEBUG) {
if let Some(ref name) = global_variable.name { if let Some(ref name) = global_variable.name {
@ -792,6 +835,28 @@ impl Writer {
id id
} }
fn get_type_inner<'a>(
&self,
ty_arena: &'a crate::Arena<crate::Type>,
lookup_ty: LookupType,
) -> MaybeOwned<'a, crate::TypeInner> {
match lookup_ty {
LookupType::Handle(handle) => MaybeOwned::Borrowed(&ty_arena[handle].inner),
LookupType::Local(local_ty) => match local_ty {
LocalType::Scalar { kind, width } => {
MaybeOwned::Owned(crate::TypeInner::Scalar { kind, width })
}
LocalType::Vector { size, kind, width } => {
MaybeOwned::Owned(crate::TypeInner::Vector { size, kind, width })
}
LocalType::Pointer { base, class } => {
MaybeOwned::Owned(crate::TypeInner::Pointer { base, class })
}
_ => unreachable!(),
},
}
}
fn write_expression<'a>( fn write_expression<'a>(
&mut self, &mut self,
ir_module: &'a crate::Module, ir_module: &'a crate::Module,
@ -799,36 +864,153 @@ impl Writer {
expression: &crate::Expression, expression: &crate::Expression,
block: &mut Block, block: &mut Block,
function: &mut Function, function: &mut Function,
) -> Option<(Word, Option<crate::Handle<crate::Type>>)> { ) -> Result<WriteExpressionOutput, Error> {
match expression { match *expression {
crate::Expression::GlobalVariable(handle) => { crate::Expression::Access { base, index } => {
let var = &ir_module.global_variables[*handle]; let id = self.generate_id();
let id = self.get_global_variable_id(
let (base_id, base_lookup_ty) = self.write_expression(
ir_module,
ir_function,
&ir_function.expressions[base],
block,
function,
)?;
let (index_id, _) = self.write_expression(
ir_module,
ir_function,
&ir_function.expressions[index],
block,
function,
)?;
let base_ty_inner = self.get_type_inner(&ir_module.types, base_lookup_ty.unwrap());
let (pointer_id, type_id, lookup_ty) = match *base_ty_inner {
crate::TypeInner::Vector { kind, width, .. } => {
let scalar_id = self.get_type_id(
&ir_module.types, &ir_module.types,
&ir_module.global_variables, LookupType::Local(LocalType::Scalar { kind, width }),
*handle,
); );
Some((id, Some(var.ty))) (
self.create_pointer(scalar_id, spirv::StorageClass::Function),
scalar_id,
LookupType::Local(LocalType::Scalar { kind, width }),
)
}
_ => unimplemented!(),
};
block
.body
.push(super::instructions::instruction_access_chain(
pointer_id,
id,
base_id,
&[index_id],
));
let load_id = self.generate_id();
block.body.push(super::instructions::instruction_load(
type_id, load_id, id, None,
));
Ok((load_id, Some(lookup_ty)))
}
crate::Expression::AccessIndex { base, index } => {
let id = self.generate_id();
let (base_id, base_lookup_ty) = self
.write_expression(
ir_module,
ir_function,
&ir_function.expressions[base],
block,
function,
)
.unwrap();
let base_ty_inner = self.get_type_inner(&ir_module.types, base_lookup_ty.unwrap());
let (pointer_id, type_id, lookup_ty) = match *base_ty_inner {
crate::TypeInner::Vector { kind, width, .. } => {
let scalar_id = self.get_type_id(
&ir_module.types,
LookupType::Local(LocalType::Scalar { kind, width }),
);
(
self.create_pointer(scalar_id, spirv::StorageClass::Function),
scalar_id,
LookupType::Local(LocalType::Scalar { kind, width }),
)
}
crate::TypeInner::Struct { ref members } => {
let member = &members[index as usize];
let type_id =
self.get_type_id(&ir_module.types, LookupType::Handle(member.ty));
(
self.create_pointer(type_id, spirv::StorageClass::Uniform),
type_id,
LookupType::Handle(member.ty),
)
}
_ => unimplemented!(),
};
let const_ty_id = self.get_type_id(
&ir_module.types,
LookupType::Local(LocalType::Scalar {
kind: crate::ScalarKind::Sint,
width: 4,
}),
);
let const_id = self.create_constant(const_ty_id, &[index]);
block
.body
.push(super::instructions::instruction_access_chain(
pointer_id,
id,
base_id,
&[const_id],
));
let load_id = self.generate_id();
block.body.push(super::instructions::instruction_load(
type_id, load_id, id, None,
));
Ok((load_id, Some(lookup_ty)))
}
crate::Expression::GlobalVariable(handle) => {
let var = &ir_module.global_variables[handle];
let id = self.get_global_variable_id(&ir_module, handle);
Ok((id, Some(LookupType::Handle(var.ty))))
} }
crate::Expression::Constant(handle) => { crate::Expression::Constant(handle) => {
let var = &ir_module.constants[*handle]; let var = &ir_module.constants[handle];
let id = self.get_constant_id(*handle, ir_module); let id = self.get_constant_id(handle, ir_module);
Some((id, Some(var.ty))) Ok((id, Some(LookupType::Handle(var.ty))))
} }
crate::Expression::Compose { ty, components } => { crate::Expression::Compose { ty, ref components } => {
let base_type_id = self.get_type_id(&ir_module.types, LookupType::Handle(*ty)); let base_type_id = self.get_type_id(&ir_module.types, LookupType::Handle(ty));
let mut constituent_ids = Vec::with_capacity(components.len()); let mut constituent_ids = Vec::with_capacity(components.len());
for component in components { for component in components {
let expression = &ir_function.expressions[*component]; let expression = &ir_function.expressions[*component];
let (component_id, _) = self let (component_id, _) = self.write_expression(
.write_expression(ir_module, &ir_function, expression, block, function) ir_module,
.unwrap(); &ir_function,
expression,
block,
function,
)?;
constituent_ids.push(component_id); constituent_ids.push(component_id);
} }
let constituent_ids_slice = constituent_ids.as_slice(); let constituent_ids_slice = constituent_ids.as_slice();
let id = match ir_module.types[*ty].inner { let id = match ir_module.types[ty].inner {
crate::TypeInner::Vector { .. } => { crate::TypeInner::Vector { .. } => {
self.write_composite_construct(base_type_id, constituent_ids_slice, block) self.write_composite_construct(base_type_id, constituent_ids_slice, block)
} }
@ -868,44 +1050,40 @@ impl Writer {
_ => unreachable!(), _ => unreachable!(),
}; };
Some((id, Some(*ty))) Ok((id, Some(LookupType::Handle(ty))))
} }
crate::Expression::Binary { op, left, right } => { crate::Expression::Binary { op, left, right } => {
match op { match op {
crate::BinaryOperator::Multiply => { crate::BinaryOperator::Multiply => {
let id = self.generate_id(); let id = self.generate_id();
let left_expression = &ir_function.expressions[*left]; let left_expression = &ir_function.expressions[left];
let right_expression = &ir_function.expressions[*right]; let right_expression = &ir_function.expressions[right];
let (left_id, left_ty) = self let (left_id, left_lookup_ty) = self.write_expression(
.write_expression(
ir_module, ir_module,
ir_function, ir_function,
left_expression, left_expression,
block, block,
function, function,
) )?;
.unwrap(); let (right_id, right_lookup_ty) = self.write_expression(
let (right_id, right_ty) = self
.write_expression(
ir_module, ir_module,
ir_function, ir_function,
right_expression, right_expression,
block, block,
function, function,
) )?;
.unwrap();
let left_ty = left_ty.unwrap(); let left_lookup_ty = left_lookup_ty.unwrap();
let right_ty = right_ty.unwrap(); let right_lookup_ty = right_lookup_ty.unwrap();
let left_ty_inner = &ir_module.types[left_ty].inner; let left_ty_inner = self.get_type_inner(&ir_module.types, left_lookup_ty);
let right_ty_inner = &ir_module.types[right_ty].inner; let right_ty_inner = self.get_type_inner(&ir_module.types, right_lookup_ty);
let left_result_type_id = let left_result_type_id =
self.get_type_id(&ir_module.types, LookupType::Handle(left_ty)); self.get_type_id(&ir_module.types, left_lookup_ty);
let right_result_type_id = let right_result_type_id =
self.get_type_id(&ir_module.types, LookupType::Handle(right_ty)); self.get_type_id(&ir_module.types, right_lookup_ty);
let left_id = match *left_expression { let left_id = match *left_expression {
crate::Expression::LocalVariable(_) crate::Expression::LocalVariable(_)
@ -937,8 +1115,8 @@ impl Writer {
_ => right_id, _ => right_id,
}; };
let (instruction, ty) = match left_ty_inner { let (instruction, lookup_ty) = match *left_ty_inner {
crate::TypeInner::Vector { .. } => match right_ty_inner { crate::TypeInner::Vector { .. } => match *right_ty_inner {
crate::TypeInner::Scalar { .. } => ( crate::TypeInner::Scalar { .. } => (
super::instructions::instruction_vector_times_scalar( super::instructions::instruction_vector_times_scalar(
left_result_type_id, left_result_type_id,
@ -946,7 +1124,7 @@ impl Writer {
left_id, left_id,
right_id, right_id,
), ),
left_ty, left_lookup_ty,
), ),
crate::TypeInner::Matrix { .. } => ( crate::TypeInner::Matrix { .. } => (
super::instructions::instruction_vector_times_matrix( super::instructions::instruction_vector_times_matrix(
@ -955,11 +1133,11 @@ impl Writer {
left_id, left_id,
right_id, right_id,
), ),
left_ty, left_lookup_ty,
), ),
_ => unreachable!(), _ => unreachable!(),
}, },
crate::TypeInner::Matrix { .. } => match right_ty_inner { crate::TypeInner::Matrix { .. } => match *right_ty_inner {
crate::TypeInner::Scalar { .. } => ( crate::TypeInner::Scalar { .. } => (
super::instructions::instruction_matrix_times_scalar( super::instructions::instruction_matrix_times_scalar(
left_result_type_id, left_result_type_id,
@ -967,7 +1145,7 @@ impl Writer {
left_id, left_id,
right_id, right_id,
), ),
left_ty, left_lookup_ty,
), ),
crate::TypeInner::Vector { .. } => ( crate::TypeInner::Vector { .. } => (
super::instructions::instruction_matrix_times_vector( super::instructions::instruction_matrix_times_vector(
@ -976,7 +1154,7 @@ impl Writer {
left_id, left_id,
right_id, right_id,
), ),
right_ty, right_lookup_ty,
), ),
crate::TypeInner::Matrix { .. } => ( crate::TypeInner::Matrix { .. } => (
super::instructions::instruction_matrix_times_matrix( super::instructions::instruction_matrix_times_matrix(
@ -985,7 +1163,7 @@ impl Writer {
left_id, left_id,
right_id, right_id,
), ),
left_ty, left_lookup_ty,
), ),
_ => unreachable!(), _ => unreachable!(),
}, },
@ -999,7 +1177,7 @@ impl Writer {
left_id, left_id,
right_id, right_id,
), ),
left_ty, left_lookup_ty,
), ),
crate::ScalarKind::Sint | crate::ScalarKind::Uint => ( crate::ScalarKind::Sint | crate::ScalarKind::Uint => (
super::instructions::instruction_i_mul( super::instructions::instruction_i_mul(
@ -1008,7 +1186,7 @@ impl Writer {
left_id, left_id,
right_id, right_id,
), ),
left_ty, left_lookup_ty,
), ),
_ => unreachable!(), _ => unreachable!(),
} }
@ -1017,59 +1195,66 @@ impl Writer {
}; };
block.body.push(instruction); block.body.push(instruction);
Some((id, Some(ty))) Ok((id, Some(lookup_ty)))
} }
_ => unimplemented!("{:?}", op), _ => unimplemented!("{:?}", op),
} }
} }
crate::Expression::LocalVariable(variable) => { crate::Expression::LocalVariable(variable) => {
let var = &ir_function.local_variables[*variable]; let var = &ir_function.local_variables[variable];
let id = if let Some(local_var) = function function
.variables .variables
.iter() .iter()
.find(|&v| v.name.as_ref().unwrap() == var.name.as_ref().unwrap()) .find(|&v| v.name.as_ref().unwrap() == var.name.as_ref().unwrap())
{ .map(|local_var| (local_var.id, Some(LookupType::Handle(var.ty))))
local_var.id .ok_or_else(|| Error::UnknownLocalVariable(var.clone()))
} else {
panic!("Could not find: {:?}", var)
};
Some((id, Some(var.ty)))
} }
crate::Expression::FunctionParameter(index) => { crate::Expression::FunctionArgument(index) => {
let handle = ir_function.parameter_types.get(*index as usize).unwrap(); let handle = ir_function.arguments[index as usize].ty;
let type_id = self.get_type_id(&ir_module.types, LookupType::Handle(*handle)); let type_id = self.get_type_id(&ir_module.types, LookupType::Handle(handle));
let load_id = self.generate_id(); let load_id = self.generate_id();
block.body.push(super::instructions::instruction_load( block.body.push(super::instructions::instruction_load(
type_id, type_id,
load_id, load_id,
function.parameters[*index as usize].result_id.unwrap(), function.parameters[index as usize].result_id.unwrap(),
None, None,
)); ));
Some((load_id, Some(*handle))) Ok((load_id, Some(LookupType::Handle(handle))))
} }
crate::Expression::Call { origin, arguments } => match origin { crate::Expression::Call {
ref origin,
ref arguments,
} => match *origin {
crate::FunctionOrigin::Local(local_function) => { crate::FunctionOrigin::Local(local_function) => {
let origin_function = &ir_module.functions[*local_function]; let origin_function = &ir_module.functions[local_function];
let id = self.generate_id(); let id = self.generate_id();
let mut argument_ids = vec![]; let mut argument_ids = vec![];
for argument in arguments { for argument in arguments {
let expression = &ir_function.expressions[*argument]; let expression = &ir_function.expressions[*argument];
let (id, ty) = self let (id, lookup_ty) = self.write_expression(
.write_expression(ir_module, ir_function, expression, block, function) ir_module,
.unwrap(); ir_function,
expression,
block,
function,
)?;
// Create variable - OpVariable // Create variable - OpVariable
// Store value to variable - OpStore // Store value to variable - OpStore
// Use id of variable // Use id of variable
let handle = match lookup_ty.unwrap() {
LookupType::Handle(handle) => handle,
LookupType::Local(_) => unreachable!(),
};
let pointer_id = self.get_pointer_id( let pointer_id = self.get_pointer_id(
&ir_module.types, &ir_module.types,
ty.unwrap(), handle,
spirv::StorageClass::Function, crate::StorageClass::Function,
); );
let variable_id = self.generate_id(); let variable_id = self.generate_id();
@ -1099,10 +1284,10 @@ impl Writer {
.push(super::instructions::instruction_function_call( .push(super::instructions::instruction_function_call(
return_type_id, return_type_id,
id, id,
*self.lookup_function.get(local_function).unwrap(), *self.lookup_function.get(&local_function).unwrap(),
argument_ids.as_slice(), argument_ids.as_slice(),
)); ));
Some((id, None)) Ok((id, None))
} }
_ => unimplemented!("{:?}", origin), _ => unimplemented!("{:?}", origin),
}, },
@ -1112,31 +1297,31 @@ impl Writer {
convert, convert,
} => { } => {
if !convert { if !convert {
return None; return Err(Error::FeatureNotImplemented());
} }
let (expr_id, expr_type) = self let (expr_id, expr_type) = self.write_expression(
.write_expression(
ir_module, ir_module,
ir_function, ir_function,
&ir_function.expressions[*expr], &ir_function.expressions[expr],
block, block,
function, function,
) )?;
.unwrap();
let id = self.generate_id(); let id = self.generate_id();
let instruction = match ir_module.types[expr_type.unwrap()].inner { let expr_type_inner = self.get_type_inner(&ir_module.types, expr_type.unwrap());
let instruction = match *expr_type_inner {
crate::TypeInner::Scalar { crate::TypeInner::Scalar {
kind: expr_kind, kind: expr_kind,
width, width,
} => { } => {
let kind_type_id = self.get_type_id( let kind_type_id = self.get_type_id(
&ir_module.types, &ir_module.types,
LookupType::Local(LocalType::Scalar { kind: *kind, width }), LookupType::Local(LocalType::Scalar { kind, width }),
); );
if *convert { if convert {
super::instructions::instruction_bit_cast(kind_type_id, id, expr_id) super::instructions::instruction_bit_cast(kind_type_id, id, expr_id)
} else { } else {
match (expr_kind, kind) { match (expr_kind, kind) {
@ -1177,7 +1362,147 @@ impl Writer {
block.body.push(instruction); block.body.push(instruction);
Some((id, None)) Ok((id, None))
}
crate::Expression::ImageSample {
image,
sampler,
coordinate,
level: _,
depth_ref: _,
} => {
// image
let image_expression = &ir_function.expressions[image];
let (image_id, image_lookup_ty) = self.write_expression(
ir_module,
ir_function,
image_expression,
block,
function,
)?;
let image_lookup_ty = image_lookup_ty.ok_or(Error::EmptyValue)?;
let image_result_type_id = self.get_type_id(&ir_module.types, image_lookup_ty);
let image_id = match *image_expression {
crate::Expression::LocalVariable(_) | crate::Expression::GlobalVariable(_) => {
let load_id = self.generate_id();
block.body.push(super::instructions::instruction_load(
image_result_type_id,
load_id,
image_id,
None,
));
load_id
}
_ => image_id,
};
let image_ty = match image_lookup_ty {
LookupType::Handle(handle) => handle,
LookupType::Local(_) => unreachable!(),
};
// OpTypeSampledImage
let sampled_image_type_id = self.get_type_id(
&ir_module.types,
LookupType::Local(LocalType::SampledImage {
image_type: image_ty,
}),
);
// sampler
let sampler_expression = &ir_function.expressions[sampler];
let (sampler_id, sampler_lookup_ty) = self.write_expression(
ir_module,
ir_function,
sampler_expression,
block,
function,
)?;
let sampler_result_type_id =
self.get_type_id(&ir_module.types, sampler_lookup_ty.unwrap());
let sampler_id = match *sampler_expression {
crate::Expression::LocalVariable(_) | crate::Expression::GlobalVariable(_) => {
let load_id = self.generate_id();
block.body.push(super::instructions::instruction_load(
sampler_result_type_id,
load_id,
sampler_id,
None,
));
load_id
}
_ => sampler_id,
};
// coordinate
let coordinate_expression = &ir_function.expressions[coordinate];
let (coordinate_id, coordinate_lookup_ty) = self.write_expression(
ir_module,
ir_function,
coordinate_expression,
block,
function,
)?;
let coordinate_result_type_id =
self.get_type_id(&ir_module.types, coordinate_lookup_ty.unwrap());
let coordinate_id = match *coordinate_expression {
crate::Expression::LocalVariable(_) | crate::Expression::GlobalVariable(_) => {
let load_id = self.generate_id();
block.body.push(super::instructions::instruction_load(
coordinate_result_type_id,
load_id,
coordinate_id,
None,
));
load_id
}
_ => coordinate_id,
};
// component kind
let image_type = &ir_module.types[image_ty];
let image_sample_result_type =
if let crate::TypeInner::Image { class, .. } = image_type.inner {
let width = 4;
let local_type = match class {
crate::ImageClass::Sampled { kind, multi: _ } => LocalType::Vector {
kind,
width,
size: crate::VectorSize::Quad,
},
crate::ImageClass::Depth => LocalType::Scalar {
kind: crate::ScalarKind::Float,
width,
},
_ => return Err(Error::BadImageClass(class)),
};
self.get_type_id(&ir_module.types, LookupType::Local(local_type))
} else {
return Err(Error::NotImage);
};
let sampled_image_id = self.generate_id();
block
.body
.push(super::instructions::instruction_sampled_image(
sampled_image_type_id,
sampled_image_id,
image_id,
sampler_id,
));
let id = self.generate_id();
block
.body
.push(super::instructions::instruction_image_sample_implicit_lod(
image_sample_result_type,
id,
sampled_image_id,
coordinate_id,
));
Ok((id, None))
} }
_ => unimplemented!("{:?}", expression), _ => unimplemented!("{:?}", expression),
} }
@ -1206,7 +1531,7 @@ impl Writer {
block.termination = Some(match ir_function.return_type { block.termination = Some(match ir_function.return_type {
Some(_) => { Some(_) => {
let expression = &ir_function.expressions[value.unwrap()]; let expression = &ir_function.expressions[value.unwrap()];
let (id, ty) = self let (id, lookup_ty) = self
.write_expression( .write_expression(
ir_module, ir_module,
ir_function, ir_function,
@ -1220,10 +1545,8 @@ impl Writer {
crate::Expression::LocalVariable(_) crate::Expression::LocalVariable(_)
| crate::Expression::GlobalVariable(_) => { | crate::Expression::GlobalVariable(_) => {
let load_id = self.generate_id(); let load_id = self.generate_id();
let value_ty_id = self.get_type_id( let value_ty_id =
&ir_module.types, self.get_type_id(&ir_module.types, lookup_ty.unwrap());
LookupType::Handle(ty.unwrap()),
);
block.body.push(super::instructions::instruction_load( block.body.push(super::instructions::instruction_load(
value_ty_id, value_ty_id,
load_id, load_id,
@ -1252,7 +1575,7 @@ impl Writer {
function, function,
) )
.unwrap(); .unwrap();
let (value_id, value_ty) = self let (value_id, value_lookup_ty) = self
.write_expression( .write_expression(
ir_module, ir_module,
ir_function, ir_function,
@ -1266,10 +1589,8 @@ impl Writer {
crate::Expression::LocalVariable(_) crate::Expression::LocalVariable(_)
| crate::Expression::GlobalVariable(_) => { | crate::Expression::GlobalVariable(_) => {
let load_id = self.generate_id(); let load_id = self.generate_id();
let value_ty_id = self.get_type_id( let value_ty_id =
&ir_module.types, self.get_type_id(&ir_module.types, value_lookup_ty.unwrap());
LookupType::Handle(value_ty.unwrap()),
);
block.body.push(super::instructions::instruction_load( block.body.push(super::instructions::instruction_load(
value_ty_id, value_ty_id,
load_id, load_id,
@ -1309,21 +1630,6 @@ impl Writer {
)); ));
} }
// Looking through all global variable, types, constants.
// Doing this because we also want to include not used parts of the module
// to be included in the output
for (handle, _) in ir_module.types.iter() {
self.get_type_id(&ir_module.types, LookupType::Handle(handle));
}
for (handle, _) in ir_module.global_variables.iter() {
self.get_global_variable_id(&ir_module.types, &ir_module.global_variables, handle);
}
for (handle, _) in ir_module.constants.iter() {
self.get_constant_id(handle, &ir_module);
}
for annotation in self.annotations.iter() { for annotation in self.annotations.iter() {
annotation.to_words(&mut self.logical_layout.annotations); annotation.to_words(&mut self.logical_layout.annotations);
} }

View file

@ -42,8 +42,8 @@ impl Program {
pub fn binary_expr( pub fn binary_expr(
&mut self, &mut self,
op: BinaryOperator, op: BinaryOperator,
left: ExpressionRule, left: &ExpressionRule,
right: ExpressionRule, right: &ExpressionRule,
) -> ExpressionRule { ) -> ExpressionRule {
ExpressionRule::from_expression(self.context.expressions.append(Expression::Binary { ExpressionRule::from_expression(self.context.expressions.append(Expression::Binary {
op, op,
@ -57,13 +57,13 @@ impl Program {
handle: Handle<crate::Expression>, handle: Handle<crate::Expression>,
) -> Result<&crate::TypeInner, ErrorKind> { ) -> Result<&crate::TypeInner, ErrorKind> {
let functions = Arena::new(); //TODO let functions = Arena::new(); //TODO
let parameter_types: Vec<Handle<Type>> = vec![]; //TODO let arguments = Vec::new(); //TODO
let resolve_ctx = ResolveContext { let resolve_ctx = ResolveContext {
constants: &self.module.constants, constants: &self.module.constants,
global_vars: &self.module.global_variables, global_vars: &self.module.global_variables,
local_vars: &self.context.local_variables, local_vars: &self.context.local_variables,
functions: &functions, functions: &functions,
parameter_types: &parameter_types, arguments: &arguments,
}; };
match self.context.typifier.grow( match self.context.typifier.grow(
handle, handle,
@ -138,6 +138,7 @@ impl Context {
pub struct ExpressionRule { pub struct ExpressionRule {
pub expression: Handle<Expression>, pub expression: Handle<Expression>,
pub statements: Vec<Statement>, pub statements: Vec<Statement>,
pub sampler: Option<Handle<Expression>>,
} }
impl ExpressionRule { impl ExpressionRule {
@ -145,6 +146,7 @@ impl ExpressionRule {
ExpressionRule { ExpressionRule {
expression, expression,
statements: vec![], statements: vec![],
sampler: None,
} }
} }
} }
@ -166,12 +168,11 @@ pub struct VarDeclaration {
#[derive(Debug)] #[derive(Debug)]
pub enum FunctionCallKind { pub enum FunctionCallKind {
TypeConstructor(Handle<Type>), TypeConstructor(Handle<Type>),
Function(Handle<Expression>), Function(String),
} }
#[derive(Debug)] #[derive(Debug)]
pub struct FunctionCall { pub struct FunctionCall {
pub kind: FunctionCallKind, pub kind: FunctionCallKind,
pub args: Vec<Handle<Expression>>, pub args: Vec<ExpressionRule>,
pub statements: Vec<Statement>,
} }

View file

@ -21,6 +21,7 @@ pub enum ErrorKind {
VariableNotAvailable(String), VariableNotAvailable(String),
ExpectedConstant, ExpectedConstant,
SemanticError(&'static str), SemanticError(&'static str),
PreprocessorError(String),
} }
impl fmt::Display for ErrorKind { impl fmt::Display for ErrorKind {
@ -53,6 +54,7 @@ impl fmt::Display for ErrorKind {
} }
ErrorKind::ExpectedConstant => write!(f, "Expected constant"), ErrorKind::ExpectedConstant => write!(f, "Expected constant"),
ErrorKind::SemanticError(msg) => write!(f, "Semantic error: {}", msg), ErrorKind::SemanticError(msg) => write!(f, "Semantic error: {}", msg),
ErrorKind::PreprocessorError(val) => write!(f, "Preprocessor error: {}", val),
} }
} }
} }

View file

@ -1,5 +1,5 @@
use super::parser::Token; use super::parser::Token;
use super::{token::TokenMetadata, types::parse_type}; use super::{preprocess::LinePreProcessor, token::TokenMetadata, types::parse_type};
use std::{iter::Enumerate, str::Lines}; use std::{iter::Enumerate, str::Lines};
fn _consume_str<'a>(input: &'a str, what: &str) -> Option<&'a str> { fn _consume_str<'a>(input: &'a str, what: &str) -> Option<&'a str> {
@ -23,6 +23,7 @@ pub struct Lexer<'a> {
line: usize, line: usize,
offset: usize, offset: usize,
inside_comment: bool, inside_comment: bool,
pub pp: LinePreProcessor,
} }
impl<'a> Lexer<'a> { impl<'a> Lexer<'a> {
@ -139,6 +140,16 @@ impl<'a> Lexer<'a> {
"break" => Some(Token::Break(meta)), "break" => Some(Token::Break(meta)),
"return" => Some(Token::Return(meta)), "return" => Some(Token::Return(meta)),
"discard" => Some(Token::Discard(meta)), "discard" => Some(Token::Discard(meta)),
// selection statements
"if" => Some(Token::If(meta)),
"else" => Some(Token::Else(meta)),
"switch" => Some(Token::Switch(meta)),
"case" => Some(Token::Case(meta)),
"default" => Some(Token::Default(meta)),
// iteration statements
"while" => Some(Token::While(meta)),
"do" => Some(Token::Do(meta)),
"for" => Some(Token::For(meta)),
// types // types
"void" => Some(Token::Void(meta)), "void" => Some(Token::Void(meta)),
word => { word => {
@ -283,12 +294,24 @@ impl<'a> Lexer<'a> {
} }
pub fn new(input: &'a str) -> Self { pub fn new(input: &'a str) -> Self {
let mut lines = input.lines().enumerate(); let mut lexer = Lexer {
let (line, input) = lines.next().unwrap_or((0, "")); lines: input.lines().enumerate(),
input: "".to_string(),
line: 0,
offset: 0,
inside_comment: false,
pp: LinePreProcessor::new(),
};
lexer.next_line();
lexer
}
fn next_line(&mut self) -> bool {
if let Some((line, input)) = self.lines.next() {
let mut input = String::from(input); let mut input = String::from(input);
while input.ends_with('\\') { while input.ends_with('\\') {
if let Some((_, next)) = lines.next() { if let Some((_, next)) = self.lines.next() {
input.pop(); input.pop();
input.push_str(next); input.push_str(next);
} else { } else {
@ -296,12 +319,17 @@ impl<'a> Lexer<'a> {
} }
} }
Lexer { if let Ok(processed) = self.pp.process_line(&input) {
lines, self.input = processed.unwrap_or_default();
input, self.line = line;
line, self.offset = 0;
offset: 0, true
inside_comment: false, } else {
//TODO: handle preprocessor error
false
}
} else {
false
} }
} }
@ -331,22 +359,9 @@ impl<'a> Lexer<'a> {
self.next() self.next()
} }
} else { } else {
let (line, input) = self.lines.next()?; if !self.next_line() {
return None;
let mut input = String::from(input);
while input.ends_with('\\') {
if let Some((_, next)) = self.lines.next() {
input.pop();
input.push_str(next);
} else {
break;
} }
}
self.input = input;
self.line = line;
self.offset = 0;
self.next() self.next()
} }
} }

View file

@ -1,9 +1,13 @@
use crate::{Module, ShaderStage}; use crate::{FastHashMap, Module, ShaderStage};
mod lex; mod lex;
#[cfg(test)] #[cfg(test)]
mod lex_tests; mod lex_tests;
mod preprocess;
#[cfg(test)]
mod preprocess_tests;
mod ast; mod ast;
use ast::Program; use ast::Program;
@ -17,11 +21,17 @@ mod token;
mod types; mod types;
mod variables; mod variables;
pub fn parse_str(source: &str, entry: &str, stage: ShaderStage) -> Result<Module, ParseError> { pub fn parse_str(
log::debug!("------ GLSL-pomelo ------"); source: &str,
entry: &str,
stage: ShaderStage,
defines: FastHashMap<String, String>,
) -> Result<Module, ParseError> {
let mut program = Program::new(stage, entry); let mut program = Program::new(stage, entry);
let lex = Lexer::new(source);
let mut lex = Lexer::new(source);
lex.pp.defines = defines;
let mut parser = parser::Parser::new(&mut program); let mut parser = parser::Parser::new(&mut program);
for token in lex { for token in lex {

View file

@ -6,9 +6,9 @@ pomelo! {
%include { %include {
use super::super::{error::ErrorKind, token::*, ast::*}; use super::super::{error::ErrorKind, token::*, ast::*};
use crate::{proc::Typifier, Arena, BinaryOperator, Binding, Block, Constant, use crate::{proc::Typifier, Arena, BinaryOperator, Binding, Block, Constant,
ConstantInner, EntryPoint, Expression, Function, GlobalVariable, Handle, Interpolation, ConstantInner, EntryPoint, Expression, FallThrough, FastHashMap, Function, GlobalVariable, Handle, Interpolation,
LocalVariable, MemberOrigin, ScalarKind, Statement, StorageAccess, LocalVariable, MemberOrigin, SampleLevel, ScalarKind, Statement, StorageAccess,
StorageClass, StructMember, Type, TypeInner}; StorageClass, StructMember, Type, TypeInner, UnaryOperator};
} }
%token #[derive(Debug)] #[cfg_attr(test, derive(PartialEq))] pub enum Token {}; %token #[derive(Debug)] #[cfg_attr(test, derive(PartialEq))] pub enum Token {};
%parser pub struct Parser<'a> {}; %parser pub struct Parser<'a> {};
@ -55,6 +55,13 @@ pomelo! {
%type expression_statement Statement; %type expression_statement Statement;
%type declaration_statement Statement; %type declaration_statement Statement;
%type jump_statement Statement; %type jump_statement Statement;
%type iteration_statement Statement;
%type selection_statement Statement;
%type switch_statement_list Vec<(Option<i32>, Block, Option<FallThrough>)>;
%type switch_statement (Option<i32>, Block, Option<FallThrough>);
%type for_init_statement Statement;
%type for_rest_statement (Option<ExpressionRule>, Option<ExpressionRule>);
%type condition_opt Option<ExpressionRule>;
// expressions // expressions
%type unary_expression ExpressionRule; %type unary_expression ExpressionRule;
@ -90,7 +97,7 @@ pomelo! {
%type initializer ExpressionRule; %type initializer ExpressionRule;
// decalartions // declarations
%type declaration VarDeclaration; %type declaration VarDeclaration;
%type init_declarator_list VarDeclaration; %type init_declarator_list VarDeclaration;
%type single_declaration VarDeclaration; %type single_declaration VarDeclaration;
@ -115,6 +122,9 @@ pomelo! {
%type TypeName Type; %type TypeName Type;
// precedence
%right Else;
root ::= version_pragma translation_unit; root ::= version_pragma translation_unit;
version_pragma ::= Version IntConstant(V) Identifier?(P) { version_pragma ::= Version IntConstant(V) Identifier?(P) {
match V.1 { match V.1 {
@ -140,9 +150,7 @@ pomelo! {
let var = extra.lookup_variable(&v.1)?; let var = extra.lookup_variable(&v.1)?;
match var { match var {
Some(expression) => { Some(expression) => {
ExpressionRule::from_expression( ExpressionRule::from_expression(expression)
expression
)
}, },
None => { None => {
return Err(ErrorKind::UnknownVariable(v.0, v.1)); return Err(ErrorKind::UnknownVariable(v.0, v.1));
@ -220,7 +228,7 @@ pomelo! {
postfix_expression ::= postfix_expression(e) Dot Identifier(i) /* FieldSelection in spec */ { postfix_expression ::= postfix_expression(e) Dot Identifier(i) /* FieldSelection in spec */ {
//TODO: how will this work as l-value? //TODO: how will this work as l-value?
let expression = extra.field_selection(e.expression, &*i.1, i.0)?; let expression = extra.field_selection(e.expression, &*i.1, i.0)?;
ExpressionRule { expression, statements: e.statements } ExpressionRule { expression, statements: e.statements, sampler: None }
} }
postfix_expression ::= postfix_expression(pe) IncOp { postfix_expression ::= postfix_expression(pe) IncOp {
//TODO //TODO
@ -234,17 +242,49 @@ pomelo! {
integer_expression ::= expression; integer_expression ::= expression;
function_call ::= function_call_or_method(fc) { function_call ::= function_call_or_method(fc) {
if let FunctionCallKind::TypeConstructor(ty) = fc.kind { match fc.kind {
FunctionCallKind::TypeConstructor(ty) => {
let h = extra.context.expressions.append(Expression::Compose { let h = extra.context.expressions.append(Expression::Compose {
ty, ty,
components: fc.args, components: fc.args.iter().map(|a| a.expression).collect(),
}); });
ExpressionRule { ExpressionRule {
expression: h, expression: h,
statements: fc.statements, statements: fc.args.into_iter().map(|a| a.statements).flatten().collect(),
sampler: None
}
}
FunctionCallKind::Function(name) => {
match name.as_str() {
"sampler2D" => {
//TODO: check args len
ExpressionRule{
expression: fc.args[0].expression,
sampler: Some(fc.args[1].expression),
statements: fc.args.into_iter().map(|a| a.statements).flatten().collect(),
}
}
"texture" => {
//TODO: check args len
if let Some(sampler) = fc.args[0].sampler {
ExpressionRule{
expression: extra.context.expressions.append(Expression::ImageSample {
image: fc.args[0].expression,
sampler,
coordinate: fc.args[1].expression,
level: SampleLevel::Auto,
depth_ref: None,
}),
sampler: None,
statements: fc.args.into_iter().map(|a| a.statements).flatten().collect(),
} }
} else { } else {
return Err(ErrorKind::NotImplemented("Function call")); return Err(ErrorKind::SemanticError("Bad call to texture"));
}
}
_ => { return Err(ErrorKind::NotImplemented("Function call")); }
}
}
} }
} }
function_call_or_method ::= function_call_generic; function_call_or_method ::= function_call_generic;
@ -259,26 +299,22 @@ pomelo! {
} }
function_call_header_no_parameters ::= function_call_header; function_call_header_no_parameters ::= function_call_header;
function_call_header_with_parameters ::= function_call_header(mut h) assignment_expression(ae) { function_call_header_with_parameters ::= function_call_header(mut h) assignment_expression(ae) {
h.args.push(ae.expression); h.args.push(ae);
h.statements.extend(ae.statements);
h h
} }
function_call_header_with_parameters ::= function_call_header_with_parameters(mut h) Comma assignment_expression(ae) { function_call_header_with_parameters ::= function_call_header_with_parameters(mut h) Comma assignment_expression(ae) {
h.args.push(ae.expression); h.args.push(ae);
h.statements.extend(ae.statements);
h h
} }
function_call_header ::= function_identifier(i) LeftParen { function_call_header ::= function_identifier(i) LeftParen {
FunctionCall { FunctionCall {
kind: i, kind: i,
args: vec![], args: vec![],
statements: vec![],
} }
} }
// Grammar Note: Constructors look like functions, but lexical analysis recognized most of them as // Grammar Note: Constructors look like functions, but lexical analysis recognized most of them as
// keywords. They are now recognized through “type_specifier”. // keywords. They are now recognized through “type_specifier”.
// Methods (.length), subroutine array calls, and identifiers are recognized through postfix_expression.
function_identifier ::= type_specifier(t) { function_identifier ::= type_specifier(t) {
if let Some(ty) = t { if let Some(ty) = t {
FunctionCallKind::TypeConstructor(ty) FunctionCallKind::TypeConstructor(ty)
@ -286,10 +322,19 @@ pomelo! {
return Err(ErrorKind::NotImplemented("bad type ctor")) return Err(ErrorKind::NotImplemented("bad type ctor"))
} }
} }
function_identifier ::= postfix_expression(e) {
FunctionCallKind::Function(e.expression) //TODO
// Methods (.length), subroutine array calls, and identifiers are recognized through postfix_expression.
// function_identifier ::= postfix_expression(e) {
// FunctionCallKind::Function(e.expression)
// }
// Simplification of above
function_identifier ::= Identifier(i) {
FunctionCallKind::Function(i.1)
} }
unary_expression ::= postfix_expression; unary_expression ::= postfix_expression;
unary_expression ::= IncOp unary_expression { unary_expression ::= IncOp unary_expression {
@ -311,74 +356,76 @@ pomelo! {
unary_operator ::= Tilde; unary_operator ::= Tilde;
multiplicative_expression ::= unary_expression; multiplicative_expression ::= unary_expression;
multiplicative_expression ::= multiplicative_expression(left) Star unary_expression(right) { multiplicative_expression ::= multiplicative_expression(left) Star unary_expression(right) {
extra.binary_expr(BinaryOperator::Multiply, left, right) extra.binary_expr(BinaryOperator::Multiply, &left, &right)
} }
multiplicative_expression ::= multiplicative_expression(left) Slash unary_expression(right) { multiplicative_expression ::= multiplicative_expression(left) Slash unary_expression(right) {
extra.binary_expr(BinaryOperator::Divide, left, right) extra.binary_expr(BinaryOperator::Divide, &left, &right)
} }
multiplicative_expression ::= multiplicative_expression(left) Percent unary_expression(right) { multiplicative_expression ::= multiplicative_expression(left) Percent unary_expression(right) {
extra.binary_expr(BinaryOperator::Modulo, left, right) extra.binary_expr(BinaryOperator::Modulo, &left, &right)
} }
additive_expression ::= multiplicative_expression; additive_expression ::= multiplicative_expression;
additive_expression ::= additive_expression(left) Plus multiplicative_expression(right) { additive_expression ::= additive_expression(left) Plus multiplicative_expression(right) {
extra.binary_expr(BinaryOperator::Add, left, right) extra.binary_expr(BinaryOperator::Add, &left, &right)
} }
additive_expression ::= additive_expression(left) Dash multiplicative_expression(right) { additive_expression ::= additive_expression(left) Dash multiplicative_expression(right) {
extra.binary_expr(BinaryOperator::Subtract, left, right) extra.binary_expr(BinaryOperator::Subtract, &left, &right)
} }
shift_expression ::= additive_expression; shift_expression ::= additive_expression;
shift_expression ::= shift_expression(left) LeftOp additive_expression(right) { shift_expression ::= shift_expression(left) LeftOp additive_expression(right) {
extra.binary_expr(BinaryOperator::ShiftLeftLogical, left, right) extra.binary_expr(BinaryOperator::ShiftLeft, &left, &right)
} }
shift_expression ::= shift_expression(left) RightOp additive_expression(right) { shift_expression ::= shift_expression(left) RightOp additive_expression(right) {
//TODO: when to use ShiftRightArithmetic extra.binary_expr(BinaryOperator::ShiftRight, &left, &right)
extra.binary_expr(BinaryOperator::ShiftRightLogical, left, right)
} }
relational_expression ::= shift_expression; relational_expression ::= shift_expression;
relational_expression ::= relational_expression(left) LeftAngle shift_expression(right) { relational_expression ::= relational_expression(left) LeftAngle shift_expression(right) {
extra.binary_expr(BinaryOperator::Less, left, right) extra.binary_expr(BinaryOperator::Less, &left, &right)
} }
relational_expression ::= relational_expression(left) RightAngle shift_expression(right) { relational_expression ::= relational_expression(left) RightAngle shift_expression(right) {
extra.binary_expr(BinaryOperator::Greater, left, right) extra.binary_expr(BinaryOperator::Greater, &left, &right)
} }
relational_expression ::= relational_expression(left) LeOp shift_expression(right) { relational_expression ::= relational_expression(left) LeOp shift_expression(right) {
extra.binary_expr(BinaryOperator::LessEqual, left, right) extra.binary_expr(BinaryOperator::LessEqual, &left, &right)
} }
relational_expression ::= relational_expression(left) GeOp shift_expression(right) { relational_expression ::= relational_expression(left) GeOp shift_expression(right) {
extra.binary_expr(BinaryOperator::GreaterEqual, left, right) extra.binary_expr(BinaryOperator::GreaterEqual, &left, &right)
} }
equality_expression ::= relational_expression; equality_expression ::= relational_expression;
equality_expression ::= equality_expression(left) EqOp relational_expression(right) { equality_expression ::= equality_expression(left) EqOp relational_expression(right) {
extra.binary_expr(BinaryOperator::Equal, left, right) extra.binary_expr(BinaryOperator::Equal, &left, &right)
} }
equality_expression ::= equality_expression(left) NeOp relational_expression(right) { equality_expression ::= equality_expression(left) NeOp relational_expression(right) {
extra.binary_expr(BinaryOperator::NotEqual, left, right) extra.binary_expr(BinaryOperator::NotEqual, &left, &right)
} }
and_expression ::= equality_expression; and_expression ::= equality_expression;
and_expression ::= and_expression(left) Ampersand equality_expression(right) { and_expression ::= and_expression(left) Ampersand equality_expression(right) {
extra.binary_expr(BinaryOperator::And, left, right) extra.binary_expr(BinaryOperator::And, &left, &right)
} }
exclusive_or_expression ::= and_expression; exclusive_or_expression ::= and_expression;
exclusive_or_expression ::= exclusive_or_expression(left) Caret and_expression(right) { exclusive_or_expression ::= exclusive_or_expression(left) Caret and_expression(right) {
extra.binary_expr(BinaryOperator::ExclusiveOr, left, right) extra.binary_expr(BinaryOperator::ExclusiveOr, &left, &right)
} }
inclusive_or_expression ::= exclusive_or_expression; inclusive_or_expression ::= exclusive_or_expression;
inclusive_or_expression ::= inclusive_or_expression(left) VerticalBar exclusive_or_expression(right) { inclusive_or_expression ::= inclusive_or_expression(left) VerticalBar exclusive_or_expression(right) {
extra.binary_expr(BinaryOperator::InclusiveOr, left, right) extra.binary_expr(BinaryOperator::InclusiveOr, &left, &right)
} }
logical_and_expression ::= inclusive_or_expression; logical_and_expression ::= inclusive_or_expression;
logical_and_expression ::= logical_and_expression(left) AndOp inclusive_or_expression(right) { logical_and_expression ::= logical_and_expression(left) AndOp inclusive_or_expression(right) {
extra.binary_expr(BinaryOperator::LogicalAnd, left, right) extra.binary_expr(BinaryOperator::LogicalAnd, &left, &right)
} }
logical_xor_expression ::= logical_and_expression; logical_xor_expression ::= logical_and_expression;
logical_xor_expression ::= logical_xor_expression(left) XorOp logical_and_expression(right) { logical_xor_expression ::= logical_xor_expression(left) XorOp logical_and_expression(right) {
return Err(ErrorKind::NotImplemented("logical xor")) let exp1 = extra.binary_expr(BinaryOperator::LogicalOr, &left, &right);
//TODO: naga doesn't have BinaryOperator::LogicalXor let exp2 = {
// extra.context.expressions.append(Expression::Binary{op: BinaryOperator::LogicalXor, left, right}) let tmp = extra.binary_expr(BinaryOperator::LogicalAnd, &left, &right).expression;
ExpressionRule::from_expression(extra.context.expressions.append(Expression::Unary { op: UnaryOperator::Not, expr: tmp }))
};
extra.binary_expr(BinaryOperator::LogicalAnd, &exp1, &exp2)
} }
logical_or_expression ::= logical_xor_expression; logical_or_expression ::= logical_xor_expression;
logical_or_expression ::= logical_or_expression(left) OrOp logical_xor_expression(right) { logical_or_expression ::= logical_or_expression(left) OrOp logical_xor_expression(right) {
extra.binary_expr(BinaryOperator::LogicalOr, left, right) extra.binary_expr(BinaryOperator::LogicalOr, &left, &right)
} }
conditional_expression ::= logical_or_expression; conditional_expression ::= logical_or_expression;
@ -389,17 +436,29 @@ pomelo! {
assignment_expression ::= conditional_expression; assignment_expression ::= conditional_expression;
assignment_expression ::= unary_expression(mut pointer) assignment_operator(op) assignment_expression(value) { assignment_expression ::= unary_expression(mut pointer) assignment_operator(op) assignment_expression(value) {
pointer.statements.extend(value.statements);
match op { match op {
BinaryOperator::Equal => { BinaryOperator::Equal => {
pointer.statements.extend(value.statements);
pointer.statements.push(Statement::Store{ pointer.statements.push(Statement::Store{
pointer: pointer.expression, pointer: pointer.expression,
value: value.expression value: value.expression
}); });
pointer pointer
}, },
//TODO: op != Equal _ => {
_ => {return Err(ErrorKind::NotImplemented("assign op"))} let h = extra.context.expressions.append(
Expression::Binary{
op,
left: pointer.expression,
right: value.expression,
}
);
pointer.statements.push(Statement::Store{
pointer: pointer.expression,
value: h,
});
pointer
}
} }
} }
@ -422,10 +481,10 @@ pomelo! {
BinaryOperator::Subtract BinaryOperator::Subtract
} }
assignment_operator ::= LeftAssign { assignment_operator ::= LeftAssign {
BinaryOperator::ShiftLeftLogical BinaryOperator::ShiftLeft
} }
assignment_operator ::= RightAssign { assignment_operator ::= RightAssign {
BinaryOperator::ShiftRightLogical BinaryOperator::ShiftRight
} }
assignment_operator ::= AndAssign { assignment_operator ::= AndAssign {
BinaryOperator::And BinaryOperator::And
@ -443,6 +502,7 @@ pomelo! {
ExpressionRule { ExpressionRule {
expression: e.expression, expression: e.expression,
statements: ae.statements, statements: ae.statements,
sampler: None,
} }
} }
@ -597,9 +657,7 @@ pomelo! {
// single_type_qualifier ::= invariant_qualifier; // single_type_qualifier ::= invariant_qualifier;
// single_type_qualifier ::= precise_qualifier; // single_type_qualifier ::= precise_qualifier;
storage_qualifier ::= Const { // storage_qualifier ::= Const
StorageClass::Constant
}
// storage_qualifier ::= InOut; // storage_qualifier ::= InOut;
storage_qualifier ::= In { storage_qualifier ::= In {
StorageClass::Input StorageClass::Input
@ -702,18 +760,33 @@ pomelo! {
return Err(ErrorKind::VariableAlreadyDeclared(id)) return Err(ErrorKind::VariableAlreadyDeclared(id))
} }
} }
let mut init_exp: Option<Handle<Expression>> = None;
let localVar = extra.context.local_variables.append( let localVar = extra.context.local_variables.append(
LocalVariable { LocalVariable {
name: Some(id.clone()), name: Some(id.clone()),
ty: d.ty, ty: d.ty,
init: initializer.map(|i| { init: initializer.map(|i| {
statements.extend(i.statements); statements.extend(i.statements);
i.expression if let Expression::Constant(constant) = extra.context.expressions[i.expression] {
}), Some(constant)
} else {
init_exp = Some(i.expression);
None
}
}).flatten(),
} }
); );
let exp = extra.context.expressions.append(Expression::LocalVariable(localVar)); let exp = extra.context.expressions.append(Expression::LocalVariable(localVar));
extra.context.add_local_var(id, exp); extra.context.add_local_var(id, exp);
if let Some(value) = init_exp {
statements.push(
Statement::Store {
pointer: exp,
value,
}
);
}
} }
match statements.len() { match statements.len() {
1 => statements.remove(0), 1 => statements.remove(0),
@ -727,14 +800,138 @@ pomelo! {
} }
statement ::= simple_statement; statement ::= simple_statement;
// Grammar Note: labeled statements for SWITCH only; 'goto' is not supported.
simple_statement ::= declaration_statement; simple_statement ::= declaration_statement;
simple_statement ::= expression_statement; simple_statement ::= expression_statement;
//simple_statement ::= selection_statement; simple_statement ::= selection_statement;
//simple_statement ::= switch_statement;
//simple_statement ::= case_label;
//simple_statement ::= iteration_statement;
simple_statement ::= jump_statement; simple_statement ::= jump_statement;
simple_statement ::= iteration_statement;
selection_statement ::= If LeftParen expression(e) RightParen statement(s1) Else statement(s2) {
Statement::If {
condition: e.expression,
accept: vec![s1],
reject: vec![s2],
}
}
selection_statement ::= If LeftParen expression(e) RightParen statement(s) [Else] {
Statement::If {
condition: e.expression,
accept: vec![s],
reject: vec![],
}
}
selection_statement ::= Switch LeftParen expression(e) RightParen LeftBrace switch_statement_list(ls) RightBrace {
let mut default = Vec::new();
let mut cases = FastHashMap::default();
for (v, s, ft) in ls {
if let Some(v) = v {
cases.insert(v, (s, ft));
} else {
default.extend_from_slice(&s);
}
}
Statement::Switch {
selector: e.expression,
cases,
default,
}
}
switch_statement_list ::= {
vec![]
}
switch_statement_list ::= switch_statement_list(mut ssl) switch_statement((v, sl, ft)) {
ssl.push((v, sl, ft));
ssl
}
switch_statement ::= Case IntConstant(v) Colon statement_list(sl) {
let fallthrough = match sl.last() {
Some(Statement::Break) => None,
_ => Some(FallThrough),
};
(Some(v.1 as i32), sl, fallthrough)
}
switch_statement ::= Default Colon statement_list(sl) {
let fallthrough = match sl.last() {
Some(Statement::Break) => Some(FallThrough),
_ => None,
};
(None, sl, fallthrough)
}
iteration_statement ::= While LeftParen expression(e) RightParen compound_statement_no_new_scope(sl) {
let mut body = Vec::with_capacity(sl.len() + 1);
body.push(
Statement::If {
condition: e.expression,
accept: vec![Statement::Break],
reject: vec![],
}
);
body.extend_from_slice(&sl);
Statement::Loop {
body,
continuing: vec![],
}
}
iteration_statement ::= Do compound_statement(sl) While LeftParen expression(e) RightParen {
let mut body = sl;
body.push(
Statement::If {
condition: e.expression,
accept: vec![Statement::Break],
reject: vec![],
}
);
Statement::Loop {
body,
continuing: vec![],
}
}
iteration_statement ::= For LeftParen for_init_statement(s_init) for_rest_statement((cond_e, loop_e)) RightParen compound_statement_no_new_scope(sl) {
let mut body = Vec::with_capacity(sl.len() + 2);
if let Some(cond_e) = cond_e {
body.push(
Statement::If {
condition: cond_e.expression,
accept: vec![Statement::Break],
reject: vec![],
}
);
}
body.extend_from_slice(&sl);
if let Some(loop_e) = loop_e {
body.extend_from_slice(&loop_e.statements);
}
Statement::Block(vec![
s_init,
Statement::Loop {
body,
continuing: vec![],
}
])
}
for_init_statement ::= expression_statement;
for_init_statement ::= declaration_statement;
for_rest_statement ::= condition_opt(c) Semicolon {
(c, None)
}
for_rest_statement ::= condition_opt(c) Semicolon expression(e) {
(c, Some(e))
}
condition_opt ::= {
None
}
condition_opt ::= conditional_expression(c) {
Some(c)
}
compound_statement ::= LeftBrace RightBrace { compound_statement ::= LeftBrace RightBrace {
vec![] vec![]
@ -810,7 +1007,7 @@ pomelo! {
function_header ::= fully_specified_type(t) Identifier(n) LeftParen { function_header ::= fully_specified_type(t) Identifier(n) LeftParen {
Function { Function {
name: Some(n.1), name: Some(n.1),
parameter_types: vec![], arguments: vec![],
return_type: t.1, return_type: t.1,
global_usage: vec![], global_usage: vec![],
local_variables: Arena::<LocalVariable>::new(), local_variables: Arena::<LocalVariable>::new(),
@ -884,6 +1081,7 @@ pomelo! {
class, class,
binding: binding.clone(), binding: binding.clone(),
ty: d.ty, ty: d.ty,
init: None,
interpolation, interpolation,
storage_access: StorageAccess::empty(), //TODO storage_access: StorageAccess::empty(), //TODO
}, },
@ -894,12 +1092,17 @@ pomelo! {
} }
} }
function_definition ::= function_prototype(mut f) compound_statement_no_new_scope(cs) { function_definition ::= function_prototype(mut f) compound_statement_no_new_scope(mut cs) {
std::mem::swap(&mut f.expressions, &mut extra.context.expressions); std::mem::swap(&mut f.expressions, &mut extra.context.expressions);
std::mem::swap(&mut f.local_variables, &mut extra.context.local_variables); std::mem::swap(&mut f.local_variables, &mut extra.context.local_variables);
extra.context.clear_scopes(); extra.context.clear_scopes();
extra.context.lookup_global_var_exps.clear(); extra.context.lookup_global_var_exps.clear();
extra.context.typifier = Typifier::new(); extra.context.typifier = Typifier::new();
// make sure function ends with return
match cs.last() {
Some(Statement::Return {..}) => {}
_ => {cs.push(Statement::Return { value:None });}
}
f.body = cs; f.body = cs;
f.fill_global_use(&extra.module.global_variables); f.fill_global_use(&extra.module.global_variables);
f f

View file

@ -78,3 +78,105 @@ fn version() {
"(450, Core)" "(450, Core)"
); );
} }
#[test]
fn control_flow() {
let _program = parse_program(
r#"
# version 450
void main() {
if (true) {
return 1;
} else {
return 2;
}
}
"#,
ShaderStage::Vertex,
)
.unwrap();
let _program = parse_program(
r#"
# version 450
void main() {
if (true) {
return 1;
}
}
"#,
ShaderStage::Vertex,
)
.unwrap();
let _program = parse_program(
r#"
# version 450
void main() {
int x;
int y = 3;
switch (5) {
case 2:
x = 2;
case 5:
x = 5;
y = 2;
break;
default:
x = 0;
}
}
"#,
ShaderStage::Vertex,
)
.unwrap();
let _program = parse_program(
r#"
# version 450
void main() {
int x = 0;
while(x < 5) {
x = x + 1;
}
do {
x = x - 1;
} while(x >= 4)
}
"#,
ShaderStage::Vertex,
)
.unwrap();
let _program = parse_program(
r#"
# version 450
void main() {
int x = 0;
for(int i = 0; i < 10;) {
x = x + 2;
}
return x;
}
"#,
ShaderStage::Vertex,
)
.unwrap();
}
#[test]
fn textures() {
let _program = parse_program(
r#"
#version 450
layout(location = 0) in vec2 v_uv;
layout(location = 0) out vec4 o_color;
layout(set = 1, binding = 1) uniform texture2D tex;
layout(set = 1, binding = 2) uniform sampler tex_sampler;
void main() {
o_color = texture(sampler2D(tex, tex_sampler), v_uv);
}
"#,
ShaderStage::Fragment,
)
.unwrap();
}

View file

@ -0,0 +1,152 @@
use crate::FastHashMap;
use thiserror::Error;
#[derive(Clone, Debug, Error)]
#[cfg_attr(test, derive(PartialEq))]
pub enum Error {
#[error("unmatched else")]
UnmatchedElse,
#[error("unmatched endif")]
UnmatchedEndif,
#[error("missing macro name")]
MissingMacro,
}
#[derive(Clone, Debug)]
pub struct IfState {
true_branch: bool,
else_seen: bool,
}
#[derive(Clone, Debug)]
pub struct LinePreProcessor {
pub defines: FastHashMap<String, String>,
if_stack: Vec<IfState>,
inside_comment: bool,
in_preprocess: bool,
}
impl LinePreProcessor {
pub fn new() -> Self {
LinePreProcessor {
defines: FastHashMap::default(),
if_stack: vec![],
inside_comment: false,
in_preprocess: false,
}
}
fn subst_defines(&self, input: &str) -> String {
//TODO: don't subst in commments, strings literals?
self.defines
.iter()
.fold(input.to_string(), |acc, (k, v)| acc.replace(k, v))
}
pub fn process_line(&mut self, line: &str) -> Result<Option<String>, Error> {
let mut skip = !self.if_stack.last().map(|i| i.true_branch).unwrap_or(true);
let mut inside_comment = self.inside_comment;
let mut in_preprocess = inside_comment && self.in_preprocess;
// single-line comment
let mut processed = line;
if let Some(pos) = line.find("//") {
processed = line.split_at(pos).0;
}
// multi-line comment
let mut processed_string: String;
loop {
if inside_comment {
if let Some(pos) = processed.find("*/") {
processed = processed.split_at(pos + 2).1;
inside_comment = false;
self.inside_comment = false;
continue;
}
} else if let Some(pos) = processed.find("/*") {
if let Some(end_pos) = processed[pos + 2..].find("*/") {
// comment ends during this line
processed_string = processed.to_string();
processed_string.replace_range(pos..pos + end_pos + 4, "");
processed = &processed_string;
} else {
processed = processed.split_at(pos).0;
inside_comment = true;
}
continue;
}
break;
}
// strip leading whitespace
processed = processed.trim_start();
if processed.starts_with('#') && !self.inside_comment {
let mut iter = processed[1..]
.trim_start()
.splitn(2, |c: char| c.is_whitespace());
if let Some(directive) = iter.next() {
skip = true;
in_preprocess = true;
match directive {
"version" => {
skip = false;
}
"define" => {
let rest = iter.next().ok_or(Error::MissingMacro)?;
let pos = rest
.find(|c: char| !c.is_ascii_alphanumeric() && c != '_' && c != '(')
.unwrap_or_else(|| rest.len());
let (key, mut value) = rest.split_at(pos);
value = value.trim();
self.defines.insert(key.into(), self.subst_defines(value));
}
"undef" => {
let rest = iter.next().ok_or(Error::MissingMacro)?;
let key = rest.trim();
self.defines.remove(key);
}
"ifdef" => {
let rest = iter.next().ok_or(Error::MissingMacro)?;
let key = rest.trim();
self.if_stack.push(IfState {
true_branch: self.defines.contains_key(key),
else_seen: false,
});
}
"ifndef" => {
let rest = iter.next().ok_or(Error::MissingMacro)?;
let key = rest.trim();
self.if_stack.push(IfState {
true_branch: !self.defines.contains_key(key),
else_seen: false,
});
}
"else" => {
let if_state = self.if_stack.last_mut().ok_or(Error::UnmatchedElse)?;
if !if_state.else_seen {
// this is first else
if_state.true_branch = !if_state.true_branch;
if_state.else_seen = true;
} else {
return Err(Error::UnmatchedElse);
}
}
"endif" => {
self.if_stack.pop().ok_or(Error::UnmatchedEndif)?;
}
_ => {}
}
}
}
let res = if !skip && !self.inside_comment {
Ok(Some(self.subst_defines(&line)))
} else {
Ok(if in_preprocess && !self.in_preprocess {
Some("".to_string())
} else {
None
})
};
self.in_preprocess = in_preprocess || skip;
self.inside_comment = inside_comment;
res
}
}

View file

@ -0,0 +1,218 @@
use super::preprocess::{Error, LinePreProcessor};
use std::{iter::Enumerate, str::Lines};
#[derive(Clone, Debug)]
pub struct PreProcessor<'a> {
lines: Enumerate<Lines<'a>>,
input: String,
line: usize,
offset: usize,
line_pp: LinePreProcessor,
}
impl<'a> PreProcessor<'a> {
pub fn new(input: &'a str) -> Self {
let mut lexer = PreProcessor {
lines: input.lines().enumerate(),
input: "".to_string(),
line: 0,
offset: 0,
line_pp: LinePreProcessor::new(),
};
lexer.next_line();
lexer
}
fn next_line(&mut self) -> bool {
if let Some((line, input)) = self.lines.next() {
let mut input = String::from(input);
while input.ends_with('\\') {
if let Some((_, next)) = self.lines.next() {
input.pop();
input.push_str(next);
} else {
break;
}
}
self.input = input;
self.line = line;
self.offset = 0;
true
} else {
false
}
}
pub fn process(&mut self) -> Result<String, Error> {
let mut res = String::new();
loop {
let line = &self.line_pp.process_line(&self.input)?;
if let Some(line) = line {
res.push_str(line);
}
if !self.next_line() {
break;
}
if line.is_some() {
res.push_str("\n");
}
}
Ok(res)
}
}
#[test]
fn preprocess() {
// line continuation
let mut pp = PreProcessor::new(
"void main my_\
func",
);
assert_eq!(pp.process().unwrap(), "void main my_func");
// preserve #version
let mut pp = PreProcessor::new(
"#version 450 core\n\
void main()",
);
assert_eq!(pp.process().unwrap(), "#version 450 core\nvoid main()");
// simple define
let mut pp = PreProcessor::new(
"#define FOO 42 \n\
fun=FOO",
);
assert_eq!(pp.process().unwrap(), "\nfun=42");
// ifdef with else
let mut pp = PreProcessor::new(
"#define FOO\n\
#ifdef FOO\n\
foo=42\n\
#endif\n\
some=17\n\
#ifdef BAR\n\
bar=88\n\
#else\n\
mm=49\n\
#endif\n\
done=1",
);
assert_eq!(
pp.process().unwrap(),
"\n\
foo=42\n\
\n\
some=17\n\
\n\
mm=49\n\
\n\
done=1"
);
// nested ifdef/ifndef
let mut pp = PreProcessor::new(
"#define FOO\n\
#define BOO\n\
#ifdef FOO\n\
foo=42\n\
#ifdef BOO\n\
boo=44\n\
#endif\n\
ifd=0\n\
#ifndef XYZ\n\
nxyz=8\n\
#endif\n\
#endif\n\
some=17\n\
#ifdef BAR\n\
bar=88\n\
#else\n\
mm=49\n\
#endif\n\
done=1",
);
assert_eq!(
pp.process().unwrap(),
"\n\
foo=42\n\
\n\
boo=44\n\
\n\
ifd=0\n\
\n\
nxyz=8\n\
\n\
some=17\n\
\n\
mm=49\n\
\n\
done=1"
);
// undef
let mut pp = PreProcessor::new(
"#define FOO\n\
#ifdef FOO\n\
foo=42\n\
#endif\n\
some=17\n\
#undef FOO\n\
#ifdef FOO\n\
foo=88\n\
#else\n\
nofoo=66\n\
#endif\n\
done=1",
);
assert_eq!(
pp.process().unwrap(),
"\n\
foo=42\n\
\n\
some=17\n\
\n\
nofoo=66\n\
\n\
done=1"
);
// single-line comment
let mut pp = PreProcessor::new(
"#define FOO 42//1234\n\
fun=FOO",
);
assert_eq!(pp.process().unwrap(), "\nfun=42");
// multi-line comments
let mut pp = PreProcessor::new(
"#define FOO 52/*/1234\n\
#define FOO 88\n\
end of comment*/ /* one more comment */ #define FOO 56\n\
fun=FOO",
);
assert_eq!(pp.process().unwrap(), "\nfun=56");
// unmatched endif
let mut pp = PreProcessor::new(
"#ifdef FOO\n\
foo=42\n\
#endif\n\
#endif",
);
assert_eq!(pp.process(), Err(Error::UnmatchedEndif));
// unmatched else
let mut pp = PreProcessor::new(
"#ifdef FOO\n\
foo=42\n\
#else\n\
bar=88\n\
#else\n\
bad=true\n\
#endif",
);
assert_eq!(pp.process(), Err(Error::UnmatchedElse));
}

View file

@ -37,6 +37,21 @@ pub fn parse_type(type_name: &str) -> Option<Type> {
width: 4, width: 4,
}, },
}), }),
"texture2D" => Some(Type {
name: None,
inner: TypeInner::Image {
dim: crate::ImageDimension::D2,
arrayed: false,
class: crate::ImageClass::Sampled {
kind: ScalarKind::Float,
multi: false,
},
},
}),
"sampler" => Some(Type {
name: None,
inner: TypeInner::Sampler { comparison: false },
}),
word => { word => {
fn kind_width_parse(ty: &str) -> Option<(ScalarKind, u8)> { fn kind_width_parse(ty: &str) -> Option<(ScalarKind, u8)> {
Some(match ty { Some(match ty {

View file

@ -38,6 +38,7 @@ impl Program {
width: 4, width: 4,
}, },
}), }),
init: None,
interpolation: None, interpolation: None,
storage_access: StorageAccess::empty(), storage_access: StorageAccess::empty(),
}); });
@ -72,6 +73,7 @@ impl Program {
width: 4, width: 4,
}, },
}), }),
init: None,
interpolation: None, interpolation: None,
storage_access: StorageAccess::empty(), storage_access: StorageAccess::empty(),
}); });

View file

@ -43,21 +43,6 @@ pub fn map_vector_size(word: spirv::Word) -> Result<crate::VectorSize, Error> {
} }
} }
pub fn map_storage_class(word: spirv::Word) -> Result<crate::StorageClass, Error> {
use spirv::StorageClass as Sc;
match Sc::from_u32(word) {
Some(Sc::UniformConstant) => Ok(crate::StorageClass::Constant),
Some(Sc::Function) => Ok(crate::StorageClass::Function),
Some(Sc::Input) => Ok(crate::StorageClass::Input),
Some(Sc::Output) => Ok(crate::StorageClass::Output),
Some(Sc::Private) => Ok(crate::StorageClass::Private),
Some(Sc::StorageBuffer) => Ok(crate::StorageClass::StorageBuffer),
Some(Sc::Uniform) => Ok(crate::StorageClass::Uniform),
Some(Sc::Workgroup) => Ok(crate::StorageClass::WorkGroup),
_ => Err(Error::UnsupportedStorageClass(word)),
}
}
pub fn map_image_dim(word: spirv::Word) -> Result<crate::ImageDimension, Error> { pub fn map_image_dim(word: spirv::Word) -> Result<crate::ImageDimension, Error> {
use spirv::Dim as D; use spirv::Dim as D;
match D::from_u32(word) { match D::from_u32(word) {

View file

@ -46,9 +46,11 @@ pub enum Error {
InvalidAsType(Handle<crate::Type>), InvalidAsType(Handle<crate::Type>),
InconsistentComparisonSampling(Handle<crate::Type>), InconsistentComparisonSampling(Handle<crate::Type>),
WrongFunctionResultType(spirv::Word), WrongFunctionResultType(spirv::Word),
WrongFunctionParameterType(spirv::Word), WrongFunctionArgumentType(spirv::Word),
MissingDecoration(spirv::Decoration), MissingDecoration(spirv::Decoration),
BadString, BadString,
IncompleteData, IncompleteData,
InvalidTerminator, InvalidTerminator,
InvalidEdgeClassification,
UnexpectedComparisonType(Handle<crate::Type>),
} }

View file

@ -152,6 +152,12 @@ impl FlowGraph {
let (node_source_index, node_target_index) = let (node_source_index, node_target_index) =
self.flow.edge_endpoints(edge_index).unwrap(); self.flow.edge_endpoints(edge_index).unwrap();
if self.flow[node_source_index].ty == Some(ControlFlowNodeType::Header)
|| self.flow[node_source_index].ty == Some(ControlFlowNodeType::Loop)
{
continue;
}
// Back // Back
if self.flow[node_target_index].ty == Some(ControlFlowNodeType::Loop) if self.flow[node_target_index].ty == Some(ControlFlowNodeType::Loop)
&& self.flow[node_source_index].id > self.flow[node_target_index].id && self.flow[node_source_index].id > self.flow[node_target_index].id
@ -219,11 +225,9 @@ impl FlowGraph {
node_index: BlockNodeIndex, node_index: BlockNodeIndex,
stop_node_index: Option<BlockNodeIndex>, stop_node_index: Option<BlockNodeIndex>,
) -> Result<crate::Block, Error> { ) -> Result<crate::Block, Error> {
if let Some(stop_node_index) = stop_node_index { if stop_node_index == Some(node_index) {
if stop_node_index == node_index {
return Ok(vec![]); return Ok(vec![]);
} }
}
let node = &self.flow[node_index]; let node = &self.flow[node_index];
@ -246,7 +250,7 @@ impl FlowGraph {
accept: self.naga_traverse(true_node_index, Some(merge_node_index))?, accept: self.naga_traverse(true_node_index, Some(merge_node_index))?,
reject: self.naga_traverse(false_node_index, Some(merge_node_index))?, reject: self.naga_traverse(false_node_index, Some(merge_node_index))?,
}); });
result.extend(self.naga_traverse(merge_node_index, None)?); result.extend(self.naga_traverse(merge_node_index, stop_node_index)?);
} else { } else {
result.push(crate::Statement::If { result.push(crate::Statement::If {
condition, condition,
@ -254,7 +258,7 @@ impl FlowGraph {
self.block_to_node[&true_id], self.block_to_node[&true_id],
Some(merge_node_index), Some(merge_node_index),
)?, )?,
reject: self.naga_traverse(merge_node_index, None)?, reject: self.naga_traverse(merge_node_index, stop_node_index)?,
}); });
} }
@ -305,7 +309,7 @@ impl FlowGraph {
.naga_traverse(self.block_to_node[&default], Some(merge_node_index))?, .naga_traverse(self.block_to_node[&default], Some(merge_node_index))?,
}); });
result.extend(self.naga_traverse(merge_node_index, None)?); result.extend(self.naga_traverse(merge_node_index, stop_node_index)?);
Ok(result) Ok(result)
} }
@ -323,7 +327,7 @@ impl FlowGraph {
self.flow[continue_edge.target()].block.clone() self.flow[continue_edge.target()].block.clone()
}; };
let mut body: crate::Block = node.block.clone(); let mut body = node.block.clone();
match node.terminator { match node.terminator {
Terminator::BranchConditional { Terminator::BranchConditional {
condition, condition,
@ -342,7 +346,10 @@ impl FlowGraph {
_ => return Err(Error::InvalidTerminator), _ => return Err(Error::InvalidTerminator),
}; };
Ok(vec![crate::Statement::Loop { body, continuing }]) let mut result = vec![crate::Statement::Loop { body, continuing }];
result.extend(self.naga_traverse(merge_node_index, stop_node_index)?);
Ok(result)
} }
Some(ControlFlowNodeType::Break) => { Some(ControlFlowNodeType::Break) => {
let mut result = node.block.clone(); let mut result = node.block.clone();
@ -351,25 +358,52 @@ impl FlowGraph {
condition, condition,
true_id, true_id,
false_id, false_id,
} => result.push(crate::Statement::If { } => {
let true_node_id = self.block_to_node[&true_id];
let false_node_id = self.block_to_node[&false_id];
let true_edge =
self.flow[self.flow.find_edge(node_index, true_node_id).unwrap()];
let false_edge =
self.flow[self.flow.find_edge(node_index, false_node_id).unwrap()];
if true_edge == ControlFlowEdgeType::LoopBreak {
result.push(crate::Statement::If {
condition, condition,
accept: self accept: vec![crate::Statement::Break],
.naga_traverse(self.block_to_node[&true_id], stop_node_index)?, reject: self.naga_traverse(false_node_id, stop_node_index)?,
reject: self });
.naga_traverse(self.block_to_node[&false_id], stop_node_index)?, } else if false_edge == ControlFlowEdgeType::LoopBreak {
}), result.push(crate::Statement::If {
condition,
accept: self.naga_traverse(true_node_id, stop_node_index)?,
reject: vec![crate::Statement::Break],
});
} else {
return Err(Error::InvalidEdgeClassification);
}
}
Terminator::Branch { .. } => {
result.push(crate::Statement::Break);
}
_ => return Err(Error::InvalidTerminator), _ => return Err(Error::InvalidTerminator),
}; };
Ok(result) Ok(result)
} }
Some(ControlFlowNodeType::Continue) => { Some(ControlFlowNodeType::Continue) => {
let back_block = match node.terminator {
Terminator::Branch { target_id } => {
self.naga_traverse(self.block_to_node[&target_id], None)?
}
_ => return Err(Error::InvalidTerminator),
};
let mut result = node.block.clone(); let mut result = node.block.clone();
result.extend(back_block);
result.push(crate::Statement::Continue); result.push(crate::Statement::Continue);
Ok(result) Ok(result)
} }
Some(ControlFlowNodeType::Back) | Some(ControlFlowNodeType::Merge) => { Some(ControlFlowNodeType::Back) => Ok(node.block.clone()),
Ok(node.block.clone())
}
Some(ControlFlowNodeType::Kill) => { Some(ControlFlowNodeType::Kill) => {
let mut result = node.block.clone(); let mut result = node.block.clone();
result.push(crate::Statement::Kill); result.push(crate::Statement::Kill);
@ -384,7 +418,7 @@ impl FlowGraph {
result.push(crate::Statement::Return { value }); result.push(crate::Statement::Return { value });
Ok(result) Ok(result)
} }
None => match node.terminator { Some(ControlFlowNodeType::Merge) | None => match node.terminator {
Terminator::Branch { target_id } => { Terminator::Branch { target_id } => {
let mut result = node.block.clone(); let mut result = node.block.clone();
result.extend( result.extend(
@ -401,7 +435,7 @@ impl FlowGraph {
pub(super) fn to_graphviz(&self) -> Result<String, std::fmt::Error> { pub(super) fn to_graphviz(&self) -> Result<String, std::fmt::Error> {
let mut output = String::new(); let mut output = String::new();
output += "digraph ControlFlowGraph {"; output += "digraph ControlFlowGraph {\n";
for node_index in self.flow.node_indices() { for node_index in self.flow.node_indices() {
let node = &self.flow[node_index]; let node = &self.flow[node_index];
@ -419,10 +453,16 @@ impl FlowGraph {
let target = edge.target(); let target = edge.target();
let style = match edge.weight { let style = match edge.weight {
ControlFlowEdgeType::Forward => "",
ControlFlowEdgeType::ForwardMerge => "style=dotted",
ControlFlowEdgeType::ForwardContinue => "color=green",
ControlFlowEdgeType::Back => "style=dashed",
ControlFlowEdgeType::LoopBreak => "color=yellow",
ControlFlowEdgeType::LoopContinue => "color=green",
ControlFlowEdgeType::IfTrue => "color=blue", ControlFlowEdgeType::IfTrue => "color=blue",
ControlFlowEdgeType::IfFalse => "color=red", ControlFlowEdgeType::IfFalse => "color=red",
ControlFlowEdgeType::ForwardMerge => "style=dotted", ControlFlowEdgeType::SwitchBreak => "color=yellow",
_ => "", ControlFlowEdgeType::CaseFallThrough => "style=dotted",
}; };
writeln!( writeln!(

View file

@ -69,7 +69,7 @@ impl<I: Iterator<Item = u32>> super::Parser<I> {
} }
crate::Function { crate::Function {
name: self.future_decor.remove(&fun_id).and_then(|dec| dec.name), name: self.future_decor.remove(&fun_id).and_then(|dec| dec.name),
parameter_types: Vec::with_capacity(ft.parameter_type_ids.len()), arguments: Vec::with_capacity(ft.parameter_type_ids.len()),
return_type: if self.lookup_void_type.contains(&result_type) { return_type: if self.lookup_void_type.contains(&result_type) {
None None
} else { } else {
@ -83,7 +83,7 @@ impl<I: Iterator<Item = u32>> super::Parser<I> {
}; };
// read parameters // read parameters
for i in 0..fun.parameter_types.capacity() { for i in 0..fun.arguments.capacity() {
match self.next_inst()? { match self.next_inst()? {
Instruction { Instruction {
op: spirv::Op::FunctionParameter, op: spirv::Op::FunctionParameter,
@ -93,7 +93,7 @@ impl<I: Iterator<Item = u32>> super::Parser<I> {
let id = self.next()?; let id = self.next()?;
let handle = fun let handle = fun
.expressions .expressions
.append(crate::Expression::FunctionParameter(i as u32)); .append(crate::Expression::FunctionArgument(i as u32));
self.lookup_expression self.lookup_expression
.insert(id, LookupExpression { type_id, handle }); .insert(id, LookupExpression { type_id, handle });
//Note: we redo the lookup in order to work around `self` borrowing //Note: we redo the lookup in order to work around `self` borrowing
@ -104,10 +104,11 @@ impl<I: Iterator<Item = u32>> super::Parser<I> {
.lookup(fun_type)? .lookup(fun_type)?
.parameter_type_ids[i] .parameter_type_ids[i]
{ {
return Err(Error::WrongFunctionParameterType(type_id)); return Err(Error::WrongFunctionArgumentType(type_id));
} }
let ty = self.lookup_type.lookup(type_id)?.handle; let ty = self.lookup_type.lookup(type_id)?.handle;
fun.parameter_types.push(ty); fun.arguments
.push(crate::FunctionArgument { name: None, ty });
} }
Instruction { op, .. } => return Err(Error::InvalidParameter(op)), Instruction { op, .. } => return Err(Error::InvalidParameter(op)),
} }
@ -175,6 +176,17 @@ impl<I: Iterator<Item = u32>> super::Parser<I> {
} }
}; };
if let Some(ref prefix) = self.options.flow_graph_dump_prefix {
let dump = flow_graph.to_graphviz().unwrap_or_default();
let suffix = match source {
DeferredSource::EntryPoint(stage, ref name) => {
format!("flow.{:?}-{}.dot", stage, name)
}
DeferredSource::Function(handle) => format!("flow.Fun-{}.dot", handle.index()),
};
let _ = std::fs::write(prefix.join(suffix), dump);
}
for (expr_handle, dst_id) in local_function_calls { for (expr_handle, dst_id) in local_function_calls {
self.deferred_function_calls.push(DeferredFunctionCall { self.deferred_function_calls.push(DeferredFunctionCall {
source: source.clone(), source: source.clone(),

View file

@ -29,7 +29,7 @@ use crate::{
}; };
use num_traits::cast::FromPrimitive; use num_traits::cast::FromPrimitive;
use std::{convert::TryInto, num::NonZeroU32}; use std::{convert::TryInto, num::NonZeroU32, path::PathBuf};
pub const SUPPORTED_CAPABILITIES: &[spirv::Capability] = &[ pub const SUPPORTED_CAPABILITIES: &[spirv::Capability] = &[
spirv::Capability::Shader, spirv::Capability::Shader,
@ -304,6 +304,11 @@ pub struct Assignment {
value: Handle<crate::Expression>, value: Handle<crate::Expression>,
} }
#[derive(Clone, Debug, Default)]
pub struct Options {
pub flow_graph_dump_prefix: Option<PathBuf>,
}
pub struct Parser<I> { pub struct Parser<I> {
data: I, data: I,
state: ModuleState, state: ModuleState,
@ -325,10 +330,11 @@ pub struct Parser<I> {
lookup_function: FastHashMap<spirv::Word, Handle<crate::Function>>, lookup_function: FastHashMap<spirv::Word, Handle<crate::Function>>,
lookup_entry_point: FastHashMap<spirv::Word, EntryPoint>, lookup_entry_point: FastHashMap<spirv::Word, EntryPoint>,
deferred_function_calls: Vec<DeferredFunctionCall>, deferred_function_calls: Vec<DeferredFunctionCall>,
options: Options,
} }
impl<I: Iterator<Item = u32>> Parser<I> { impl<I: Iterator<Item = u32>> Parser<I> {
pub fn new(data: I) -> Self { pub fn new(data: I, options: &Options) -> Self {
Parser { Parser {
data, data,
state: ModuleState::Empty, state: ModuleState::Empty,
@ -349,6 +355,7 @@ impl<I: Iterator<Item = u32>> Parser<I> {
lookup_function: FastHashMap::default(), lookup_function: FastHashMap::default(),
lookup_entry_point: FastHashMap::default(), lookup_entry_point: FastHashMap::default(),
deferred_function_calls: Vec::new(), deferred_function_calls: Vec::new(),
options: options.clone(),
} }
} }
@ -547,8 +554,8 @@ impl<I: Iterator<Item = u32>> Parser<I> {
let init = if inst.wc > 4 { let init = if inst.wc > 4 {
inst.expect(5)?; inst.expect(5)?;
let init_id = self.next()?; let init_id = self.next()?;
let lexp = self.lookup_expression.lookup(init_id)?; let lconst = self.lookup_constant.lookup(init_id)?;
Some(lexp.handle) Some(lconst.handle)
} else { } else {
None None
}; };
@ -852,6 +859,38 @@ impl<I: Iterator<Item = u32>> Parser<I> {
}, },
); );
} }
// Bitwise instructions
Op::Not => {
inst.expect(4)?;
self.parse_expr_unary_op(expressions, crate::UnaryOperator::Not)?;
}
Op::BitwiseOr => {
inst.expect(5)?;
self.parse_expr_binary_op(expressions, crate::BinaryOperator::InclusiveOr)?;
}
Op::BitwiseXor => {
inst.expect(5)?;
self.parse_expr_binary_op(expressions, crate::BinaryOperator::ExclusiveOr)?;
}
Op::BitwiseAnd => {
inst.expect(5)?;
self.parse_expr_binary_op(expressions, crate::BinaryOperator::And)?;
}
Op::ShiftRightLogical => {
inst.expect(5)?;
//TODO: convert input and result to usigned
self.parse_expr_binary_op(expressions, crate::BinaryOperator::ShiftRight)?;
}
Op::ShiftRightArithmetic => {
inst.expect(5)?;
//TODO: convert input and result to signed
self.parse_expr_binary_op(expressions, crate::BinaryOperator::ShiftRight)?;
}
Op::ShiftLeftLogical => {
inst.expect(5)?;
self.parse_expr_binary_op(expressions, crate::BinaryOperator::ShiftLeft)?;
}
// Sampling
Op::SampledImage => { Op::SampledImage => {
inst.expect(5)?; inst.expect(5)?;
let _result_type_id = self.next()?; let _result_type_id = self.next()?;
@ -1028,6 +1067,31 @@ impl<I: Iterator<Item = u32>> Parser<I> {
}, },
); );
} }
Op::Select => {
inst.expect(6)?;
let result_type_id = self.next()?;
let result_id = self.next()?;
let condition = self.next()?;
let o1_id = self.next()?;
let o2_id = self.next()?;
let cond_lexp = self.lookup_expression.lookup(condition)?;
let o1_lexp = self.lookup_expression.lookup(o1_id)?;
let o2_lexp = self.lookup_expression.lookup(o2_id)?;
let expr = crate::Expression::Select {
condition: cond_lexp.handle,
accept: o1_lexp.handle,
reject: o2_lexp.handle,
};
self.lookup_expression.insert(
result_id,
LookupExpression {
handle: expressions.append(expr),
type_id: result_type_id,
},
);
}
Op::VectorShuffle => { Op::VectorShuffle => {
inst.expect_at_least(5)?; inst.expect_at_least(5)?;
let result_type_id = self.next()?; let result_type_id = self.next()?;
@ -1092,11 +1156,10 @@ impl<I: Iterator<Item = u32>> Parser<I> {
let value_lexp = self.lookup_expression.lookup(value_id)?; let value_lexp = self.lookup_expression.lookup(value_id)?;
let ty_lookup = self.lookup_type.lookup(result_type_id)?; let ty_lookup = self.lookup_type.lookup(result_type_id)?;
let kind = match type_arena[ty_lookup.handle].inner { let kind = type_arena[ty_lookup.handle]
crate::TypeInner::Scalar { kind, .. } .inner
| crate::TypeInner::Vector { kind, .. } => kind, .scalar_kind()
_ => return Err(Error::InvalidAsType(ty_lookup.handle)), .ok_or(Error::InvalidAsType(ty_lookup.handle))?;
};
let expr = crate::Expression::As { let expr = crate::Expression::As {
expr: value_lexp.handle, expr: value_lexp.handle,
@ -1220,6 +1283,10 @@ impl<I: Iterator<Item = u32>> Parser<I> {
inst.expect(base_wc + 1)?; inst.expect(base_wc + 1)?;
"length" "length"
} }
Some(spirv::GLOp::Distance) => {
inst.expect(base_wc + 2)?;
"distance"
}
Some(spirv::GLOp::Cross) => { Some(spirv::GLOp::Cross) => {
inst.expect(base_wc + 2)?; inst.expect(base_wc + 2)?;
"cross" "cross"
@ -1477,7 +1544,9 @@ impl<I: Iterator<Item = u32>> Parser<I> {
}; };
*comparison = true; *comparison = true;
} }
_ => panic!("Unexpected comparison type {:?}", ty), _ => {
return Err(Error::UnexpectedComparisonType(handle));
}
} }
} }
@ -1906,12 +1975,13 @@ impl<I: Iterator<Item = u32>> Parser<I> {
inst.expect(4)?; inst.expect(4)?;
let id = self.next()?; let id = self.next()?;
let type_id = self.next()?; let type_id = self.next()?;
let length = self.next()?; let length_id = self.next()?;
let length_const = self.lookup_constant.lookup(length_id)?;
let decor = self.future_decor.remove(&id); let decor = self.future_decor.remove(&id);
let inner = crate::TypeInner::Array { let inner = crate::TypeInner::Array {
base: self.lookup_type.lookup(type_id)?.handle, base: self.lookup_type.lookup(type_id)?.handle,
size: crate::ArraySize::Static(length), size: crate::ArraySize::Constant(length_const.handle),
stride: decor.as_ref().and_then(|dec| dec.array_stride), stride: decor.as_ref().and_then(|dec| dec.array_stride),
}; };
self.lookup_type.insert( self.lookup_type.insert(
@ -2030,10 +2100,10 @@ impl<I: Iterator<Item = u32>> Parser<I> {
let format = self.next()?; let format = self.next()?;
let base_handle = self.lookup_type.lookup(sample_type_id)?.handle; let base_handle = self.lookup_type.lookup(sample_type_id)?.handle;
let kind = match module.types[base_handle].inner { let kind = module.types[base_handle]
crate::TypeInner::Scalar { kind, .. } | crate::TypeInner::Vector { kind, .. } => kind, .inner
_ => return Err(Error::InvalidImageBaseType(base_handle)), .scalar_kind()
}; .ok_or(Error::InvalidImageBaseType(base_handle))?;
let class = if format != 0 { let class = if format != 0 {
crate::ImageClass::Storage(map_image_format(format)?) crate::ImageClass::Storage(map_image_format(format)?)
@ -2231,28 +2301,52 @@ impl<I: Iterator<Item = u32>> Parser<I> {
let type_id = self.next()?; let type_id = self.next()?;
let id = self.next()?; let id = self.next()?;
let storage_class = self.next()?; let storage_class = self.next()?;
if inst.wc != 4 { let init = if inst.wc > 4 {
inst.expect(5)?; inst.expect(5)?;
let _init = self.next()?; //TODO let init_id = self.next()?;
} let lconst = self.lookup_constant.lookup(init_id)?;
Some(lconst.handle)
} else {
None
};
let lookup_type = self.lookup_type.lookup(type_id)?; let lookup_type = self.lookup_type.lookup(type_id)?;
let dec = self let dec = self
.future_decor .future_decor
.remove(&id) .remove(&id)
.ok_or(Error::InvalidBinding(id))?; .ok_or(Error::InvalidBinding(id))?;
let class = map_storage_class(storage_class)?;
let class = {
use spirv::StorageClass as Sc;
match Sc::from_u32(storage_class) {
Some(Sc::Function) => crate::StorageClass::Function,
Some(Sc::Input) => crate::StorageClass::Input,
Some(Sc::Output) => crate::StorageClass::Output,
Some(Sc::Private) => crate::StorageClass::Private,
Some(Sc::UniformConstant) => crate::StorageClass::Handle,
Some(Sc::StorageBuffer) => crate::StorageClass::Storage,
Some(Sc::Uniform) => {
if self
.lookup_storage_buffer_types
.contains(&lookup_type.handle)
{
crate::StorageClass::Storage
} else {
crate::StorageClass::Uniform
}
}
Some(Sc::Workgroup) => crate::StorageClass::WorkGroup,
Some(Sc::PushConstant) => crate::StorageClass::PushConstant,
_ => return Err(Error::UnsupportedStorageClass(storage_class)),
}
};
let binding = match (class, &module.types[lookup_type.handle].inner) { let binding = match (class, &module.types[lookup_type.handle].inner) {
(crate::StorageClass::Input, &crate::TypeInner::Struct { .. }) (crate::StorageClass::Input, &crate::TypeInner::Struct { .. })
| (crate::StorageClass::Output, &crate::TypeInner::Struct { .. }) => None, | (crate::StorageClass::Output, &crate::TypeInner::Struct { .. }) => None,
_ => Some(dec.get_binding().ok_or(Error::InvalidBinding(id))?), _ => Some(dec.get_binding().ok_or(Error::InvalidBinding(id))?),
}; };
let is_storage = match module.types[lookup_type.handle].inner { let is_storage = match module.types[lookup_type.handle].inner {
crate::TypeInner::Struct { .. } => match class { crate::TypeInner::Struct { .. } => class == crate::StorageClass::Storage,
crate::StorageClass::StorageBuffer => true,
_ => self
.lookup_storage_buffer_types
.contains(&lookup_type.handle),
},
crate::TypeInner::Image { crate::TypeInner::Image {
class: crate::ImageClass::Storage(_), class: crate::ImageClass::Storage(_),
.. ..
@ -2278,6 +2372,7 @@ impl<I: Iterator<Item = u32>> Parser<I> {
class, class,
binding, binding,
ty: lookup_type.handle, ty: lookup_type.handle,
init,
interpolation: dec.interpolation, interpolation: dec.interpolation,
storage_access, storage_access,
}; };
@ -2292,7 +2387,7 @@ impl<I: Iterator<Item = u32>> Parser<I> {
} }
} }
pub fn parse_u8_slice(data: &[u8]) -> Result<crate::Module, Error> { pub fn parse_u8_slice(data: &[u8], options: &Options) -> Result<crate::Module, Error> {
if data.len() % 4 != 0 { if data.len() % 4 != 0 {
return Err(Error::IncompleteData); return Err(Error::IncompleteData);
} }
@ -2300,7 +2395,7 @@ pub fn parse_u8_slice(data: &[u8]) -> Result<crate::Module, Error> {
let words = data let words = data
.chunks(4) .chunks(4)
.map(|c| u32::from_le_bytes(c.try_into().unwrap())); .map(|c| u32::from_le_bytes(c.try_into().unwrap()));
Parser::new(words).parse() Parser::new(words, options).parse()
} }
#[cfg(test)] #[cfg(test)]
@ -2316,6 +2411,6 @@ mod test {
0x0e, 0x00, 0x03, 0x00, 0x00, 0x00, 0x00, 0x00, // GLSL450. 0x0e, 0x00, 0x03, 0x00, 0x00, 0x00, 0x00, 0x00, // GLSL450.
0x01, 0x00, 0x00, 0x00, 0x01, 0x00, 0x00, 0x00,
]; ];
let _ = super::parse_u8_slice(&bin).unwrap(); let _ = super::parse_u8_slice(&bin, &Default::default()).unwrap();
} }
} }

View file

@ -9,7 +9,7 @@ fn rosetta_test(file_name: &str) {
let file_path = Path::new(TEST_PATH).join(file_name); let file_path = Path::new(TEST_PATH).join(file_name);
let input = fs::read(&file_path).unwrap(); let input = fs::read(&file_path).unwrap();
let module = super::parse_u8_slice(&input).unwrap(); let module = super::parse_u8_slice(&input, &Default::default()).unwrap();
let output = ron::ser::to_string_pretty(&module, Default::default()).unwrap(); let output = ron::ser::to_string_pretty(&module, Default::default()).unwrap();
let expected = fs::read_to_string(file_path.with_extension("expected.ron")).unwrap(); let expected = fs::read_to_string(file_path.with_extension("expected.ron")).unwrap();

View file

@ -4,8 +4,9 @@ pub fn map_storage_class(word: &str) -> Result<crate::StorageClass, Error<'_>> {
match word { match word {
"in" => Ok(crate::StorageClass::Input), "in" => Ok(crate::StorageClass::Input),
"out" => Ok(crate::StorageClass::Output), "out" => Ok(crate::StorageClass::Output),
"private" => Ok(crate::StorageClass::Private),
"uniform" => Ok(crate::StorageClass::Uniform), "uniform" => Ok(crate::StorageClass::Uniform),
"storage_buffer" => Ok(crate::StorageClass::StorageBuffer), "storage" => Ok(crate::StorageClass::Storage),
_ => Err(Error::UnknownStorageClass(word)), _ => Err(Error::UnknownStorageClass(word)),
} }
} }

View file

@ -69,12 +69,7 @@ fn consume_token(mut input: &str) -> (Token<'_>, &str) {
if next == Some('=') { if next == Some('=') {
(Token::LogicalOperation(cur), chars.as_str()) (Token::LogicalOperation(cur), chars.as_str())
} else if next == Some(cur) { } else if next == Some(cur) {
input = chars.as_str(); (Token::ShiftOperation(cur), chars.as_str())
if chars.next() == Some(cur) {
(Token::ArithmeticShiftOperation(cur), chars.as_str())
} else {
(Token::ShiftOperation(cur), input)
}
} else { } else {
(Token::Paren(cur), input) (Token::Paren(cur), input)
} }

View file

@ -26,7 +26,6 @@ pub enum Token<'a> {
Operation(char), Operation(char),
LogicalOperation(char), LogicalOperation(char),
ShiftOperation(char), ShiftOperation(char),
ArithmeticShiftOperation(char),
Arrow, Arrow,
Unknown(char), Unknown(char),
UnterminatedString, UnterminatedString,
@ -37,8 +36,6 @@ pub enum Token<'a> {
pub enum Error<'a> { pub enum Error<'a> {
#[error("unexpected token: {0:?}")] #[error("unexpected token: {0:?}")]
Unexpected(Token<'a>), Unexpected(Token<'a>),
#[error("constant {0:?} doesn't match its type {1:?}")]
UnexpectedConstantType(crate::ConstantInner, Handle<crate::Type>),
#[error("unable to parse `{0}` as integer: {1}")] #[error("unable to parse `{0}` as integer: {1}")]
BadInteger(&'a str, std::num::ParseIntError), BadInteger(&'a str, std::num::ParseIntError),
#[error("unable to parse `{1}` as float: {1}")] #[error("unable to parse `{1}` as float: {1}")]
@ -100,7 +97,7 @@ struct StatementContext<'input, 'temp, 'out> {
types: &'out mut Arena<crate::Type>, types: &'out mut Arena<crate::Type>,
constants: &'out mut Arena<crate::Constant>, constants: &'out mut Arena<crate::Constant>,
global_vars: &'out Arena<crate::GlobalVariable>, global_vars: &'out Arena<crate::GlobalVariable>,
parameter_types: &'out [Handle<crate::Type>], arguments: &'out [crate::FunctionArgument],
} }
impl<'a> StatementContext<'a, '_, '_> { impl<'a> StatementContext<'a, '_, '_> {
@ -113,7 +110,7 @@ impl<'a> StatementContext<'a, '_, '_> {
types: self.types, types: self.types,
constants: self.constants, constants: self.constants,
global_vars: self.global_vars, global_vars: self.global_vars,
parameter_types: self.parameter_types, arguments: self.arguments,
} }
} }
@ -126,7 +123,7 @@ impl<'a> StatementContext<'a, '_, '_> {
constants: self.constants, constants: self.constants,
global_vars: self.global_vars, global_vars: self.global_vars,
local_vars: self.variables, local_vars: self.variables,
parameter_types: self.parameter_types, arguments: self.arguments,
} }
} }
} }
@ -139,7 +136,7 @@ struct ExpressionContext<'input, 'temp, 'out> {
constants: &'out mut Arena<crate::Constant>, constants: &'out mut Arena<crate::Constant>,
global_vars: &'out Arena<crate::GlobalVariable>, global_vars: &'out Arena<crate::GlobalVariable>,
local_vars: &'out Arena<crate::LocalVariable>, local_vars: &'out Arena<crate::LocalVariable>,
parameter_types: &'out [Handle<crate::Type>], arguments: &'out [crate::FunctionArgument],
} }
impl<'a> ExpressionContext<'a, '_, '_> { impl<'a> ExpressionContext<'a, '_, '_> {
@ -152,7 +149,7 @@ impl<'a> ExpressionContext<'a, '_, '_> {
constants: self.constants, constants: self.constants,
global_vars: self.global_vars, global_vars: self.global_vars,
local_vars: self.local_vars, local_vars: self.local_vars,
parameter_types: self.parameter_types, arguments: self.arguments,
} }
} }
@ -166,7 +163,7 @@ impl<'a> ExpressionContext<'a, '_, '_> {
global_vars: self.global_vars, global_vars: self.global_vars,
local_vars: self.local_vars, local_vars: self.local_vars,
functions: &functions, functions: &functions,
parameter_types: self.parameter_types, arguments: self.arguments,
}; };
match self match self
.typifier .typifier
@ -265,6 +262,7 @@ struct ParsedVariable<'a> {
class: Option<crate::StorageClass>, class: Option<crate::StorageClass>,
ty: Handle<crate::Type>, ty: Handle<crate::Type>,
access: crate::StorageAccess, access: crate::StorageAccess,
init: Option<Handle<crate::Constant>>,
} }
#[derive(Clone, Debug, Error)] #[derive(Clone, Debug, Error)]
@ -375,9 +373,10 @@ impl Parser {
fn parse_const_expression<'a>( fn parse_const_expression<'a>(
&mut self, &mut self,
lexer: &mut Lexer<'a>, lexer: &mut Lexer<'a>,
self_ty: Handle<crate::Type>,
type_arena: &mut Arena<crate::Type>, type_arena: &mut Arena<crate::Type>,
const_arena: &mut Arena<crate::Constant>, const_arena: &mut Arena<crate::Constant>,
) -> Result<crate::ConstantInner, Error<'a>> { ) -> Result<Handle<crate::Constant>, Error<'a>> {
self.scopes.push(Scope::ConstantExpr); self.scopes.push(Scope::ConstantExpr);
let inner = match lexer.peek() { let inner = match lexer.peek() {
Token::Word("true") => { Token::Word("true") => {
@ -394,7 +393,7 @@ impl Parser {
inner inner
} }
_ => { _ => {
let composite_ty = self.parse_type_decl(lexer, None, type_arena)?; let composite_ty = self.parse_type_decl(lexer, None, type_arena, const_arena)?;
lexer.expect(Token::Paren('('))?; lexer.expect(Token::Paren('('))?;
let mut components = Vec::new(); let mut components = Vec::new();
while !lexer.skip(Token::Paren(')')) { while !lexer.skip(Token::Paren(')')) {
@ -406,19 +405,21 @@ impl Parser {
composite_ty, composite_ty,
components.len(), components.len(),
)?; )?;
let inner = self.parse_const_expression(lexer, type_arena, const_arena)?; let component =
components.push(const_arena.fetch_or_append(crate::Constant { self.parse_const_expression(lexer, ty, type_arena, const_arena)?;
name: None, components.push(component);
specialization: None,
inner,
ty,
}));
} }
crate::ConstantInner::Composite(components) crate::ConstantInner::Composite(components)
} }
}; };
let handle = const_arena.fetch_or_append(crate::Constant {
name: None,
specialization: None,
inner,
ty: self_ty,
});
self.scopes.pop(); self.scopes.pop();
Ok(inner) Ok(handle)
} }
fn parse_primary_expression<'a>( fn parse_primary_expression<'a>(
@ -490,7 +491,7 @@ impl Parser {
expr expr
} else { } else {
*lexer = backup; *lexer = backup;
let ty = self.parse_type_decl(lexer, None, ctx.types)?; let ty = self.parse_type_decl(lexer, None, ctx.types, ctx.constants)?;
lexer.expect(Token::Paren('('))?; lexer.expect(Token::Paren('('))?;
let mut components = Vec::new(); let mut components = Vec::new();
while !lexer.skip(Token::Paren(')')) { while !lexer.skip(Token::Paren(')')) {
@ -790,13 +791,10 @@ impl Parser {
lexer, lexer,
|token| match token { |token| match token {
Token::ShiftOperation('<') => { Token::ShiftOperation('<') => {
Some(crate::BinaryOperator::ShiftLeftLogical) Some(crate::BinaryOperator::ShiftLeft)
} }
Token::ShiftOperation('>') => { Token::ShiftOperation('>') => {
Some(crate::BinaryOperator::ShiftRightLogical) Some(crate::BinaryOperator::ShiftRight)
}
Token::ArithmeticShiftOperation('>') => {
Some(crate::BinaryOperator::ShiftRightArithmetic)
} }
_ => None, _ => None,
}, },
@ -910,10 +908,11 @@ impl Parser {
&mut self, &mut self,
lexer: &mut Lexer<'a>, lexer: &mut Lexer<'a>,
type_arena: &mut Arena<crate::Type>, type_arena: &mut Arena<crate::Type>,
const_arena: &mut Arena<crate::Constant>,
) -> Result<(&'a str, Handle<crate::Type>), Error<'a>> { ) -> Result<(&'a str, Handle<crate::Type>), Error<'a>> {
let name = lexer.next_ident()?; let name = lexer.next_ident()?;
lexer.expect(Token::Separator(':'))?; lexer.expect(Token::Separator(':'))?;
let ty = self.parse_type_decl(lexer, None, type_arena)?; let ty = self.parse_type_decl(lexer, None, type_arena, const_arena)?;
Ok((name, ty)) Ok((name, ty))
} }
@ -932,16 +931,27 @@ impl Parser {
} }
let name = lexer.next_ident()?; let name = lexer.next_ident()?;
lexer.expect(Token::Separator(':'))?; lexer.expect(Token::Separator(':'))?;
let ty = self.parse_type_decl(lexer, None, type_arena)?; let ty = self.parse_type_decl(lexer, None, type_arena, const_arena)?;
let access = match class { let access = match class {
Some(crate::StorageClass::StorageBuffer) => crate::StorageAccess::all(), Some(crate::StorageClass::Storage) => crate::StorageAccess::all(),
Some(crate::StorageClass::Constant) => crate::StorageAccess::LOAD, Some(crate::StorageClass::Handle) => {
match type_arena[ty].inner {
//TODO: RW textures
crate::TypeInner::Image {
class: crate::ImageClass::Storage(_),
..
} => crate::StorageAccess::LOAD,
_ => crate::StorageAccess::empty(),
}
}
_ => crate::StorageAccess::empty(), _ => crate::StorageAccess::empty(),
}; };
if lexer.skip(Token::Operation('=')) { let init = if lexer.skip(Token::Operation('=')) {
let _inner = self.parse_const_expression(lexer, type_arena, const_arena)?; let handle = self.parse_const_expression(lexer, ty, type_arena, const_arena)?;
//TODO Some(handle)
} } else {
None
};
lexer.expect(Token::Separator(';'))?; lexer.expect(Token::Separator(';'))?;
self.scopes.pop(); self.scopes.pop();
Ok(ParsedVariable { Ok(ParsedVariable {
@ -949,6 +959,7 @@ impl Parser {
class, class,
ty, ty,
access, access,
init,
}) })
} }
@ -956,6 +967,7 @@ impl Parser {
&mut self, &mut self,
lexer: &mut Lexer<'a>, lexer: &mut Lexer<'a>,
type_arena: &mut Arena<crate::Type>, type_arena: &mut Arena<crate::Type>,
const_arena: &mut Arena<crate::Constant>,
) -> Result<Vec<crate::StructMember>, Error<'a>> { ) -> Result<Vec<crate::StructMember>, Error<'a>> {
let mut members = Vec::new(); let mut members = Vec::new();
lexer.expect(Token::Paren('{'))?; lexer.expect(Token::Paren('{'))?;
@ -992,7 +1004,7 @@ impl Parser {
return Err(Error::MissingMemberOffset(name)); return Err(Error::MissingMemberOffset(name));
} }
lexer.expect(Token::Separator(':'))?; lexer.expect(Token::Separator(':'))?;
let ty = self.parse_type_decl(lexer, None, type_arena)?; let ty = self.parse_type_decl(lexer, None, type_arena, const_arena)?;
lexer.expect(Token::Separator(';'))?; lexer.expect(Token::Separator(';'))?;
members.push(crate::StructMember { members.push(crate::StructMember {
name: Some(name.to_owned()), name: Some(name.to_owned()),
@ -1007,6 +1019,7 @@ impl Parser {
lexer: &mut Lexer<'a>, lexer: &mut Lexer<'a>,
self_name: Option<&'a str>, self_name: Option<&'a str>,
type_arena: &mut Arena<crate::Type>, type_arena: &mut Arena<crate::Type>,
const_arena: &mut Arena<crate::Constant>,
) -> Result<Handle<crate::Type>, Error<'a>> { ) -> Result<Handle<crate::Type>, Error<'a>> {
self.scopes.push(Scope::TypeDecl); self.scopes.push(Scope::TypeDecl);
let decoration_lexer = if lexer.skip(Token::DoubleParen('[')) { let decoration_lexer = if lexer.skip(Token::DoubleParen('[')) {
@ -1128,18 +1141,30 @@ impl Parser {
lexer.expect(Token::Paren('<'))?; lexer.expect(Token::Paren('<'))?;
let class = conv::map_storage_class(lexer.next_ident()?)?; let class = conv::map_storage_class(lexer.next_ident()?)?;
lexer.expect(Token::Separator(','))?; lexer.expect(Token::Separator(','))?;
let base = self.parse_type_decl(lexer, None, type_arena)?; let base = self.parse_type_decl(lexer, None, type_arena, const_arena)?;
lexer.expect(Token::Paren('>'))?; lexer.expect(Token::Paren('>'))?;
crate::TypeInner::Pointer { base, class } crate::TypeInner::Pointer { base, class }
} }
Token::Word("array") => { Token::Word("array") => {
lexer.expect(Token::Paren('<'))?; lexer.expect(Token::Paren('<'))?;
let base = self.parse_type_decl(lexer, None, type_arena)?; let base = self.parse_type_decl(lexer, None, type_arena, const_arena)?;
let size = match lexer.next() { let size = match lexer.next() {
Token::Separator(',') => { Token::Separator(',') => {
let value = lexer.next_uint_literal()?; let value = lexer.next_uint_literal()?;
lexer.expect(Token::Paren('>'))?; lexer.expect(Token::Paren('>'))?;
crate::ArraySize::Static(value) let const_handle = const_arena.fetch_or_append(crate::Constant {
name: None,
specialization: None,
inner: crate::ConstantInner::Uint(value as u64),
ty: type_arena.fetch_or_append(crate::Type {
name: None,
inner: crate::TypeInner::Scalar {
kind: crate::ScalarKind::Uint,
width: 4,
},
}),
});
crate::ArraySize::Constant(const_handle)
} }
Token::Paren('>') => crate::ArraySize::Dynamic, Token::Paren('>') => crate::ArraySize::Dynamic,
other => return Err(Error::Unexpected(other)), other => return Err(Error::Unexpected(other)),
@ -1167,7 +1192,7 @@ impl Parser {
crate::TypeInner::Array { base, size, stride } crate::TypeInner::Array { base, size, stride }
} }
Token::Word("struct") => { Token::Word("struct") => {
let members = self.parse_struct_body(lexer, type_arena)?; let members = self.parse_struct_body(lexer, type_arena, const_arena)?;
crate::TypeInner::Struct { members } crate::TypeInner::Struct { members }
} }
Token::Word("sampler") => crate::TypeInner::Sampler { comparison: false }, Token::Word("sampler") => crate::TypeInner::Sampler { comparison: false },
@ -1368,15 +1393,20 @@ impl Parser {
"var" => { "var" => {
enum Init { enum Init {
Empty, Empty,
Uniform(Handle<crate::Expression>), Constant(Handle<crate::Constant>),
Variable(Handle<crate::Expression>), Variable(Handle<crate::Expression>),
} }
let (name, ty) = self.parse_variable_ident_decl(lexer, context.types)?; let (name, ty) = self.parse_variable_ident_decl(
lexer,
context.types,
context.constants,
)?;
let init = if lexer.skip(Token::Operation('=')) { let init = if lexer.skip(Token::Operation('=')) {
let value = let value =
self.parse_general_expression(lexer, context.as_expression())?; self.parse_general_expression(lexer, context.as_expression())?;
if let crate::Expression::Constant(_) = context.expressions[value] { if let crate::Expression::Constant(handle) = context.expressions[value]
Init::Uniform(value) {
Init::Constant(handle)
} else { } else {
Init::Variable(value) Init::Variable(value)
} }
@ -1388,7 +1418,7 @@ impl Parser {
name: Some(name.to_owned()), name: Some(name.to_owned()),
ty, ty,
init: match init { init: match init {
Init::Uniform(value) => Some(value), Init::Constant(value) => Some(value),
_ => None, _ => None,
}, },
}); });
@ -1515,31 +1545,34 @@ impl Parser {
lookup_ident.insert(name, expr_handle); lookup_ident.insert(name, expr_handle);
} }
// read parameter list // read parameter list
let mut parameter_types = Vec::new(); let mut arguments = Vec::new();
lexer.expect(Token::Paren('('))?; lexer.expect(Token::Paren('('))?;
while !lexer.skip(Token::Paren(')')) { while !lexer.skip(Token::Paren(')')) {
if !parameter_types.is_empty() { if !arguments.is_empty() {
lexer.expect(Token::Separator(','))?; lexer.expect(Token::Separator(','))?;
} }
let (param_name, param_type) = let (param_name, param_type) =
self.parse_variable_ident_decl(lexer, &mut module.types)?; self.parse_variable_ident_decl(lexer, &mut module.types, &mut module.constants)?;
let param_index = parameter_types.len() as u32; let param_index = arguments.len() as u32;
let expression_token = let expression_token =
expressions.append(crate::Expression::FunctionParameter(param_index)); expressions.append(crate::Expression::FunctionArgument(param_index));
lookup_ident.insert(param_name, expression_token); lookup_ident.insert(param_name, expression_token);
parameter_types.push(param_type); arguments.push(crate::FunctionArgument {
name: Some(param_name.to_string()),
ty: param_type,
});
} }
// read return type // read return type
lexer.expect(Token::Arrow)?; lexer.expect(Token::Arrow)?;
let return_type = if lexer.skip(Token::Word("void")) { let return_type = if lexer.skip(Token::Word("void")) {
None None
} else { } else {
Some(self.parse_type_decl(lexer, None, &mut module.types)?) Some(self.parse_type_decl(lexer, None, &mut module.types, &mut module.constants)?)
}; };
let mut fun = crate::Function { let mut fun = crate::Function {
name: Some(fun_name.to_string()), name: Some(fun_name.to_string()),
parameter_types, arguments,
return_type, return_type,
global_usage: Vec::new(), global_usage: Vec::new(),
local_variables: Arena::new(), local_variables: Arena::new(),
@ -1559,7 +1592,7 @@ impl Parser {
types: &mut module.types, types: &mut module.types,
constants: &mut module.constants, constants: &mut module.constants,
global_vars: &module.global_variables, global_vars: &module.global_variables,
parameter_types: &fun.parameter_types, arguments: &fun.arguments,
}, },
)?; )?;
// done // done
@ -1680,25 +1713,29 @@ impl Parser {
Token::Word("type") => { Token::Word("type") => {
let name = lexer.next_ident()?; let name = lexer.next_ident()?;
lexer.expect(Token::Operation('='))?; lexer.expect(Token::Operation('='))?;
let ty = self.parse_type_decl(lexer, Some(name), &mut module.types)?; let ty = self.parse_type_decl(
lexer,
Some(name),
&mut module.types,
&mut module.constants,
)?;
self.lookup_type.insert(name.to_owned(), ty); self.lookup_type.insert(name.to_owned(), ty);
lexer.expect(Token::Separator(';'))?; lexer.expect(Token::Separator(';'))?;
} }
Token::Word("const") => { Token::Word("const") => {
let (name, ty) = self.parse_variable_ident_decl(lexer, &mut module.types)?; let (name, ty) = self.parse_variable_ident_decl(
lexer,
&mut module.types,
&mut module.constants,
)?;
lexer.expect(Token::Operation('='))?; lexer.expect(Token::Operation('='))?;
let inner = let const_handle = self.parse_const_expression(
self.parse_const_expression(lexer, &mut module.types, &mut module.constants)?; lexer,
lexer.expect(Token::Separator(';'))?;
if !crate::proc::check_constant_type(&inner, &module.types[ty].inner) {
return Err(Error::UnexpectedConstantType(inner, ty));
}
let const_handle = module.constants.append(crate::Constant {
name: Some(name.to_owned()),
specialization: None,
inner,
ty, ty,
}); &mut module.types,
&mut module.constants,
)?;
lexer.expect(Token::Separator(';'))?;
lookup_global_expression.insert(name, crate::Expression::Constant(const_handle)); lookup_global_expression.insert(name, crate::Expression::Constant(const_handle));
} }
Token::Word("var") => { Token::Word("var") => {
@ -1712,7 +1749,7 @@ impl Parser {
crate::BuiltIn::Position => crate::StorageClass::Output, crate::BuiltIn::Position => crate::StorageClass::Output,
_ => unimplemented!(), _ => unimplemented!(),
}, },
_ => crate::StorageClass::Private, _ => crate::StorageClass::Handle,
}, },
}; };
let var_handle = module.global_variables.append(crate::GlobalVariable { let var_handle = module.global_variables.append(crate::GlobalVariable {
@ -1720,6 +1757,7 @@ impl Parser {
class, class,
binding: binding.take(), binding: binding.take(),
ty: pvar.ty, ty: pvar.ty,
init: pvar.init,
interpolation, interpolation,
storage_access: pvar.access, storage_access: pvar.access,
}); });
@ -1787,7 +1825,13 @@ impl Parser {
} }
Ok(true) => {} Ok(true) => {}
Ok(false) => { Ok(false) => {
assert_eq!(self.scopes, Vec::new()); if !self.scopes.is_empty() {
return Err(ParseError {
error: Error::Other,
scopes: std::mem::replace(&mut self.scopes, Vec::new()),
pos: (0, 0),
});
};
return Ok(module); return Ok(module);
} }
} }
@ -1802,5 +1846,5 @@ pub fn parse_str(source: &str) -> Result<crate::Module, ParseError> {
#[test] #[test]
fn parse_types() { fn parse_types() {
assert!(parse_str("const a : i32 = 2;").is_ok()); assert!(parse_str("const a : i32 = 2;").is_ok());
assert!(parse_str("const a : i32 = 2.0;").is_err()); assert!(parse_str("const a : x32 = 2;").is_err());
} }

View file

@ -4,7 +4,11 @@
//! //!
//! To improve performance and reduce memory usage, most structures are stored //! To improve performance and reduce memory usage, most structures are stored
//! in an [`Arena`], and can be retrieved using the corresponding [`Handle`]. //! in an [`Arena`], and can be retrieved using the corresponding [`Handle`].
#![allow(clippy::new_without_default, clippy::unneeded_field_pattern)] #![allow(
clippy::new_without_default,
clippy::unneeded_field_pattern,
clippy::match_like_matches_macro
)]
#![deny(clippy::panic)] #![deny(clippy::panic)]
mod arena; mod arena;
@ -57,7 +61,7 @@ pub struct Header {
/// For more, see: /// For more, see:
/// - https://www.khronos.org/opengl/wiki/Early_Fragment_Test#Explicit_specification /// - https://www.khronos.org/opengl/wiki/Early_Fragment_Test#Explicit_specification
/// - https://docs.microsoft.com/en-us/windows/win32/direct3dhlsl/sm5-attributes-earlydepthstencil /// - https://docs.microsoft.com/en-us/windows/win32/direct3dhlsl/sm5-attributes-earlydepthstencil
#[derive(Clone, Copy, Debug, PartialEq)] #[derive(Clone, Copy, Debug, Hash, Eq, Ord, PartialEq, PartialOrd)]
#[cfg_attr(feature = "serialize", derive(Serialize))] #[cfg_attr(feature = "serialize", derive(Serialize))]
#[cfg_attr(feature = "deserialize", derive(Deserialize))] #[cfg_attr(feature = "deserialize", derive(Deserialize))]
pub struct EarlyDepthTest { pub struct EarlyDepthTest {
@ -73,7 +77,7 @@ pub struct EarlyDepthTest {
/// For more, see: /// For more, see:
/// - https://www.khronos.org/registry/OpenGL/extensions/ARB/ARB_conservative_depth.txt /// - https://www.khronos.org/registry/OpenGL/extensions/ARB/ARB_conservative_depth.txt
/// - https://docs.microsoft.com/en-us/windows/win32/direct3dhlsl/dx-graphics-hlsl-semantics#system-value-semantics /// - https://docs.microsoft.com/en-us/windows/win32/direct3dhlsl/dx-graphics-hlsl-semantics#system-value-semantics
#[derive(Clone, Copy, Debug, PartialEq)] #[derive(Clone, Copy, Debug, Hash, Eq, Ord, PartialEq, PartialOrd)]
#[cfg_attr(feature = "serialize", derive(Serialize))] #[cfg_attr(feature = "serialize", derive(Serialize))]
#[cfg_attr(feature = "deserialize", derive(Deserialize))] #[cfg_attr(feature = "deserialize", derive(Deserialize))]
pub enum ConservativeDepth { pub enum ConservativeDepth {
@ -88,7 +92,7 @@ pub enum ConservativeDepth {
} }
/// Stage of the programmable pipeline. /// Stage of the programmable pipeline.
#[derive(Clone, Copy, Debug, Eq, Hash, PartialEq)] #[derive(Clone, Copy, Debug, Hash, Eq, Ord, PartialEq, PartialOrd)]
#[cfg_attr(feature = "serialize", derive(Serialize))] #[cfg_attr(feature = "serialize", derive(Serialize))]
#[cfg_attr(feature = "deserialize", derive(Deserialize))] #[cfg_attr(feature = "deserialize", derive(Deserialize))]
#[allow(missing_docs)] // The names are self evident #[allow(missing_docs)] // The names are self evident
@ -99,23 +103,33 @@ pub enum ShaderStage {
} }
/// Class of storage for variables. /// Class of storage for variables.
#[derive(Clone, Copy, Debug, PartialEq)] #[derive(Clone, Copy, Debug, Hash, Eq, Ord, PartialEq, PartialOrd)]
#[cfg_attr(feature = "serialize", derive(Serialize))] #[cfg_attr(feature = "serialize", derive(Serialize))]
#[cfg_attr(feature = "deserialize", derive(Deserialize))] #[cfg_attr(feature = "deserialize", derive(Deserialize))]
#[allow(missing_docs)] // The names are self evident #[allow(missing_docs)] // The names are self evident
pub enum StorageClass { pub enum StorageClass {
Constant, /// Function locals.
Function, Function,
/// Pipeline input, per invocation.
Input, Input,
/// Pipeline output, per invocation, mutable.
Output, Output,
/// Private data, per invocation, mutable.
Private, Private,
StorageBuffer, /// Workgroup shared data, mutable.
Uniform,
WorkGroup, WorkGroup,
/// Uniform buffer data.
Uniform,
/// Storage buffer data, potentially mutable.
Storage,
/// Opaque handles, such as samplers and images.
Handle,
/// Push constants.
PushConstant,
} }
/// Built-in inputs and outputs. /// Built-in inputs and outputs.
#[derive(Clone, Copy, Debug, PartialEq)] #[derive(Clone, Copy, Debug, Hash, Eq, Ord, PartialEq, PartialOrd)]
#[cfg_attr(feature = "serialize", derive(Serialize))] #[cfg_attr(feature = "serialize", derive(Serialize))]
#[cfg_attr(feature = "deserialize", derive(Deserialize))] #[cfg_attr(feature = "deserialize", derive(Deserialize))]
pub enum BuiltIn { pub enum BuiltIn {
@ -144,7 +158,7 @@ pub type Bytes = u8;
/// Number of components in a vector. /// Number of components in a vector.
#[repr(u8)] #[repr(u8)]
#[derive(Clone, Copy, Debug, PartialEq, Hash, Eq)] #[derive(Clone, Copy, Debug, Hash, Eq, Ord, PartialEq, PartialOrd)]
#[cfg_attr(feature = "serialize", derive(Serialize))] #[cfg_attr(feature = "serialize", derive(Serialize))]
#[cfg_attr(feature = "deserialize", derive(Deserialize))] #[cfg_attr(feature = "deserialize", derive(Deserialize))]
pub enum VectorSize { pub enum VectorSize {
@ -158,7 +172,7 @@ pub enum VectorSize {
/// Primitive type for a scalar. /// Primitive type for a scalar.
#[repr(u8)] #[repr(u8)]
#[derive(Clone, Copy, Debug, PartialEq, Hash, Eq)] #[derive(Clone, Copy, Debug, Hash, Eq, Ord, PartialEq, PartialOrd)]
#[cfg_attr(feature = "serialize", derive(Serialize))] #[cfg_attr(feature = "serialize", derive(Serialize))]
#[cfg_attr(feature = "deserialize", derive(Deserialize))] #[cfg_attr(feature = "deserialize", derive(Deserialize))]
pub enum ScalarKind { pub enum ScalarKind {
@ -174,18 +188,18 @@ pub enum ScalarKind {
/// Size of an array. /// Size of an array.
#[repr(u8)] #[repr(u8)]
#[derive(Clone, Copy, Debug, PartialEq)] #[derive(Clone, Copy, Debug, Hash, Eq, Ord, PartialEq, PartialOrd)]
#[cfg_attr(feature = "serialize", derive(Serialize))] #[cfg_attr(feature = "serialize", derive(Serialize))]
#[cfg_attr(feature = "deserialize", derive(Deserialize))] #[cfg_attr(feature = "deserialize", derive(Deserialize))]
pub enum ArraySize { pub enum ArraySize {
/// The array size is known at compilation. /// The array size is constant.
Static(u32), Constant(Handle<Constant>),
/// The array size can change at runtime. /// The array size can change at runtime.
Dynamic, Dynamic,
} }
/// Describes where a struct member is placed. /// Describes where a struct member is placed.
#[derive(Clone, Debug, PartialEq)] #[derive(Clone, Copy, Debug, Hash, Eq, Ord, PartialEq, PartialOrd)]
#[cfg_attr(feature = "serialize", derive(Serialize))] #[cfg_attr(feature = "serialize", derive(Serialize))]
#[cfg_attr(feature = "deserialize", derive(Deserialize))] #[cfg_attr(feature = "deserialize", derive(Deserialize))]
pub enum MemberOrigin { pub enum MemberOrigin {
@ -198,7 +212,7 @@ pub enum MemberOrigin {
} }
/// The interpolation qualifier of a binding or struct field. /// The interpolation qualifier of a binding or struct field.
#[derive(Clone, Copy, Debug, PartialEq)] #[derive(Clone, Copy, Debug, Hash, Eq, Ord, PartialEq, PartialOrd)]
#[cfg_attr(feature = "serialize", derive(Serialize))] #[cfg_attr(feature = "serialize", derive(Serialize))]
#[cfg_attr(feature = "deserialize", derive(Deserialize))] #[cfg_attr(feature = "deserialize", derive(Deserialize))]
pub enum Interpolation { pub enum Interpolation {
@ -233,7 +247,7 @@ pub struct StructMember {
} }
/// The number of dimensions an image has. /// The number of dimensions an image has.
#[derive(Clone, Copy, Debug, PartialEq)] #[derive(Clone, Copy, Debug, Hash, Eq, Ord, PartialEq, PartialOrd)]
#[cfg_attr(feature = "serialize", derive(Serialize))] #[cfg_attr(feature = "serialize", derive(Serialize))]
#[cfg_attr(feature = "deserialize", derive(Deserialize))] #[cfg_attr(feature = "deserialize", derive(Deserialize))]
pub enum ImageDimension { pub enum ImageDimension {
@ -260,7 +274,7 @@ bitflags::bitflags! {
} }
// Storage image format. // Storage image format.
#[derive(Clone, Copy, Debug, PartialEq)] #[derive(Clone, Copy, Debug, Hash, Eq, Ord, PartialEq, PartialOrd)]
#[cfg_attr(feature = "serialize", derive(Serialize))] #[cfg_attr(feature = "serialize", derive(Serialize))]
#[cfg_attr(feature = "deserialize", derive(Deserialize))] #[cfg_attr(feature = "deserialize", derive(Deserialize))]
pub enum StorageFormat { pub enum StorageFormat {
@ -310,7 +324,7 @@ pub enum StorageFormat {
} }
/// Sub-class of the image type. /// Sub-class of the image type.
#[derive(Clone, Copy, Debug, PartialEq)] #[derive(Clone, Copy, Debug, Hash, Eq, Ord, PartialEq, PartialOrd)]
#[cfg_attr(feature = "serialize", derive(Serialize))] #[cfg_attr(feature = "serialize", derive(Serialize))]
#[cfg_attr(feature = "deserialize", derive(Deserialize))] #[cfg_attr(feature = "deserialize", derive(Deserialize))]
pub enum ImageClass { pub enum ImageClass {
@ -392,8 +406,7 @@ pub struct Constant {
} }
/// Additional information, dependendent on the kind of constant. /// Additional information, dependendent on the kind of constant.
// Clone is used only for error reporting and is not intended for end users #[derive(Debug, PartialEq)]
#[derive(Clone, Debug, PartialEq)]
#[cfg_attr(feature = "serialize", derive(Serialize))] #[cfg_attr(feature = "serialize", derive(Serialize))]
#[cfg_attr(feature = "deserialize", derive(Deserialize))] #[cfg_attr(feature = "deserialize", derive(Deserialize))]
pub enum ConstantInner { pub enum ConstantInner {
@ -442,6 +455,8 @@ pub struct GlobalVariable {
pub binding: Option<Binding>, pub binding: Option<Binding>,
/// The type of this variable. /// The type of this variable.
pub ty: Handle<Type>, pub ty: Handle<Type>,
/// Initial value for this variable.
pub init: Option<Handle<Constant>>,
/// The interpolation qualifier, if any. /// The interpolation qualifier, if any.
/// If the this `GlobalVariable` is a vertex output /// If the this `GlobalVariable` is a vertex output
/// or fragment input, `None` corresponds to the /// or fragment input, `None` corresponds to the
@ -461,11 +476,11 @@ pub struct LocalVariable {
/// The type of this variable. /// The type of this variable.
pub ty: Handle<Type>, pub ty: Handle<Type>,
/// Initial value for this variable. /// Initial value for this variable.
pub init: Option<Handle<Expression>>, pub init: Option<Handle<Constant>>,
} }
/// Operation that can be applied on a single value. /// Operation that can be applied on a single value.
#[derive(Clone, Copy, Debug, PartialEq)] #[derive(Clone, Copy, Debug, Hash, Eq, Ord, PartialEq, PartialOrd)]
#[cfg_attr(feature = "serialize", derive(Serialize))] #[cfg_attr(feature = "serialize", derive(Serialize))]
#[cfg_attr(feature = "deserialize", derive(Deserialize))] #[cfg_attr(feature = "deserialize", derive(Deserialize))]
pub enum UnaryOperator { pub enum UnaryOperator {
@ -474,7 +489,7 @@ pub enum UnaryOperator {
} }
/// Operation that can be applied on two values. /// Operation that can be applied on two values.
#[derive(Clone, Copy, Debug, PartialEq)] #[derive(Clone, Copy, Debug, Hash, Eq, Ord, PartialEq, PartialOrd)]
#[cfg_attr(feature = "serialize", derive(Serialize))] #[cfg_attr(feature = "serialize", derive(Serialize))]
#[cfg_attr(feature = "deserialize", derive(Deserialize))] #[cfg_attr(feature = "deserialize", derive(Deserialize))]
pub enum BinaryOperator { pub enum BinaryOperator {
@ -494,13 +509,13 @@ pub enum BinaryOperator {
InclusiveOr, InclusiveOr,
LogicalAnd, LogicalAnd,
LogicalOr, LogicalOr,
ShiftLeftLogical, ShiftLeft,
ShiftRightLogical, /// Right shift carries the sign of signed integers only.
ShiftRightArithmetic, ShiftRight,
} }
/// Built-in shader function. /// Built-in shader function.
#[derive(Clone, Copy, Debug, PartialEq)] #[derive(Clone, Copy, Debug, Hash, Eq, Ord, PartialEq, PartialOrd)]
#[cfg_attr(feature = "serialize", derive(Serialize))] #[cfg_attr(feature = "serialize", derive(Serialize))]
#[cfg_attr(feature = "deserialize", derive(Deserialize))] #[cfg_attr(feature = "deserialize", derive(Deserialize))]
pub enum IntrinsicFunction { pub enum IntrinsicFunction {
@ -513,7 +528,7 @@ pub enum IntrinsicFunction {
} }
/// Axis on which to compute a derivative. /// Axis on which to compute a derivative.
#[derive(Clone, Copy, Debug, PartialEq)] #[derive(Clone, Copy, Debug, Hash, Eq, Ord, PartialEq, PartialOrd)]
#[cfg_attr(feature = "serialize", derive(Serialize))] #[cfg_attr(feature = "serialize", derive(Serialize))]
#[cfg_attr(feature = "deserialize", derive(Deserialize))] #[cfg_attr(feature = "deserialize", derive(Deserialize))]
pub enum DerivativeAxis { pub enum DerivativeAxis {
@ -569,7 +584,7 @@ pub enum Expression {
components: Vec<Handle<Expression>>, components: Vec<Handle<Expression>>,
}, },
/// Reference a function parameter, by its index. /// Reference a function parameter, by its index.
FunctionParameter(u32), FunctionArgument(u32),
/// Reference a global variable. /// Reference a global variable.
GlobalVariable(Handle<GlobalVariable>), GlobalVariable(Handle<GlobalVariable>),
/// Reference a local variable. /// Reference a local variable.
@ -604,6 +619,13 @@ pub enum Expression {
left: Handle<Expression>, left: Handle<Expression>,
right: Handle<Expression>, right: Handle<Expression>,
}, },
/// Select between two values based on a condition.
Select {
/// Boolean expression
condition: Handle<Expression>,
accept: Handle<Expression>,
reject: Handle<Expression>,
},
/// Call an intrinsic function. /// Call an intrinsic function.
Intrinsic { Intrinsic {
fun: IntrinsicFunction, fun: IntrinsicFunction,
@ -687,6 +709,17 @@ pub enum Statement {
}, },
} }
/// A function argument.
#[derive(Debug)]
#[cfg_attr(feature = "serialize", derive(Serialize))]
#[cfg_attr(feature = "deserialize", derive(Deserialize))]
pub struct FunctionArgument {
/// Name of the argument, if any.
pub name: Option<String>,
/// Type of the argument.
pub ty: Handle<Type>,
}
/// A function defined in the module. /// A function defined in the module.
#[derive(Debug)] #[derive(Debug)]
#[cfg_attr(feature = "serialize", derive(Serialize))] #[cfg_attr(feature = "serialize", derive(Serialize))]
@ -694,9 +727,8 @@ pub enum Statement {
pub struct Function { pub struct Function {
/// Name of the function, if any. /// Name of the function, if any.
pub name: Option<String>, pub name: Option<String>,
//pub control: spirv::FunctionControl, /// Information about function argument.
/// The types of the parameters of this function. pub arguments: Vec<FunctionArgument>,
pub parameter_types: Vec<Handle<Type>>,
/// The return type of this function, if any. /// The return type of this function, if any.
pub return_type: Option<Handle<Type>>, pub return_type: Option<Handle<Type>>,
/// Vector of global variable usages. /// Vector of global variable usages.

View file

@ -2,6 +2,7 @@ use crate::arena::{Arena, Handle};
pub struct Interface<'a, T> { pub struct Interface<'a, T> {
pub expressions: &'a Arena<crate::Expression>, pub expressions: &'a Arena<crate::Expression>,
pub local_variables: &'a Arena<crate::LocalVariable>,
pub visitor: T, pub visitor: T,
} }
@ -36,7 +37,7 @@ where
self.traverse_expr(comp); self.traverse_expr(comp);
} }
} }
E::FunctionParameter(_) | E::GlobalVariable(_) | E::LocalVariable(_) => {} E::FunctionArgument(_) | E::GlobalVariable(_) | E::LocalVariable(_) => {}
E::Load { pointer } => { E::Load { pointer } => {
self.traverse_expr(pointer); self.traverse_expr(pointer);
} }
@ -78,6 +79,15 @@ where
self.traverse_expr(left); self.traverse_expr(left);
self.traverse_expr(right); self.traverse_expr(right);
} }
E::Select {
condition,
accept,
reject,
} => {
self.traverse_expr(condition);
self.traverse_expr(accept);
self.traverse_expr(reject);
}
E::Intrinsic { argument, .. } => { E::Intrinsic { argument, .. } => {
self.traverse_expr(argument); self.traverse_expr(argument);
} }
@ -201,6 +211,7 @@ impl crate::Function {
let mut io = Interface { let mut io = Interface {
expressions: &self.expressions, expressions: &self.expressions,
local_variables: &self.local_variables,
visitor: GlobalUseVisitor(&mut self.global_usage), visitor: GlobalUseVisitor(&mut self.global_usage),
}; };
io.traverse(&self.body); io.traverse(&self.body);
@ -218,9 +229,10 @@ mod tests {
fn global_use_scan() { fn global_use_scan() {
let test_global = GlobalVariable { let test_global = GlobalVariable {
name: None, name: None,
class: StorageClass::Constant, class: StorageClass::Uniform,
binding: None, binding: None,
ty: Handle::new(std::num::NonZeroU32::new(1).unwrap()), ty: Handle::new(std::num::NonZeroU32::new(1).unwrap()),
init: None,
interpolation: None, interpolation: None,
storage_access: StorageAccess::empty(), storage_access: StorageAccess::empty(),
}; };
@ -256,7 +268,7 @@ mod tests {
let mut function = crate::Function { let mut function = crate::Function {
name: None, name: None,
parameter_types: Vec::new(), arguments: Vec::new(),
return_type: None, return_type: None,
local_variables: Arena::new(), local_variables: Arena::new(),
expressions, expressions,

View file

@ -1,10 +1,12 @@
//! Module processing functionality. //! Module processing functionality.
mod interface; mod interface;
mod namer;
mod typifier; mod typifier;
mod validator; mod validator;
pub use interface::{Interface, Visitor}; pub use interface::{Interface, Visitor};
pub use namer::{EntryPointIndex, NameKey, Namer};
pub use typifier::{check_constant_type, ResolveContext, ResolveError, Typifier}; pub use typifier::{check_constant_type, ResolveContext, ResolveError, Typifier};
pub use validator::{ValidationError, Validator}; pub use validator::{ValidationError, Validator};
@ -47,3 +49,15 @@ impl From<super::StorageFormat> for super::ScalarKind {
} }
} }
} }
impl crate::TypeInner {
pub fn scalar_kind(&self) -> Option<super::ScalarKind> {
match *self {
super::TypeInner::Scalar { kind, .. } | super::TypeInner::Vector { kind, .. } => {
Some(kind)
}
super::TypeInner::Matrix { .. } => Some(super::ScalarKind::Float),
_ => None,
}
}
}

113
third_party/rust/naga/src/proc/namer.rs vendored Normal file
View file

@ -0,0 +1,113 @@
use crate::{arena::Handle, FastHashMap};
use std::collections::hash_map::Entry;
pub type EntryPointIndex = u16;
#[derive(Debug, Eq, Hash, PartialEq)]
pub enum NameKey {
GlobalVariable(Handle<crate::GlobalVariable>),
Type(Handle<crate::Type>),
StructMember(Handle<crate::Type>, u32),
Function(Handle<crate::Function>),
FunctionArgument(Handle<crate::Function>, u32),
FunctionLocal(Handle<crate::Function>, Handle<crate::LocalVariable>),
EntryPoint(EntryPointIndex),
EntryPointLocal(EntryPointIndex, Handle<crate::LocalVariable>),
}
/// This processor assigns names to all the things in a module
/// that may need identifiers in a textual backend.
pub struct Namer {
unique: FastHashMap<String, u32>,
}
impl Namer {
fn sanitize(string: &str) -> String {
let mut base = string
.chars()
.skip_while(|c| c.is_numeric())
.filter(|&c| c.is_ascii_alphanumeric() || c == '_')
.collect::<String>();
// close the name by '_' if the re is a number, so that
// we can have our own number!
match base.chars().next_back() {
Some(c) if !c.is_numeric() => {}
_ => base.push('_'),
};
base
}
fn call(&mut self, label_raw: &str) -> String {
let base = Self::sanitize(label_raw);
match self.unique.entry(base) {
Entry::Occupied(mut e) => {
*e.get_mut() += 1;
format!("{}{}", e.key(), e.get())
}
Entry::Vacant(e) => {
let name = e.key().to_string();
e.insert(0);
name
}
}
}
fn call_or(&mut self, label: &Option<String>, fallback: &str) -> String {
self.call(match *label {
Some(ref name) => name,
None => fallback,
})
}
pub fn process(
module: &crate::Module,
reserved: &[&str],
output: &mut FastHashMap<NameKey, String>,
) {
let mut this = Namer {
unique: reserved
.iter()
.map(|string| (string.to_string(), 0))
.collect(),
};
for (handle, var) in module.global_variables.iter() {
let name = this.call_or(&var.name, "global");
output.insert(NameKey::GlobalVariable(handle), name);
}
for (ty_handle, ty) in module.types.iter() {
let ty_name = this.call_or(&ty.name, "type");
output.insert(NameKey::Type(ty_handle), ty_name);
if let crate::TypeInner::Struct { ref members } = ty.inner {
for (index, member) in members.iter().enumerate() {
let name = this.call_or(&member.name, "member");
output.insert(NameKey::StructMember(ty_handle, index as u32), name);
}
}
}
for (fun_handle, fun) in module.functions.iter() {
let fun_name = this.call_or(&fun.name, "function");
output.insert(NameKey::Function(fun_handle), fun_name);
for (index, arg) in fun.arguments.iter().enumerate() {
let name = this.call_or(&arg.name, "param");
output.insert(NameKey::FunctionArgument(fun_handle, index as u32), name);
}
for (handle, var) in fun.local_variables.iter() {
let name = this.call_or(&var.name, "local");
output.insert(NameKey::FunctionLocal(fun_handle, handle), name);
}
}
for (ep_index, (&(_, ref base_name), ep)) in module.entry_points.iter().enumerate() {
let ep_name = this.call(base_name);
output.insert(NameKey::EntryPoint(ep_index as _), ep_name);
for (handle, var) in ep.function.local_variables.iter() {
let name = this.call_or(&var.name, "local");
output.insert(NameKey::EntryPointLocal(ep_index as _, handle), name);
}
}
}
}

View file

@ -29,6 +29,7 @@ impl Clone for Resolution {
columns, columns,
width, width,
}, },
#[allow(clippy::panic)]
_ => panic!("Unepxected clone type: {:?}", v), _ => panic!("Unepxected clone type: {:?}", v),
}), }),
} }
@ -50,6 +51,14 @@ pub enum ResolveError {
FunctionReturnsVoid, FunctionReturnsVoid,
#[error("Type is not found in the given immutable arena")] #[error("Type is not found in the given immutable arena")]
TypeNotFound, TypeNotFound,
#[error("Incompatible operand: {op} {operand}")]
IncompatibleOperand { op: String, operand: String },
#[error("Incompatible operands: {left} {op} {right}")]
IncompatibleOperands {
op: String,
left: String,
right: String,
},
} }
pub struct ResolveContext<'a> { pub struct ResolveContext<'a> {
@ -57,7 +66,7 @@ pub struct ResolveContext<'a> {
pub global_vars: &'a Arena<crate::GlobalVariable>, pub global_vars: &'a Arena<crate::GlobalVariable>,
pub local_vars: &'a Arena<crate::LocalVariable>, pub local_vars: &'a Arena<crate::LocalVariable>,
pub functions: &'a Arena<crate::Function>, pub functions: &'a Arena<crate::Function>,
pub parameter_types: &'a [Handle<crate::Type>], pub arguments: &'a [crate::FunctionArgument],
} }
impl Typifier { impl Typifier {
@ -82,6 +91,16 @@ impl Typifier {
} }
} }
pub fn get_handle(
&self,
expr_handle: Handle<crate::Expression>,
) -> Option<Handle<crate::Type>> {
match self.resolutions[expr_handle.index()] {
Resolution::Handle(ty_handle) => Some(ty_handle),
Resolution::Value(_) => None,
}
}
fn resolve_impl( fn resolve_impl(
&self, &self,
expr: &crate::Expression, expr: &crate::Expression,
@ -105,7 +124,12 @@ impl Typifier {
kind: crate::ScalarKind::Float, kind: crate::ScalarKind::Float,
width, width,
}), }),
ref other => panic!("Can't access into {:?}", other), ref other => {
return Err(ResolveError::IncompatibleOperand {
op: "access".to_string(),
operand: format!("{:?}", other),
})
}
}, },
crate::Expression::AccessIndex { base, index } => match *self.get(base, types) { crate::Expression::AccessIndex { base, index } => match *self.get(base, types) {
crate::TypeInner::Vector { size, kind, width } => { crate::TypeInner::Vector { size, kind, width } => {
@ -135,12 +159,17 @@ impl Typifier {
.ok_or(ResolveError::InvalidAccessIndex)?; .ok_or(ResolveError::InvalidAccessIndex)?;
Resolution::Handle(member.ty) Resolution::Handle(member.ty)
} }
ref other => panic!("Can't access into {:?}", other), ref other => {
return Err(ResolveError::IncompatibleOperand {
op: "access index".to_string(),
operand: format!("{:?}", other),
})
}
}, },
crate::Expression::Constant(h) => Resolution::Handle(ctx.constants[h].ty), crate::Expression::Constant(h) => Resolution::Handle(ctx.constants[h].ty),
crate::Expression::Compose { ty, .. } => Resolution::Handle(ty), crate::Expression::Compose { ty, .. } => Resolution::Handle(ty),
crate::Expression::FunctionParameter(index) => { crate::Expression::FunctionArgument(index) => {
Resolution::Handle(ctx.parameter_types[index as usize]) Resolution::Handle(ctx.arguments[index as usize].ty)
} }
crate::Expression::GlobalVariable(h) => Resolution::Handle(ctx.global_vars[h].ty), crate::Expression::GlobalVariable(h) => Resolution::Handle(ctx.global_vars[h].ty),
crate::Expression::LocalVariable(h) => Resolution::Handle(ctx.local_vars[h].ty), crate::Expression::LocalVariable(h) => Resolution::Handle(ctx.local_vars[h].ty),
@ -192,7 +221,13 @@ impl Typifier {
kind: crate::ScalarKind::Float, kind: crate::ScalarKind::Float,
width, width,
}), }),
_ => panic!("Incompatible arguments {:?} x {:?}", ty_left, ty_right), _ => {
return Err(ResolveError::IncompatibleOperands {
op: "x".to_string(),
left: format!("{:?}", ty_left),
right: format!("{:?}", ty_right),
})
}
} }
} }
} }
@ -207,12 +242,10 @@ impl Typifier {
crate::BinaryOperator::And crate::BinaryOperator::And
| crate::BinaryOperator::ExclusiveOr | crate::BinaryOperator::ExclusiveOr
| crate::BinaryOperator::InclusiveOr | crate::BinaryOperator::InclusiveOr
| crate::BinaryOperator::ShiftLeftLogical | crate::BinaryOperator::ShiftLeft
| crate::BinaryOperator::ShiftRightLogical | crate::BinaryOperator::ShiftRight => self.resolutions[left.index()].clone(),
| crate::BinaryOperator::ShiftRightArithmetic => {
self.resolutions[left.index()].clone()
}
}, },
crate::Expression::Select { accept, .. } => self.resolutions[accept.index()].clone(),
crate::Expression::Intrinsic { .. } => unimplemented!(), crate::Expression::Intrinsic { .. } => unimplemented!(),
crate::Expression::Transpose(expr) => match *self.get(expr, types) { crate::Expression::Transpose(expr) => match *self.get(expr, types) {
crate::TypeInner::Matrix { crate::TypeInner::Matrix {
@ -224,7 +257,12 @@ impl Typifier {
rows: columns, rows: columns,
width, width,
}), }),
ref other => panic!("incompatible transpose of {:?}", other), ref other => {
return Err(ResolveError::IncompatibleOperand {
op: "transpose".to_string(),
operand: format!("{:?}", other),
})
}
}, },
crate::Expression::DotProduct(left_expr, _) => match *self.get(left_expr, types) { crate::Expression::DotProduct(left_expr, _) => match *self.get(left_expr, types) {
crate::TypeInner::Vector { crate::TypeInner::Vector {
@ -232,7 +270,12 @@ impl Typifier {
size: _, size: _,
width, width,
} => Resolution::Value(crate::TypeInner::Scalar { kind, width }), } => Resolution::Value(crate::TypeInner::Scalar { kind, width }),
ref other => panic!("incompatible dot of {:?}", other), ref other => {
return Err(ResolveError::IncompatibleOperand {
op: "dot product".to_string(),
operand: format!("{:?}", other),
})
}
}, },
crate::Expression::CrossProduct(_, _) => unimplemented!(), crate::Expression::CrossProduct(_, _) => unimplemented!(),
crate::Expression::As { crate::Expression::As {
@ -248,7 +291,12 @@ impl Typifier {
size, size,
width, width,
} => Resolution::Value(crate::TypeInner::Vector { kind, size, width }), } => Resolution::Value(crate::TypeInner::Vector { kind, size, width }),
ref other => panic!("incompatible as of {:?}", other), ref other => {
return Err(ResolveError::IncompatibleOperand {
op: "as".to_string(),
operand: format!("{:?}", other),
})
}
}, },
crate::Expression::Derivative { .. } => unimplemented!(), crate::Expression::Derivative { .. } => unimplemented!(),
crate::Expression::Call { crate::Expression::Call {
@ -260,13 +308,23 @@ impl Typifier {
| crate::TypeInner::Scalar { kind, width } => { | crate::TypeInner::Scalar { kind, width } => {
Resolution::Value(crate::TypeInner::Scalar { kind, width }) Resolution::Value(crate::TypeInner::Scalar { kind, width })
} }
ref other => panic!("Unexpected argument {:?} on {}", other, name), ref other => {
return Err(ResolveError::IncompatibleOperand {
op: name.clone(),
operand: format!("{:?}", other),
})
}
}, },
"dot" => match *self.get(arguments[0], types) { "dot" => match *self.get(arguments[0], types) {
crate::TypeInner::Vector { kind, width, .. } => { crate::TypeInner::Vector { kind, width, .. } => {
Resolution::Value(crate::TypeInner::Scalar { kind, width }) Resolution::Value(crate::TypeInner::Scalar { kind, width })
} }
ref other => panic!("Unexpected argument {:?} on {}", other, name), ref other => {
return Err(ResolveError::IncompatibleOperand {
op: name.clone(),
operand: format!("{:?}", other),
})
}
}, },
//Note: `cross` is here too, we still need to figure out what to do with it //Note: `cross` is here too, we still need to figure out what to do with it
"abs" | "atan2" | "cos" | "sin" | "floor" | "inverse" | "normalize" | "min" "abs" | "atan2" | "cos" | "sin" | "floor" | "inverse" | "normalize" | "min"

View file

@ -21,20 +21,37 @@ pub enum GlobalVariableError {
InvalidType, InvalidType,
#[error("Interpolation is not valid")] #[error("Interpolation is not valid")]
InvalidInterpolation, InvalidInterpolation,
#[error("Storage access flags are invalid")] #[error("Storage access {seen:?} exceed the allowed {allowed:?}")]
InvalidStorageAccess, InvalidStorageAccess {
allowed: crate::StorageAccess,
seen: crate::StorageAccess,
},
#[error("Binding decoration is missing or not applicable")] #[error("Binding decoration is missing or not applicable")]
InvalidBinding, InvalidBinding,
#[error("Binding is out of range")] #[error("Binding is out of range")]
OutOfRangeBinding, OutOfRangeBinding,
} }
#[derive(Clone, Debug, thiserror::Error)]
pub enum LocalVariableError {
#[error("Initializer is not a constant expression")]
InitializerConst,
#[error("Initializer doesn't match the variable type")]
InitializerType,
}
#[derive(Clone, Debug, thiserror::Error)] #[derive(Clone, Debug, thiserror::Error)]
pub enum FunctionError { pub enum FunctionError {
#[error(transparent)] #[error(transparent)]
Resolve(#[from] ResolveError), Resolve(#[from] ResolveError),
#[error("There are instructions after `return`/`break`/`continue`")] #[error("There are instructions after `return`/`break`/`continue`")]
InvalidControlFlowExitTail, InvalidControlFlowExitTail,
#[error("Local variable {handle:?} '{name}' is invalid: {error:?}")]
LocalVariable {
handle: Handle<crate::LocalVariable>,
name: String,
error: LocalVariableError,
},
} }
#[derive(Clone, Debug, thiserror::Error)] #[derive(Clone, Debug, thiserror::Error)]
@ -63,8 +80,14 @@ pub enum ValidationError {
InvalidTypeWidth(crate::ScalarKind, crate::Bytes), InvalidTypeWidth(crate::ScalarKind, crate::Bytes),
#[error("The type handle {0:?} can not be resolved")] #[error("The type handle {0:?} can not be resolved")]
UnresolvedType(Handle<crate::Type>), UnresolvedType(Handle<crate::Type>),
#[error("Global variable {0:?} is invalid: {1:?}")] #[error("The constant {0:?} can not be used for an array size")]
GlobalVariable(Handle<crate::GlobalVariable>, GlobalVariableError), InvalidArraySizeConstant(Handle<crate::Constant>),
#[error("Global variable {handle:?} '{name}' is invalid: {error:?}")]
GlobalVariable {
handle: Handle<crate::GlobalVariable>,
name: String,
error: GlobalVariableError,
},
#[error("Function {0:?} is invalid: {1:?}")] #[error("Function {0:?} is invalid: {1:?}")]
Function(Handle<crate::Function>, FunctionError), Function(Handle<crate::Function>, FunctionError),
#[error("Entry point {name} at {stage:?} is invalid: {error:?}")] #[error("Entry point {name} at {stage:?} is invalid: {error:?}")]
@ -73,6 +96,43 @@ pub enum ValidationError {
name: String, name: String,
error: EntryPointError, error: EntryPointError,
}, },
#[error("Module is corrupted")]
Corrupted,
}
impl crate::GlobalVariable {
fn forbid_interpolation(&self) -> Result<(), GlobalVariableError> {
match self.interpolation {
Some(_) => Err(GlobalVariableError::InvalidInterpolation),
None => Ok(()),
}
}
fn check_resource(&self) -> Result<(), GlobalVariableError> {
match self.binding {
Some(crate::Binding::BuiltIn(_)) => {} // validated per entry point
Some(crate::Binding::Resource { group, binding }) => {
if group > MAX_BIND_GROUPS || binding > MAX_BIND_INDICES {
return Err(GlobalVariableError::OutOfRangeBinding);
}
}
Some(crate::Binding::Location(_)) | None => {
return Err(GlobalVariableError::InvalidBinding)
}
}
self.forbid_interpolation()
}
}
fn storage_usage(access: crate::StorageAccess) -> crate::GlobalUse {
let mut storage_usage = crate::GlobalUse::empty();
if access.contains(crate::StorageAccess::LOAD) {
storage_usage |= crate::GlobalUse::LOAD;
}
if access.contains(crate::StorageAccess::STORE) {
storage_usage |= crate::GlobalUse::STORE;
}
storage_usage
} }
impl Validator { impl Validator {
@ -89,15 +149,13 @@ impl Validator {
types: &Arena<crate::Type>, types: &Arena<crate::Type>,
) -> Result<(), GlobalVariableError> { ) -> Result<(), GlobalVariableError> {
log::debug!("var {:?}", var); log::debug!("var {:?}", var);
let is_storage = match var.class { let allowed_storage_access = match var.class {
crate::StorageClass::Function => return Err(GlobalVariableError::InvalidUsage), crate::StorageClass::Function => return Err(GlobalVariableError::InvalidUsage),
crate::StorageClass::Input | crate::StorageClass::Output => { crate::StorageClass::Input | crate::StorageClass::Output => {
match var.binding { match var.binding {
Some(crate::Binding::BuiltIn(_)) => { Some(crate::Binding::BuiltIn(_)) => {
// validated per entry point // validated per entry point
if var.interpolation.is_some() { var.forbid_interpolation()?
return Err(GlobalVariableError::InvalidInterpolation);
}
} }
Some(crate::Binding::Location(loc)) => { Some(crate::Binding::Location(loc)) => {
if loc > MAX_LOCATIONS { if loc > MAX_LOCATIONS {
@ -117,61 +175,73 @@ impl Validator {
match types[var.ty].inner { match types[var.ty].inner {
//TODO: check the member types //TODO: check the member types
crate::TypeInner::Struct { members: _ } => { crate::TypeInner::Struct { members: _ } => {
if var.interpolation.is_some() { var.forbid_interpolation()?
return Err(GlobalVariableError::InvalidInterpolation);
}
} }
_ => return Err(GlobalVariableError::InvalidType), _ => return Err(GlobalVariableError::InvalidType),
} }
} }
} }
false crate::StorageAccess::empty()
} }
crate::StorageClass::Constant crate::StorageClass::Storage => {
| crate::StorageClass::StorageBuffer var.check_resource()?;
| crate::StorageClass::Uniform => { crate::StorageAccess::all()
match var.binding {
Some(crate::Binding::BuiltIn(_)) => {} // validated per entry point
Some(crate::Binding::Resource { group, binding }) => {
if group > MAX_BIND_GROUPS || binding > MAX_BIND_INDICES {
return Err(GlobalVariableError::OutOfRangeBinding);
} }
crate::StorageClass::Uniform => {
var.check_resource()?;
crate::StorageAccess::empty()
} }
Some(crate::Binding::Location(_)) | None => { crate::StorageClass::Handle => {
return Err(GlobalVariableError::InvalidBinding) var.check_resource()?;
}
}
if var.interpolation.is_some() {
return Err(GlobalVariableError::InvalidInterpolation);
}
//TODO: prevent `Uniform` storage class with `STORE` access
match types[var.ty].inner { match types[var.ty].inner {
crate::TypeInner::Struct { .. } crate::TypeInner::Image {
| crate::TypeInner::Image {
class: crate::ImageClass::Storage(_), class: crate::ImageClass::Storage(_),
.. ..
} => true, } => crate::StorageAccess::all(),
_ => false, _ => crate::StorageAccess::empty(),
} }
} }
crate::StorageClass::Private | crate::StorageClass::WorkGroup => { crate::StorageClass::Private | crate::StorageClass::WorkGroup => {
if var.binding.is_some() { if var.binding.is_some() {
return Err(GlobalVariableError::InvalidBinding); return Err(GlobalVariableError::InvalidBinding);
} }
if var.interpolation.is_some() { var.forbid_interpolation()?;
return Err(GlobalVariableError::InvalidInterpolation); crate::StorageAccess::empty()
} }
false crate::StorageClass::PushConstant => {
//TODO
return Err(GlobalVariableError::InvalidStorageAccess {
allowed: crate::StorageAccess::empty(),
seen: crate::StorageAccess::empty(),
});
} }
}; };
if !is_storage && !var.storage_access.is_empty() { if !allowed_storage_access.contains(var.storage_access) {
return Err(GlobalVariableError::InvalidStorageAccess); return Err(GlobalVariableError::InvalidStorageAccess {
allowed: allowed_storage_access,
seen: var.storage_access,
});
} }
Ok(()) Ok(())
} }
fn validate_local_var(
&self,
var: &crate::LocalVariable,
_fun: &crate::Function,
_types: &Arena<crate::Type>,
) -> Result<(), LocalVariableError> {
log::debug!("var {:?}", var);
if let Some(_expr_handle) = var.init {
if false {
return Err(LocalVariableError::InitializerConst);
}
}
Ok(())
}
fn validate_function( fn validate_function(
&mut self, &mut self,
fun: &crate::Function, fun: &crate::Function,
@ -182,10 +252,19 @@ impl Validator {
global_vars: &module.global_variables, global_vars: &module.global_variables,
local_vars: &fun.local_variables, local_vars: &fun.local_variables,
functions: &module.functions, functions: &module.functions,
parameter_types: &fun.parameter_types, arguments: &fun.arguments,
}; };
self.typifier self.typifier
.resolve_all(&fun.expressions, &module.types, &resolve_ctx)?; .resolve_all(&fun.expressions, &module.types, &resolve_ctx)?;
for (var_handle, var) in fun.local_variables.iter() {
self.validate_local_var(var, fun, &module.types)
.map_err(|error| FunctionError::LocalVariable {
handle: var_handle,
name: var.name.clone().unwrap_or_default(),
error,
})?;
}
Ok(()) Ok(())
} }
@ -226,17 +305,12 @@ impl Validator {
match (stage, var.class) { match (stage, var.class) {
(crate::ShaderStage::Vertex, crate::StorageClass::Output) (crate::ShaderStage::Vertex, crate::StorageClass::Output)
| (crate::ShaderStage::Fragment, crate::StorageClass::Input) => { | (crate::ShaderStage::Fragment, crate::StorageClass::Input) => {
match module.types[var.ty].inner { match module.types[var.ty].inner.scalar_kind() {
crate::TypeInner::Scalar { kind, .. } Some(crate::ScalarKind::Float) => {}
| crate::TypeInner::Vector { kind, .. } => { Some(_) if var.interpolation != Some(crate::Interpolation::Flat) => {
if kind != crate::ScalarKind::Float
&& var.interpolation != Some(crate::Interpolation::Flat)
{
return Err(EntryPointError::InvalidIntegerInterpolation); return Err(EntryPointError::InvalidIntegerInterpolation);
} }
} _ => {}
crate::TypeInner::Matrix { .. } => {}
_ => unreachable!(),
} }
} }
_ => {} _ => {}
@ -291,26 +365,19 @@ impl Validator {
location_out_mask |= mask; location_out_mask |= mask;
crate::GlobalUse::LOAD | crate::GlobalUse::STORE crate::GlobalUse::LOAD | crate::GlobalUse::STORE
} }
crate::StorageClass::Constant => crate::GlobalUse::LOAD, crate::StorageClass::Uniform => crate::GlobalUse::LOAD,
crate::StorageClass::Uniform | crate::StorageClass::StorageBuffer => { crate::StorageClass::Storage => storage_usage(var.storage_access),
//TODO: built-in checks? crate::StorageClass::Handle => match module.types[var.ty].inner {
let mut storage_usage = crate::GlobalUse::empty(); crate::TypeInner::Image {
if var.storage_access.contains(crate::StorageAccess::LOAD) { class: crate::ImageClass::Storage(_),
storage_usage |= crate::GlobalUse::LOAD; ..
} } => storage_usage(var.storage_access),
if var.storage_access.contains(crate::StorageAccess::STORE) { _ => crate::GlobalUse::LOAD,
storage_usage |= crate::GlobalUse::STORE; },
}
if storage_usage.is_empty() {
// its a uniform buffer
crate::GlobalUse::LOAD
} else {
storage_usage
}
}
crate::StorageClass::Private | crate::StorageClass::WorkGroup => { crate::StorageClass::Private | crate::StorageClass::WorkGroup => {
crate::GlobalUse::all() crate::GlobalUse::all()
} }
crate::StorageClass::PushConstant => crate::GlobalUse::LOAD,
}; };
if !allowed_usage.contains(usage) { if !allowed_usage.contains(usage) {
log::warn!("\tUsage error for: {:?}", var); log::warn!("\tUsage error for: {:?}", var);
@ -364,10 +431,22 @@ impl Validator {
return Err(ValidationError::UnresolvedType(base)); return Err(ValidationError::UnresolvedType(base));
} }
} }
Ti::Array { base, .. } => { Ti::Array { base, size, .. } => {
if base >= handle { if base >= handle {
return Err(ValidationError::UnresolvedType(base)); return Err(ValidationError::UnresolvedType(base));
} }
if let crate::ArraySize::Constant(const_handle) = size {
let constant = module
.constants
.try_get(const_handle)
.ok_or(ValidationError::Corrupted)?;
match constant.inner {
crate::ConstantInner::Uint(_) => {}
_ => {
return Err(ValidationError::InvalidArraySizeConstant(const_handle))
}
}
}
} }
Ti::Struct { ref members } => { Ti::Struct { ref members } => {
//TODO: check that offsets are not intersecting? //TODO: check that offsets are not intersecting?
@ -384,7 +463,11 @@ impl Validator {
for (var_handle, var) in module.global_variables.iter() { for (var_handle, var) in module.global_variables.iter() {
self.validate_global_var(var, &module.types) self.validate_global_var(var, &module.types)
.map_err(|e| ValidationError::GlobalVariable(var_handle, e))?; .map_err(|error| ValidationError::GlobalVariable {
handle: var_handle,
name: var.name.clone().unwrap_or_default(),
error,
})?;
} }
for (fun_handle, fun) in module.functions.iter() { for (fun_handle, fun) in module.functions.iter() {

View file

@ -1,7 +1,8 @@
( (
spv_flow_dump_prefix: "",
metal_bindings: { metal_bindings: {
(group: 0, binding: 0): (buffer: Some(0), mutable: false), (stage: Compute, group: 0, binding: 0): (buffer: Some(0), mutable: false),
(group: 0, binding: 1): (buffer: Some(1), mutable: true), (stage: Compute, group: 0, binding: 1): (buffer: Some(1), mutable: true),
(group: 0, binding: 2): (buffer: Some(2), mutable: true), (stage: Compute, group: 0, binding: 2): (buffer: Some(2), mutable: true),
} }
) )

View file

@ -61,8 +61,8 @@ type Particles = struct {
}; };
[[group(0), binding(0)]] var<uniform> params : SimParams; [[group(0), binding(0)]] var<uniform> params : SimParams;
[[group(0), binding(1)]] var<storage_buffer> particlesA : Particles; [[group(0), binding(1)]] var<storage> particlesA : Particles;
[[group(0), binding(2)]] var<storage_buffer> particlesB : Particles; [[group(0), binding(2)]] var<storage> particlesB : Particles;
[[builtin(global_invocation_id)]] var gl_GlobalInvocationID : vec3<u32>; [[builtin(global_invocation_id)]] var gl_GlobalInvocationID : vec3<u32>;

View file

@ -1,6 +1,6 @@
( (
metal_bindings: { metal_bindings: {
(group: 0, binding: 0): (texture: Some(0)), (stage: Fragment, group: 0, binding: 0): (texture: Some(0)),
(group: 0, binding: 1): (sampler: Some(0)), (stage: Fragment, group: 0, binding: 1): (sampler: Some(0)),
} }
) )

View file

@ -14,8 +14,8 @@ fn main() -> void {
# fragment # fragment
[[location(0)]] var<in> v_uv : vec2<f32>; [[location(0)]] var<in> v_uv : vec2<f32>;
[[group(0), binding(0)]] var<uniform> u_texture : texture_sampled_2d<f32>; [[group(0), binding(0)]] var u_texture : texture_sampled_2d<f32>;
[[group(0), binding(1)]] var<uniform> u_sampler : sampler; [[group(0), binding(1)]] var u_sampler : sampler;
[[location(0)]] var<out> o_color : vec4<f32>; [[location(0)]] var<out> o_color : vec4<f32>;
[[stage(fragment)]] [[stage(fragment)]]

View file

@ -48,6 +48,7 @@
class: Input, class: Input,
binding: Some(Location(0)), binding: Some(Location(0)),
ty: 1, ty: 1,
init: None,
interpolation: None, interpolation: None,
storage_access: ( storage_access: (
bits: 0, bits: 0,
@ -58,6 +59,7 @@
class: Output, class: Output,
binding: Some(Location(0)), binding: Some(Location(0)),
ty: 2, ty: 2,
init: None,
interpolation: None, interpolation: None,
storage_access: ( storage_access: (
bits: 0, bits: 0,
@ -71,7 +73,7 @@
workgroup_size: (0, 0, 0), workgroup_size: (0, 0, 0),
function: ( function: (
name: Some("main"), name: Some("main"),
parameter_types: [], arguments: [],
return_type: None, return_type: None,
global_usage: [ global_usage: [
( (
@ -85,7 +87,7 @@
( (
name: Some("w"), name: Some("w"),
ty: 3, ty: 3,
init: Some(3), init: Some(1),
), ),
], ],
expressions: [ expressions: [

View file

@ -14,13 +14,13 @@ fn load_wgsl(name: &str) -> naga::Module {
fn load_spv(name: &str) -> naga::Module { fn load_spv(name: &str) -> naga::Module {
let path = format!("{}/test-data/spv/{}", env!("CARGO_MANIFEST_DIR"), name); let path = format!("{}/test-data/spv/{}", env!("CARGO_MANIFEST_DIR"), name);
let input = std::fs::read(path).unwrap(); let input = std::fs::read(path).unwrap();
naga::front::spv::parse_u8_slice(&input).unwrap() naga::front::spv::parse_u8_slice(&input, &Default::default()).unwrap()
} }
#[cfg(feature = "glsl-in")] #[cfg(feature = "glsl-in")]
fn load_glsl(name: &str, entry: &str, stage: naga::ShaderStage) -> naga::Module { fn load_glsl(name: &str, entry: &str, stage: naga::ShaderStage) -> naga::Module {
let input = load_test_data(name); let input = load_test_data(name);
naga::front::glsl::parse_str(&input, entry, stage).unwrap() naga::front::glsl::parse_str(&input, entry, stage, Default::default()).unwrap()
} }
#[cfg(feature = "wgsl-in")] #[cfg(feature = "wgsl-in")]
@ -34,6 +34,7 @@ fn convert_quad() {
let mut binding_map = msl::BindingMap::default(); let mut binding_map = msl::BindingMap::default();
binding_map.insert( binding_map.insert(
msl::BindSource { msl::BindSource {
stage: naga::ShaderStage::Fragment,
group: 0, group: 0,
binding: 0, binding: 0,
}, },
@ -46,6 +47,7 @@ fn convert_quad() {
); );
binding_map.insert( binding_map.insert(
msl::BindSource { msl::BindSource {
stage: naga::ShaderStage::Fragment,
group: 0, group: 0,
binding: 1, binding: 1,
}, },
@ -57,9 +59,11 @@ fn convert_quad() {
}, },
); );
let options = msl::Options { let options = msl::Options {
binding_map: &binding_map, lang_version: (1, 0),
spirv_cross_compatibility: false,
binding_map,
}; };
msl::write_string(&module, options).unwrap(); msl::write_string(&module, &options).unwrap();
} }
} }
@ -74,6 +78,7 @@ fn convert_boids() {
let mut binding_map = msl::BindingMap::default(); let mut binding_map = msl::BindingMap::default();
binding_map.insert( binding_map.insert(
msl::BindSource { msl::BindSource {
stage: naga::ShaderStage::Compute,
group: 0, group: 0,
binding: 0, binding: 0,
}, },
@ -86,6 +91,7 @@ fn convert_boids() {
); );
binding_map.insert( binding_map.insert(
msl::BindSource { msl::BindSource {
stage: naga::ShaderStage::Compute,
group: 0, group: 0,
binding: 1, binding: 1,
}, },
@ -98,6 +104,7 @@ fn convert_boids() {
); );
binding_map.insert( binding_map.insert(
msl::BindSource { msl::BindSource {
stage: naga::ShaderStage::Compute,
group: 0, group: 0,
binding: 2, binding: 2,
}, },
@ -109,9 +116,11 @@ fn convert_boids() {
}, },
); );
let options = msl::Options { let options = msl::Options {
binding_map: &binding_map, lang_version: (1, 0),
spirv_cross_compatibility: false,
binding_map,
}; };
msl::write_string(&module, options).unwrap(); msl::write_string(&module, &options).unwrap();
} }
} }
@ -129,6 +138,7 @@ fn convert_cube() {
let mut binding_map = msl::BindingMap::default(); let mut binding_map = msl::BindingMap::default();
binding_map.insert( binding_map.insert(
msl::BindSource { msl::BindSource {
stage: naga::ShaderStage::Vertex,
group: 0, group: 0,
binding: 0, binding: 0,
}, },
@ -141,6 +151,7 @@ fn convert_cube() {
); );
binding_map.insert( binding_map.insert(
msl::BindSource { msl::BindSource {
stage: naga::ShaderStage::Fragment,
group: 0, group: 0,
binding: 1, binding: 1,
}, },
@ -153,6 +164,7 @@ fn convert_cube() {
); );
binding_map.insert( binding_map.insert(
msl::BindSource { msl::BindSource {
stage: naga::ShaderStage::Fragment,
group: 0, group: 0,
binding: 2, binding: 2,
}, },
@ -164,10 +176,12 @@ fn convert_cube() {
}, },
); );
let options = msl::Options { let options = msl::Options {
binding_map: &binding_map, lang_version: (1, 0),
spirv_cross_compatibility: false,
binding_map,
}; };
msl::write_string(&vs, options).unwrap(); msl::write_string(&vs, &options).unwrap();
msl::write_string(&fs, options).unwrap(); msl::write_string(&fs, &options).unwrap();
} }
} }

View file

@ -40,15 +40,33 @@ fn test_rosetta(dir_name: &str) {
#[cfg(feature = "glsl-in")] #[cfg(feature = "glsl-in")]
{ {
if let Ok(input) = fs::read_to_string(dir_path.join("x.vert")) { if let Ok(input) = fs::read_to_string(dir_path.join("x.vert")) {
let module = glsl::parse_str(&input, "main", naga::ShaderStage::Vertex).unwrap(); let module = glsl::parse_str(
&input,
"main",
naga::ShaderStage::Vertex,
Default::default(),
)
.unwrap();
check("vert", &module, &expected); check("vert", &module, &expected);
} }
if let Ok(input) = fs::read_to_string(dir_path.join("x.frag")) { if let Ok(input) = fs::read_to_string(dir_path.join("x.frag")) {
let module = glsl::parse_str(&input, "main", naga::ShaderStage::Fragment).unwrap(); let module = glsl::parse_str(
&input,
"main",
naga::ShaderStage::Fragment,
Default::default(),
)
.unwrap();
check("frag", &module, &expected); check("frag", &module, &expected);
} }
if let Ok(input) = fs::read_to_string(dir_path.join("x.comp")) { if let Ok(input) = fs::read_to_string(dir_path.join("x.comp")) {
let module = glsl::parse_str(&input, "main", naga::ShaderStage::Compute).unwrap(); let module = glsl::parse_str(
&input,
"main",
naga::ShaderStage::Compute,
Default::default(),
)
.unwrap();
check("comp", &module, &expected); check("comp", &module, &expected);
} }
} }